diff --git a/.github/codecov.yml b/.github/codecov.yml new file mode 100644 index 00000000..69cb7601 --- /dev/null +++ b/.github/codecov.yml @@ -0,0 +1 @@ +comment: false diff --git a/.github/workflows/benchmark-dev.yml b/.github/workflows/benchmark-dev.yml index 21f9559e..2c6eae4b 100644 --- a/.github/workflows/benchmark-dev.yml +++ b/.github/workflows/benchmark-dev.yml @@ -75,7 +75,7 @@ jobs: name: ${{ matrix.bench.title }} (Windows ${{matrix.arch}}) tool: "benchmarkluau" output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json + external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json github-token: ${{ secrets.GITHUB_TOKEN }} - name: Push benchmark results @@ -85,7 +85,7 @@ jobs: cd gh-pages git config user.name github-actions git config user.email github@users.noreply.github.com - git add ./dev/bench/data.json + git add ./dev/bench/data-${{ matrix.os }}.json git commit -m "Add benchmarks results for ${{ github.sha }}" git push cd .. @@ -156,7 +156,7 @@ jobs: name: ${{ matrix.bench.title }} tool: "benchmarkluau" output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json + external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Store ${{ matrix.bench.title }} result (CacheGrind) @@ -166,7 +166,7 @@ jobs: name: ${{ matrix.bench.title }} (CacheGrind) tool: "roblox" output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json + external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Push benchmark results @@ -176,7 +176,7 @@ jobs: cd gh-pages git config user.name github-actions git config user.email github@users.noreply.github.com - git add ./dev/bench/data.json + git add ./dev/bench/data-${{ matrix.os }}.json git commit -m "Add benchmarks results for ${{ github.sha }}" git push cd .. @@ -220,13 +220,13 @@ jobs: sudo apt-get install valgrind - name: Run Luau Analyze on static file - run: sudo python ./bench/measure_time.py ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee ${{ matrix.bench.script }}-output.txt + run: sudo python ./bench/measure_time.py ./build/release/luau-analyze bench/other/LuauPolyfillMap.lua | tee ${{ matrix.bench.script }}-output.txt - name: Run ${{ matrix.bench.title }} (Cold Cachegrind) - run: sudo ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt + run: sudo ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 ./build/release/luau-analyze bench/other/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt - name: Run ${{ matrix.bench.title }} (Warm Cachegrind) - run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt + run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}" 1 ./build/release/luau-analyze bench/other/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt - name: Checkout Benchmark Results repository uses: actions/checkout@v3 @@ -244,7 +244,7 @@ jobs: gh-pages-branch: "main" output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json + external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Store ${{ matrix.bench.title }} result (CacheGrind) @@ -254,7 +254,7 @@ jobs: tool: "roblox" gh-pages-branch: "main" output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json + external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Push benchmark results @@ -264,7 +264,7 @@ jobs: cd gh-pages git config user.name github-actions git config user.email github@users.noreply.github.com - git add ./dev/bench/data.json + git add ./dev/bench/data-${{ matrix.os }}.json git commit -m "Add benchmarks results for ${{ github.sha }}" git push cd .. diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 48296823..9d26186e 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -32,11 +32,33 @@ jobs: sudo apt-get install valgrind - name: Build Luau - run: CXX=${{ matrix.compiler }} make config=release CALLGRIND=1 luau + run: CXX=${{ matrix.compiler }} make config=release CALLGRIND=1 luau luau-analyze - - name: Run benchmark + - name: Run benchmark (bench) run: | - python bench/bench.py --callgrind --vm "./luau -O2" | tee output.txt + python bench/bench.py --callgrind --vm "./luau -O2" | tee -a bench-output.txt + + - name: Run benchmark (analyze) + run: | + filter() { + awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau-analyze"}' + } + valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-nonstrict | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-strict | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/regex.lua 2>&1 | filter regex-nonstrict | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/regex.lua 2>&1 | filter regex-strict | tee -a analyze-output.txt + + - name: Run benchmark (compile) + run: | + filter() { + awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau --compile"}' + } + valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt - name: Checkout benchmark results uses: actions/checkout@v3 @@ -46,13 +68,29 @@ jobs: token: ${{ secrets.BENCH_GITHUB_TOKEN }} path: "./gh-pages" - - name: Store results + - name: Store results (bench) uses: Roblox/rhysd-github-action-benchmark@v-luau with: name: callgrind ${{ matrix.compiler }} tool: "benchmarkluau" - output-file-path: ./output.txt - external-data-json-path: ./gh-pages/bench/data.json + output-file-path: ./bench-output.txt + external-data-json-path: ./gh-pages/bench.json + + - name: Store results (analyze) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: luau-analyze + tool: "benchmarkluau" + output-file-path: ./analyze-output.txt + external-data-json-path: ./gh-pages/analyze.json + + - name: Store results (compile) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: luau --compile + tool: "benchmarkluau" + output-file-path: ./compile-output.txt + external-data-json-path: ./gh-pages/compile.json - name: Push benchmark results if: github.event_name == 'push' @@ -61,7 +99,7 @@ jobs: cd gh-pages git config user.name github-actions git config user.email github@users.noreply.github.com - git add ./bench/data.json + git add *.json git commit -m "Add benchmarks results for ${{ github.sha }}" git push cd .. diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 93b92645..da525ad8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -76,20 +76,10 @@ jobs: - name: make coverage run: | CXX=clang++-10 make -j2 config=coverage coverage - - name: debug coverage - run: | - git status - git log -5 - echo SHA: $GITHUB_SHA - name: upload coverage - uses: coverallsapp/github-action@master + uses: codecov/codecov-action@v3 with: - path-to-lcov: ./coverage.info - github-token: ${{ secrets.GITHUB_TOKEN }} - - uses: actions/upload-artifact@v2 - with: - name: coverage - path: coverage + files: ./coverage.info web: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 5688dff5..af77f73c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ /default.prof* /fuzz-* /luau +/luau-tests +/luau-analyze +__pycache__ diff --git a/Analysis/include/Luau/ApplyTypeFunction.h b/Analysis/include/Luau/ApplyTypeFunction.h new file mode 100644 index 00000000..8da3bc42 --- /dev/null +++ b/Analysis/include/Luau/ApplyTypeFunction.h @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Substitution.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeVar.h" + +namespace Luau +{ + +// A substitution which replaces the type parameters of a type function by arguments +struct ApplyTypeFunction : Substitution +{ + ApplyTypeFunction(TypeArena* arena) + : Substitution(TxnLog::empty(), arena) + , encounteredForwardedType(false) + { + } + + // Never set under deferred constraint resolution. + bool encounteredForwardedType; + std::unordered_map typeArguments; + std::unordered_map typePackArguments; + bool ignoreChildren(TypeId ty) override; + bool ignoreChildren(TypePackId tp) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/JsonEncoder.h b/Analysis/include/Luau/AstJsonEncoder.h similarity index 67% rename from Analysis/include/Luau/JsonEncoder.h rename to Analysis/include/Luau/AstJsonEncoder.h index aa00390b..e79d9e62 100644 --- a/Analysis/include/Luau/JsonEncoder.h +++ b/Analysis/include/Luau/AstJsonEncoder.h @@ -2,12 +2,15 @@ #pragma once #include +#include namespace Luau { class AstNode; +struct Comment; std::string toJson(AstNode* node); +std::string toJson(AstNode* node, const std::vector& commentLocations); } // namespace Luau diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h index dfe373a5..950a19da 100644 --- a/Analysis/include/Luau/AstQuery.h +++ b/Analysis/include/Luau/AstQuery.h @@ -63,6 +63,7 @@ private: AstLocal* local = nullptr; }; +std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos); std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos); AstNode* findNodeAtPosition(const SourceModule& source, Position pos); AstExpr* findExprAtPosition(const SourceModule& source, Position pos); diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index 65b788d3..5e8d6605 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -78,16 +78,6 @@ struct AutocompleteResult using ModuleName = std::string; using StringCompletionCallback = std::function(std::string tag, std::optional ctx)>; -struct OwningAutocompleteResult -{ - AutocompleteResult result; - ModulePtr module; - std::unique_ptr sourceModule; -}; - AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); -// Deprecated, do not use in new work. -OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback); - } // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index dcfb14b4..5b737146 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -4,6 +4,7 @@ #include "Luau/Ast.h" // Used for some of the enumerations #include "Luau/NotNull.h" #include "Luau/Variant.h" +#include "Luau/TypeVar.h" #include #include @@ -12,7 +13,8 @@ namespace Luau { -struct Scope2; +struct Scope; + struct TypeVar; using TypeId = const TypeVar*; @@ -38,7 +40,7 @@ struct GeneralizationConstraint { TypeId generalizedType; TypeId sourceType; - Scope2* scope; + Scope* scope; }; // subType ~ inst superType @@ -70,8 +72,15 @@ struct NameConstraint std::string name; }; +// target ~ inst target +struct TypeAliasExpansionConstraint +{ + // Must be a PendingExpansionTypeVar. + TypeId target; +}; + using ConstraintV = Variant; + BinaryConstraint, NameConstraint, TypeAliasExpansionConstraint>; using ConstraintPtr = std::unique_ptr; struct Constraint diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index a49e8594..69f35d46 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -17,21 +17,22 @@ namespace Luau { -struct Scope2; +struct Scope; +using ScopePtr = std::shared_ptr; struct ConstraintGraphBuilder { // A list of all the scopes in the module. This vector holds ownership of the // scope pointers; the scopes themselves borrow pointers to other scopes to // define the scope hierarchy. - std::vector>> scopes; + std::vector> scopes; ModuleName moduleName; SingletonTypes& singletonTypes; const NotNull arena; // The root scope of the module we're generating constraints for. // This is null when the CGB is initially constructed. - Scope2* rootScope; + Scope* rootScope; // A mapping of AST node to TypeId. DenseHashMap astTypes{nullptr}; // A mapping of AST node to TypePackId. @@ -41,6 +42,8 @@ struct ConstraintGraphBuilder DenseHashMap astResolvedTypes{nullptr}; // Type packs resolved from type annotations. Analogous to astTypePacks. DenseHashMap astResolvedTypePacks{nullptr}; + // Defining scopes for AST nodes. + DenseHashMap astTypeAliasDefiningScopes{nullptr}; int recursionCount = 0; @@ -50,42 +53,42 @@ struct ConstraintGraphBuilder // Occasionally constraint generation needs to produce an ICE. const NotNull ice; - NotNull globalScope; + NotNull globalScope; - ConstraintGraphBuilder(const ModuleName& moduleName, TypeArena* arena, NotNull ice, NotNull globalScope); + ConstraintGraphBuilder(const ModuleName& moduleName, TypeArena* arena, NotNull ice, NotNull globalScope); /** * Fabricates a new free type belonging to a given scope. * @param scope the scope the free type belongs to. */ - TypeId freshType(NotNull scope); + TypeId freshType(const ScopePtr& scope); /** * Fabricates a new free type pack belonging to a given scope. * @param scope the scope the free type pack belongs to. */ - TypePackId freshTypePack(NotNull scope); + TypePackId freshTypePack(const ScopePtr& scope); /** * Fabricates a scope that is a child of another scope. * @param location the lexical extent of the scope in the source code. * @param parent the parent scope of the new scope. Must not be null. */ - NotNull childScope(Location location, NotNull parent); + ScopePtr childScope(Location location, const ScopePtr& parent); /** * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. * @param cv the constraint variant to add. */ - void addConstraint(NotNull scope, ConstraintV cv); + void addConstraint(const ScopePtr& scope, ConstraintV cv); /** * Adds a constraint to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param c the constraint to add. */ - void addConstraint(NotNull scope, std::unique_ptr c); + void addConstraint(const ScopePtr& scope, std::unique_ptr c); /** * The entry point to the ConstraintGraphBuilder. This will construct a set @@ -94,22 +97,26 @@ struct ConstraintGraphBuilder */ void visit(AstStatBlock* block); - void visitBlockWithoutChildScope(NotNull scope, AstStatBlock* block); + void visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); - void visit(NotNull scope, AstStat* stat); - void visit(NotNull scope, AstStatBlock* block); - void visit(NotNull scope, AstStatLocal* local); - void visit(NotNull scope, AstStatLocalFunction* function); - void visit(NotNull scope, AstStatFunction* function); - void visit(NotNull scope, AstStatReturn* ret); - void visit(NotNull scope, AstStatAssign* assign); - void visit(NotNull scope, AstStatIf* ifStatement); - void visit(NotNull scope, AstStatTypeAlias* alias); + void visit(const ScopePtr& scope, AstStat* stat); + void visit(const ScopePtr& scope, AstStatBlock* block); + void visit(const ScopePtr& scope, AstStatLocal* local); + void visit(const ScopePtr& scope, AstStatFor* for_); + void visit(const ScopePtr& scope, AstStatLocalFunction* function); + void visit(const ScopePtr& scope, AstStatFunction* function); + void visit(const ScopePtr& scope, AstStatReturn* ret); + void visit(const ScopePtr& scope, AstStatAssign* assign); + void visit(const ScopePtr& scope, AstStatIf* ifStatement); + void visit(const ScopePtr& scope, AstStatTypeAlias* alias); + void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); + void visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); + void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); - TypePackId checkExprList(NotNull scope, const AstArray& exprs); + TypePackId checkExprList(const ScopePtr& scope, const AstArray& exprs); - TypePackId checkPack(NotNull scope, AstArray exprs); - TypePackId checkPack(NotNull scope, AstExpr* expr); + TypePackId checkPack(const ScopePtr& scope, AstArray exprs); + TypePackId checkPack(const ScopePtr& scope, AstExpr* expr); /** * Checks an expression that is expected to evaluate to one type. @@ -117,13 +124,13 @@ struct ConstraintGraphBuilder * @param expr the expression to check. * @return the type of the expression. */ - TypeId check(NotNull scope, AstExpr* expr); + TypeId check(const ScopePtr& scope, AstExpr* expr); - TypeId checkExprTable(NotNull scope, AstExprTable* expr); - TypeId check(NotNull scope, AstExprIndexName* indexName); - TypeId check(NotNull scope, AstExprIndexExpr* indexExpr); - TypeId check(NotNull scope, AstExprUnary* unary); - TypeId check(NotNull scope, AstExprBinary* binary); + TypeId checkExprTable(const ScopePtr& scope, AstExprTable* expr); + TypeId check(const ScopePtr& scope, AstExprIndexName* indexName); + TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); + TypeId check(const ScopePtr& scope, AstExprUnary* unary); + TypeId check(const ScopePtr& scope, AstExprBinary* binary); struct FunctionSignature { @@ -132,28 +139,29 @@ struct ConstraintGraphBuilder // The scope that encompasses the function's signature. May be nullptr // if there was no need for a signature scope (the function has no // generics). - Scope2* signatureScope; + ScopePtr signatureScope; // The scope that encompasses the function's body. Is a child scope of // signatureScope, if present. - NotNull bodyScope; + ScopePtr bodyScope; }; - FunctionSignature checkFunctionSignature(NotNull parent, AstExprFunction* fn); + FunctionSignature checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn); /** * Checks the body of a function expression. * @param scope the interior scope of the body of the function. * @param fn the function expression to check. */ - void checkFunctionBody(NotNull scope, AstExprFunction* fn); + void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn); /** * Resolves a type from its AST annotation. * @param scope the scope that the type annotation appears within. * @param ty the AST annotation to resolve. + * @param topLevel whether the annotation is a "top-level" annotation. * @return the type of the AST annotation. **/ - TypeId resolveType(NotNull scope, AstType* ty); + TypeId resolveType(const ScopePtr& scope, AstType* ty, bool topLevel = false); /** * Resolves a type pack from its AST annotation. @@ -161,14 +169,14 @@ struct ConstraintGraphBuilder * @param tp the AST annotation to resolve. * @return the type pack of the AST annotation. **/ - TypePackId resolveTypePack(NotNull scope, AstTypePack* tp); + TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp); - TypePackId resolveTypePack(NotNull scope, const AstTypeList& list); + TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list); - std::vector> createGenerics(NotNull scope, AstArray generics); - std::vector> createGenericPacks(NotNull scope, AstArray packs); + std::vector> createGenerics(const ScopePtr& scope, AstArray generics); + std::vector> createGenericPacks(const ScopePtr& scope, AstArray packs); - TypeId flattenPack(NotNull scope, Location location, TypePackId tp); + TypeId flattenPack(const ScopePtr& scope, Location location, TypePackId tp); void reportError(Location location, TypeErrorData err); void reportCodeTooComplex(Location location); @@ -179,7 +187,7 @@ struct ConstraintGraphBuilder * real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an * initial scan of the AST and note what globals are defined. */ - void prepopulateGlobalScope(NotNull globalScope, AstStatBlock* program); + void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); }; /** @@ -192,6 +200,6 @@ struct ConstraintGraphBuilder * @return a list of pointers to constraints contained within the scope graph. * None of these pointers should be null. */ -std::vector> collectConstraints(NotNull rootScope); +std::vector> collectConstraints(NotNull rootScope); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index cf88efb6..9cc0e4cb 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -17,15 +17,35 @@ namespace Luau // never dereference this pointer. using BlockedConstraintId = const void*; +struct InstantiationSignature +{ + TypeFun fn; + std::vector arguments; + std::vector packArguments; + + bool operator==(const InstantiationSignature& rhs) const; + bool operator!=(const InstantiationSignature& rhs) const + { + return !((*this) == rhs); + } +}; + +struct HashInstantiationSignature +{ + size_t operator()(const InstantiationSignature& signature) const; +}; + struct ConstraintSolver { TypeArena* arena; InternalErrorReporter iceReporter; - // The entire set of constraints that the solver is trying to resolve. It - // is important to not add elements to this vector, lest the underlying - // storage that we retain pointers to be mutated underneath us. - const std::vector> constraints; - NotNull rootScope; + // The entire set of constraints that the solver is trying to resolve. + std::vector> constraints; + NotNull rootScope; + + // Constraints that the solver has generated, rather than sourcing from the + // scope tree. + std::vector> solverConstraints; // This includes every constraint that has not been fully solved. // A constraint can be both blocked and unsolved, for instance. @@ -37,10 +57,12 @@ struct ConstraintSolver std::unordered_map, size_t> blockedConstraints; // A mapping of type/pack pointers to the constraints they block. std::unordered_map>> blocked; + // Memoized instantiations of type aliases. + DenseHashMap instantiatedAliases{{}}; ConstraintSolverLogger logger; - explicit ConstraintSolver(TypeArena* arena, NotNull rootScope); + explicit ConstraintSolver(TypeArena* arena, NotNull rootScope); /** * Attempts to dispatch all pending constraints and reach a type solution @@ -62,6 +84,7 @@ struct ConstraintSolver bool tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force); bool tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force); bool tryDispatch(const NameConstraint& c, NotNull constraint); + bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint); void block(NotNull target, NotNull constraint); /** @@ -102,6 +125,11 @@ struct ConstraintSolver */ void unify(TypePackId subPack, TypePackId superPack); + /** Pushes a new solver constraint to the solver. + * @param cv the body of the constraint. + **/ + void pushConstraint(ConstraintV cv); + private: /** * Marks a constraint as being blocked on a type or type pack. The constraint @@ -121,6 +149,6 @@ private: void unblock_(BlockedConstraintId progressed); }; -void dump(NotNull rootScope, struct ToStringOptions& opts); +void dump(NotNull rootScope, struct ToStringOptions& opts); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolverLogger.h b/Analysis/include/Luau/ConstraintSolverLogger.h index 55336a23..fe2177c4 100644 --- a/Analysis/include/Luau/ConstraintSolverLogger.h +++ b/Analysis/include/Luau/ConstraintSolverLogger.h @@ -15,8 +15,8 @@ namespace Luau struct ConstraintSolverLogger { std::string compileOutput(); - void captureBoundarySnapshot(const Scope2* rootScope, std::vector>& unsolvedConstraints); - void prepareStepSnapshot(const Scope2* rootScope, NotNull current, std::vector>& unsolvedConstraints); + void captureBoundarySnapshot(const Scope* rootScope, std::vector>& unsolvedConstraints); + void prepareStepSnapshot(const Scope* rootScope, NotNull current, std::vector>& unsolvedConstraints); void commitPreparedStepSnapshot(); private: diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index f0d43090..9b8ec19e 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -127,13 +127,6 @@ struct Frontend CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); - /** Lint some code that has no associated DataModel object - * - * Since this source fragment has no name, we cannot cache its AST. Instead, - * we return it to the caller to use as they wish. - */ - std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); - LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; @@ -159,7 +152,9 @@ struct Frontend void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); - NotNull getGlobalScope2(); + LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName); + + NotNull getGlobalScope(); private: ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope); @@ -176,7 +171,7 @@ private: std::unordered_map environments; std::unordered_map> builtinDefinitions; - std::unique_ptr globalScope2; + ScopePtr globalScope; public: FileResolver* fileResolver; @@ -187,6 +182,7 @@ public: ConfigResolver* configResolver; FrontendOptions options; InternalErrorReporter iceHandler; + TypeArena globalTypes; TypeArena arenaForAutocomplete; std::unordered_map sourceNodes; diff --git a/Analysis/include/Luau/JsonEmitter.h b/Analysis/include/Luau/JsonEmitter.h new file mode 100644 index 00000000..0bf3327a --- /dev/null +++ b/Analysis/include/Luau/JsonEmitter.h @@ -0,0 +1,235 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include +#include +#include + +#include "Luau/NotNull.h" + +namespace Luau::Json +{ + +struct JsonEmitter; + +/// Writes a value to the JsonEmitter. Note that this can produce invalid JSON +/// if you do not insert commas or appropriate object / array syntax. +template +void write(JsonEmitter&, T) = delete; + +/// Writes a boolean to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param b the boolean to write. +void write(JsonEmitter& emitter, bool b); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, int i); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, long i); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, long long i); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, unsigned int i); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, unsigned long i); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, unsigned long long i); + +/// Writes a double to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param d the double to write. +void write(JsonEmitter& emitter, double d); + +/// Writes a string to a JsonEmitter. The string will be escaped. +/// @param emitter the emitter to write to. +/// @param sv the string to write. +void write(JsonEmitter& emitter, std::string_view sv); + +/// Writes a character to a JsonEmitter as a single-character string. The +/// character will be escaped. +/// @param emitter the emitter to write to. +/// @param c the string to write. +void write(JsonEmitter& emitter, char c); + +/// Writes a string to a JsonEmitter. The string will be escaped. +/// @param emitter the emitter to write to. +/// @param str the string to write. +void write(JsonEmitter& emitter, const char* str); + +/// Writes a string to a JsonEmitter. The string will be escaped. +/// @param emitter the emitter to write to. +/// @param str the string to write. +void write(JsonEmitter& emitter, const std::string& str); + +/// Writes null to a JsonEmitter. +/// @param emitter the emitter to write to. +void write(JsonEmitter& emitter, std::nullptr_t); + +/// Writes null to a JsonEmitter. +/// @param emitter the emitter to write to. +void write(JsonEmitter& emitter, std::nullopt_t); + +struct ObjectEmitter; +struct ArrayEmitter; + +struct JsonEmitter +{ + JsonEmitter(); + + /// Converts the current contents of the JsonEmitter to a string value. This + /// does not invalidate the emitter, but it does not clear it either. + std::string str(); + + /// Returns the current comma state and resets it to false. Use popComma to + /// restore the old state. + /// @returns the previous comma state. + bool pushComma(); + + /// Restores a previous comma state. + /// @param c the comma state to restore. + void popComma(bool c); + + /// Writes a raw sequence of characters to the buffer, without escaping or + /// other processing. + /// @param sv the character sequence to write. + void writeRaw(std::string_view sv); + + /// Writes a character to the buffer, without escaping or other processing. + /// @param c the character to write. + void writeRaw(char c); + + /// Writes a comma if this wasn't the first time writeComma has been + /// invoked. Otherwise, sets the comma state to true. + /// @see pushComma + /// @see popComma + void writeComma(); + + /// Begins writing an object to the emitter. + /// @returns an ObjectEmitter that can be used to write key-value pairs. + ObjectEmitter writeObject(); + + /// Begins writing an array to the emitter. + /// @returns an ArrayEmitter that can be used to write values. + ArrayEmitter writeArray(); + +private: + bool comma = false; + std::vector chunks; + + void newChunk(); +}; + +/// An interface for writing an object into a JsonEmitter instance. +/// @see JsonEmitter::writeObject +struct ObjectEmitter +{ + ObjectEmitter(NotNull emitter); + ~ObjectEmitter(); + + NotNull emitter; + bool comma; + bool finished; + + /// Writes a key-value pair to the associated JsonEmitter. Keys will be escaped. + /// @param name the name of the key-value pair. + /// @param value the value to write. + template + void writePair(std::string_view name, T value) + { + if (finished) + { + return; + } + + emitter->writeComma(); + write(*emitter, name); + emitter->writeRaw(':'); + write(*emitter, value); + } + + /// Finishes writing the object, appending a closing `}` character and + /// resetting the comma state of the associated emitter. This can only be + /// called once, and once called will render the emitter unusable. This + /// method is also called when the ObjectEmitter is destructed. + void finish(); +}; + +/// An interface for writing an array into a JsonEmitter instance. Array values +/// do not need to be the same type. +/// @see JsonEmitter::writeArray +struct ArrayEmitter +{ + ArrayEmitter(NotNull emitter); + ~ArrayEmitter(); + + NotNull emitter; + bool comma; + bool finished; + + /// Writes a value to the array. + /// @param value the value to write. + template + void writeValue(T value) + { + if (finished) + { + return; + } + + emitter->writeComma(); + write(*emitter, value); + } + + /// Finishes writing the object, appending a closing `]` character and + /// resetting the comma state of the associated emitter. This can only be + /// called once, and once called will render the emitter unusable. This + /// method is also called when the ArrayEmitter is destructed. + void finish(); +}; + +/// Writes a vector as an array to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param vec the vector to write. +template +void write(JsonEmitter& emitter, const std::vector& vec) +{ + ArrayEmitter a = emitter.writeArray(); + + for (const T& value : vec) + a.writeValue(value); + + a.finish(); +} + +/// Writes an optional to a JsonEmitter. Will write the contained value, if +/// present, or null, if no value is present. +/// @param emitter the emitter to write to. +/// @param v the value to write. +template +void write(JsonEmitter& emitter, const std::optional& v) +{ + if (v.has_value()) + write(emitter, *v); + else + emitter.writeRaw("null"); +} + +} // namespace Luau::Json diff --git a/Analysis/include/Luau/Linter.h b/Analysis/include/Luau/Linter.h index 6c7ce47f..2cd91d54 100644 --- a/Analysis/include/Luau/Linter.h +++ b/Analysis/include/Luau/Linter.h @@ -52,6 +52,7 @@ struct LintWarning Code_DuplicateCondition = 24, Code_MisleadingAndOr = 25, Code_CommentDirective = 26, + Code_IntegerParsing = 27, Code__Count }; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index b3105b78..6f4c6098 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -68,8 +68,7 @@ struct Module std::shared_ptr allocator; std::shared_ptr names; - std::vector> scopes; // never empty - std::vector>> scope2s; // never empty + std::vector> scopes; // never empty DenseHashMap astTypes{nullptr}; DenseHashMap astTypePacks{nullptr}; @@ -86,7 +85,6 @@ struct Module bool timeout = false; ScopePtr getModuleScope() const; - Scope2* getModuleScope2() const; // Once a module has been typechecked, we clone its public interface into a separate arena. // This helps us to force TypeVar ownership into a DAG rather than a DCG. diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index f46f0cb5..7edf23b8 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -7,9 +7,9 @@ namespace Luau { struct TypeArena; -struct Scope2; +struct Scope; void quantify(TypeId ty, TypeLevel level); -TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope); +TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope); } // namespace Luau diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 0eaecf1d..55ca54c6 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -32,10 +32,16 @@ struct Scope explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. const ScopePtr parent; // null for the root + + // All the children of this scope. + std::vector> children; std::unordered_map bindings; + std::unordered_map typeBindings; + std::unordered_map typePackBindings; TypePackId returnType; - bool breakOk = false; std::optional varargPack; + // All constraints belonging to this scope. + std::vector constraints; TypeLevel level; @@ -45,7 +51,9 @@ struct Scope std::unordered_map> importedTypeBindings; - std::optional lookup(const Symbol& name); + std::optional lookup(Symbol sym); + std::optional lookupTypeBinding(const Name& name); + std::optional lookupTypePackBinding(const Name& name); std::optional lookupType(const Name& name); std::optional lookupImportedType(const Name& moduleAlias, const Name& name); @@ -66,24 +74,4 @@ struct Scope std::unordered_map typeAliasTypePackParameters; }; -struct Scope2 -{ - // The parent scope of this scope. Null if there is no parent (i.e. this - // is the module-level scope). - Scope2* parent = nullptr; - // All the children of this scope. - std::vector> children; - std::unordered_map bindings; // TODO: I think this can be a DenseHashMap - std::unordered_map typeBindings; - std::unordered_map typePackBindings; - TypePackId returnType; - std::optional varargPack; - // All constraints belonging to this scope. - std::vector constraints; - - std::optional lookup(Symbol sym); - std::optional lookupTypeBinding(const Name& name); - std::optional lookupTypePackBinding(const Name& name); -}; - } // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index f3c3ae9a..6ad38f9d 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -139,6 +139,8 @@ struct FindDirty : Tarjan { std::vector dirty; + void clearTarjan(); + // Get/set the dirty bit for an index (grows the vector if needed) bool getDirty(int index); void setDirty(int index, bool d); @@ -176,6 +178,8 @@ public: TypeArena* arena; DenseHashMap newTypes{nullptr}; DenseHashMap newPacks{nullptr}; + DenseHashSet replacedTypes{nullptr}; + DenseHashSet replacedTypePacks{nullptr}; std::optional substitute(TypeId ty); std::optional substitute(TypePackId tp); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 455654d9..c50b2c8c 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -65,28 +65,6 @@ struct Anyification : Substitution } }; -// A substitution which replaces the type parameters of a type function by arguments -struct ApplyTypeFunction : Substitution -{ - ApplyTypeFunction(TypeArena* arena, TypeLevel level) - : Substitution(TxnLog::empty(), arena) - , level(level) - , encounteredForwardedType(false) - { - } - - TypeLevel level; - bool encounteredForwardedType; - std::unordered_map typeArguments; - std::unordered_map typePackArguments; - bool ignoreChildren(TypeId ty) override; - bool ignoreChildren(TypePackId tp) override; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; -}; - struct GenericTypeDefinitions { std::vector genericTypes; @@ -153,7 +131,7 @@ struct TypeChecker const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); TypeId checkBinaryOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); - WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional expectedType = std::nullopt); WithPredicate checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); @@ -180,8 +158,12 @@ struct TypeChecker const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector& argLocations); WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); - WithPredicate checkExprPack(const ScopePtr& scope, const AstExprCall& expr); + + WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr); + WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr); + std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); + std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); @@ -236,10 +218,11 @@ struct TypeChecker void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location); - std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); - std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); + std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors); + std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors); std::optional getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); + std::optional getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); // Reduces the union to its simplest possible shape. // (A | B) | B | C yields A | B | C @@ -316,11 +299,12 @@ private: TypeIdPredicate mkTruthyPredicate(bool sense); - // Returns nullopt if the predicate filters down the TypeId to 0 options. - std::optional filterMap(TypeId type, TypeIdPredicate predicate); + // TODO: Return TypeId only. + std::optional filterMapImpl(TypeId type, TypeIdPredicate predicate); + std::pair, bool> filterMap(TypeId type, TypeIdPredicate predicate); public: - std::optional pickTypesFromSense(TypeId type, bool sense); + std::pair, bool> pickTypesFromSense(TypeId type, bool sense); private: TypeId unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes = true); @@ -345,6 +329,7 @@ private: TypePackId freshTypePack(TypeLevel level); TypeId resolveType(const ScopePtr& scope, const AstType& annotation); + TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, @@ -412,8 +397,12 @@ public: const TypeId booleanType; const TypeId threadType; const TypeId anyType; + const TypeId unknownType; + const TypeId neverType; const TypePackId anyTypePack; + const TypePackId neverTypePack; + const TypePackId uninhabitableTypePack; private: int checkRecursionCount = 0; diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index c1de242f..b17003b1 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -173,5 +173,6 @@ std::pair, std::optional> flatten(TypePackId tp, bool isVariadic(TypePackId tp); bool isVariadic(TypePackId tp, const TxnLog& log); +bool containsNever(TypePackId tp); } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 6ad6b927..6a13b11c 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -24,7 +24,7 @@ namespace Luau { struct TypeArena; -struct Scope2; +struct Scope; /** * There are three kinds of type variables: @@ -143,7 +143,7 @@ struct ConstrainedTypeVar std::vector parts; TypeLevel level; - Scope2* scope = nullptr; + Scope* scope = nullptr; }; // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md @@ -223,12 +223,16 @@ struct GenericTypeDefinition { TypeId ty; std::optional defaultValue; + + bool operator==(const GenericTypeDefinition& rhs) const; }; struct GenericTypePackDefinition { TypePackId tp; std::optional defaultValue; + + bool operator==(const GenericTypePackDefinition& rhs) const; }; struct FunctionArgument @@ -275,7 +279,7 @@ struct FunctionTypeVar std::optional defn = {}, bool hasSelf = false); TypeLevel level; - Scope2* scope = nullptr; + Scope* scope = nullptr; /// These should all be generic std::vector generics; std::vector genericPacks; @@ -344,7 +348,7 @@ struct TableTypeVar TableState state = TableState::Unsealed; TypeLevel level; - Scope2* scope = nullptr; + Scope* scope = nullptr; std::optional name; // Sometimes we throw a type on a name to make for nicer error messages, but without creating any entry in the type namespace @@ -426,6 +430,12 @@ struct TypeFun TypeId type; TypeFun() = default; + + explicit TypeFun(TypeId ty) + : type(ty) + { + } + TypeFun(std::vector typeParams, TypeId type) : typeParams(std::move(typeParams)) , type(type) @@ -438,6 +448,27 @@ struct TypeFun , type(type) { } + + bool operator==(const TypeFun& rhs) const; +}; + +/** Represents a pending type alias instantiation. + * + * In order to afford (co)recursive type aliases, we need to reason about a + * partially-complete instantiation. This requires encoding more information in + * a type variable than a BlockedTypeVar affords, hence this. Each + * PendingExpansionTypeVar has a corresponding TypeAliasExpansionConstraint + * enqueued in the solver to convert it to an actual instantiated type + */ +struct PendingExpansionTypeVar +{ + PendingExpansionTypeVar(TypeFun fn, std::vector typeArguments, std::vector packArguments); + TypeFun fn; + std::vector typeArguments; + std::vector packArguments; + size_t index; + + static size_t nextIndex; }; // Anything! All static checking is off. @@ -460,10 +491,20 @@ struct LazyTypeVar std::function thunk; }; +struct UnknownTypeVar +{ +}; + +struct NeverTypeVar +{ +}; + using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = + Unifiable::Variant; + struct TypeVar final { @@ -575,8 +616,12 @@ struct SingletonTypes const TypeId trueType; const TypeId falseType; const TypeId anyType; + const TypeId unknownType; + const TypeId neverType; const TypePackId anyTypePack; + const TypePackId neverTypePack; + const TypePackId uninhabitableTypePack; SingletonTypes(); ~SingletonTypes(); @@ -632,12 +677,30 @@ T* getMutable(TypeId tv) return get_if(&asMutable(tv)->ty); } -/* Traverses the UnionTypeVar yielding each TypeId. - * If the iterator encounters a nested UnionTypeVar, it will instead yield each TypeId within. - * - * Beware: the iterator does not currently filter for unique TypeIds. This may change in the future. +const std::vector& getTypes(const UnionTypeVar* utv); +const std::vector& getTypes(const IntersectionTypeVar* itv); +const std::vector& getTypes(const ConstrainedTypeVar* ctv); + +template +struct TypeIterator; + +using UnionTypeVarIterator = TypeIterator; +UnionTypeVarIterator begin(const UnionTypeVar* utv); +UnionTypeVarIterator end(const UnionTypeVar* utv); + +using IntersectionTypeVarIterator = TypeIterator; +IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv); +IntersectionTypeVarIterator end(const IntersectionTypeVar* itv); + +using ConstrainedTypeVarIterator = TypeIterator; +ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv); +ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv); + +/* Traverses the type T yielding each TypeId. + * If the iterator encounters a nested type T, it will instead yield each TypeId within. */ -struct UnionTypeVarIterator +template +struct TypeIterator { using value_type = Luau::TypeId; using pointer = value_type*; @@ -645,33 +708,116 @@ struct UnionTypeVarIterator using difference_type = size_t; using iterator_category = std::input_iterator_tag; - explicit UnionTypeVarIterator(const UnionTypeVar* utv); + explicit TypeIterator(const T* t) + { + LUAU_ASSERT(t); - UnionTypeVarIterator& operator++(); - UnionTypeVarIterator operator++(int); - bool operator!=(const UnionTypeVarIterator& rhs); - bool operator==(const UnionTypeVarIterator& rhs); + const std::vector& types = getTypes(t); + if (!types.empty()) + stack.push_front({t, 0}); - const TypeId& operator*(); + seen.insert(t); + } - friend UnionTypeVarIterator end(const UnionTypeVar* utv); + TypeIterator& operator++() + { + advance(); + descend(); + return *this; + } + + TypeIterator operator++(int) + { + TypeIterator copy = *this; + ++copy; + return copy; + } + + bool operator==(const TypeIterator& rhs) const + { + if (!stack.empty() && !rhs.stack.empty()) + return stack.front() == rhs.stack.front(); + + return stack.empty() && rhs.stack.empty(); + } + + bool operator!=(const TypeIterator& rhs) const + { + return !(*this == rhs); + } + + const TypeId& operator*() + { + LUAU_ASSERT(!stack.empty()); + + descend(); + + auto [t, currentIndex] = stack.front(); + LUAU_ASSERT(t); + const std::vector& types = getTypes(t); + LUAU_ASSERT(currentIndex < types.size()); + + const TypeId& ty = types[currentIndex]; + LUAU_ASSERT(!get(follow(ty))); + return ty; + } + + // Normally, we'd have `begin` and `end` be a template but there's too much trouble + // with templates portability in this area, so not worth it. Thanks MSVC. + friend UnionTypeVarIterator end(const UnionTypeVar*); + friend IntersectionTypeVarIterator end(const IntersectionTypeVar*); + friend ConstrainedTypeVarIterator end(const ConstrainedTypeVar*); private: - UnionTypeVarIterator() = default; + TypeIterator() = default; - // (UnionTypeVar* utv, size_t currentIndex) - using SavedIterInfo = std::pair; + // (T* t, size_t currentIndex) + using SavedIterInfo = std::pair; std::deque stack; - std::unordered_set seen; // Only needed to protect the iterator from hanging the thread. + std::unordered_set seen; // Only needed to protect the iterator from hanging the thread. - void advance(); - void descend(); + void advance() + { + while (!stack.empty()) + { + auto& [t, currentIndex] = stack.front(); + ++currentIndex; + + const std::vector& types = getTypes(t); + if (currentIndex >= types.size()) + stack.pop_front(); + else + break; + } + } + + void descend() + { + while (!stack.empty()) + { + auto [current, currentIndex] = stack.front(); + const std::vector& types = getTypes(current); + if (auto inner = get(follow(types[currentIndex]))) + { + // If we're about to descend into a cyclic type, we should skip over this. + // Ideally this should never happen, but alas it does from time to time. :( + if (seen.find(inner) != seen.end()) + advance(); + else + { + seen.insert(inner); + stack.push_front({inner, 0}); + } + + continue; + } + + break; + } + } }; -UnionTypeVarIterator begin(const UnionTypeVar* utv); -UnionTypeVarIterator end(const UnionTypeVar* utv); - using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 4ff91714..e5eb4198 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -8,7 +8,7 @@ namespace Luau { -struct Scope2; +struct Scope; /** * The 'level' of a TypeVar is an indirect way to talk about the scope that it 'belongs' too. @@ -84,11 +84,11 @@ using Name = std::string; struct Free { explicit Free(TypeLevel level); - explicit Free(Scope2* scope); + explicit Free(Scope* scope); int index; TypeLevel level; - Scope2* scope = nullptr; + Scope* scope = nullptr; // True if this free type variable is part of a mutually // recursive type alias whose definitions haven't been // resolved yet. @@ -115,13 +115,13 @@ struct Generic Generic(); explicit Generic(TypeLevel level); explicit Generic(const Name& name); - explicit Generic(Scope2* scope); + explicit Generic(Scope* scope); Generic(TypeLevel level, const Name& name); - Generic(Scope2* scope, const Name& name); + Generic(Scope* scope, const Name& name); int index; TypeLevel level; - Scope2* scope = nullptr; + Scope* scope = nullptr; Name name; bool explicitName = false; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 4af324cb..f460dc87 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -79,6 +79,7 @@ private: void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); + void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 5fd43f0b..7e5d71d6 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -9,7 +9,7 @@ #include "Luau/TypeVar.h" LUAU_FASTINT(LuauVisitRecursionLimit) -LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) +LUAU_FASTFLAG(LuauCompleteVisitor); namespace Luau { @@ -129,6 +129,14 @@ struct GenericTypeVarVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const UnknownTypeVar& utv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const NeverTypeVar& ntv) + { + return visit(ty); + } virtual bool visit(TypeId ty, const UnionTypeVar& utv) { return visit(ty); @@ -137,6 +145,18 @@ struct GenericTypeVarVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const BlockedTypeVar& btv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const PendingExpansionTypeVar& petv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const SingletonTypeVar& stv) + { + return visit(ty); + } virtual bool visit(TypePackId tp) { @@ -182,16 +202,12 @@ struct GenericTypeVarVisitor if (visit(ty, *btv)) traverse(btv->boundTo); } - else if (auto ftv = get(ty)) visit(ty, *ftv); - else if (auto gtv = get(ty)) visit(ty, *gtv); - else if (auto etv = get(ty)) visit(ty, *etv); - else if (auto ctv = get(ty)) { if (visit(ty, *ctv)) @@ -200,10 +216,8 @@ struct GenericTypeVarVisitor traverse(part); } } - else if (auto ptv = get(ty)) visit(ty, *ptv); - else if (auto ftv = get(ty)) { if (visit(ty, *ftv)) @@ -212,7 +226,6 @@ struct GenericTypeVarVisitor traverse(ftv->retTypes); } } - else if (auto ttv = get(ty)) { // Some visitors want to see bound tables, that's why we traverse the original type @@ -235,7 +248,6 @@ struct GenericTypeVarVisitor } } } - else if (auto mtv = get(ty)) { if (visit(ty, *mtv)) @@ -244,7 +256,6 @@ struct GenericTypeVarVisitor traverse(mtv->metatable); } } - else if (auto ctv = get(ty)) { if (visit(ty, *ctv)) @@ -259,10 +270,8 @@ struct GenericTypeVarVisitor traverse(*ctv->metatable); } } - else if (auto atv = get(ty)) visit(ty, *atv); - else if (auto utv = get(ty)) { if (visit(ty, *utv)) @@ -271,7 +280,6 @@ struct GenericTypeVarVisitor traverse(optTy); } } - else if (auto itv = get(ty)) { if (visit(ty, *itv)) @@ -280,6 +288,53 @@ struct GenericTypeVarVisitor traverse(partTy); } } + else if (get(ty)) + { + // Visiting into LazyTypeVar may necessarily cause infinite expansion, so we don't do that on purpose. + // Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassTypeVar + // that doesn't need to be expanded. + } + else if (auto stv = get(ty)) + visit(ty, *stv); + else if (auto btv = get(ty)) + visit(ty, *btv); + else if (auto utv = get(ty)) + visit(ty, *utv); + else if (auto ntv = get(ty)) + visit(ty, *ntv); + else if (auto petv = get(ty)) + { + if (visit(ty, *petv)) + { + traverse(petv->fn.type); + + for (const GenericTypeDefinition& p : petv->fn.typeParams) + { + traverse(p.ty); + + if (p.defaultValue) + traverse(*p.defaultValue); + } + + for (const GenericTypePackDefinition& p : petv->fn.typePackParams) + { + traverse(p.tp); + + if (p.defaultValue) + traverse(*p.defaultValue); + } + + for (TypeId a : petv->typeArguments) + traverse(a); + + for (TypePackId a : petv->packArguments) + traverse(a); + } + } + else if (!FFlag::LuauCompleteVisitor) + return visit_detail::unsee(seen, ty); + else + LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypeId) is not exhaustive!"); visit_detail::unsee(seen, ty); } @@ -310,7 +365,7 @@ struct GenericTypeVarVisitor else if (auto pack = get(tp)) { bool res = visit(tp, *pack); - if (!FFlag::LuauNormalizeFlagIsConservative || res) + if (res) { for (TypeId ty : pack->head) traverse(ty); @@ -322,7 +377,7 @@ struct GenericTypeVarVisitor else if (auto pack = get(tp)) { bool res = visit(tp, *pack); - if (!FFlag::LuauNormalizeFlagIsConservative || res) + if (res) traverse(pack->ty); } else diff --git a/Analysis/src/ApplyTypeFunction.cpp b/Analysis/src/ApplyTypeFunction.cpp new file mode 100644 index 00000000..c6ac3e19 --- /dev/null +++ b/Analysis/src/ApplyTypeFunction.cpp @@ -0,0 +1,60 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ApplyTypeFunction.h" + +namespace Luau +{ + +bool ApplyTypeFunction::isDirty(TypeId ty) +{ + if (typeArguments.count(ty)) + return true; + else if (const FreeTypeVar* ftv = get(ty)) + { + if (ftv->forwardedTypeAlias) + encounteredForwardedType = true; + return false; + } + else + return false; +} + +bool ApplyTypeFunction::isDirty(TypePackId tp) +{ + if (typePackArguments.count(tp)) + return true; + else + return false; +} + +bool ApplyTypeFunction::ignoreChildren(TypeId ty) +{ + if (get(ty)) + return true; + else + return false; +} + +bool ApplyTypeFunction::ignoreChildren(TypePackId tp) +{ + if (get(tp)) + return true; + else + return false; +} + +TypeId ApplyTypeFunction::clean(TypeId ty) +{ + TypeId& arg = typeArguments[ty]; + LUAU_ASSERT(arg); + return arg; +} + +TypePackId ApplyTypeFunction::clean(TypePackId tp) +{ + TypePackId& arg = typePackArguments[tp]; + LUAU_ASSERT(arg); + return arg; +} + +} // namespace Luau diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp similarity index 87% rename from Analysis/src/JsonEncoder.cpp rename to Analysis/src/AstJsonEncoder.cpp index 829ffa02..2897875d 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -1,7 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/JsonEncoder.h" +#include "Luau/AstJsonEncoder.h" #include "Luau/Ast.h" +#include "Luau/ParseResult.h" #include "Luau/StringUtils.h" #include "Luau/Common.h" @@ -75,6 +76,11 @@ struct AstJsonEncoder : public AstVisitor writeRaw(std::string_view{&c, 1}); } + void writeType(std::string_view propValue) + { + write("type", propValue); + } + template void write(std::string_view propName, const T& value) { @@ -97,8 +103,8 @@ struct AstJsonEncoder : public AstVisitor void write(double d) { - char b[256]; - sprintf(b, "%g", d); + char b[32]; + snprintf(b, sizeof(b), "%.17g", d); writeRaw(b); } @@ -111,8 +117,12 @@ struct AstJsonEncoder : public AstVisitor { if (c == '"') writeRaw("\\\""); - else if (c == '\0') - writeRaw("\\\0"); + else if (c == '\\') + writeRaw("\\\\"); + else if (c < ' ') + writeRaw(format("\\u%04x", c)); + else if (c == '\n') + writeRaw("\\n"); else writeRaw(c); } @@ -189,10 +199,11 @@ struct AstJsonEncoder : public AstVisitor writeRaw("{"); bool c = pushComma(); if (local->annotation != nullptr) - write("type", local->annotation); + write("luauType", local->annotation); else - write("type", nullptr); + write("luauType", nullptr); write("name", local->name); + writeType("AstLocal"); write("location", local->location); popComma(c); writeRaw("}"); @@ -208,7 +219,7 @@ struct AstJsonEncoder : public AstVisitor { writeRaw("{"); bool c = pushComma(); - write("type", name); + writeType(name); writeNode(node); f(); popComma(c); @@ -358,6 +369,7 @@ struct AstJsonEncoder : public AstVisitor { writeRaw("{"); bool c = pushComma(); + writeType("AstTypeList"); write("types", typeList.types); if (typeList.tailType) write("tailType", typeList.tailType); @@ -369,9 +381,10 @@ struct AstJsonEncoder : public AstVisitor { writeRaw("{"); bool c = pushComma(); + writeType("AstGenericType"); write("name", genericType.name); if (genericType.defaultValue) - write("type", genericType.defaultValue); + write("luauType", genericType.defaultValue); popComma(c); writeRaw("}"); } @@ -380,9 +393,10 @@ struct AstJsonEncoder : public AstVisitor { writeRaw("{"); bool c = pushComma(); + writeType("AstGenericTypePack"); write("name", genericTypePack.name); if (genericTypePack.defaultValue) - write("type", genericTypePack.defaultValue); + write("luauType", genericTypePack.defaultValue); popComma(c); writeRaw("}"); } @@ -404,6 +418,7 @@ struct AstJsonEncoder : public AstVisitor { writeRaw("{"); bool c = pushComma(); + writeType("AstExprTableItem"); write("kind", item.kind); switch (item.kind) { @@ -419,6 +434,17 @@ struct AstJsonEncoder : public AstVisitor writeRaw("}"); } + void write(class AstExprIfElse* node) + { + writeNode(node, "AstExprIfElse", [&]() { + PROP(condition); + PROP(hasThen); + PROP(trueExpr); + PROP(hasElse); + PROP(falseExpr); + }); + } + void write(class AstExprTable* node) { writeNode(node, "AstExprTable", [&]() { @@ -431,11 +457,11 @@ struct AstJsonEncoder : public AstVisitor switch (op) { case AstExprUnary::Not: - return writeString("not"); + return writeString("Not"); case AstExprUnary::Minus: - return writeString("minus"); + return writeString("Minus"); case AstExprUnary::Len: - return writeString("len"); + return writeString("Len"); } } @@ -541,7 +567,7 @@ struct AstJsonEncoder : public AstVisitor void write(class AstStatWhile* node) { - writeNode(node, "AtStatWhile", [&]() { + writeNode(node, "AstStatWhile", [&]() { PROP(condition); PROP(body); PROP(hasDo); @@ -684,7 +710,8 @@ struct AstJsonEncoder : public AstVisitor writeRaw("{"); bool c = pushComma(); write("name", prop.name); - write("type", prop.ty); + writeType("AstDeclaredClassProp"); + write("luauType", prop.ty); popComma(c); writeRaw("}"); } @@ -731,8 +758,9 @@ struct AstJsonEncoder : public AstVisitor bool c = pushComma(); write("name", prop.name); + writeType("AstTableProp"); write("location", prop.location); - write("type", prop.type); + write("propType", prop.type); popComma(c); writeRaw("}"); @@ -746,6 +774,24 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(struct AstTableIndexer* indexer) + { + if (indexer) + { + writeRaw("{"); + bool c = pushComma(); + write("location", indexer->location); + write("indexType", indexer->indexType); + write("resultType", indexer->resultType); + popComma(c); + writeRaw("}"); + } + else + { + writeRaw("null"); + } + } + void write(class AstTypeFunction* node) { writeNode(node, "AstTypeFunction", [&]() { @@ -836,6 +882,12 @@ struct AstJsonEncoder : public AstVisitor return false; } + bool visit(class AstExprIfElse* node) override + { + write(node); + return false; + } + bool visit(class AstExprLocal* node) override { write(node); @@ -1093,6 +1145,41 @@ struct AstJsonEncoder : public AstVisitor write(node); return false; } + + void writeComments(std::vector commentLocations) + { + bool commentComma = false; + for (Comment comment : commentLocations) + { + if (commentComma) + { + writeRaw(","); + } + else + { + commentComma = true; + } + writeRaw("{"); + bool c = pushComma(); + switch (comment.type) + { + case Lexeme::Comment: + writeType("Comment"); + break; + case Lexeme::BlockComment: + writeType("BlockComment"); + break; + case Lexeme::BrokenComment: + writeType("BrokenComment"); + break; + default: + break; + } + write("location", comment.location); + popComma(c); + writeRaw("}"); + } + } }; std::string toJson(AstNode* node) @@ -1102,4 +1189,15 @@ std::string toJson(AstNode* node) return encoder.str(); } +std::string toJson(AstNode* node, const std::vector& commentLocations) +{ + AstJsonEncoder encoder; + encoder.writeRaw(R"({"root":)"); + node->visit(&encoder); + encoder.writeRaw(R"(,"commentLocations":[)"); + encoder.writeComments(commentLocations); + encoder.writeRaw("]}"); + return encoder.str(); +} + } // namespace Luau diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 0522b1fa..50299704 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -17,6 +17,104 @@ namespace Luau namespace { + +struct AutocompleteNodeFinder : public AstVisitor +{ + const Position pos; + std::vector ancestry; + + explicit AutocompleteNodeFinder(Position pos, AstNode* root) + : pos(pos) + { + } + + bool visit(AstExpr* expr) override + { + if (expr->location.begin < pos && pos <= expr->location.end) + { + ancestry.push_back(expr); + return true; + } + return false; + } + + bool visit(AstStat* stat) override + { + if (stat->location.begin < pos && pos <= stat->location.end) + { + ancestry.push_back(stat); + return true; + } + return false; + } + + bool visit(AstType* type) override + { + if (type->location.begin < pos && pos <= type->location.end) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(AstTypeError* type) override + { + // For a missing type, match the whole range including the start position + if (type->isMissing && type->location.containsClosed(pos)) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(class AstTypePack* typePack) override + { + return true; + } + + bool visit(AstStatBlock* block) override + { + // If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite. + if (ancestry.empty()) + { + ancestry.push_back(block); + return true; + } + + // AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes. + // ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz} + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // Type annotation error might intersect the block statement when the function header is being written, + // annotation takes priority + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // If the cursor is at the end of an expression or type and simultaneously at the beginning of a block, + // the expression or type wins out. + // The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to + // be within the block. + if (block->location.begin == pos && !ancestry.empty()) + { + if (ancestry.back()->asExpr() && !ancestry.back()->is()) + return false; + + if (ancestry.back()->asType()) + return false; + } + + if (block->location.begin <= pos && pos <= block->location.end) + { + ancestry.push_back(block); + return true; + } + return false; + } +}; + struct FindNode : public AstVisitor { const Position pos; @@ -102,6 +200,13 @@ struct FindFullAncestry final : public AstVisitor } // namespace +std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos) +{ + AutocompleteNodeFinder finder{pos, source.root}; + source.root->visit(&finder); + return finder.ancestry; +} + std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos) { const Position end = source.root->location.end; @@ -110,7 +215,7 @@ std::vector findAstAncestryOfPosition(const SourceModule& source, Posi FindFullAncestry finder(pos, end); source.root->visit(&finder); - return std::move(finder.nodes); + return finder.nodes; } AstNode* findNodeAtPosition(const SourceModule& source, Position pos) @@ -209,7 +314,7 @@ std::optional findBindingAtPosition(const Module& module, const SourceM auto iter = currentScope->bindings.find(name); if (iter != currentScope->bindings.end() && iter->second.location.begin <= pos) { - /* Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope */ + // Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope std::optional bindingStatement = findBindingLocalStatement(source, iter->second); if (!bindingStatement || !(*bindingStatement)->location.contains(pos)) return iter->second; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 8a63901f..a57a789f 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -7,13 +7,12 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/Parser.h" // TODO: only needed for autocompleteSource which is deprecated #include #include #include -LUAU_FASTFLAG(LuauSelfCallAutocompleteFix2) +LUAU_FASTFLAG(LuauSelfCallAutocompleteFix3) static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -21,102 +20,6 @@ static const std::unordered_set kStatementStartingKeywords = { namespace Luau { -struct NodeFinder : public AstVisitor -{ - const Position pos; - std::vector ancestry; - - explicit NodeFinder(Position pos, AstNode* root) - : pos(pos) - { - } - - bool visit(AstExpr* expr) override - { - if (expr->location.begin < pos && pos <= expr->location.end) - { - ancestry.push_back(expr); - return true; - } - return false; - } - - bool visit(AstStat* stat) override - { - if (stat->location.begin < pos && pos <= stat->location.end) - { - ancestry.push_back(stat); - return true; - } - return false; - } - - bool visit(AstType* type) override - { - if (type->location.begin < pos && pos <= type->location.end) - { - ancestry.push_back(type); - return true; - } - return false; - } - - bool visit(AstTypeError* type) override - { - // For a missing type, match the whole range including the start position - if (type->isMissing && type->location.containsClosed(pos)) - { - ancestry.push_back(type); - return true; - } - return false; - } - - bool visit(class AstTypePack* typePack) override - { - return true; - } - - bool visit(AstStatBlock* block) override - { - // If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite. - if (ancestry.empty()) - { - ancestry.push_back(block); - return true; - } - - // AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes. - // ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz} - if (!ancestry.empty() && ancestry.back()->is()) - return false; - - // Type annotation error might intersect the block statement when the function header is being written, - // annotation takes priority - if (!ancestry.empty() && ancestry.back()->is()) - return false; - - // If the cursor is at the end of an expression or type and simultaneously at the beginning of a block, - // the expression or type wins out. - // The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to - // be within the block. - if (block->location.begin == pos && !ancestry.empty()) - { - if (ancestry.back()->asExpr() && !ancestry.back()->is()) - return false; - - if (ancestry.back()->asType()) - return false; - } - - if (block->location.begin <= pos && pos <= block->location.end) - { - ancestry.push_back(block); - return true; - } - return false; - } -}; static bool alreadyHasParens(const std::vector& nodes) { @@ -246,7 +149,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ ty = follow(ty); auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2); + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3); InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); @@ -265,7 +168,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*typeAtPosition); auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { - if (FFlag::LuauSelfCallAutocompleteFix2) + if (FFlag::LuauSelfCallAutocompleteFix3) { if (std::optional firstRetTy = first(ftv->retTypes)) return checkTypeMatch(typeArena, *firstRetTy, expectedType); @@ -306,7 +209,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } } - if (FFlag::LuauSelfCallAutocompleteFix2) + if (FFlag::LuauSelfCallAutocompleteFix3) return checkTypeMatch(typeArena, ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; else return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; @@ -323,7 +226,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) { - if (FFlag::LuauSelfCallAutocompleteFix2) + if (FFlag::LuauSelfCallAutocompleteFix3) rootTy = follow(rootTy); ty = follow(ty); @@ -333,7 +236,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId seen.insert(ty); auto isWrongIndexer_DEPRECATED = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2); + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3); if (indexType == PropIndexType::Key) return false; @@ -366,7 +269,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } }; auto isWrongIndexer = [typeArena, rootTy, indexType](Luau::TypeId type) { - LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix2); + LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix3); if (indexType == PropIndexType::Key) return false; @@ -374,21 +277,20 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId bool calledWithSelf = indexType == PropIndexType::Colon; auto isCompatibleCall = [typeArena, rootTy, calledWithSelf](const FunctionTypeVar* ftv) { - if (get(rootTy)) - { - // Calls on classes require strict match between how function is declared and how it's called - return calledWithSelf == ftv->hasSelf; - } + // Strong match with definition is a success + if (calledWithSelf == ftv->hasSelf) + return true; - // If a call is made with ':', it is invalid if a function has incompatible first argument or no arguments at all - // If a call is made with '.', but it was declared with 'self', it is considered invalid if first argument is compatible - if (calledWithSelf || ftv->hasSelf) + // Calls on classes require strict match between how function is declared and how it's called + if (get(rootTy)) + return false; + + // When called with ':', but declared without 'self', it is invalid if a function has incompatible first argument or no arguments at all + // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible + if (std::optional firstArgTy = first(ftv->argTypes)) { - if (std::optional firstArgTy = first(ftv->argTypes)) - { - if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) - return calledWithSelf; - } + if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) + return calledWithSelf; } return !calledWithSelf; @@ -430,7 +332,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryKind::Property, type, prop.deprecated, - FFlag::LuauSelfCallAutocompleteFix2 ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), + FFlag::LuauSelfCallAutocompleteFix3 ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), typeCorrect, containingClass, &prop, @@ -473,7 +375,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId { autocompleteProps(module, typeArena, rootTy, mt->table, indexType, nodes, result, seen); - if (FFlag::LuauSelfCallAutocompleteFix2) + if (FFlag::LuauSelfCallAutocompleteFix3) { if (auto mtable = get(mt->metatable)) fillMetatableProps(mtable); @@ -539,7 +441,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen; - if (!FFlag::LuauSelfCallAutocompleteFix2) + if (!FFlag::LuauSelfCallAutocompleteFix3) innerSeen = seen; if (isNil(*iter)) @@ -565,7 +467,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ++iter; } } - else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix2) + else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix3) { if (pt->metatable) { @@ -573,7 +475,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId fillMetatableProps(mtable); } } - else if (FFlag::LuauSelfCallAutocompleteFix2 && get(get(ty))) + else if (FFlag::LuauSelfCallAutocompleteFix3 && get(get(ty))) { autocompleteProps(module, typeArena, rootTy, getSingletonTypes().stringType, indexType, nodes, result, seen); } @@ -905,7 +807,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi } AstNode* parent = nullptr; - AstType* topType = nullptr; + AstType* topType = nullptr; // TODO: rename? for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it) { @@ -1477,21 +1379,20 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (isWithinComment(sourceModule, position)) return {}; - NodeFinder finder{position, sourceModule.root}; - sourceModule.root->visit(&finder); - LUAU_ASSERT(!finder.ancestry.empty()); - AstNode* node = finder.ancestry.back(); + std::vector ancestry = findAncestryAtPositionForAutocomplete(sourceModule, position); + LUAU_ASSERT(!ancestry.empty()); + AstNode* node = ancestry.back(); AstExprConstantNil dummy{Location{}}; - AstNode* parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy; + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; // If we are inside a body of a function that doesn't have a completed argument list, ignore the body node if (auto exprFunction = parent->as(); exprFunction && !exprFunction->argLocation && node == exprFunction->body) { - finder.ancestry.pop_back(); + ancestry.pop_back(); - node = finder.ancestry.back(); - parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy; + node = ancestry.back(); + parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; } if (auto indexName = node->as()) @@ -1503,48 +1404,48 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - if (!FFlag::LuauSelfCallAutocompleteFix2 && isString(ty)) - return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry), - finder.ancestry}; + if (!FFlag::LuauSelfCallAutocompleteFix3 && isString(ty)) + return { + autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), ancestry}; else - return {autocompleteProps(*module, typeArena, ty, indexType, finder.ancestry), finder.ancestry}; + return {autocompleteProps(*module, typeArena, ty, indexType, ancestry), ancestry}; } else if (auto typeReference = node->as()) { if (typeReference->prefix) - return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), finder.ancestry}; + return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), ancestry}; else - return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; + return {autocompleteTypeNames(*module, position, ancestry), ancestry}; } else if (node->is()) { - return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; + return {autocompleteTypeNames(*module, position, ancestry), ancestry}; } else if (AstStatLocal* statLocal = node->as()) { if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) - return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; else return {}; } - else if (AstStatFor* statFor = extractStat(finder.ancestry)) + else if (AstStatFor* statFor = extractStat(ancestry)) { if (!statFor->hasDo || position < statFor->doLocation.begin) { if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || (statFor->step && statFor->step->location.containsClosed(position))) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; return {}; } - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; } else if (AstStatForIn* statForIn = parent->as(); statForIn && (node->is() || isIdentifier(node))) @@ -1560,7 +1461,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {}; } - return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; } if (!statForIn->hasDo || position <= statForIn->doLocation.begin) @@ -1569,58 +1470,58 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; if (lastExpr->location.containsClosed(position)) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; if (position > lastExpr->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; return {}; // Not sure what this means } } - else if (AstStatForIn* statForIn = extractStat(finder.ancestry)) + else if (AstStatForIn* statForIn = extractStat(ancestry)) { // The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed. // ex "for f in f do" if (!statForIn->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; } else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) { if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; if (!statWhile->hasDo || position < statWhile->doLocation.begin) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; if (statWhile->hasDo && position > statWhile->doLocation.end) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; } - else if (AstStatWhile* statWhile = extractStat(finder.ancestry); statWhile && !statWhile->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + else if (AstStatWhile* statWhile = extractStat(ancestry); statWhile && !statWhile->hasDo) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { - return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, - finder.ancestry}; + return { + {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; } else if (AstStatIf* statIf = parent->as(); statIf && node->is()) { if (statIf->condition->is()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; } - else if (AstStatIf* statIf = extractStat(finder.ancestry); + else if (AstStatIf* statIf = extractStat(ancestry); statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position))) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; - else if (AstStatRepeat* statRepeat = extractStat(finder.ancestry); statRepeat) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; + else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is())) { for (const auto& [kind, key, value] : exprTable->items) @@ -1630,7 +1531,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, finder.ancestry); + auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, ancestry); // Remove keys that are already completed for (const auto& item : exprTable->items) @@ -1644,9 +1545,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // If we know for sure that a key is being written, do not offer general expression suggestions if (!key) - autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position, result); + autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position, result); - return {result, finder.ancestry}; + return {result, ancestry}; } break; @@ -1654,11 +1555,11 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } } else if (isIdentifier(node) && (parent->is() || parent->is())) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; - if (std::optional ret = autocompleteStringParams(sourceModule, module, finder.ancestry, position, callback)) + if (std::optional ret = autocompleteStringParams(sourceModule, module, ancestry, position, callback)) { - return {*ret, finder.ancestry}; + return {*ret, ancestry}; } else if (node->is()) { @@ -1667,14 +1568,14 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto it = module->astExpectedTypes.find(node->asExpr())) autocompleteStringSingleton(*it, false, result); - if (finder.ancestry.size() >= 2) + if (ancestry.size() >= 2) { - if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) + if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) { if (auto it = module->astTypes.find(idxExpr->expr)) - autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result); + autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, ancestry, result); } - else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) + else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) { if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { @@ -1684,7 +1585,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } } - return {result, finder.ancestry}; + return {result, ancestry}; } if (node->is()) @@ -1693,9 +1594,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } if (node->asExpr()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; else if (node->asStat()) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; return {}; } @@ -1725,32 +1626,4 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName return autocompleteResult; } -OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback) -{ - // TODO: Remove #include "Luau/Parser.h" with this function - auto sourceModule = std::make_unique(); - ParseOptions parseOptions; - parseOptions.captureComments = true; - ParseResult result = Parser::parse(source.data(), source.size(), *sourceModule->names, *sourceModule->allocator, parseOptions); - - if (!result.root) - return {AutocompleteResult{}, {}, nullptr}; - - sourceModule->name = "FRAGMENT_SCRIPT"; - sourceModule->root = result.root; - sourceModule->mode = Mode::Strict; - sourceModule->commentLocations = std::move(result.commentLocations); - - TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete; - ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); - - OwningAutocompleteResult autocompleteResult = { - autocomplete(*sourceModule, module, typeChecker, &frontend.arenaForAutocomplete, position, callback), std::move(module), - std::move(sourceModule)}; - - frontend.arenaForAutocomplete.clear(); - - return autocompleteResult; -} - } // namespace Luau diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 2f57e23c..826179b3 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -9,6 +9,8 @@ #include LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauBuiltInMetatableNoBadSynthetic, false) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -222,14 +224,14 @@ void registerBuiltinTypes(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - // setmetatable({ @metatable MT }, MT) -> { @metatable MT } // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } addGlobalBinding(typeChecker, "setmetatable", arena.addType( FunctionTypeVar{ {genericMT}, {}, - arena.addTypePack(TypePack{{tableMetaMT, genericMT}}), + arena.addTypePack(TypePack{{FFlag::LuauUnknownAndNeverType ? tabTy : tableMetaMT, genericMT}}), arena.addTypePack(TypePack{{tableMetaMT}}) } ), "@luau" @@ -309,6 +311,12 @@ static std::optional> magicFunctionSetMetaTable( { auto [paramPack, _predicates] = withPredicate; + if (FFlag::LuauUnknownAndNeverType) + { + if (size(paramPack) < 2 && finite(paramPack)) + return std::nullopt; + } + TypeArena& arena = typechecker.currentModule->internalTypes; std::vector expectedArgs = typechecker.unTypePack(scope, paramPack, 2, expr.location); @@ -316,6 +324,12 @@ static std::optional> magicFunctionSetMetaTable( TypeId target = follow(expectedArgs[0]); TypeId mt = follow(expectedArgs[1]); + if (FFlag::LuauUnknownAndNeverType) + { + typechecker.tablify(target); + typechecker.tablify(mt); + } + if (const auto& tab = get(target)) { if (target->persistent) @@ -324,7 +338,8 @@ static std::optional> magicFunctionSetMetaTable( } else { - typechecker.tablify(mt); + if (!FFlag::LuauUnknownAndNeverType) + typechecker.tablify(mt); const TableTypeVar* mtTtv = get(mt); MetatableTypeVar mtv{target, mt}; @@ -335,7 +350,7 @@ static std::optional> magicFunctionSetMetaTable( if (tableName == metatableName) mtv.syntheticName = tableName; - else + else if (!FFlag::LuauBuiltInMetatableNoBadSynthetic) mtv.syntheticName = "{ @metatable: " + metatableName + ", " + tableName + " }"; } @@ -343,7 +358,10 @@ static std::optional> magicFunctionSetMetaTable( if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1) { - return WithPredicate{}; + if (FFlag::LuauUnknownAndNeverType) + return std::nullopt; + else + return WithPredicate{}; } if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self) @@ -390,11 +408,21 @@ static std::optional> magicFunctionAssert( if (head.size() > 0) { - std::optional newhead = typechecker.pickTypesFromSense(head[0], true); - if (!newhead) - head = {typechecker.nilType}; + auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true); + if (FFlag::LuauUnknownAndNeverType) + { + if (get(*ty)) + head = {*ty}; + else + head[0] = *ty; + } else - head[0] = *newhead; + { + if (!ty) + head = {typechecker.nilType}; + else + head[0] = *ty; + } } return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index df4e0a6b..51ad61d5 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -48,6 +48,7 @@ struct TypeCloner void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); void operator()(const BlockedTypeVar& t); + void operator()(const PendingExpansionTypeVar& t); void operator()(const PrimitiveTypeVar& t); void operator()(const ConstrainedTypeVar& t); void operator()(const SingletonTypeVar& t); @@ -59,6 +60,8 @@ struct TypeCloner void operator()(const UnionTypeVar& t); void operator()(const IntersectionTypeVar& t); void operator()(const LazyTypeVar& t); + void operator()(const UnknownTypeVar& t); + void operator()(const NeverTypeVar& t); }; struct TypePackCloner @@ -164,6 +167,52 @@ void TypeCloner::operator()(const BlockedTypeVar& t) defaultClone(t); } +void TypeCloner::operator()(const PendingExpansionTypeVar& t) +{ + TypeId res = dest.addType(PendingExpansionTypeVar{t.fn, t.typeArguments, t.packArguments}); + PendingExpansionTypeVar* petv = getMutable(res); + LUAU_ASSERT(petv); + + seenTypes[typeId] = res; + + std::vector typeArguments; + for (TypeId arg : t.typeArguments) + typeArguments.push_back(clone(arg, dest, cloneState)); + + std::vector packArguments; + for (TypePackId arg : t.packArguments) + packArguments.push_back(clone(arg, dest, cloneState)); + + TypeFun fn; + fn.type = clone(t.fn.type, dest, cloneState); + + for (const GenericTypeDefinition& param : t.fn.typeParams) + { + TypeId ty = clone(param.ty, dest, cloneState); + std::optional defaultValue = param.defaultValue; + + if (defaultValue) + defaultValue = clone(*defaultValue, dest, cloneState); + + fn.typeParams.push_back(GenericTypeDefinition{ty, defaultValue}); + } + + for (const GenericTypePackDefinition& param : t.fn.typePackParams) + { + TypePackId tp = clone(param.tp, dest, cloneState); + std::optional defaultValue = param.defaultValue; + + if (defaultValue) + defaultValue = clone(*defaultValue, dest, cloneState); + + fn.typePackParams.push_back(GenericTypePackDefinition{tp, defaultValue}); + } + + petv->fn = std::move(fn); + petv->typeArguments = std::move(typeArguments); + petv->packArguments = std::move(packArguments); +} + void TypeCloner::operator()(const PrimitiveTypeVar& t) { defaultClone(t); @@ -310,6 +359,16 @@ void TypeCloner::operator()(const LazyTypeVar& t) defaultClone(t); } +void TypeCloner::operator()(const UnknownTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const NeverTypeVar& t) +{ + defaultClone(t); +} + } // anonymous namespace TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) @@ -440,6 +499,11 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log) ConstrainedTypeVar clone{ctv->level, ctv->parts}; result = dest.addType(std::move(clone)); } + else if (const PendingExpansionTypeVar* petv = get(ty)) + { + PendingExpansionTypeVar clone{petv->fn, petv->typeArguments, petv->packArguments}; + result = dest.addType(std::move(clone)); + } else return result; diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 3b9000cd..ea7037bf 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -14,7 +14,7 @@ namespace Luau const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp ConstraintGraphBuilder::ConstraintGraphBuilder( - const ModuleName& moduleName, TypeArena* arena, NotNull ice, NotNull globalScope) + const ModuleName& moduleName, TypeArena* arena, NotNull ice, NotNull globalScope) : moduleName(moduleName) , singletonTypes(getSingletonTypes()) , arena(arena) @@ -25,36 +25,34 @@ ConstraintGraphBuilder::ConstraintGraphBuilder( LUAU_ASSERT(arena); } -TypeId ConstraintGraphBuilder::freshType(NotNull scope) +TypeId ConstraintGraphBuilder::freshType(const ScopePtr& scope) { - return arena->addType(FreeTypeVar{scope}); + return arena->addType(FreeTypeVar{scope.get()}); } -TypePackId ConstraintGraphBuilder::freshTypePack(NotNull scope) +TypePackId ConstraintGraphBuilder::freshTypePack(const ScopePtr& scope) { - FreeTypePack f{scope}; + FreeTypePack f{scope.get()}; return arena->addTypePack(TypePackVar{std::move(f)}); } -NotNull ConstraintGraphBuilder::childScope(Location location, NotNull parent) +ScopePtr ConstraintGraphBuilder::childScope(Location location, const ScopePtr& parent) { - auto scope = std::make_unique(); - NotNull borrow = NotNull(scope.get()); - scopes.emplace_back(location, std::move(scope)); + auto scope = std::make_shared(parent); + scopes.emplace_back(location, scope); - borrow->parent = parent; - borrow->returnType = parent->returnType; - parent->children.push_back(borrow); + scope->returnType = parent->returnType; + parent->children.push_back(NotNull(scope.get())); - return borrow; + return scope; } -void ConstraintGraphBuilder::addConstraint(NotNull scope, ConstraintV cv) +void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, ConstraintV cv) { scope->constraints.emplace_back(new Constraint{std::move(cv)}); } -void ConstraintGraphBuilder::addConstraint(NotNull scope, std::unique_ptr c) +void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) { scope->constraints.emplace_back(std::move(c)); } @@ -63,25 +61,25 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) { LUAU_ASSERT(scopes.empty()); LUAU_ASSERT(rootScope == nullptr); - scopes.emplace_back(block->location, std::make_unique()); - rootScope = scopes.back().second.get(); - NotNull borrow = NotNull(rootScope); + ScopePtr scope = std::make_shared(singletonTypes.anyTypePack); + rootScope = scope.get(); + scopes.emplace_back(block->location, scope); - rootScope->returnType = freshTypePack(borrow); + rootScope->returnType = freshTypePack(scope); - prepopulateGlobalScope(borrow, block); + prepopulateGlobalScope(scope, block); // TODO: We should share the global scope. - rootScope->typeBindings["nil"] = singletonTypes.nilType; - rootScope->typeBindings["number"] = singletonTypes.numberType; - rootScope->typeBindings["string"] = singletonTypes.stringType; - rootScope->typeBindings["boolean"] = singletonTypes.booleanType; - rootScope->typeBindings["thread"] = singletonTypes.threadType; + rootScope->typeBindings["nil"] = TypeFun{singletonTypes.nilType}; + rootScope->typeBindings["number"] = TypeFun{singletonTypes.numberType}; + rootScope->typeBindings["string"] = TypeFun{singletonTypes.stringType}; + rootScope->typeBindings["boolean"] = TypeFun{singletonTypes.booleanType}; + rootScope->typeBindings["thread"] = TypeFun{singletonTypes.threadType}; - visitBlockWithoutChildScope(borrow, block); + visitBlockWithoutChildScope(scope, block); } -void ConstraintGraphBuilder::visitBlockWithoutChildScope(NotNull scope, AstStatBlock* block) +void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) { RecursionCounter counter{&recursionCount}; @@ -91,11 +89,58 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(NotNull scope, return; } + std::unordered_map aliasDefinitionLocations; + + // In order to enable mutually-recursive type aliases, we need to + // populate the type bindings before we actually check any of the + // alias statements. Since we're not ready to actually resolve + // any of the annotations, we just use a fresh type for now. + for (AstStat* stat : block->body) + { + if (auto alias = stat->as()) + { + if (scope->typeBindings.count(alias->name.value) != 0) + { + auto it = aliasDefinitionLocations.find(alias->name.value); + LUAU_ASSERT(it != aliasDefinitionLocations.end()); + reportError(alias->location, DuplicateTypeDefinition{alias->name.value, it->second}); + continue; + } + + bool hasGenerics = alias->generics.size > 0 || alias->genericPacks.size > 0; + + ScopePtr defnScope = scope; + if (hasGenerics) + { + defnScope = childScope(alias->location, scope); + } + + TypeId initialType = freshType(scope); + TypeFun initialFun = TypeFun{initialType}; + + for (const auto& [name, gen] : createGenerics(defnScope, alias->generics)) + { + initialFun.typeParams.push_back(gen); + defnScope->typeBindings[name] = TypeFun{gen.ty}; + } + + for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks)) + { + initialFun.typePackParams.push_back(genPack); + defnScope->typePackBindings[name] = genPack.tp; + } + + scope->typeBindings[alias->name.value] = std::move(initialFun); + astTypeAliasDefiningScopes[alias] = defnScope; + aliasDefinitionLocations[alias->name.value] = alias->location; + } + } + for (AstStat* stat : block->body) visit(scope, stat); } -void ConstraintGraphBuilder::visit(NotNull scope, AstStat* stat) +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) { RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit}; @@ -103,6 +148,8 @@ void ConstraintGraphBuilder::visit(NotNull scope, AstStat* stat) visit(scope, s); else if (auto s = stat->as()) visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); else if (auto f = stat->as()) visit(scope, f); else if (auto f = stat->as()) @@ -117,26 +164,34 @@ void ConstraintGraphBuilder::visit(NotNull scope, AstStat* stat) visit(scope, i); else if (auto a = stat->as()) visit(scope, a); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); else LUAU_ASSERT(0); } -void ConstraintGraphBuilder::visit(NotNull scope, AstStatLocal* local) +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) { std::vector varTypes; for (AstLocal* local : local->vars) { TypeId ty = freshType(scope); + Location location = local->location; if (local->annotation) { - TypeId annotation = resolveType(scope, local->annotation); + location = local->annotation->location; + TypeId annotation = resolveType(scope, local->annotation, /* topLevel */ true); addConstraint(scope, SubtypeConstraint{ty, annotation}); } varTypes.push_back(ty); - scope->bindings[local] = ty; + scope->bindings[local] = Binding{ty, location}; } for (size_t i = 0; i < local->values.size; ++i) @@ -167,18 +222,38 @@ void ConstraintGraphBuilder::visit(NotNull scope, AstStatLocal* local) } } -void addConstraints(Constraint* constraint, NotNull scope) +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) +{ + auto checkNumber = [&](AstExpr* expr) { + if (!expr) + return; + + TypeId t = check(scope, expr); + addConstraint(scope, SubtypeConstraint{t, singletonTypes.numberType}); + }; + + checkNumber(for_->from); + checkNumber(for_->to); + checkNumber(for_->step); + + ScopePtr forScope = childScope(for_->location, scope); + forScope->bindings[for_->var] = Binding{singletonTypes.numberType, for_->var->location}; + + visit(forScope, for_->body); +} + +void addConstraints(Constraint* constraint, NotNull scope) { scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); for (const auto& c : scope->constraints) constraint->dependencies.push_back(NotNull{c.get()}); - for (NotNull childScope : scope->children) + for (NotNull childScope : scope->children) addConstraints(constraint, childScope); } -void ConstraintGraphBuilder::visit(NotNull scope, AstStatLocalFunction* function) +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function) { // Local // Global @@ -190,21 +265,21 @@ void ConstraintGraphBuilder::visit(NotNull scope, AstStatLocalFunction* LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. functionType = arena->addType(BlockedTypeVar{}); - scope->bindings[function->name] = functionType; + scope->bindings[function->name] = Binding{functionType, function->name->location}; FunctionSignature sig = checkFunctionSignature(scope, function->func); - sig.bodyScope->bindings[function->name] = sig.signature; + sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; checkFunctionBody(sig.bodyScope, function->func); std::unique_ptr c{ - new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}}; - addConstraints(c.get(), sig.bodyScope); + new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}}}; + addConstraints(c.get(), NotNull(sig.bodyScope.get())); addConstraint(scope, std::move(c)); } -void ConstraintGraphBuilder::visit(NotNull scope, AstStatFunction* function) +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* function) { // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. // With or without self @@ -224,9 +299,9 @@ void ConstraintGraphBuilder::visit(NotNull scope, AstStatFunction* funct else { functionType = arena->addType(BlockedTypeVar{}); - scope->bindings[localName->local] = functionType; + scope->bindings[localName->local] = Binding{functionType, localName->location}; } - sig.bodyScope->bindings[localName->local] = sig.signature; + sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location}; } else if (AstExprGlobal* globalName = function->name->as()) { @@ -239,9 +314,9 @@ void ConstraintGraphBuilder::visit(NotNull scope, AstStatFunction* funct else { functionType = arena->addType(BlockedTypeVar{}); - rootScope->bindings[globalName->name] = functionType; + rootScope->bindings[globalName->name] = Binding{functionType, globalName->location}; } - sig.bodyScope->bindings[globalName->name] = sig.signature; + sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; } else if (AstExprIndexName* indexName = function->name->as()) { @@ -268,39 +343,26 @@ void ConstraintGraphBuilder::visit(NotNull scope, AstStatFunction* funct checkFunctionBody(sig.bodyScope, function->func); std::unique_ptr c{ - new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}}; - addConstraints(c.get(), sig.bodyScope); + new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}}}; + addConstraints(c.get(), NotNull(sig.bodyScope.get())); addConstraint(scope, std::move(c)); } -void ConstraintGraphBuilder::visit(NotNull scope, AstStatReturn* ret) +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) { TypePackId exprTypes = checkPack(scope, ret->list); addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); } -void ConstraintGraphBuilder::visit(NotNull scope, AstStatBlock* block) +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) { - NotNull innerScope = childScope(block->location, scope); - - // In order to enable mutually-recursive type aliases, we need to - // populate the type bindings before we actually check any of the - // alias statements. Since we're not ready to actually resolve - // any of the annotations, we just use a fresh type for now. - for (AstStat* stat : block->body) - { - if (auto alias = stat->as()) - { - TypeId initialType = freshType(scope); - scope->typeBindings[alias->name.value] = initialType; - } - } + ScopePtr innerScope = childScope(block->location, scope); visitBlockWithoutChildScope(innerScope, block); } -void ConstraintGraphBuilder::visit(NotNull scope, AstStatAssign* assign) +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { TypePackId varPackId = checkExprList(scope, assign->vars); TypePackId valuePack = checkPack(scope, assign->values); @@ -308,47 +370,66 @@ void ConstraintGraphBuilder::visit(NotNull scope, AstStatAssign* assign) addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}); } -void ConstraintGraphBuilder::visit(NotNull scope, AstStatIf* ifStatement) +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { check(scope, ifStatement->condition); - NotNull thenScope = childScope(ifStatement->thenbody->location, scope); + ScopePtr thenScope = childScope(ifStatement->thenbody->location, scope); visit(thenScope, ifStatement->thenbody); if (ifStatement->elsebody) { - NotNull elseScope = childScope(ifStatement->elsebody->location, scope); + ScopePtr elseScope = childScope(ifStatement->elsebody->location, scope); visit(elseScope, ifStatement->elsebody); } } -void ConstraintGraphBuilder::visit(NotNull scope, AstStatTypeAlias* alias) +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alias) { // TODO: Exported type aliases - // TODO: Generic type aliases - auto it = scope->typeBindings.find(alias->name.value); - // This should always be here since we do a separate pass over the - // AST to set up typeBindings. If it's not, we've somehow skipped - // this alias in that first pass. - LUAU_ASSERT(it != scope->typeBindings.end()); - if (it == scope->typeBindings.end()) + auto bindingIt = scope->typeBindings.find(alias->name.value); + ScopePtr* defnIt = astTypeAliasDefiningScopes.find(alias); + // These will be undefined if the alias was a duplicate definition, in which + // case we just skip over it. + if (bindingIt == scope->typeBindings.end() || defnIt == nullptr) { - ice->ice("Type alias does not have a pre-populated binding", alias->location); + return; } - TypeId ty = resolveType(scope, alias->type); + ScopePtr resolvingScope = *defnIt; + TypeId ty = resolveType(resolvingScope, alias->type, /* topLevel */ true); + + LUAU_ASSERT(get(bindingIt->second.type)); // Rather than using a subtype constraint, we instead directly bind // the free type we generated in the first pass to the resolved type. // This prevents a case where you could cause another constraint to // bind the free alias type to an unrelated type, causing havoc. - asMutable(it->second)->ty.emplace(ty); + asMutable(bindingIt->second.type)->ty.emplace(ty); addConstraint(scope, NameConstraint{ty, alias->name.value}); } -TypePackId ConstraintGraphBuilder::checkPack(NotNull scope, AstArray exprs) +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) +{ + LUAU_ASSERT(global->type); + + TypeId globalTy = resolveType(scope, global->type); + scope->bindings[global->name] = Binding{globalTy, global->location}; +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* global) +{ + LUAU_ASSERT(false); // TODO: implement +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) +{ + LUAU_ASSERT(false); // TODO: implement +} + +TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs) { if (exprs.size == 0) return arena->addTypePack({}); @@ -369,7 +450,7 @@ TypePackId ConstraintGraphBuilder::checkPack(NotNull scope, AstArrayaddTypePack(TypePack{std::move(types), last}); } -TypePackId ConstraintGraphBuilder::checkExprList(NotNull scope, const AstArray& exprs) +TypePackId ConstraintGraphBuilder::checkExprList(const ScopePtr& scope, const AstArray& exprs) { TypePackId result = arena->addTypePack({}); TypePack* resultPack = getMutable(result); @@ -390,7 +471,7 @@ TypePackId ConstraintGraphBuilder::checkExprList(NotNull scope, const As return result; } -TypePackId ConstraintGraphBuilder::checkPack(NotNull scope, AstExpr* expr) +TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr) { RecursionCounter counter{&recursionCount}; @@ -445,7 +526,7 @@ TypePackId ConstraintGraphBuilder::checkPack(NotNull scope, AstExpr* exp return result; } -TypeId ConstraintGraphBuilder::check(NotNull scope, AstExpr* expr) +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr) { RecursionCounter counter{&recursionCount}; @@ -525,7 +606,7 @@ TypeId ConstraintGraphBuilder::check(NotNull scope, AstExpr* expr) return result; } -TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprIndexName* indexName) +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr); TypeId result = freshType(scope); @@ -541,7 +622,7 @@ TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprIndexName* in return result; } -TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprIndexExpr* indexExpr) +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) { TypeId obj = check(scope, indexExpr->expr); TypeId indexType = check(scope, indexExpr->index); @@ -556,7 +637,7 @@ TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprIndexExpr* in return result; } -TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprUnary* unary) +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { TypeId operandType = check(scope, unary->expr); @@ -576,7 +657,7 @@ TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprUnary* unary) return singletonTypes.errorRecoveryType(); } -TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprBinary* binary) +TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary) { TypeId leftType = check(scope, binary->left); TypeId rightType = check(scope, binary->right); @@ -601,7 +682,7 @@ TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprBinary* binar return nullptr; } -TypeId ConstraintGraphBuilder::checkExprTable(NotNull scope, AstExprTable* expr) +TypeId ConstraintGraphBuilder::checkExprTable(const ScopePtr& scope, AstExprTable* expr) { TypeId ty = arena->addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(ty); @@ -651,10 +732,10 @@ TypeId ConstraintGraphBuilder::checkExprTable(NotNull scope, AstExprTabl return ty; } -ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(NotNull parent, AstExprFunction* fn) +ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn) { - Scope2* signatureScope = nullptr; - Scope2* bodyScope = nullptr; + ScopePtr signatureScope = nullptr; + ScopePtr bodyScope = nullptr; TypePackId returnType = nullptr; std::vector genericTypes; @@ -667,25 +748,24 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS // generics properly. if (hasGenerics) { - NotNull signatureBorrow = childScope(fn->location, parent); - signatureScope = signatureBorrow.get(); + signatureScope = childScope(fn->location, parent); // We need to assign returnType before creating bodyScope so that the // return type gets propogated to bodyScope. - returnType = freshTypePack(signatureBorrow); + returnType = freshTypePack(signatureScope); signatureScope->returnType = returnType; - bodyScope = childScope(fn->body->location, signatureBorrow).get(); + bodyScope = childScope(fn->body->location, signatureScope); - std::vector> genericDefinitions = createGenerics(signatureBorrow, fn->generics); - std::vector> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks); + std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); // We do not support default values on function generics, so we only // care about the types involved. for (const auto& [name, g] : genericDefinitions) { genericTypes.push_back(g.ty); - signatureScope->typeBindings[name] = g.ty; + signatureScope->typeBindings[name] = TypeFun{g.ty}; } for (const auto& [name, g] : genericPackDefinitions) @@ -696,11 +776,10 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS } else { - NotNull bodyBorrow = childScope(fn->body->location, parent); - bodyScope = bodyBorrow.get(); + bodyScope = childScope(fn->body->location, parent); - returnType = freshTypePack(bodyBorrow); - bodyBorrow->returnType = returnType; + returnType = freshTypePack(bodyScope); + bodyScope->returnType = returnType; // To eliminate the need to branch on hasGenerics below, we say that the // signature scope is the body scope when there is no real signature @@ -708,27 +787,24 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS signatureScope = bodyScope; } - NotNull bodyBorrow = NotNull(bodyScope); - NotNull signatureBorrow = NotNull(signatureScope); - if (fn->returnAnnotation) { - TypePackId annotatedRetType = resolveTypePack(signatureBorrow, *fn->returnAnnotation); - addConstraint(signatureBorrow, PackSubtypeConstraint{returnType, annotatedRetType}); + TypePackId annotatedRetType = resolveTypePack(signatureScope, *fn->returnAnnotation); + addConstraint(signatureScope, PackSubtypeConstraint{returnType, annotatedRetType}); } std::vector argTypes; for (AstLocal* local : fn->args) { - TypeId t = freshType(signatureBorrow); + TypeId t = freshType(signatureScope); argTypes.push_back(t); - signatureScope->bindings[local] = t; + signatureScope->bindings[local] = Binding{t, local->location}; if (local->annotation) { - TypeId argAnnotation = resolveType(signatureBorrow, local->annotation); - addConstraint(signatureBorrow, SubtypeConstraint{t, argAnnotation}); + TypeId argAnnotation = resolveType(signatureScope, local->annotation, /* topLevel */ true); + addConstraint(signatureScope, SubtypeConstraint{t, argAnnotation}); } } @@ -749,11 +825,11 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS // Undo the workaround we made above: if there's no signature scope, // don't report it. /* signatureScope */ hasGenerics ? signatureScope : nullptr, - /* bodyScope */ bodyBorrow, + /* bodyScope */ bodyScope, }; } -void ConstraintGraphBuilder::checkFunctionBody(NotNull scope, AstExprFunction* fn) +void ConstraintGraphBuilder::checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn) { visitBlockWithoutChildScope(scope, fn->body); @@ -766,20 +842,65 @@ void ConstraintGraphBuilder::checkFunctionBody(NotNull scope, AstExprFun } } -TypeId ConstraintGraphBuilder::resolveType(NotNull scope, AstType* ty) +TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, bool topLevel) { TypeId result = nullptr; if (auto ref = ty->as()) { // TODO: Support imported types w/ require tracing. - // TODO: Support generic type references. LUAU_ASSERT(!ref->prefix); - LUAU_ASSERT(!ref->hasParameterList); - // TODO: If it doesn't exist, should we introduce a free binding? - // This is probably important for handling type aliases. - result = scope->lookupTypeBinding(ref->name.value).value_or(singletonTypes.errorRecoveryType()); + std::optional alias = scope->lookupTypeBinding(ref->name.value); + + if (alias.has_value()) + { + // If the alias is not generic, we don't need to set up a blocked + // type and an instantiation constraint. + if (alias->typeParams.empty() && alias->typePackParams.empty()) + { + result = alias->type; + } + else + { + std::vector parameters; + std::vector packParameters; + + for (const AstTypeOrPack& p : ref->parameters) + { + // We do not enforce the ordering of types vs. type packs here; + // that is done in the parser. + if (p.type) + { + parameters.push_back(resolveType(scope, p.type)); + } + else if (p.typePack) + { + packParameters.push_back(resolveTypePack(scope, p.typePack)); + } + else + { + // This indicates a parser bug: one of these two pointers + // should be set. + LUAU_ASSERT(false); + } + } + + result = arena->addType(PendingExpansionTypeVar{*alias, parameters, packParameters}); + + if (topLevel) + { + addConstraint(scope, TypeAliasExpansionConstraint{ + /* target */ result, + }); + } + } + } + else + { + reportError(ty->location, UnknownSymbol{ref->name.value, UnknownSymbol::Context::Type}); + result = singletonTypes.errorRecoveryType(); + } } else if (auto tab = ty->as()) { @@ -811,7 +932,7 @@ TypeId ConstraintGraphBuilder::resolveType(NotNull scope, AstType* ty) { // TODO: Recursion limit. bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; - Scope2* signatureScope = nullptr; + ScopePtr signatureScope = nullptr; std::vector genericTypes; std::vector genericTypePacks; @@ -820,22 +941,21 @@ TypeId ConstraintGraphBuilder::resolveType(NotNull scope, AstType* ty) // for the generic bindings to live on. if (hasGenerics) { - NotNull signatureBorrow = childScope(fn->location, scope); - signatureScope = signatureBorrow.get(); + signatureScope = childScope(fn->location, scope); - std::vector> genericDefinitions = createGenerics(signatureBorrow, fn->generics); - std::vector> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks); + std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); for (const auto& [name, g] : genericDefinitions) { genericTypes.push_back(g.ty); - signatureBorrow->typeBindings[name] = g.ty; + signatureScope->typeBindings[name] = TypeFun{g.ty}; } for (const auto& [name, g] : genericPackDefinitions) { genericTypePacks.push_back(g.tp); - signatureBorrow->typePackBindings[name] = g.tp; + signatureScope->typePackBindings[name] = g.tp; } } else @@ -843,13 +963,11 @@ TypeId ConstraintGraphBuilder::resolveType(NotNull scope, AstType* ty) // To eliminate the need to branch on hasGenerics below, we say that // the signature scope is the parent scope if we don't have // generics. - signatureScope = scope.get(); + signatureScope = scope; } - NotNull signatureBorrow(signatureScope); - - TypePackId argTypes = resolveTypePack(signatureBorrow, fn->argTypes); - TypePackId returnTypes = resolveTypePack(signatureBorrow, fn->returnTypes); + TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes); + TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes); // TODO: FunctionTypeVar needs a pointer to the scope so that we know // how to quantify/instantiate it. @@ -927,7 +1045,7 @@ TypeId ConstraintGraphBuilder::resolveType(NotNull scope, AstType* ty) return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull scope, AstTypePack* tp) +TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTypePack* tp) { TypePackId result; if (auto expl = tp->as()) @@ -941,7 +1059,15 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull scope, AstTyp } else if (auto gen = tp->as()) { - result = arena->addTypePack(TypePackVar{GenericTypePack{scope, gen->genericName.value}}); + if (std::optional lookup = scope->lookupTypePackBinding(gen->genericName.value)) + { + result = *lookup; + } + else + { + reportError(tp->location, UnknownSymbol{gen->genericName.value, UnknownSymbol::Context::Type}); + result = singletonTypes.errorRecoveryTypePack(); + } } else { @@ -953,7 +1079,7 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull scope, AstTyp return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull scope, const AstTypeList& list) +TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const AstTypeList& list) { std::vector head; @@ -971,12 +1097,12 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull scope, const return arena->addTypePack(TypePack{head, tail}); } -std::vector> ConstraintGraphBuilder::createGenerics(NotNull scope, AstArray generics) +std::vector> ConstraintGraphBuilder::createGenerics(const ScopePtr& scope, AstArray generics) { std::vector> result; for (const auto& generic : generics) { - TypeId genericTy = arena->addType(GenericTypeVar{scope, generic.name.value}); + TypeId genericTy = arena->addType(GenericTypeVar{scope.get(), generic.name.value}); std::optional defaultTy = std::nullopt; if (generic.defaultValue) @@ -992,12 +1118,12 @@ std::vector> ConstraintGraphBuilder::crea } std::vector> ConstraintGraphBuilder::createGenericPacks( - NotNull scope, AstArray generics) + const ScopePtr& scope, AstArray generics) { std::vector> result; for (const auto& generic : generics) { - TypePackId genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope, generic.name.value}}); + TypePackId genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic.name.value}}); std::optional defaultTy = std::nullopt; if (generic.defaultValue) @@ -1012,7 +1138,7 @@ std::vector> ConstraintGraphBuilder:: return result; } -TypeId ConstraintGraphBuilder::flattenPack(NotNull scope, Location location, TypePackId tp) +TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, TypePackId tp) { if (auto f = first(tp)) return *f; @@ -1038,10 +1164,10 @@ void ConstraintGraphBuilder::reportCodeTooComplex(Location location) struct GlobalPrepopulator : AstVisitor { - const NotNull globalScope; + const NotNull globalScope; const NotNull arena; - GlobalPrepopulator(NotNull globalScope, NotNull arena) + GlobalPrepopulator(NotNull globalScope, NotNull arena) : globalScope(globalScope) , arena(arena) { @@ -1050,29 +1176,29 @@ struct GlobalPrepopulator : AstVisitor bool visit(AstStatFunction* function) override { if (AstExprGlobal* g = function->name->as()) - globalScope->bindings[g->name] = arena->addType(BlockedTypeVar{}); + globalScope->bindings[g->name] = Binding{arena->addType(BlockedTypeVar{})}; return true; } }; -void ConstraintGraphBuilder::prepopulateGlobalScope(NotNull globalScope, AstStatBlock* program) +void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) { - GlobalPrepopulator gp{NotNull{globalScope}, arena}; + GlobalPrepopulator gp{NotNull{globalScope.get()}, arena}; program->visit(&gp); } -void collectConstraints(std::vector>& result, NotNull scope) +void collectConstraints(std::vector>& result, NotNull scope) { for (const auto& c : scope->constraints) result.push_back(NotNull{c.get()}); - for (NotNull child : scope->children) + for (NotNull child : scope->children) collectConstraints(result, child); } -std::vector> collectConstraints(NotNull rootScope) +std::vector> collectConstraints(NotNull rootScope) { std::vector> result; collectConstraints(result, rootScope); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 077a4e28..0898f9aa 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1,11 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ApplyTypeFunction.h" #include "Luau/ConstraintSolver.h" #include "Luau/Instantiation.h" #include "Luau/Location.h" #include "Luau/Quantify.h" #include "Luau/ToString.h" #include "Luau/Unifier.h" +#include "Luau/VisitTypeVar.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); @@ -13,31 +15,195 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { -[[maybe_unused]] static void dumpBindings(NotNull scope, ToStringOptions& opts) +[[maybe_unused]] static void dumpBindings(NotNull scope, ToStringOptions& opts) { for (const auto& [k, v] : scope->bindings) { - auto d = toStringDetailed(v, opts); + auto d = toStringDetailed(v.typeId, opts); opts.nameMap = d.nameMap; printf("\t%s : %s\n", k.c_str(), d.name.c_str()); } - for (NotNull child : scope->children) + for (NotNull child : scope->children) dumpBindings(child, opts); } -static void dumpConstraints(NotNull scope, ToStringOptions& opts) +static void dumpConstraints(NotNull scope, ToStringOptions& opts) { for (const ConstraintPtr& c : scope->constraints) { printf("\t%s\n", toString(*c, opts).c_str()); } - for (NotNull child : scope->children) + for (NotNull child : scope->children) dumpConstraints(child, opts); } -void dump(NotNull rootScope, ToStringOptions& opts) +static std::pair, std::vector> saturateArguments( + const TypeFun& fn, const std::vector& rawTypeArguments, const std::vector& rawPackArguments, TypeArena* arena) +{ + std::vector saturatedTypeArguments; + std::vector extraTypes; + std::vector saturatedPackArguments; + + for (size_t i = 0; i < rawTypeArguments.size(); ++i) + { + TypeId ty = rawTypeArguments[i]; + + if (i < fn.typeParams.size()) + saturatedTypeArguments.push_back(ty); + else + extraTypes.push_back(ty); + } + + // If we collected extra types, put them in a type pack now. This case is + // mutually exclusive with the type pack -> type conversion we do below: + // extraTypes will only have elements in it if we have more types than we + // have parameter slots for them to go into. + if (!extraTypes.empty()) + { + saturatedPackArguments.push_back(arena->addTypePack(extraTypes)); + } + + for (size_t i = 0; i < rawPackArguments.size(); ++i) + { + TypePackId tp = rawPackArguments[i]; + + // If we are short on regular type saturatedTypeArguments and we have a single + // element type pack, we can decompose that to the type it contains and + // use that as a type parameter. + if (saturatedTypeArguments.size() < fn.typeParams.size() && size(tp) == 1 && finite(tp) && first(tp) && saturatedPackArguments.empty()) + { + saturatedTypeArguments.push_back(*first(tp)); + } + else + { + saturatedPackArguments.push_back(tp); + } + } + + size_t typesProvided = saturatedTypeArguments.size(); + size_t typesRequired = fn.typeParams.size(); + + size_t packsProvided = saturatedPackArguments.size(); + size_t packsRequired = fn.typePackParams.size(); + + // Extra types should be accumulated in extraTypes, not saturatedTypeArguments. Extra + // packs will be accumulated in saturatedPackArguments, so we don't have an + // assertion for that. + LUAU_ASSERT(typesProvided <= typesRequired); + + // If we didn't provide enough types, but we did provide a type pack, we + // don't want to use defaults. The rationale for this is that if the user + // provides a pack but doesn't provide enough types, we want to report an + // error, rather than simply using the default saturatedTypeArguments, if they exist. If + // they did provide enough types, but not enough packs, we of course want to + // use the default packs. + bool needsDefaults = (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired); + + if (needsDefaults) + { + // Default types can reference earlier types. It's legal to write + // something like + // type T = (A, B) -> number + // and we need to respect that. We use an ApplyTypeFunction for this. + ApplyTypeFunction atf{arena}; + + for (size_t i = 0; i < typesProvided; ++i) + atf.typeArguments[fn.typeParams[i].ty] = saturatedTypeArguments[i]; + + for (size_t i = typesProvided; i < typesRequired; ++i) + { + TypeId defaultTy = fn.typeParams[i].defaultValue.value_or(nullptr); + + // We will fill this in with the error type later. + if (!defaultTy) + break; + + TypeId instantiatedDefault = atf.substitute(defaultTy).value_or(getSingletonTypes().errorRecoveryType()); + atf.typeArguments[fn.typeParams[i].ty] = instantiatedDefault; + saturatedTypeArguments.push_back(instantiatedDefault); + } + + for (size_t i = 0; i < packsProvided; ++i) + { + atf.typePackArguments[fn.typePackParams[i].tp] = saturatedPackArguments[i]; + } + + for (size_t i = packsProvided; i < packsRequired; ++i) + { + TypePackId defaultTp = fn.typePackParams[i].defaultValue.value_or(nullptr); + + // We will fill this in with the error type pack later. + if (!defaultTp) + break; + + TypePackId instantiatedDefault = atf.substitute(defaultTp).value_or(getSingletonTypes().errorRecoveryTypePack()); + atf.typePackArguments[fn.typePackParams[i].tp] = instantiatedDefault; + saturatedPackArguments.push_back(instantiatedDefault); + } + } + + // If we didn't create an extra type pack from overflowing parameter packs, + // and we're still missing a type pack, plug in an empty type pack as the + // value of the empty packs. + if (extraTypes.empty() && saturatedPackArguments.size() + 1 == fn.typePackParams.size()) + { + saturatedPackArguments.push_back(arena->addTypePack({})); + } + + // We need to have _something_ when we substitute the generic saturatedTypeArguments, + // even if they're missing, so we use the error type as a filler. + for (size_t i = saturatedTypeArguments.size(); i < typesRequired; ++i) + { + saturatedTypeArguments.push_back(getSingletonTypes().errorRecoveryType()); + } + + for (size_t i = saturatedPackArguments.size(); i < packsRequired; ++i) + { + saturatedPackArguments.push_back(getSingletonTypes().errorRecoveryTypePack()); + } + + // At this point, these two conditions should be true. If they aren't we + // will run into access violations. + LUAU_ASSERT(saturatedTypeArguments.size() == fn.typeParams.size()); + LUAU_ASSERT(saturatedPackArguments.size() == fn.typePackParams.size()); + + return {saturatedTypeArguments, saturatedPackArguments}; +} + +bool InstantiationSignature::operator==(const InstantiationSignature& rhs) const +{ + return fn == rhs.fn && arguments == rhs.arguments && packArguments == rhs.packArguments; +} + +size_t HashInstantiationSignature::operator()(const InstantiationSignature& signature) const +{ + size_t hash = std::hash{}(signature.fn.type); + for (const GenericTypeDefinition& p : signature.fn.typeParams) + { + hash ^= (std::hash{}(p.ty) << 1); + } + + for (const GenericTypePackDefinition& p : signature.fn.typePackParams) + { + hash ^= (std::hash{}(p.tp) << 1); + } + + for (const TypeId a : signature.arguments) + { + hash ^= (std::hash{}(a) << 1); + } + + for (const TypePackId a : signature.packArguments) + { + hash ^= (std::hash{}(a) << 1); + } + + return hash; +} + +void dump(NotNull rootScope, ToStringOptions& opts) { printf("constraints:\n"); dumpConstraints(rootScope, opts); @@ -55,7 +221,7 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } } -ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull rootScope) +ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull rootScope) : arena(arena) , constraints(collectConstraints(rootScope)) , rootScope(rootScope) @@ -77,6 +243,7 @@ void ConstraintSolver::run() return; ToStringOptions opts; + opts.exhaustive = true; if (FFlag::DebugLuauLogSolver) { @@ -186,6 +353,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*bc, constraint, force); else if (auto nc = get(*constraint)) success = tryDispatch(*nc, constraint); + else if (auto taec = get(*constraint)) + success = tryDispatch(*taec, constraint); else LUAU_ASSERT(0); @@ -325,6 +494,198 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNullarena); + + if (follow(petv.fn.type) == follow(signature.fn.type) && (signature.arguments != typeArguments || signature.packArguments != packArguments)) + { + foundInfiniteType = true; + return false; + } + + return true; + } +}; + +struct InstantiationQueuer : TypeVarOnceVisitor +{ + ConstraintSolver* solver; + const InstantiationSignature& signature; + + explicit InstantiationQueuer(ConstraintSolver* solver, const InstantiationSignature& signature) + : solver(solver) + , signature(signature) + { + } + + bool visit(TypeId ty, const PendingExpansionTypeVar& petv) override + { + solver->pushConstraint(TypeAliasExpansionConstraint{ty}); + return false; + } +}; + +bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint) +{ + const PendingExpansionTypeVar* petv = get(follow(c.target)); + if (!petv) + { + unblock(c.target); + return true; + } + + auto bindResult = [this, &c](TypeId result) { + asMutable(c.target)->ty.emplace(result); + unblock(c.target); + }; + + // If there are no parameters to the type function we can just use the type + // directly. + if (petv->fn.typeParams.empty() && petv->fn.typePackParams.empty()) + { + bindResult(petv->fn.type); + return true; + } + + auto [typeArguments, packArguments] = saturateArguments(petv->fn, petv->typeArguments, petv->packArguments, arena); + + bool sameTypes = + std::equal(typeArguments.begin(), typeArguments.end(), petv->fn.typeParams.begin(), petv->fn.typeParams.end(), [](auto&& itp, auto&& p) { + return itp == p.ty; + }); + + bool samePacks = std::equal( + packArguments.begin(), packArguments.end(), petv->fn.typePackParams.begin(), petv->fn.typePackParams.end(), [](auto&& itp, auto&& p) { + return itp == p.tp; + }); + + // If we're instantiating the type with its generic saturatedTypeArguments we are + // performing the identity substitution. We can just short-circuit and bind + // to the TypeFun's type. + if (sameTypes && samePacks) + { + bindResult(petv->fn.type); + return true; + } + + InstantiationSignature signature{ + petv->fn, + typeArguments, + packArguments, + }; + + // If we use the same signature, we don't need to bother trying to + // instantiate the alias again, since the instantiation should be + // deterministic. + if (TypeId* cached = instantiatedAliases.find(signature)) + { + bindResult(*cached); + return true; + } + + // In order to prevent infinite types from being expanded and causing us to + // cycle infinitely, we need to scan the type function for cases where we + // expand the same alias with different type saturatedTypeArguments. See + // https://github.com/Roblox/luau/pull/68 for the RFC responsible for this. + // This is a little nicer than using a recursion limit because we can catch + // the infinite expansion before actually trying to expand it. + InfiniteTypeFinder itf{this, signature}; + itf.traverse(petv->fn.type); + + if (itf.foundInfiniteType) + { + // TODO (CLI-56761): Report an error. + bindResult(getSingletonTypes().errorRecoveryType()); + return true; + } + + ApplyTypeFunction applyTypeFunction{arena}; + for (size_t i = 0; i < typeArguments.size(); ++i) + { + applyTypeFunction.typeArguments[petv->fn.typeParams[i].ty] = typeArguments[i]; + } + + for (size_t i = 0; i < packArguments.size(); ++i) + { + applyTypeFunction.typePackArguments[petv->fn.typePackParams[i].tp] = packArguments[i]; + } + + std::optional maybeInstantiated = applyTypeFunction.substitute(petv->fn.type); + // Note that ApplyTypeFunction::encounteredForwardedType is never set in + // DCR, because we do not use free types for forward-declared generic + // aliases. + + if (!maybeInstantiated.has_value()) + { + // TODO (CLI-56761): Report an error. + bindResult(getSingletonTypes().errorRecoveryType()); + return true; + } + + TypeId instantiated = *maybeInstantiated; + TypeId target = follow(instantiated); + // Type function application will happily give us the exact same type if + // there are e.g. generic saturatedTypeArguments that go unused. + bool needsClone = follow(petv->fn.type) == target; + // Only tables have the properties we're trying to set. + TableTypeVar* ttv = getMutableTableType(target); + + if (ttv) + { + if (needsClone) + { + // Substitution::clone is a shallow clone. If this is a + // metatable type, we want to mutate its table, so we need to + // explicitly clone that table as well. If we don't, we will + // mutate another module's type surface and cause a + // use-after-free. + if (get(target)) + { + instantiated = applyTypeFunction.clone(target); + MetatableTypeVar* mtv = getMutable(instantiated); + mtv->table = applyTypeFunction.clone(mtv->table); + ttv = getMutable(mtv->table); + } + else if (get(target)) + { + instantiated = applyTypeFunction.clone(target); + ttv = getMutable(instantiated); + } + + target = follow(instantiated); + } + + ttv->instantiatedTypeParams = typeArguments; + ttv->instantiatedTypePackParams = packArguments; + // TODO: Fill in definitionModuleName. + } + + bindResult(target); + + // The application is not recursive, so we need to queue up application of + // any child type function instantiations within the result in order for it + // to be complete. + InstantiationQueuer queuer{this, signature}; + queuer.traverse(target); + + instantiatedAliases[signature] = target; + + return true; +} + void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { blocked[target].push_back(constraint); @@ -388,7 +749,7 @@ void ConstraintSolver::unblock(TypePackId progressed) bool ConstraintSolver::isBlocked(TypeId ty) { - return nullptr != get(follow(ty)); + return nullptr != get(follow(ty)) || nullptr != get(follow(ty)); } bool ConstraintSolver::isBlocked(NotNull constraint) @@ -415,4 +776,12 @@ void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack) u.log.commit(); } +void ConstraintSolver::pushConstraint(ConstraintV cv) +{ + std::unique_ptr c = std::make_unique(std::move(cv)); + NotNull borrow = NotNull(c.get()); + solverConstraints.push_back(std::move(c)); + unsolvedConstraints.push_back(borrow); +} + } // namespace Luau diff --git a/Analysis/src/ConstraintSolverLogger.cpp b/Analysis/src/ConstraintSolverLogger.cpp index 2f93c280..adb9c54e 100644 --- a/Analysis/src/ConstraintSolverLogger.cpp +++ b/Analysis/src/ConstraintSolverLogger.cpp @@ -2,45 +2,39 @@ #include "Luau/ConstraintSolverLogger.h" +#include "Luau/JsonEmitter.h" + namespace Luau { -static std::string dumpScopeAndChildren(const Scope2* scope, ToStringOptions& opts) +static void dumpScopeAndChildren(const Scope* scope, Json::JsonEmitter& emitter, ToStringOptions& opts) { - std::string output = "{\"bindings\":{"; + emitter.writeRaw("{"); + Json::write(emitter, "bindings"); + emitter.writeRaw(":"); - bool comma = false; - for (const auto& [name, type] : scope->bindings) + Json::ObjectEmitter o = emitter.writeObject(); + + for (const auto& [name, binding] : scope->bindings) { - if (comma) - output += ","; - - output += "\""; - output += name.c_str(); - output += "\": \""; - - ToStringResult result = toStringDetailed(type, opts); + ToStringResult result = toStringDetailed(binding.typeId, opts); opts.nameMap = std::move(result.nameMap); - output += result.name; - output += "\""; - - comma = true; + o.writePair(name.c_str(), result.name); } - output += "},\"children\":["; - comma = false; + o.finish(); + emitter.writeRaw(","); + Json::write(emitter, "children"); + emitter.writeRaw(":"); - for (const Scope2* child : scope->children) + Json::ArrayEmitter a = emitter.writeArray(); + for (const Scope* child : scope->children) { - if (comma) - output += ","; - - output += dumpScopeAndChildren(child, opts); - comma = true; + dumpScopeAndChildren(child, emitter, opts); } - output += "]}"; - return output; + a.finish(); + emitter.writeRaw("}"); } static std::string dumpConstraintsToDot(std::vector>& constraints, ToStringOptions& opts) @@ -80,51 +74,49 @@ static std::string dumpConstraintsToDot(std::vector>& std::string ConstraintSolverLogger::compileOutput() { - std::string output = "["; - bool comma = false; - + Json::JsonEmitter emitter; + emitter.writeRaw("["); for (const std::string& snapshot : snapshots) { - if (comma) - output += ","; - output += snapshot; - - comma = true; + emitter.writeComma(); + emitter.writeRaw(snapshot); } - output += "]"; - return output; + emitter.writeRaw("]"); + return emitter.str(); } -void ConstraintSolverLogger::captureBoundarySnapshot(const Scope2* rootScope, std::vector>& unsolvedConstraints) +void ConstraintSolverLogger::captureBoundarySnapshot(const Scope* rootScope, std::vector>& unsolvedConstraints) { - std::string snapshot = "{\"type\":\"boundary\",\"rootScope\":"; + Json::JsonEmitter emitter; + Json::ObjectEmitter o = emitter.writeObject(); + o.writePair("type", "boundary"); + o.writePair("constraintGraph", dumpConstraintsToDot(unsolvedConstraints, opts)); + emitter.writeComma(); + Json::write(emitter, "rootScope"); + emitter.writeRaw(":"); + dumpScopeAndChildren(rootScope, emitter, opts); + o.finish(); - snapshot += dumpScopeAndChildren(rootScope, opts); - snapshot += ",\"constraintGraph\":\""; - snapshot += dumpConstraintsToDot(unsolvedConstraints, opts); - snapshot += "\"}"; - - snapshots.push_back(std::move(snapshot)); + snapshots.push_back(emitter.str()); } void ConstraintSolverLogger::prepareStepSnapshot( - const Scope2* rootScope, NotNull current, std::vector>& unsolvedConstraints) + const Scope* rootScope, NotNull current, std::vector>& unsolvedConstraints) { - // LUAU_ASSERT(!preparedSnapshot); + Json::JsonEmitter emitter; + Json::ObjectEmitter o = emitter.writeObject(); + o.writePair("type", "step"); + o.writePair("constraintGraph", dumpConstraintsToDot(unsolvedConstraints, opts)); + o.writePair("currentId", std::to_string(reinterpret_cast(current.get()))); + o.writePair("current", toString(*current, opts)); + emitter.writeComma(); + Json::write(emitter, "rootScope"); + emitter.writeRaw(":"); + dumpScopeAndChildren(rootScope, emitter, opts); + o.finish(); - std::string snapshot = "{\"type\":\"step\",\"rootScope\":"; - - snapshot += dumpScopeAndChildren(rootScope, opts); - snapshot += ",\"constraintGraph\":\""; - snapshot += dumpConstraintsToDot(unsolvedConstraints, opts); - snapshot += "\",\"currentId\":\""; - snapshot += std::to_string(reinterpret_cast(current.get())); - snapshot += "\",\"current\":\""; - snapshot += toString(*current, opts); - snapshot += "\"}"; - - preparedSnapshot = std::move(snapshot); + preparedSnapshot = emitter.str(); } void ConstraintSolverLogger::commitPreparedStepSnapshot() diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 1b5275fd..45663531 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(LuauCheckLenMT) +LUAU_FASTFLAG(LuauUnknownAndNeverType) namespace Luau { @@ -116,14 +116,13 @@ declare function typeof(value: T): string -- `assert` has a magic function attached that will give more detailed type information declare function assert(value: T, errorMessage: string?): T -declare function error(message: T, level: number?) - declare function tostring(value: T): string declare function tonumber(value: T, radix: number?): number? declare function rawequal(a: T1, b: T2): boolean declare function rawget(tab: {[K]: V}, k: K): V declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} +declare function rawlen(obj: {[K]: V} | string): number declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? @@ -204,11 +203,13 @@ declare function unpack(tab: {V}, i: number?, j: number?): ...V std::string getBuiltinDefinitionSource() { + std::string result = kBuiltinDefinitionLuaSrc; - // TODO: move this into kBuiltinDefinitionLuaSrc - if (FFlag::LuauCheckLenMT) - result += "declare function rawlen(obj: {[K]: V} | string): number\n"; + if (FFlag::LuauUnknownAndNeverType) + result += "declare function error(message: T, level: number?): never\n"; + else + result += "declare function error(message: T, level: number?)\n"; return result; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index c26493c2..607a90b8 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -102,8 +102,11 @@ void loadModuleIntoScope(TypeChecker& typeChecker, ModulePtr module, ScopePtr sc } } -LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName) +LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName) { + if (!FFlag::DebugLuauDeferredConstraintResolution) + return Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, source, packageName); + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); Luau::Allocator allocator; @@ -121,12 +124,12 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t module.root = parseResult.root; module.mode = Mode::Definition; - ModulePtr checkedModule = typeChecker.check(module, Mode::Definition); + ModulePtr checkedModule = check(module, Mode::Definition, globalScope); if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, checkedModule}; - loadModuleIntoScope(typeChecker, checkedModule, targetScope, packageName); + loadModuleIntoScope(typeChecker, checkedModule, globalScope, packageName); return LoadDefinitionFileResult{true, parseResult, checkedModule}; } @@ -503,6 +506,8 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalastTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); + module->astResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); module->scopes.resize(1); } @@ -692,29 +697,6 @@ LintResult Frontend::lint(const ModuleName& name, std::optional Frontend::lintFragment(std::string_view source, std::optional enabledLintWarnings) -{ - LUAU_TIMETRACE_SCOPE("Frontend::lintFragment", "Frontend"); - - const Config& config = configResolver->getConfig(""); - - SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions); - - uint64_t ignoreLints = LintWarning::parseMask(sourceModule.hotcomments); - - Luau::LintOptions lintOptions = enabledLintWarnings.value_or(config.enabledLint); - lintOptions.warningMask &= ~ignoreLints; - - double timestamp = getTimestamp(); - - std::vector warnings = Luau::lint(sourceModule.root, *sourceModule.names.get(), typeChecker.globalScope, nullptr, - sourceModule.hotcomments, enabledLintWarnings.value_or(config.enabledLint)); - - stats.timeLint += getTimestamp() - timestamp; - - return {std::move(sourceModule), classifyLints(warnings, config)}; -} - LintResult Frontend::lint(const SourceModule& module, std::optional enabledLintWarnings) { LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); @@ -819,35 +801,28 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons return const_cast(this)->getSourceModule(moduleName); } -NotNull Frontend::getGlobalScope2() +NotNull Frontend::getGlobalScope() { - if (!globalScope2) + if (!globalScope) { - const SingletonTypes& singletonTypes = getSingletonTypes(); - - globalScope2 = std::make_unique(); - globalScope2->typeBindings["nil"] = singletonTypes.nilType; - globalScope2->typeBindings["number"] = singletonTypes.numberType; - globalScope2->typeBindings["string"] = singletonTypes.stringType; - globalScope2->typeBindings["boolean"] = singletonTypes.booleanType; - globalScope2->typeBindings["thread"] = singletonTypes.threadType; + globalScope = typeChecker.globalScope; } - return NotNull(globalScope2.get()); + return NotNull(globalScope.get()); } ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope) { ModulePtr result = std::make_shared(); - ConstraintGraphBuilder cgb{sourceModule.name, &result->internalTypes, NotNull(&iceHandler), getGlobalScope2()}; + ConstraintGraphBuilder cgb{sourceModule.name, &result->internalTypes, NotNull(&iceHandler), getGlobalScope()}; cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope)}; cs.run(); - result->scope2s = std::move(cgb.scopes); + result->scopes = std::move(cgb.scopes); result->astTypes = std::move(cgb.astTypes); result->astTypePacks = std::move(cgb.astTypePacks); result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes); @@ -988,7 +963,7 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module { // CLI-43699 // If we can't find the current module name, that's because we bypassed the frontend's initializer - // and called typeChecker.check directly. (This is done by autocompleteSource, for example). + // and called typeChecker.check directly. // In that case, requires will always fail. return std::nullopt; } diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 77c62422..1a6013af 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -4,6 +4,8 @@ #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" +LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) + namespace Luau { @@ -31,6 +33,8 @@ bool Instantiation::ignoreChildren(TypeId ty) { if (log->getMutable(ty)) return true; + else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + return true; else return false; } diff --git a/Analysis/src/JsonEmitter.cpp b/Analysis/src/JsonEmitter.cpp new file mode 100644 index 00000000..e99619ba --- /dev/null +++ b/Analysis/src/JsonEmitter.cpp @@ -0,0 +1,220 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/JsonEmitter.h" + +#include "Luau/StringUtils.h" + +#include + +namespace Luau::Json +{ + +static constexpr int CHUNK_SIZE = 1024; + +ObjectEmitter::ObjectEmitter(NotNull emitter) + : emitter(emitter), finished(false) +{ + comma = emitter->pushComma(); + emitter->writeRaw('{'); +} + +ObjectEmitter::~ObjectEmitter() +{ + finish(); +} + +void ObjectEmitter::finish() +{ + if (finished) + return; + + emitter->writeRaw('}'); + emitter->popComma(comma); + finished = true; +} + +ArrayEmitter::ArrayEmitter(NotNull emitter) + : emitter(emitter), finished(false) +{ + comma = emitter->pushComma(); + emitter->writeRaw('['); +} + +ArrayEmitter::~ArrayEmitter() +{ + finish(); +} + +void ArrayEmitter::finish() +{ + if (finished) + return; + + emitter->writeRaw(']'); + emitter->popComma(comma); + finished = true; +} + +JsonEmitter::JsonEmitter() +{ + newChunk(); +} + +std::string JsonEmitter::str() +{ + return join(chunks, ""); +} + +bool JsonEmitter::pushComma() +{ + bool current = comma; + comma = false; + return current; +} + +void JsonEmitter::popComma(bool c) +{ + comma = c; +} + +void JsonEmitter::writeRaw(std::string_view sv) +{ + if (sv.size() > CHUNK_SIZE) + { + chunks.emplace_back(sv); + newChunk(); + return; + } + + auto& chunk = chunks.back(); + if (chunk.size() + sv.size() < CHUNK_SIZE) + { + chunk.append(sv.data(), sv.size()); + return; + } + + size_t prefix = CHUNK_SIZE - chunk.size(); + chunk.append(sv.data(), prefix); + newChunk(); + + chunks.back().append(sv.data() + prefix, sv.size() - prefix); +} + +void JsonEmitter::writeRaw(char c) +{ + writeRaw(std::string_view{&c, 1}); +} + +void write(JsonEmitter& emitter, bool b) +{ + if (b) + emitter.writeRaw("true"); + else + emitter.writeRaw("false"); +} + +void write(JsonEmitter& emitter, double d) +{ + emitter.writeRaw(std::to_string(d)); +} + +void write(JsonEmitter& emitter, int i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, long i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, long long i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, unsigned int i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, unsigned long i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, unsigned long long i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, std::string_view sv) +{ + emitter.writeRaw('\"'); + + for (char c : sv) + { + if (c == '"') + emitter.writeRaw("\\\""); + else if (c == '\\') + emitter.writeRaw("\\\\"); + else if (c == '\n') + emitter.writeRaw("\\n"); + else if (c < ' ') + emitter.writeRaw(format("\\u%04x", c)); + else + emitter.writeRaw(c); + } + + emitter.writeRaw('\"'); +} + +void write(JsonEmitter& emitter, char c) +{ + write(emitter, std::string_view{&c, 1}); +} + +void write(JsonEmitter& emitter, const char* str) +{ + write(emitter, std::string_view{str, strlen(str)}); +} + +void write(JsonEmitter& emitter, const std::string& str) +{ + write(emitter, std::string_view{str}); +} + +void write(JsonEmitter& emitter, std::nullptr_t) +{ + emitter.writeRaw("null"); +} + +void write(JsonEmitter& emitter, std::nullopt_t) +{ + emitter.writeRaw("null"); +} + +void JsonEmitter::writeComma() +{ + if (comma) + writeRaw(','); + else + comma = true; +} + +ObjectEmitter JsonEmitter::writeObject() +{ + return ObjectEmitter{NotNull(this)}; +} + +ArrayEmitter JsonEmitter::writeArray() +{ + return ArrayEmitter{NotNull(this)}; +} + +void JsonEmitter::newChunk() +{ + chunks.emplace_back(); + chunks.back().reserve(CHUNK_SIZE); +} + +} // namespace Luau::Json diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 50868e56..9fce79a2 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -48,6 +48,7 @@ static const char* kWarningNames[] = { "DuplicateCondition", "MisleadingAndOr", "CommentDirective", + "IntegerParsing", }; // clang-format on @@ -1433,7 +1434,7 @@ private: const char* checkStringFormat(const char* data, size_t size) { const char* flags = "-+ #0"; - const char* options = "cdiouxXeEfgGqs"; + const char* options = "cdiouxXeEfgGqs*"; for (size_t i = 0; i < size; ++i) { @@ -2589,6 +2590,45 @@ private: } }; +class LintIntegerParsing : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintIntegerParsing pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + bool visit(AstExprConstantNumber* node) override + { + switch (node->parseResult) + { + case ConstantNumberParseResult::Ok: + case ConstantNumberParseResult::Malformed: + break; + case ConstantNumberParseResult::BinOverflow: + emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, + "Binary number literal exceeded available precision and has been truncated to 2^64"); + break; + case ConstantNumberParseResult::HexOverflow: + emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, + "Hexadecimal number literal exceeded available precision and has been truncated to 2^64"); + break; + case ConstantNumberParseResult::DoublePrefix: + emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, + "Hexadecimal number literal has a double prefix, which will fail to parse in the future; remove the extra 0x to fix"); + break; + } + + return true; + } +}; + static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names, const ScopePtr& env) { ScopePtr current = env; @@ -2688,6 +2728,21 @@ static void lintComments(LintContext& context, const std::vector& ho else seenMode = true; } + else if (first == "optimize") + { + size_t notspace = hc.content.find_first_not_of(" \t", space); + + if (space == std::string::npos || notspace == std::string::npos) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "optimize directive requires an optimization level"); + else + { + const char* level = hc.content.c_str() + notspace; + + if (strcmp(level, "0") && strcmp(level, "1") && strcmp(level, "2")) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "optimize directive uses unknown optimization level '%s', 0..2 expected", level); + } + } else { static const char* kHotComments[] = { @@ -2695,6 +2750,7 @@ static void lintComments(LintContext& context, const std::vector& ho "nocheck", "nonstrict", "strict", + "optimize", }; if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments))) @@ -2794,6 +2850,9 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc if (context.warningEnabled(LintWarning::Code_CommentDirective)) lintComments(context, hotcomments); + if (context.warningEnabled(LintWarning::Code_IntegerParsing)) + LintIntegerParsing::process(context); + std::sort(context.result.begin(), context.result.end(), WarningComparator()); return context.result; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 95eb125e..2b46da87 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -15,8 +15,8 @@ #include LUAU_FASTFLAG(LuauLowerBoundsCalculation); -LUAU_FASTFLAG(LuauNormalizeFlagIsConservative); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAGVARIABLE(LuauForceExportSurfacesToBeNormal, false); namespace Luau { @@ -99,40 +99,37 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) CloneState cloneState; - ScopePtr moduleScope = FFlag::DebugLuauDeferredConstraintResolution ? nullptr : getModuleScope(); - Scope2* moduleScope2 = FFlag::DebugLuauDeferredConstraintResolution ? getModuleScope2() : nullptr; + ScopePtr moduleScope = getModuleScope(); - TypePackId returnType = FFlag::DebugLuauDeferredConstraintResolution ? moduleScope2->returnType : moduleScope->returnType; + TypePackId returnType = moduleScope->returnType; std::optional varargPack = FFlag::DebugLuauDeferredConstraintResolution ? std::nullopt : moduleScope->varargPack; std::unordered_map* exportedTypeBindings = FFlag::DebugLuauDeferredConstraintResolution ? nullptr : &moduleScope->exportedTypeBindings; returnType = clone(returnType, interfaceTypes, cloneState); - if (moduleScope) + moduleScope->returnType = returnType; + if (varargPack) { - moduleScope->returnType = returnType; - if (varargPack) - { - varargPack = clone(*varargPack, interfaceTypes, cloneState); - moduleScope->varargPack = varargPack; - } - } - else - { - LUAU_ASSERT(moduleScope2); - moduleScope2->returnType = returnType; // TODO varargPack + varargPack = clone(*varargPack, interfaceTypes, cloneState); + moduleScope->varargPack = varargPack; } + ForceNormal forceNormal{&interfaceTypes}; + if (FFlag::LuauLowerBoundsCalculation) { normalize(returnType, interfaceTypes, ice); + if (FFlag::LuauForceExportSurfacesToBeNormal) + forceNormal.traverse(returnType); if (varargPack) + { normalize(*varargPack, interfaceTypes, ice); + if (FFlag::LuauForceExportSurfacesToBeNormal) + forceNormal.traverse(*varargPack); + } } - ForceNormal forceNormal{&interfaceTypes}; - if (exportedTypeBindings) { for (auto& [name, tf] : *exportedTypeBindings) @@ -142,11 +139,18 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) { normalize(tf.type, interfaceTypes, ice); - if (FFlag::LuauNormalizeFlagIsConservative) + // We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables + // won't be marked normal. If the types aren't normal by now, they never will be. + forceNormal.traverse(tf.type); + for (GenericTypeDefinition param : tf.typeParams) { - // We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables - // won't be marked normal. If the types aren't normal by now, they never will be. - forceNormal.traverse(tf.type); + forceNormal.traverse(param.ty); + + if (param.defaultValue) + { + normalize(*param.defaultValue, interfaceTypes, ice); + forceNormal.traverse(*param.defaultValue); + } } } } @@ -166,7 +170,12 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) { ty = clone(ty, interfaceTypes, cloneState); if (FFlag::LuauLowerBoundsCalculation) + { normalize(ty, interfaceTypes, ice); + + if (FFlag::LuauForceExportSurfacesToBeNormal) + forceNormal.traverse(ty); + } } freeze(internalTypes); @@ -179,10 +188,4 @@ ScopePtr Module::getModuleScope() const return scopes.front().second; } -Scope2* Module::getModuleScope2() const -{ - LUAU_ASSERT(!scope2s.empty()); - return scope2s.front().second.get(); -} - } // namespace Luau diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 8ce7f742..9ae3b404 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -13,8 +13,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); -LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); -LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); +LUAU_FASTFLAGVARIABLE(LuauFixNormalizationOfCyclicUnions, false); +LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau @@ -88,13 +88,7 @@ static bool areNormal_(const T& t, const std::unordered_set& seen, Intern if (count >= FInt::LuauNormalizeIterationLimit) ice.ice("Luau::areNormal hit iteration limit"); - if (FFlag::LuauNormalizeFlagIsConservative) - return ty->normal; - else - { - // The follow is here because a bound type may not be normal, but the bound type is normal. - return ty->normal || follow(ty)->normal || seen.find(asMutable(ty)) != seen.end(); - } + return ty->normal; }; return std::all_of(begin(t), end(t), isNormal); @@ -182,7 +176,6 @@ struct Normalize final : TypeVarVisitor { if (!ty->normal) asMutable(ty)->normal = true; - return false; } @@ -193,6 +186,20 @@ struct Normalize final : TypeVarVisitor return false; } + bool visit(TypeId ty, const UnknownTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool visit(TypeId ty, const NeverTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override { CHECK_ITERATION_LIMIT(false); @@ -327,13 +334,19 @@ struct Normalize final : TypeVarVisitor return false; UnionTypeVar* utv = &const_cast(utvRef); - std::vector options = std::move(utv->options); + + // TODO: Clip tempOptions and optionsRef when clipping FFlag::LuauFixNormalizationOfCyclicUnions + std::vector tempOptions; + if (!FFlag::LuauFixNormalizationOfCyclicUnions) + tempOptions = std::move(utv->options); + + std::vector& optionsRef = FFlag::LuauFixNormalizationOfCyclicUnions ? utv->options : tempOptions; // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId option : options) + for (TypeId option : optionsRef) traverse(option); - std::vector newOptions = normalizeUnion(options); + std::vector newOptions = normalizeUnion(optionsRef); const bool normal = areNormal(newOptions, seen, ice); @@ -358,51 +371,106 @@ struct Normalize final : TypeVarVisitor IntersectionTypeVar* itv = &const_cast(itvRef); - std::vector oldParts = std::move(itv->parts); - - for (TypeId part : oldParts) - traverse(part); - - std::vector tables; - for (TypeId part : oldParts) + if (FFlag::LuauFixNormalizationOfCyclicUnions) { - part = follow(part); - if (get(part)) - tables.push_back(part); - else - { - Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD - combineIntoIntersection(replacer, itv, part); - } - } + std::vector oldParts = itv->parts; + IntersectionTypeVar newIntersection; - // Don't allocate a new table if there's just one in the intersection. - if (tables.size() == 1) - itv->parts.push_back(tables[0]); - else if (!tables.empty()) - { - const TableTypeVar* first = get(tables[0]); - LUAU_ASSERT(first); + for (TypeId part : oldParts) + traverse(part); - TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); - TableTypeVar* ttv = getMutable(newTable); - for (TypeId part : tables) + std::vector tables; + for (TypeId part : oldParts) { - // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need - // to be rewritten to point at 'newTable' in the clone. - Replacer replacer{&arena, part, newTable}; - combineIntoTable(replacer, ttv, part); + part = follow(part); + if (get(part)) + tables.push_back(part); + else + { + Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD + combineIntoIntersection(replacer, &newIntersection, part); + } } - itv->parts.push_back(newTable); + // Don't allocate a new table if there's just one in the intersection. + if (tables.size() == 1) + newIntersection.parts.push_back(tables[0]); + else if (!tables.empty()) + { + const TableTypeVar* first = get(tables[0]); + LUAU_ASSERT(first); + + TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); + TableTypeVar* ttv = getMutable(newTable); + for (TypeId part : tables) + { + // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need + // to be rewritten to point at 'newTable' in the clone. + Replacer replacer{&arena, part, newTable}; + combineIntoTable(replacer, ttv, part); + } + + newIntersection.parts.push_back(newTable); + } + + itv->parts = std::move(newIntersection.parts); + + asMutable(ty)->normal = areNormal(itv->parts, seen, ice); + + if (itv->parts.size() == 1) + { + TypeId part = itv->parts[0]; + *asMutable(ty) = BoundTypeVar{part}; + } } - - asMutable(ty)->normal = areNormal(itv->parts, seen, ice); - - if (itv->parts.size() == 1) + else { - TypeId part = itv->parts[0]; - *asMutable(ty) = BoundTypeVar{part}; + std::vector oldParts = std::move(itv->parts); + + for (TypeId part : oldParts) + traverse(part); + + std::vector tables; + for (TypeId part : oldParts) + { + part = follow(part); + if (get(part)) + tables.push_back(part); + else + { + Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD + combineIntoIntersection(replacer, itv, part); + } + } + + // Don't allocate a new table if there's just one in the intersection. + if (tables.size() == 1) + itv->parts.push_back(tables[0]); + else if (!tables.empty()) + { + const TableTypeVar* first = get(tables[0]); + LUAU_ASSERT(first); + + TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); + TableTypeVar* ttv = getMutable(newTable); + for (TypeId part : tables) + { + // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need + // to be rewritten to point at 'newTable' in the clone. + Replacer replacer{&arena, part, newTable}; + combineIntoTable(replacer, ttv, part); + } + + itv->parts.push_back(newTable); + } + + asMutable(ty)->normal = areNormal(itv->parts, seen, ice); + + if (itv->parts.size() == 1) + { + TypeId part = itv->parts[0]; + *asMutable(ty) = BoundTypeVar{part}; + } } return false; @@ -416,7 +484,13 @@ struct Normalize final : TypeVarVisitor std::vector result; for (TypeId part : options) + { + // AnyTypeVar always win the battle no matter what we do, so we're done. + if (FFlag::LuauUnknownAndNeverType && get(follow(part))) + return {part}; + combineIntoUnion(result, part); + } return result; } @@ -427,7 +501,17 @@ struct Normalize final : TypeVarVisitor if (auto utv = get(ty)) { for (TypeId t : utv) + { + // AnyTypeVar always win the battle no matter what we do, so we're done. + if (FFlag::LuauUnknownAndNeverType && get(t)) + { + result = {t}; + return; + } + combineIntoUnion(result, t); + } + return; } @@ -561,6 +645,24 @@ struct Normalize final : TypeVarVisitor table->props.insert({propName, prop}); } + if (FFlag::LuauFixNormalizationOfCyclicUnions) + { + if (tyTable->indexer) + { + if (table->indexer) + { + table->indexer->indexType = combine(replacer, replacer.smartClone(tyTable->indexer->indexType), table->indexer->indexType); + table->indexer->indexResultType = + combine(replacer, replacer.smartClone(tyTable->indexer->indexResultType), table->indexer->indexResultType); + } + else + { + table->indexer = + TableIndexer{replacer.smartClone(tyTable->indexer->indexType), replacer.smartClone(tyTable->indexer->indexResultType)}; + } + } + } + table->state = combineTableStates(table->state, tyTable->state); table->level = max(table->level, tyTable->level); } @@ -571,8 +673,7 @@ struct Normalize final : TypeVarVisitor */ TypeId combine(Replacer& replacer, TypeId a, TypeId b) { - if (FFlag::LuauNormalizeCombineEqFix) - b = follow(b); + b = follow(b); if (FFlag::LuauNormalizeCombineTableFix && a == b) return a; @@ -592,7 +693,7 @@ struct Normalize final : TypeVarVisitor } else if (auto ttv = getMutable(a)) { - if (FFlag::LuauNormalizeCombineTableFix && !get(FFlag::LuauNormalizeCombineEqFix ? b : follow(b))) + if (FFlag::LuauNormalizeCombineTableFix && !get(b)) return arena.addType(IntersectionTypeVar{{a, b}}); combineIntoTable(replacer, ttv, b); return a; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 294c479d..03049cca 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -7,7 +7,6 @@ #include "Luau/TxnLog.h" #include "Luau/VisitTypeVar.h" -LUAU_FASTFLAG(LuauAlwaysQuantify); LUAU_FASTFLAG(DebugLuauSharedSelf) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false) @@ -16,13 +15,13 @@ namespace Luau { /// @return true if outer encloses inner -static bool subsumes(Scope2* outer, Scope2* inner) +static bool subsumes(Scope* outer, Scope* inner) { while (inner) { if (inner == outer) return true; - inner = inner->parent; + inner = inner->parent.get(); } return false; @@ -33,7 +32,7 @@ struct Quantifier final : TypeVarOnceVisitor TypeLevel level; std::vector generics; std::vector genericPacks; - Scope2* scope = nullptr; + Scope* scope = nullptr; bool seenGenericType = false; bool seenMutableType = false; @@ -43,20 +42,20 @@ struct Quantifier final : TypeVarOnceVisitor LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); } - explicit Quantifier(Scope2* scope) + explicit Quantifier(Scope* scope) : scope(scope) { LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); } /// @return true if outer encloses inner - bool subsumes(Scope2* outer, Scope2* inner) + bool subsumes(Scope* outer, Scope* inner) { while (inner) { if (inner == outer) return true; - inner = inner->parent; + inner = inner->parent.get(); } return false; @@ -203,36 +202,20 @@ void quantify(TypeId ty, TypeLevel level) FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); - if (FFlag::LuauAlwaysQuantify) - { - ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); - } - else - { - ftv->generics = q.generics; - ftv->genericPacks = q.genericPacks; - } + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); } } -void quantify(TypeId ty, Scope2* scope) +void quantify(TypeId ty, Scope* scope) { Quantifier q{scope}; q.traverse(ty); FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); - if (FFlag::LuauAlwaysQuantify) - { - ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); - } - else - { - ftv->generics = q.generics; - ftv->genericPacks = q.genericPacks; - } + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) ftv->hasNoGenerics = true; @@ -240,11 +223,11 @@ void quantify(TypeId ty, Scope2* scope) struct PureQuantifier : Substitution { - Scope2* scope; + Scope* scope; std::vector insertedGenerics; std::vector insertedGenericPacks; - PureQuantifier(TypeArena* arena, Scope2* scope) + PureQuantifier(TypeArena* arena, Scope* scope) : Substitution(TxnLog::empty(), arena) , scope(scope) { @@ -322,7 +305,7 @@ struct PureQuantifier : Substitution } }; -TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope) +TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope) { PureQuantifier quantifier{arena, scope}; std::optional result = quantifier.substitute(ty); diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 247a9dd6..bee16908 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -21,22 +21,6 @@ Scope::Scope(const ScopePtr& parent, int subLevel) level.subLevel = subLevel; } -std::optional Scope::lookup(const Symbol& name) -{ - Scope* scope = this; - - while (scope) - { - auto it = scope->bindings.find(name); - if (it != scope->bindings.end()) - return it->second.typeId; - - scope = scope->parent.get(); - } - - return std::nullopt; -} - std::optional Scope::lookupType(const Name& name) { const Scope* scope = this; @@ -121,48 +105,48 @@ std::optional Scope::linearSearchForBinding(const std::string& name, bo return std::nullopt; } -std::optional Scope2::lookup(Symbol sym) +std::optional Scope::lookup(Symbol sym) { - Scope2* s = this; + Scope* s = this; while (true) { auto it = s->bindings.find(sym); if (it != s->bindings.end()) - return it->second; + return it->second.typeId; if (s->parent) - s = s->parent; + s = s->parent.get(); else return std::nullopt; } } -std::optional Scope2::lookupTypeBinding(const Name& name) +std::optional Scope::lookupTypeBinding(const Name& name) { - Scope2* s = this; + Scope* s = this; while (s) { auto it = s->typeBindings.find(name); if (it != s->typeBindings.end()) return it->second; - s = s->parent; + s = s->parent.get(); } return std::nullopt; } -std::optional Scope2::lookupTypePackBinding(const Name& name) +std::optional Scope::lookupTypePackBinding(const Name& name) { - Scope2* s = this; + Scope* s = this; while (s) { auto it = s->typePackBindings.find(name); if (it != s->typePackBindings.end()) return it->second; - s = s->parent; + s = s->parent.get(); } return std::nullopt; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 9c4ce829..148c9ee2 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -8,8 +8,13 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauAnyificationMustClone, false) +LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false) LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) +LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false) +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false) namespace Luau { @@ -26,6 +31,14 @@ void Tarjan::visitChildren(TypeId ty, int index) if (const FunctionTypeVar* ftv = get(ty)) { + if (FFlag::LuauSubstitutionFixMissingFields) + { + for (TypeId generic : ftv->generics) + visitChild(generic); + for (TypePackId genericPack : ftv->genericPacks) + visitChild(genericPack); + } + visitChild(ftv->argTypes); visitChild(ftv->retTypes); } @@ -66,6 +79,25 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypeId part : ctv->parts) visitChild(part); } + else if (const PendingExpansionTypeVar* petv = get(ty)) + { + for (TypeId a : petv->typeArguments) + visitChild(a); + + for (TypePackId a : petv->packArguments) + visitChild(a); + } + else if (const ClassTypeVar* ctv = get(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + { + for (auto [name, prop] : ctv->props) + visitChild(prop.type); + + if (ctv->parent) + visitChild(*ctv->parent); + + if (ctv->metatable) + visitChild(*ctv->metatable); + } } void Tarjan::visitChildren(TypePackId tp, int index) @@ -154,7 +186,7 @@ TarjanResult Tarjan::loop() if (currEdge == -1) { ++childCount; - if (childLimit > 0 && childLimit < childCount) + if (childLimit > 0 && (FFlag::LuauUnknownAndNeverType ? childLimit <= childCount : childLimit < childCount)) return TarjanResult::TooManyChildren; stack.push_back(index); @@ -265,6 +297,24 @@ TarjanResult Tarjan::visitRoot(TypePackId tp) return loop(); } +void FindDirty::clearTarjan() +{ + dirty.clear(); + + typeToIndex.clear(); + packToIndex.clear(); + indexToType.clear(); + indexToPack.clear(); + + stack.clear(); + onStack.clear(); + lowlink.clear(); + + edgesTy.clear(); + edgesTp.clear(); + worklist.clear(); +} + bool FindDirty::getDirty(int index) { if (dirty.size() <= size_t(index)) @@ -328,16 +378,46 @@ std::optional Substitution::substitute(TypeId ty) { ty = log->follow(ty); + // clear algorithm state for reentrancy + if (FFlag::LuauSubstitutionReentrant) + clearTarjan(); + auto result = findDirty(ty); if (result != TarjanResult::Ok) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - if (!ignoreChildren(oldTy)) - replaceChildren(newTy); + { + if (FFlag::LuauSubstitutionReentrant) + { + if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) + { + replaceChildren(newTy); + replacedTypes.insert(newTy); + } + } + else + { + if (!ignoreChildren(oldTy)) + replaceChildren(newTy); + } + } for (auto [oldTp, newTp] : newPacks) - if (!ignoreChildren(oldTp)) - replaceChildren(newTp); + { + if (FFlag::LuauSubstitutionReentrant) + { + if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) + { + replaceChildren(newTp); + replacedTypePacks.insert(newTp); + } + } + else + { + if (!ignoreChildren(oldTp)) + replaceChildren(newTp); + } + } TypeId newTy = replace(ty); return newTy; } @@ -346,16 +426,46 @@ std::optional Substitution::substitute(TypePackId tp) { tp = log->follow(tp); + // clear algorithm state for reentrancy + if (FFlag::LuauSubstitutionReentrant) + clearTarjan(); + auto result = findDirty(tp); if (result != TarjanResult::Ok) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - if (!ignoreChildren(oldTy)) - replaceChildren(newTy); + { + if (FFlag::LuauSubstitutionReentrant) + { + if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) + { + replaceChildren(newTy); + replacedTypes.insert(newTy); + } + } + else + { + if (!ignoreChildren(oldTy)) + replaceChildren(newTy); + } + } for (auto [oldTp, newTp] : newPacks) - if (!ignoreChildren(oldTp)) - replaceChildren(newTp); + { + if (FFlag::LuauSubstitutionReentrant) + { + if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) + { + replaceChildren(newTp); + replacedTypePacks.insert(newTp); + } + } + else + { + if (!ignoreChildren(oldTp)) + replaceChildren(newTp); + } + } TypePackId newTp = replace(tp); return newTp; } @@ -383,6 +493,8 @@ TypePackId Substitution::clone(TypePackId tp) { VariadicTypePack clone; clone.ty = vtp->ty; + if (FFlag::LuauSubstitutionFixMissingFields) + clone.hidden = vtp->hidden; return addTypePack(std::move(clone)); } else @@ -393,6 +505,9 @@ void Substitution::foundDirty(TypeId ty) { ty = log->follow(ty); + if (FFlag::LuauSubstitutionReentrant && newTypes.contains(ty)) + return; + if (isDirty(ty)) newTypes[ty] = follow(clean(ty)); else @@ -403,6 +518,9 @@ void Substitution::foundDirty(TypePackId tp) { tp = log->follow(tp); + if (FFlag::LuauSubstitutionReentrant && newPacks.contains(tp)) + return; + if (isDirty(tp)) newPacks[tp] = follow(clean(tp)); else @@ -439,8 +557,19 @@ void Substitution::replaceChildren(TypeId ty) if (ignoreChildren(ty)) return; + if (FFlag::LuauAnyificationMustClone && ty->owningArena != arena) + return; + if (FunctionTypeVar* ftv = getMutable(ty)) { + if (FFlag::LuauSubstitutionFixMissingFields) + { + for (TypeId& generic : ftv->generics) + generic = replace(generic); + for (TypePackId& genericPack : ftv->genericPacks) + genericPack = replace(genericPack); + } + ftv->argTypes = replace(ftv->argTypes); ftv->retTypes = replace(ftv->retTypes); } @@ -481,6 +610,25 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& part : ctv->parts) part = replace(part); } + else if (PendingExpansionTypeVar* petv = getMutable(ty)) + { + for (TypeId& a : petv->typeArguments) + a = replace(a); + + for (TypePackId& a : petv->packArguments) + a = replace(a); + } + else if (ClassTypeVar* ctv = getMutable(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + { + for (auto& [name, prop] : ctv->props) + prop.type = replace(prop.type); + + if (ctv->parent) + ctv->parent = replace(*ctv->parent); + + if (ctv->metatable) + ctv->metatable = replace(*ctv->metatable); + } } void Substitution::replaceChildren(TypePackId tp) @@ -490,6 +638,9 @@ void Substitution::replaceChildren(TypePackId tp) if (ignoreChildren(tp)) return; + if (FFlag::LuauAnyificationMustClone && tp->owningArena != arena) + return; + if (TypePack* tpp = getMutable(tp)) { for (TypeId& tv : tpp->head) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index bbbad9d5..79d2e125 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,6 +11,8 @@ #include LUAU_FASTFLAG(LuauLowerBoundsCalculation) +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) /* * Prefix generic typenames with gen- @@ -18,7 +20,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) * Fair warning: Setting this will break a lot of Luau unit tests. */ LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) -LUAU_FASTFLAGVARIABLE(LuauToStringTableBracesNewlines, false) namespace Luau { @@ -231,6 +232,11 @@ struct StringifierState emit(std::to_string(i).c_str()); } + void emit(size_t i) + { + emit(std::to_string(i).c_str()); + } + void indent() { indentation += 4; @@ -277,7 +283,10 @@ struct TypeVarStringifier if (tv->ty.valueless_by_exception()) { state.result.error = true; - state.emit("< VALUELESS BY EXCEPTION >"); + if (FFlag::LuauSpecialTypesAsterisked) + state.emit("* VALUELESS BY EXCEPTION *"); + else + state.emit("< VALUELESS BY EXCEPTION >"); return; } @@ -406,6 +415,13 @@ struct TypeVarStringifier state.emit("*"); } + void operator()(TypeId ty, const PendingExpansionTypeVar& petv) + { + state.emit("*pending-expansion-"); + state.emit(petv.index); + state.emit("*"); + } + void operator()(TypeId, const PrimitiveTypeVar& ptv) { switch (ptv.type) @@ -453,7 +469,10 @@ struct TypeVarStringifier if (state.hasSeen(&ftv)) { state.result.cycle = true; - state.emit(""); + if (FFlag::LuauSpecialTypesAsterisked) + state.emit("*CYCLE*"); + else + state.emit(""); return; } @@ -561,7 +580,10 @@ struct TypeVarStringifier if (state.hasSeen(&ttv)) { state.result.cycle = true; - state.emit(""); + if (FFlag::LuauSpecialTypesAsterisked) + state.emit("*CYCLE*"); + else + state.emit(""); return; } @@ -571,54 +593,22 @@ struct TypeVarStringifier { case TableState::Sealed: state.result.invalid = true; - if (FFlag::LuauToStringTableBracesNewlines) - { - openbrace = "{|"; - closedbrace = "|}"; - } - else - { - openbrace = "{| "; - closedbrace = " |}"; - } + openbrace = "{|"; + closedbrace = "|}"; break; case TableState::Unsealed: - if (FFlag::LuauToStringTableBracesNewlines) - { - openbrace = "{"; - closedbrace = "}"; - } - else - { - openbrace = "{ "; - closedbrace = " }"; - } + openbrace = "{"; + closedbrace = "}"; break; case TableState::Free: state.result.invalid = true; - if (FFlag::LuauToStringTableBracesNewlines) - { - openbrace = "{-"; - closedbrace = "-}"; - } - else - { - openbrace = "{- "; - closedbrace = " -}"; - } + openbrace = "{-"; + closedbrace = "-}"; break; case TableState::Generic: state.result.invalid = true; - if (FFlag::LuauToStringTableBracesNewlines) - { - openbrace = "{+"; - closedbrace = "+}"; - } - else - { - openbrace = "{+ "; - closedbrace = " +}"; - } + openbrace = "{+"; + closedbrace = "+}"; break; } @@ -637,8 +627,7 @@ struct TypeVarStringifier bool comma = false; if (ttv.indexer) { - if (FFlag::LuauToStringTableBracesNewlines) - state.newline(); + state.newline(); state.emit("["); stringify(ttv.indexer->indexType); state.emit("]: "); @@ -655,10 +644,8 @@ struct TypeVarStringifier state.emit(","); state.newline(); } - else if (FFlag::LuauToStringTableBracesNewlines) - { + else state.newline(); - } size_t length = state.result.name.length() - oldLength; @@ -685,13 +672,10 @@ struct TypeVarStringifier } state.dedent(); - if (FFlag::LuauToStringTableBracesNewlines) - { - if (comma) - state.newline(); - else - state.emit(" "); - } + if (comma) + state.newline(); + else + state.emit(" "); state.emit(closedbrace); state.unsee(&ttv); @@ -700,6 +684,12 @@ struct TypeVarStringifier void operator()(TypeId, const MetatableTypeVar& mtv) { state.result.invalid = true; + if (!state.exhaustive && mtv.syntheticName) + { + state.emit(*mtv.syntheticName); + return; + } + state.emit("{ @metatable "); stringify(mtv.metatable); state.emit(","); @@ -723,7 +713,10 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - state.emit(""); + if (FFlag::LuauSpecialTypesAsterisked) + state.emit("*CYCLE*"); + else + state.emit(""); return; } @@ -790,7 +783,10 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - state.emit(""); + if (FFlag::LuauSpecialTypesAsterisked) + state.emit("*CYCLE*"); + else + state.emit(""); return; } @@ -835,7 +831,10 @@ struct TypeVarStringifier void operator()(TypeId, const ErrorTypeVar& tv) { state.result.error = true; - state.emit("*unknown*"); + if (FFlag::LuauSpecialTypesAsterisked) + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); + else + state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); } void operator()(TypeId, const LazyTypeVar& ltv) @@ -844,7 +843,16 @@ struct TypeVarStringifier state.emit("lazy?"); } -}; // namespace + void operator()(TypeId, const UnknownTypeVar& ttv) + { + state.emit("unknown"); + } + + void operator()(TypeId, const NeverTypeVar& ttv) + { + state.emit("never"); + } +}; struct TypePackStringifier { @@ -880,7 +888,10 @@ struct TypePackStringifier if (tp->ty.valueless_by_exception()) { state.result.error = true; - state.emit("< VALUELESS TP BY EXCEPTION >"); + if (FFlag::LuauSpecialTypesAsterisked) + state.emit("* VALUELESS TP BY EXCEPTION *"); + else + state.emit("< VALUELESS TP BY EXCEPTION >"); return; } @@ -904,7 +915,10 @@ struct TypePackStringifier if (state.hasSeen(&tp)) { state.result.cycle = true; - state.emit(""); + if (FFlag::LuauSpecialTypesAsterisked) + state.emit("*CYCLETP*"); + else + state.emit(""); return; } @@ -949,14 +963,22 @@ struct TypePackStringifier void operator()(TypePackId, const Unifiable::Error& error) { state.result.error = true; - state.emit("*unknown*"); + if (FFlag::LuauSpecialTypesAsterisked) + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); + else + state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); } void operator()(TypePackId, const VariadicTypePack& pack) { state.emit("..."); if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) - state.emit(""); + { + if (FFlag::LuauSpecialTypesAsterisked) + state.emit("*hidden*"); + else + state.emit(""); + } stringify(pack.ty); } @@ -1151,7 +1173,11 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) { result.truncated = true; - result.name += "... "; + + if (FFlag::LuauSpecialTypesAsterisked) + result.name += "... *TRUNCATED*"; + else + result.name += "... "; } return result; @@ -1222,7 +1248,12 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) } if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) - result.name += "... "; + { + if (FFlag::LuauSpecialTypesAsterisked) + result.name += "... *TRUNCATED*"; + else + result.name += "... "; + } return result; } @@ -1440,6 +1471,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) opts.nameMap = std::move(namedStr.nameMap); return "@name(" + namedStr.name + ") = " + c.name; } + else if constexpr (std::is_same_v) + { + ToStringResult targetStr = toStringDetailed(c.target, opts); + opts.nameMap = std::move(targetStr.nameMap); + return "expand " + targetStr.name; + } else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 1577bd63..9feff1c0 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -205,20 +205,6 @@ struct Printer } } - void visualizeWithSelf(AstExpr& expr, bool self) - { - if (!self) - return visualize(expr); - - AstExprIndexName* func = expr.as(); - LUAU_ASSERT(func); - - visualize(*func->expr); - writer.symbol(":"); - advance(func->indexLocation.begin); - writer.identifier(func->index.value); - } - void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg) { advance(annotation.location.begin); @@ -366,7 +352,7 @@ struct Printer } else if (const auto& a = expr.as()) { - visualizeWithSelf(*a->func, a->self); + visualize(*a->func); writer.symbol("("); bool first = true; @@ -385,7 +371,7 @@ struct Printer else if (const auto& a = expr.as()) { visualize(*a->expr); - writer.symbol("."); + writer.symbol(std::string(1, a->op)); writer.write(a->index.value); } else if (const auto& a = expr.as()) @@ -766,7 +752,7 @@ struct Printer else if (const auto& a = program.as()) { writer.keyword("function"); - visualizeWithSelf(*a->name, a->func->self != nullptr); + visualize(*a->name); visualizeFunctionBody(*a->func); } else if (const auto& a = program.as()) diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 4c6d54e0..b3f60d30 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,7 +7,7 @@ #include #include -LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) +LUAU_FASTFLAG(LuauUnknownAndNeverType) namespace Luau { @@ -81,34 +81,10 @@ void TxnLog::concat(TxnLog rhs) void TxnLog::commit() { for (auto& [ty, rep] : typeVarChanges) - { - if (FFlag::LuauNonCopyableTypeVarFields) - { - asMutable(ty)->reassign(rep.get()->pending); - } - else - { - TypeArena* owningArena = ty->owningArena; - TypeVar* mtv = asMutable(ty); - *mtv = rep.get()->pending; - mtv->owningArena = owningArena; - } - } + asMutable(ty)->reassign(rep.get()->pending); for (auto& [tp, rep] : typePackChanges) - { - if (FFlag::LuauNonCopyableTypeVarFields) - { - asMutable(tp)->reassign(rep.get()->pending); - } - else - { - TypeArena* owningArena = tp->owningArena; - TypePackVar* mpv = asMutable(tp); - *mpv = rep.get()->pending; - mpv->owningArena = owningArena; - } - } + asMutable(tp)->reassign(rep.get()->pending); clear(); } @@ -196,9 +172,7 @@ PendingType* TxnLog::queue(TypeId ty) if (!pending) { pending = std::make_unique(*ty); - - if (FFlag::LuauNonCopyableTypeVarFields) - pending->pending.owningArena = nullptr; + pending->pending.owningArena = nullptr; } return pending.get(); @@ -214,9 +188,7 @@ PendingTypePack* TxnLog::queue(TypePackId tp) if (!pending) { pending = std::make_unique(*tp); - - if (FFlag::LuauNonCopyableTypeVarFields) - pending->pending.owningArena = nullptr; + pending->pending.owningArena = nullptr; } return pending.get(); @@ -255,24 +227,14 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) { PendingType* newTy = queue(ty); - - if (FFlag::LuauNonCopyableTypeVarFields) - newTy->pending.reassign(replacement); - else - newTy->pending = replacement; - + newTy->pending.reassign(replacement); return newTy; } PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) { PendingTypePack* newTp = queue(tp); - - if (FFlag::LuauNonCopyableTypeVarFields) - newTp->pending.reassign(replacement); - else - newTp->pending = replacement; - + newTp->pending.reassign(replacement); return newTp; } @@ -289,7 +251,7 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { - LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(get(ty) || get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); if (FreeTypeVar* ftv = Luau::getMutable(newTy)) @@ -305,6 +267,11 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { ftv->level = newLevel; } + else if (ConstrainedTypeVar* ctv = Luau::getMutable(newTy)) + { + if (FFlag::LuauUnknownAndNeverType) + ctv->level = newLevel; + } return newTy; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 6cca7127..f21a4fa9 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -99,6 +99,11 @@ public: return allocator->alloc(Location(), std::nullopt, AstName("*blocked*")); } + AstType* operator()(const PendingExpansionTypeVar& petv) + { + return allocator->alloc(Location(), std::nullopt, AstName("*pending-expansion*")); + } + AstType* operator()(const ConstrainedTypeVar& ctv) { AstArray types; @@ -335,6 +340,14 @@ public: { return allocator->alloc(Location(), std::nullopt, AstName("")); } + AstType* operator()(const UnknownTypeVar& ttv) + { + return allocator->alloc(Location(), std::nullopt, AstName{"unknown"}); + } + AstType* operator()(const NeverTypeVar& ttv) + { + return allocator->alloc(Location(), std::nullopt, AstName{"never"}); + } private: Allocator* allocator; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 30e498af..53b069cf 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -67,8 +67,18 @@ struct TypeChecker2 : public AstVisitor return follow(*ty); } + TypePackId lookupPackAnnotation(AstTypePack* annotation) + { + TypePackId* tp = module->astResolvedTypePacks.find(annotation); + LUAU_ASSERT(tp); + return follow(*tp); + } + TypePackId reconstructPack(AstArray exprs, TypeArena& arena) { + if (exprs.size == 0) + return arena.addTypePack(TypePack{{}, std::nullopt}); + std::vector head; for (size_t i = 0; i < exprs.size - 1; ++i) @@ -80,14 +90,14 @@ struct TypeChecker2 : public AstVisitor return arena.addTypePack(TypePack{head, tail}); } - Scope2* findInnermostScope(Location location) + Scope* findInnermostScope(Location location) { - Scope2* bestScope = module->getModuleScope2(); - Location bestLocation = module->scope2s[0].first; + Scope* bestScope = module->getModuleScope().get(); + Location bestLocation = module->scopes[0].first; - for (size_t i = 0; i < module->scope2s.size(); ++i) + for (size_t i = 0; i < module->scopes.size(); ++i) { - auto& [scopeBounds, scope] = module->scope2s[i]; + auto& [scopeBounds, scope] = module->scopes[i]; if (scopeBounds.encloses(location)) { if (scopeBounds.begin > bestLocation.begin || scopeBounds.end < bestLocation.end) @@ -181,7 +191,7 @@ struct TypeChecker2 : public AstVisitor bool visit(AstStatReturn* ret) override { - Scope2* scope = findInnermostScope(ret->location); + Scope* scope = findInnermostScope(ret->location); TypePackId expectedRetType = scope->returnType; TypeArena arena; @@ -322,10 +332,13 @@ struct TypeChecker2 : public AstVisitor { pack = follow(pack); - while (auto tp = get(pack)) + while (true) { - if (tp->head.empty() && tp->tail) + auto tp = get(pack); + if (tp && tp->head.empty() && tp->tail) pack = *tp->tail; + else + break; } if (auto ty = first(pack)) @@ -356,13 +369,154 @@ struct TypeChecker2 : public AstVisitor bool visit(AstTypeReference* ty) override { - Scope2* scope = findInnermostScope(ty->location); + Scope* scope = findInnermostScope(ty->location); + LUAU_ASSERT(scope); // TODO: Imported types - // TODO: Generic types - if (!scope->lookupTypeBinding(ty->name.value)) + + std::optional alias = scope->lookupTypeBinding(ty->name.value); + + if (alias.has_value()) { - reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); + size_t typesRequired = alias->typeParams.size(); + size_t packsRequired = alias->typePackParams.size(); + + bool hasDefaultTypes = std::any_of(alias->typeParams.begin(), alias->typeParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + + bool hasDefaultPacks = std::any_of(alias->typePackParams.begin(), alias->typePackParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + + if (!ty->hasParameterList) + { + if ((!alias->typeParams.empty() && !hasDefaultTypes) || (!alias->typePackParams.empty() && !hasDefaultPacks)) + { + reportError(GenericError{"Type parameter list is required"}, ty->location); + } + } + + size_t typesProvided = 0; + size_t extraTypes = 0; + size_t packsProvided = 0; + + for (const AstTypeOrPack& p : ty->parameters) + { + if (p.type) + { + if (packsProvided != 0) + { + reportError(GenericError{"Type parameters must come before type pack parameters"}, ty->location); + } + + if (typesProvided < typesRequired) + { + typesProvided += 1; + } + else + { + extraTypes += 1; + } + } + else if (p.typePack) + { + TypePackId tp = lookupPackAnnotation(p.typePack); + + if (typesProvided < typesRequired && size(tp) == 1 && finite(tp) && first(tp)) + { + typesProvided += 1; + } + else + { + packsProvided += 1; + } + } + } + + if (extraTypes != 0 && packsProvided == 0) + { + packsProvided += 1; + } + + for (size_t i = typesProvided; i < typesRequired; ++i) + { + if (alias->typeParams[i].defaultValue) + { + typesProvided += 1; + } + } + + for (size_t i = packsProvided; i < packsProvided; ++i) + { + if (alias->typePackParams[i].defaultValue) + { + packsProvided += 1; + } + } + + if (extraTypes == 0 && packsProvided + 1 == packsRequired) + { + packsProvided += 1; + } + + if (typesProvided != typesRequired || packsProvided != packsRequired) + { + reportError(IncorrectGenericParameterCount{ + /* name */ ty->name.value, + /* typeFun */ *alias, + /* actualParameters */ typesProvided, + /* actualPackParameters */ packsProvided, + }, + ty->location); + } + } + else + { + if (scope->lookupTypePackBinding(ty->name.value)) + { + reportError( + SwappedGenericTypeParameter{ + ty->name.value, + SwappedGenericTypeParameter::Kind::Type, + }, + ty->location); + } + else + { + reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); + } + } + + return true; + } + + bool visit(AstTypePack*) override + { + return true; + } + + bool visit(AstTypePackGeneric* tp) override + { + Scope* scope = findInnermostScope(tp->location); + LUAU_ASSERT(scope); + + std::optional alias = scope->lookupTypePackBinding(tp->genericName.value); + if (!alias.has_value()) + { + if (scope->lookupTypeBinding(tp->genericName.value)) + { + reportError( + SwappedGenericTypeParameter{ + tp->genericName.value, + SwappedGenericTypeParameter::Kind::Pack, + }, + tp->location); + } + else + { + reportError(UnknownSymbol{tp->genericName.value, UnknownSymbol::Context::Type}, tp->location); + } } return true; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index d9486a4f..bdda195c 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" +#include "Luau/ApplyTypeFunction.h" #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Instantiation.h" @@ -31,20 +32,22 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTFLAGVARIABLE(LuauExpectedTableUnionIndexerType, false) +LUAU_FASTFLAGVARIABLE(LuauIndexSilenceErrors, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) -LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) +LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix3, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. -LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) -LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false); -LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) +LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) -LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) -LUAU_FASTFLAGVARIABLE(LuauCheckLenMT, false) +LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false) +LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) +LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) +LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) +LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) namespace Luau { @@ -258,7 +261,11 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(getSingletonTypes().booleanType) , threadType(getSingletonTypes().threadType) , anyType(getSingletonTypes().anyType) + , unknownType(getSingletonTypes().unknownType) + , neverType(getSingletonTypes().neverType) , anyTypePack(getSingletonTypes().anyTypePack) + , neverTypePack(getSingletonTypes().neverTypePack) + , uninhabitableTypePack(getSingletonTypes().uninhabitableTypePack) , duplicateTypeAliases{{false, {}}} { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -269,6 +276,11 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan globalScope->exportedTypeBindings["string"] = TypeFun{{}, stringType}; globalScope->exportedTypeBindings["boolean"] = TypeFun{{}, booleanType}; globalScope->exportedTypeBindings["thread"] = TypeFun{{}, threadType}; + if (FFlag::LuauUnknownAndNeverType) + { + globalScope->exportedTypeBindings["unknown"] = TypeFun{{}, unknownType}; + globalScope->exportedTypeBindings["never"] = TypeFun{{}, neverType}; + } } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -456,6 +468,59 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } } +struct InplaceDemoter : TypeVarOnceVisitor +{ + TypeLevel newLevel; + TypeArena* arena; + + InplaceDemoter(TypeLevel level, TypeArena* arena) + : newLevel(level) + , arena(arena) + { + } + + bool demote(TypeId ty) + { + if (auto level = getMutableLevel(ty)) + { + if (level->subsumesStrict(newLevel)) + { + *level = newLevel; + return true; + } + } + + return false; + } + + bool visit(TypeId ty, const BoundTypeVar& btyRef) override + { + return true; + } + + bool visit(TypeId ty) override + { + if (ty->owningArena != arena) + return false; + return demote(ty); + } + + bool visit(TypePackId tp, const FreeTypePack& ftpRef) override + { + if (tp->owningArena != arena) + return false; + + FreeTypePack* ftp = &const_cast(ftpRef); + if (ftp->level.subsumesStrict(newLevel)) + { + ftp->level = newLevel; + return true; + } + + return false; + } +}; + void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) { int subLevel = 0; @@ -559,7 +624,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A tablify(baseTy); if (!fun->func->self) - expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, false); + expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, /* addErrors= */ false); else if (auto ttv = getMutableTableType(baseTy)) { if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy) @@ -579,7 +644,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A if (auto name = fun->name->as()) { TypeId exprTy = checkExpr(scope, *name->expr).type; - expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false); } } } @@ -634,15 +699,8 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - if (FFlag::LuauNonCopyableTypeVarFields) - { - TypeVar* mty = asMutable(follow(type)); - mty->reassign(*errorRecoveryType(anyType)); - } - else - { - *asMutable(type) = *errorRecoveryType(anyType); - } + TypeVar* mty = asMutable(follow(type)); + mty->reassign(*errorRecoveryType(anyType)); reportError(TypeError{typealias->location, OccursCheckFailed{}}); } @@ -830,7 +888,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; - if (FFlag::LuauReturnTypeInferenceInNonstrict ? FFlag::LuauLowerBoundsCalculation : useConstrainedIntersections()) + if (useConstrainedIntersections()) { unifyLowerBound(retPack, scope->returnType, demoter.demotedLevel(scope->level), return_.location); return; @@ -1206,7 +1264,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } - if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location, /* addErrors= */ true)) { // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments @@ -1232,6 +1290,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) for (size_t i = 2; i < varTypes.size(); ++i) unify(nilType, varTypes[i], forin.location); } + else if (isNonstrictMode()) + { + for (TypeId var : varTypes) + unify(anyType, var, forin.location); + } else { TypeId varTy = errorRecoveryType(loopScope); @@ -1253,7 +1316,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) for (TypeId var : varTypes) unify(varTy, var, forin.location); - if (!get(iterTy) && !get(iterTy) && !get(iterTy)) + if (!get(iterTy) && !get(iterTy) && !get(iterTy) && !get(iterTy)) reportError(firstValue->location, CannotCallNonFunction{iterTy}); return check(loopScope, *forin.body); @@ -1325,12 +1388,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco // If in nonstrict mode and allowing redefinition of global function, restore the previous definition type // in case this function has a differing signature. The signature discrepancy will be caught in checkBlock. if (previouslyDefined) - { - if (FFlag::LuauReturnTypeInferenceInNonstrict && FFlag::LuauLowerBoundsCalculation) - quantify(funScope, ty, exprName->location); - globalBindings[name] = oldBinding; - } else globalBindings[name] = {quantify(funScope, ty, exprName->location), exprName->location}; @@ -1350,7 +1408,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); - if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false)) + if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false)) { if (ttv || isTableIntersection(exprTy)) reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); @@ -1376,6 +1434,12 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); + if (FFlag::LuauUnknownAndNeverType) + { + InplaceDemoter demoter{funScope->level, ¤tModule->internalTypes}; + demoter.traverse(ty); + } + if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } @@ -1600,7 +1664,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); - if (FFlag::LuauSelfCallAutocompleteFix2) + if (FFlag::LuauSelfCallAutocompleteFix3) ftv->hasSelf = true; } } @@ -1729,7 +1793,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - result = checkExpr(scope, *a); + result = checkExpr(scope, *a, FFlag::LuauBinaryNeedsExpectedTypesToo ? expectedType : std::nullopt); else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1851,41 +1915,56 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp lhsType = stripFromNilAndReport(lhsType, expr.expr->location); - if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) + if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, /* addErrors= */ true)) return {*ty}; return {errorRecoveryType(scope)}; } -std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) +std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors) { ErrorVec errors; auto result = Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); - reportErrors(errors); + if (!FFlag::LuauIndexSilenceErrors || addErrors) + reportErrors(errors); return result; } -std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location) +std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors) { ErrorVec errors; auto result = Luau::findMetatableEntry(errors, type, entry, location); - reportErrors(errors); + if (!FFlag::LuauIndexSilenceErrors || addErrors) + reportErrors(errors); return result; } std::optional TypeChecker::getIndexTypeFromType( - const ScopePtr& scope, TypeId type, const std::string& name, const Location& location, bool addErrors) + const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) +{ + size_t errorCount = currentModule->errors.size(); + + std::optional result = getIndexTypeFromTypeImpl(scope, type, name, location, addErrors); + + if (FFlag::LuauIndexSilenceErrors && !addErrors) + LUAU_ASSERT(errorCount == currentModule->errors.size()); + + return result; +} + +std::optional TypeChecker::getIndexTypeFromTypeImpl( + const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) { type = follow(type); - if (get(type) || get(type)) + if (get(type) || get(type) || get(type)) return type; tablify(type); if (isString(type)) { - std::optional mtIndex = findMetatableEntry(stringType, "__index", location); + std::optional mtIndex = findMetatableEntry(stringType, "__index", location, addErrors); LUAU_ASSERT(mtIndex); type = *mtIndex; } @@ -1919,7 +1998,7 @@ std::optional TypeChecker::getIndexTypeFromType( return result; } - if (auto found = findTablePropertyRespectingMeta(type, name, location)) + if (auto found = findTablePropertyRespectingMeta(type, name, location, addErrors)) return *found; } else if (const ClassTypeVar* cls = get(type)) @@ -1941,7 +2020,7 @@ std::optional TypeChecker::getIndexTypeFromType( if (get(follow(t))) return t; - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false)) goodOptions.push_back(*ty); else badOptions.push_back(t); @@ -1972,6 +2051,8 @@ std::optional TypeChecker::getIndexTypeFromType( else { std::vector result = reduceUnion(goodOptions); + if (FFlag::LuauUnknownAndNeverType && result.empty()) + return neverType; if (result.size() == 1) return result[0]; @@ -1987,7 +2068,7 @@ std::optional TypeChecker::getIndexTypeFromType( { RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false)) parts.push_back(*ty); } @@ -2017,36 +2098,24 @@ std::vector TypeChecker::reduceUnion(const std::vector& types) for (TypeId t : types) { t = follow(t); + if (get(t)) + continue; + if (get(t) || get(t)) return {t}; if (const UnionTypeVar* utv = get(t)) { - if (FFlag::LuauReduceUnionRecursion) + for (TypeId ty : utv) { - for (TypeId ty : utv) - { - if (FFlag::LuauNormalizeFlagIsConservative) - ty = follow(ty); - if (get(ty) || get(ty)) - return {ty}; + ty = follow(ty); + if (get(ty)) + continue; + if (get(ty) || get(ty)) + return {ty}; - if (result.end() == std::find(result.begin(), result.end(), ty)) - result.push_back(ty); - } - } - else - { - std::vector r = reduceUnion(utv->options); - for (TypeId ty : r) - { - ty = follow(ty); - if (get(ty) || get(ty)) - return {ty}; - - if (std::find(result.begin(), result.end(), ty) == result.end()) - result.push_back(ty); - } + if (result.end() == std::find(result.begin(), result.end(), ty)) + result.push_back(ty); } } else if (std::find(result.begin(), result.end(), t) == result.end()) @@ -2275,9 +2344,16 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp { std::vector expectedResultTypes; for (TypeId expectedOption : expectedUnion) + { if (const TableTypeVar* ttv = get(follow(expectedOption))) + { if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) expectedResultTypes.push_back(prop->second.type); + else if (FFlag::LuauExpectedTableUnionIndexerType && ttv->indexer && maybeString(ttv->indexer->indexType)) + expectedResultTypes.push_back(ttv->indexer->indexResultType); + } + } + if (expectedResultTypes.size() == 1) expectedResultType = expectedResultTypes[0]; else if (expectedResultTypes.size() > 1) @@ -2314,14 +2390,14 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {booleanType, {NotPredicate{std::move(result.predicates)}}}; case AstExprUnary::Minus: { - const bool operandIsAny = get(operandType) || get(operandType); + const bool operandIsAny = get(operandType) || get(operandType) || get(operandType); if (operandIsAny) return {operandType}; if (typeCouldHaveMetatable(operandType)) { - if (auto fnt = findMetatableEntry(operandType, "__unm", expr.location)) + if (auto fnt = findMetatableEntry(operandType, "__unm", expr.location, /* addErrors= */ true)) { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); @@ -2355,14 +2431,21 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp operandType = stripFromNilAndReport(operandType, expr.location); - if (get(operandType)) - return {errorRecoveryType(scope)}; + // # operator is guaranteed to return number + if ((FFlag::LuauNeverTypesAndOperatorsInference && get(operandType)) || get(operandType) || + get(operandType)) + { + if (FFlag::LuauNeverTypesAndOperatorsInference) + return {numberType}; + else + return {!FFlag::LuauUnknownAndNeverType ? errorRecoveryType(scope) : operandType}; + } DenseHashSet seen{nullptr}; - if (FFlag::LuauCheckLenMT && typeCouldHaveMetatable(operandType)) + if (typeCouldHaveMetatable(operandType)) { - if (auto fnt = findMetatableEntry(operandType, "__len", expr.location)) + if (auto fnt = findMetatableEntry(operandType, "__len", expr.location, /* addErrors= */ true)) { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); @@ -2433,6 +2516,9 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b return a; std::vector types = reduceUnion({a, b}); + if (FFlag::LuauUnknownAndNeverType && types.empty()) + return neverType; + if (types.size() == 1) return types[0]; @@ -2485,7 +2571,7 @@ TypeId TypeChecker::checkRelationalOperation( // If we know nothing at all about the lhs type, we can usually say nothing about the result. // The notable exception to this is the equality and inequality operators, which always produce a boolean. - const bool lhsIsAny = get(lhsType) || get(lhsType); + const bool lhsIsAny = get(lhsType) || get(lhsType) || get(lhsType); // Peephole check for `cond and a or b -> type(a)|type(b)` // TODO: Kill this when singleton types arrive. :( @@ -2508,7 +2594,7 @@ TypeId TypeChecker::checkRelationalOperation( if (isNonstrictMode() && (isNil(lhsType) || isNil(rhsType))) return booleanType; - const bool rhsIsAny = get(rhsType) || get(rhsType); + const bool rhsIsAny = get(rhsType) || get(rhsType) || get(rhsType); if (lhsIsAny || rhsIsAny) return booleanType; @@ -2519,6 +2605,13 @@ TypeId TypeChecker::checkRelationalOperation( case AstExprBinary::CompareGe: case AstExprBinary::CompareLe: { + if (FFlag::LuauNeverTypesAndOperatorsInference) + { + // If one of the operand is never, it doesn't make sense to unify these. + if (get(lhsType) || get(rhsType)) + return booleanType; + } + /* Subtlety here: * We need to do this unification first, but there are situations where we don't actually want to * report any problems that might have been surfaced as a result of this step because we might already @@ -2596,7 +2689,7 @@ TypeId TypeChecker::checkRelationalOperation( if (leftMetatable) { - std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); + std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true); if (metamethod) { if (const FunctionTypeVar* ftv = get(*metamethod)) @@ -2696,8 +2789,10 @@ TypeId TypeChecker::checkBinaryOperation( // If we know nothing at all about the lhs type, we can usually say nothing about the result. // The notable exception to this is the equality and inequality operators, which always produce a boolean. - const bool lhsIsAny = get(lhsType) || get(lhsType); - const bool rhsIsAny = get(rhsType) || get(rhsType); + const bool lhsIsAny = get(lhsType) || get(lhsType) || + (FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get(lhsType)); + const bool rhsIsAny = get(rhsType) || get(rhsType) || + (FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get(rhsType)); if (lhsIsAny) return lhsType; @@ -2757,9 +2852,9 @@ TypeId TypeChecker::checkBinaryOperation( }; std::string op = opToMetaTableEntry(expr.op); - if (auto fnt = findMetatableEntry(lhsType, op, expr.location)) + if (auto fnt = findMetatableEntry(lhsType, op, expr.location, /* addErrors= */ true)) return checkMetatableCall(*fnt, lhsType, rhsType); - if (auto fnt = findMetatableEntry(rhsType, op, expr.location)) + if (auto fnt = findMetatableEntry(rhsType, op, expr.location, /* addErrors= */ true)) { // Note the intentionally reversed arguments here. return checkMetatableCall(*fnt, rhsType, lhsType); @@ -2793,27 +2888,27 @@ TypeId TypeChecker::checkBinaryOperation( } } -WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional expectedType) { if (expr.op == AstExprBinary::And) { - auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left, expectedType); ScopePtr innerScope = childScope(scope, expr.location); resolve(lhsPredicates, innerScope, true); - auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right, expectedType); return {checkBinaryOperation(scope, expr, lhsTy, rhsTy), {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::Or) { - auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left, expectedType); ScopePtr innerScope = childScope(scope, expr.location); resolve(lhsPredicates, innerScope, false); - auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right, expectedType); // Because of C++, I'm not sure if lhsPredicates was not moved out by the time we call checkBinaryOperation. TypeId result = checkBinaryOperation(scope, expr, lhsTy, rhsTy, lhsPredicates); @@ -2824,6 +2919,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; + // For these, passing expectedType is worse than simply forcing them, because their implementation + // may inadvertently check if expectedTypes exist first and use it, instead of forceSingleton first. WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); @@ -2842,6 +2939,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } else { + // Expected types are not useful for other binary operators. WithPredicate lhs = checkExpr(scope, *expr.left); WithPredicate rhs = checkExpr(scope, *expr.right); @@ -2896,6 +2994,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {trueType.type}; std::vector types = reduceUnion({trueType.type, falseType.type}); + if (FFlag::LuauUnknownAndNeverType && types.empty()) + return {neverType}; return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; } @@ -2927,7 +3027,10 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& exp TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) { if (std::optional ty = scope->lookup(expr.local)) - return *ty; + { + ty = follow(*ty); + return get(*ty) ? unknownType : *ty; + } reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); return errorRecoveryType(scope); @@ -2941,7 +3044,10 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGloba const auto it = moduleScope->bindings.find(expr.name); if (it != moduleScope->bindings.end()) - return it->second.typeId; + { + TypeId ty = follow(it->second.typeId); + return get(ty) ? unknownType : ty; + } TypeId result = freshType(scope); Binding& binding = moduleScope->bindings[expr.name]; @@ -2962,6 +3068,9 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex if (get(lhs) || get(lhs)) return lhs; + if (get(lhs)) + return unknownType; + tablify(lhs); Name name = expr.index.value; @@ -3023,7 +3132,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (get(lhs)) { - if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, false)) + if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, /* addErrors= */ false)) return *ty; // If intersection has a table part, report that it cannot be extended just as a sealed table @@ -3050,6 +3159,9 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex if (get(exprType) || get(exprType)) return exprType; + if (get(exprType)) + return unknownType; + AstExprConstantString* value = expr.index->as(); if (value) @@ -3156,7 +3268,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T if (!ttv || ttv->state == TableState::Sealed) { - if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, false)) + if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, /* addErrors= */ false)) return *ty; return errorRecoveryType(scope); @@ -3228,9 +3340,12 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& } } - // We do not infer type binders, so if a generic function is required we do not propagate - if (expectedFunctionType && !(expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty())) - expectedFunctionType = nullptr; + if (!FFlag::LuauCheckGenericHOFTypes) + { + // We do not infer type binders, so if a generic function is required we do not propagate + if (expectedFunctionType && !(expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty())) + expectedFunctionType = nullptr; + } } auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); @@ -3238,9 +3353,10 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& TypePackId retPack; if (expr.returnAnnotation) retPack = resolveTypePack(funScope, *expr.returnAnnotation); - else if (FFlag::LuauReturnTypeInferenceInNonstrict ? (!FFlag::LuauLowerBoundsCalculation && isNonstrictMode()) : isNonstrictMode()) + else if (isNonstrictMode()) retPack = anyTypePack; - else if (expectedFunctionType) + else if (expectedFunctionType && + (!FFlag::LuauCheckGenericHOFTypes || (expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty()))) { auto [head, tail] = flatten(expectedFunctionType->retTypes); @@ -3371,16 +3487,50 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); std::vector genericTys; - genericTys.reserve(generics.size()); - std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { - return el.ty; - }); + // if we have a generic expected function type and no generics, we should use the expected ones. + if (FFlag::LuauCheckGenericHOFTypes) + { + if (expectedFunctionType && generics.empty()) + { + genericTys = expectedFunctionType->generics; + } + else + { + genericTys.reserve(generics.size()); + for (const GenericTypeDefinition& generic : generics) + genericTys.push_back(generic.ty); + } + } + else + { + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + } std::vector genericTps; - genericTps.reserve(genericPacks.size()); - std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { - return el.tp; - }); + // if we have a generic expected function type and no generic typepacks, we should use the expected ones. + if (FFlag::LuauCheckGenericHOFTypes) + { + if (expectedFunctionType && genericPacks.empty()) + { + genericTps = expectedFunctionType->genericPacks; + } + else + { + genericTps.reserve(genericPacks.size()); + for (const GenericTypePackDefinition& generic : genericPacks) + genericTps.push_back(generic.tp); + } + } + else + { + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + } TypeId funTy = addType(FunctionTypeVar(funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, std::move(defn), bool(expr.self))); @@ -3468,15 +3618,31 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retTypes}); } } + + if (!currentModule->astTypes.find(&function)) + currentModule->astTypes[&function] = ty; } else ice("Checking non functional type"); } WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) +{ + if (FFlag::LuauUnknownAndNeverType) + { + WithPredicate result = checkExprPackHelper(scope, expr); + if (containsNever(result.type)) + return {uninhabitableTypePack}; + return result; + } + else + return checkExprPackHelper(scope, expr); +} + +WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) - return checkExprPack(scope, *a); + return checkExprPackHelper(scope, *a); else if (expr.is()) { if (!scope->varargPack) @@ -3616,7 +3782,10 @@ void TypeChecker::checkArgumentList( } TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); - state.tryUnify(varPack, tail); + if (FFlag::LuauReturnsFromCallsitesAreNotWidened) + state.tryUnify(tail, varPack); + else + state.tryUnify(varPack, tail); return; } } @@ -3739,7 +3908,7 @@ void TypeChecker::checkArgumentList( } } -WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr) { // evaluate type of function // decompose an intersection into its component overloads @@ -3763,7 +3932,7 @@ WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, cons selfType = checkExpr(scope, *indexExpr->expr).type; selfType = stripFromNilAndReport(selfType, expr.func->location); - if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) + if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, /* addErrors= */ true)) { functionType = *propTy; actualFunctionType = instantiate(scope, functionType, expr.func->location); @@ -3813,11 +3982,25 @@ WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, cons if (get(argPack)) return {errorRecoveryTypePack(scope)}; - TypePack* args = getMutable(argPack); - LUAU_ASSERT(args != nullptr); + TypePack* args = nullptr; + if (FFlag::LuauUnknownAndNeverType) + { + if (expr.self) + { + argPack = addTypePack(TypePack{{selfType}, argPack}); + argListResult.type = argPack; + } + args = getMutable(argPack); + LUAU_ASSERT(args); + } + else + { + args = getMutable(argPack); + LUAU_ASSERT(args != nullptr); - if (expr.self) - args->head.insert(args->head.begin(), selfType); + if (expr.self) + args->head.insert(args->head.begin(), selfType); + } std::vector argLocations; argLocations.reserve(expr.args.size + 1); @@ -3876,7 +4059,10 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st else { std::vector result = reduceUnion({*el, ty}); - el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); + if (FFlag::LuauUnknownAndNeverType && result.empty()) + el = neverType; + else + el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); } } }; @@ -3930,6 +4116,9 @@ std::optional> TypeChecker::checkCallOverload(const Sc return {{errorRecoveryTypePack(scope)}}; } + if (get(fn)) + return {{uninhabitableTypePack}}; + if (auto ftv = get(fn)) { // fn is one of the overloads of actualFunctionType, which @@ -3975,7 +4164,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc // Might be a callable table if (const MetatableTypeVar* mttv = get(fn)) { - if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) + if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false)) { // Construct arguments with 'self' added in front TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); @@ -4202,6 +4391,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil, const std::vector& instantiateGenerics, const std::vector>& expectedTypes) { + bool uninhabitable = false; TypePackId pack = addTypePack(TypePack{}); PredicateVec predicates; // At the moment we will be pushing all predicate sets into this. Do we need some way to split them up? @@ -4232,7 +4422,14 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons auto [typePack, exprPredicates] = checkExprPack(scope, *expr); insert(exprPredicates); - if (std::optional firstTy = first(typePack)) + if (FFlag::LuauUnknownAndNeverType && containsNever(typePack)) + { + // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, + // ...never) + uninhabitable = true; + continue; + } + else if (std::optional firstTy = first(typePack)) { if (!currentModule->astTypes.find(expr)) currentModule->astTypes[expr] = follow(*firstTy); @@ -4248,6 +4445,14 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons auto [type, exprPredicates] = checkExpr(scope, *expr, expectedType); insert(exprPredicates); + if (FFlag::LuauUnknownAndNeverType && get(type)) + { + // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, + // ...never) + uninhabitable = true; + continue; + } + TypeId actualType = substituteFreeForNil && expr->is() ? freshType(scope) : type; if (instantiateGenerics.size() > i && instantiateGenerics[i]) @@ -4272,6 +4477,8 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons for (TxnLog& log : inverseLogs) log.commit(); + if (FFlag::LuauUnknownAndNeverType && uninhabitable) + return {uninhabitableTypePack}; return {pack, predicates}; } @@ -4542,16 +4749,8 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location { const FunctionTypeVar* ftv = get(ty); - if (FFlag::LuauAlwaysQuantify) - { - if (ftv) - Luau::quantify(ty, scope->level); - } - else - { - if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) - Luau::quantify(ty, scope->level); - } + if (ftv) + Luau::quantify(ty, scope->level); if (FFlag::LuauLowerBoundsCalculation && ftv) { @@ -4830,7 +5029,7 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) }; } -std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) +std::optional TypeChecker::filterMapImpl(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); if (!types.empty()) @@ -4838,7 +5037,21 @@ std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predic return std::nullopt; } -std::optional TypeChecker::pickTypesFromSense(TypeId type, bool sense) +std::pair, bool> TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) +{ + if (FFlag::LuauUnknownAndNeverType) + { + TypeId ty = filterMapImpl(type, predicate).value_or(neverType); + return {ty, !bool(get(ty))}; + } + else + { + std::optional ty = filterMapImpl(type, predicate); + return {ty, bool(ty)}; + } +} + +std::pair, bool> TypeChecker::pickTypesFromSense(TypeId type, bool sense) { return filterMap(type, mkTruthyPredicate(sense)); } @@ -4884,6 +5097,13 @@ TypePackId TypeChecker::freshTypePack(TypeLevel level) } TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation) +{ + TypeId ty = resolveTypeWorker(scope, annotation); + currentModule->astResolvedTypes[&annotation] = ty; + return ty; +} + +TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& annotation) { if (const auto& lit = annotation.as()) { @@ -5004,7 +5224,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (notEnoughParameters && hasDefaultParameters) { // 'applyTypeFunction' is used to substitute default types that reference previous generic types - ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes, scope->level}; + ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes}; for (size_t i = 0; i < typesProvided; ++i) applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; @@ -5200,9 +5420,10 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypeList TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation) { + TypePackId result; if (const AstTypePackVariadic* variadic = annotation.as()) { - return addTypePack(TypePackVar{VariadicTypePack{resolveType(scope, *variadic->variadicType)}}); + result = addTypePack(TypePackVar{VariadicTypePack{resolveType(scope, *variadic->variadicType)}}); } else if (const AstTypePackGeneric* generic = annotation.as()) { @@ -5216,10 +5437,12 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack else reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); - return errorRecoveryTypePack(scope); + result = errorRecoveryTypePack(scope); + } + else + { + result = *genericTy; } - - return *genericTy; } else if (const AstTypePackExplicit* explicitTp = annotation.as()) { @@ -5229,66 +5452,17 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack types.push_back(resolveType(scope, *type)); if (auto tailType = explicitTp->typeList.tailType) - return addTypePack(types, resolveTypePack(scope, *tailType)); - - return addTypePack(types); + result = addTypePack(types, resolveTypePack(scope, *tailType)); + else + result = addTypePack(types); } else { ice("Unknown AstTypePack kind"); } -} -bool ApplyTypeFunction::isDirty(TypeId ty) -{ - if (typeArguments.count(ty)) - return true; - else if (const FreeTypeVar* ftv = get(ty)) - { - if (ftv->forwardedTypeAlias) - encounteredForwardedType = true; - return false; - } - else - return false; -} - -bool ApplyTypeFunction::isDirty(TypePackId tp) -{ - if (typePackArguments.count(tp)) - return true; - else - return false; -} - -bool ApplyTypeFunction::ignoreChildren(TypeId ty) -{ - if (get(ty)) - return true; - else - return false; -} - -bool ApplyTypeFunction::ignoreChildren(TypePackId tp) -{ - if (get(tp)) - return true; - else - return false; -} - -TypeId ApplyTypeFunction::clean(TypeId ty) -{ - TypeId& arg = typeArguments[ty]; - LUAU_ASSERT(arg); - return arg; -} - -TypePackId ApplyTypeFunction::clean(TypePackId tp) -{ - TypePackId& arg = typePackArguments[tp]; - LUAU_ASSERT(arg); - return arg; + currentModule->astResolvedTypePacks[&annotation] = result; + return result; } TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, @@ -5297,7 +5471,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (tf.typeParams.empty() && tf.typePackParams.empty()) return tf.type; - ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes, scope->level}; + ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes}; for (size_t i = 0; i < tf.typeParams.size(); ++i) applyTypeFunction.typeArguments[tf.typeParams[i].ty] = typeParams[i]; @@ -5452,10 +5626,18 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const // If we do not have a key, it means we're not trying to discriminate anything, so it's a simple matter of just filtering for a subset. if (!key) { - if (std::optional result = filterMap(*ty, predicate)) + auto [result, ok] = filterMap(*ty, predicate); + if (FFlag::LuauUnknownAndNeverType) + { addRefinement(refis, *target, *result); + } else - addRefinement(refis, *target, errorRecoveryType(scope)); + { + if (ok) + addRefinement(refis, *target, *result); + else + addRefinement(refis, *target, errorRecoveryType(scope)); + } return; } @@ -5471,17 +5653,29 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const { std::optional discriminantTy; if (auto field = Luau::get(*key)) // need to fully qualify Luau::get because of ADL. - discriminantTy = getIndexTypeFromType(scope, option, field->key, Location(), false); + discriminantTy = getIndexTypeFromType(scope, option, field->key, Location(), /* addErrors= */ false); else LUAU_ASSERT(!"Unhandled LValue alternative?"); if (!discriminantTy) return; // Do nothing. An error was already reported, as per usual. - if (std::optional result = filterMap(*discriminantTy, predicate)) + auto [result, ok] = filterMap(*discriminantTy, predicate); + if (FFlag::LuauUnknownAndNeverType) { - viableTargetOptions.insert(option); - viableChildOptions.insert(*result); + if (!get(*result)) + { + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); + } + } + else + { + if (ok) + { + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); + } } } @@ -5560,7 +5754,7 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV continue; else if (auto field = get(key)) { - found = getIndexTypeFromType(scope, *found, field->key, Location(), false); + found = getIndexTypeFromType(scope, *found, field->key, Location(), /* addErrors= */ false); if (!found) return std::nullopt; // Turns out this type doesn't have the property at all. We're done. } @@ -5740,6 +5934,9 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r auto mkFilter = [](ConditionFunc f, std::optional other = std::nullopt) -> SenseToTypeIdPredicate { return [f, other](bool sense) -> TypeIdPredicate { return [f, other, sense](TypeId ty) -> std::optional { + if (FFlag::LuauUnknownAndNeverType && sense && get(ty)) + return other.value_or(ty); + if (f(ty) == sense) return ty; @@ -5847,8 +6044,15 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp for (size_t i = 0; i < expectedLength; ++i) expectedPack->head.push_back(freshType(scope)); + size_t oldErrorsSize = currentModule->errors.size(); + unify(tp, expectedTypePack, location); + // HACK: tryUnify would undo the changes to the expectedTypePack if the length mismatches, but + // we want to tie up free types to be error types, so we do this instead. + if (FFlag::LuauUnknownAndNeverType) + currentModule->errors.resize(oldErrorsSize); + for (TypeId& tp : expectedPack->head) tp = follow(tp); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 82451bd1..d4544483 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) - namespace Luau { @@ -40,19 +38,10 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) TypePackVar& TypePackVar::operator=(const TypePackVar& rhs) { - if (FFlag::LuauNonCopyableTypeVarFields) - { - LUAU_ASSERT(owningArena == rhs.owningArena); - LUAU_ASSERT(!rhs.persistent); + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); - reassign(rhs); - } - else - { - ty = rhs.ty; - persistent = rhs.persistent; - owningArena = rhs.owningArena; - } + reassign(rhs); return *this; } @@ -294,6 +283,16 @@ std::optional first(TypePackId tp, bool ignoreHiddenVariadics) return std::nullopt; } +TypePackVar* asMutable(TypePackId tp) +{ + return const_cast(tp); +} + +TypePack* asMutable(const TypePack* tp) +{ + return const_cast(tp); +} + bool isEmpty(TypePackId tp) { tp = follow(tp); @@ -360,13 +359,25 @@ bool isVariadic(TypePackId tp, const TxnLog& log) return false; } -TypePackVar* asMutable(TypePackId tp) +bool containsNever(TypePackId tp) { - return const_cast(tp); + auto it = begin(tp); + auto endIt = end(tp); + + while (it != endIt) + { + if (get(follow(*it))) + return true; + ++it; + } + + if (auto tail = it.tail()) + { + if (auto vtp = get(*tail); vtp && get(follow(vtp->ty))) + return true; + } + + return false; } -TypePack* asMutable(const TypePack* tp) -{ - return const_cast(tp); -} } // namespace Luau diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 3d97e6eb..66b38cf3 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -24,7 +24,7 @@ std::optional findMetatableEntry(ErrorVec& errors, TypeId type, std::str const TableTypeVar* mtt = getTableType(unwrapped); if (!mtt) { - errors.push_back(TypeError{location, GenericError{"Metatable was not a table."}}); + errors.push_back(TypeError{location, GenericError{"Metatable was not a table"}}); return std::nullopt; } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index ade70d72..ada2b012 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,7 +23,10 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauDeduceGmatchReturnTypes, false) +LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) +LUAU_FASTFLAGVARIABLE(LuauDeduceFindMatchReturnTypes, false) namespace Luau { @@ -31,6 +34,15 @@ namespace Luau std::optional> magicFunctionFormat( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + +static std::optional> magicFunctionMatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + +static std::optional> magicFunctionFind( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeId follow(TypeId t) { return follow(t, [](TypeId t) { @@ -159,10 +171,12 @@ bool isNumber(TypeId ty) // Returns true when ty is a subtype of string bool isString(TypeId ty) { - if (isPrim(ty, PrimitiveTypeVar::String) || get(get(follow(ty)))) + ty = follow(ty); + + if (isPrim(ty, PrimitiveTypeVar::String) || get(get(ty))) return true; - if (auto utv = get(follow(ty))) + if (auto utv = get(ty)) return std::all_of(begin(utv), end(utv), isString); return false; @@ -194,7 +208,7 @@ bool isOptional(TypeId ty) ty = follow(ty); - if (get(ty)) + if (get(ty) || (FFlag::LuauUnknownAndNeverType && get(ty))) return true; auto utv = get(ty); @@ -228,6 +242,8 @@ bool isOverloadedFunction(TypeId ty) std::optional getMetatable(TypeId type) { + type = follow(type); + if (const MetatableTypeVar* mtType = get(type)) return mtType->metatable; else if (const ClassTypeVar* classType = get(type)) @@ -334,6 +350,28 @@ bool isGeneric(TypeId ty) bool maybeGeneric(TypeId ty) { + if (FFlag::LuauMaybeGenericIntersectionTypes) + { + ty = follow(ty); + + if (get(ty)) + return true; + + if (auto ttv = get(ty)) + { + // TODO: recurse on table types CLI-39914 + (void)ttv; + return true; + } + + if (auto itv = get(ty)) + { + return std::any_of(begin(itv), end(itv), maybeGeneric); + } + + return isGeneric(ty); + } + ty = follow(ty); if (get(ty)) return true; @@ -407,6 +445,16 @@ BlockedTypeVar::BlockedTypeVar() int BlockedTypeVar::nextIndex = 0; +PendingExpansionTypeVar::PendingExpansionTypeVar(TypeFun fn, std::vector typeArguments, std::vector packArguments) + : fn(fn) + , typeArguments(typeArguments) + , packArguments(packArguments) + , index(++nextIndex) +{ +} + +size_t PendingExpansionTypeVar::nextIndex = 0; + FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : argTypes(argTypes) , retTypes(retTypes) @@ -646,20 +694,10 @@ TypeVar& TypeVar::operator=(TypeVariant&& rhs) TypeVar& TypeVar::operator=(const TypeVar& rhs) { - if (FFlag::LuauNonCopyableTypeVarFields) - { - LUAU_ASSERT(owningArena == rhs.owningArena); - LUAU_ASSERT(!rhs.persistent); + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); - reassign(rhs); - } - else - { - ty = rhs.ty; - persistent = rhs.persistent; - normal = rhs.normal; - owningArena = rhs.owningArena; - } + reassign(rhs); return *this; } @@ -676,10 +714,14 @@ static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persist static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}; static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true}; +static TypeVar unknownType_{UnknownTypeVar{}, /*persistent*/ true}; +static TypeVar neverType_{NeverTypeVar{}, /*persistent*/ true}; static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true}; -static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; -static TypePackVar errorTypePack_{Unifiable::Error{}}; +static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, /*persistent*/ true}; +static TypePackVar errorTypePack_{Unifiable::Error{}, /*persistent*/ true}; +static TypePackVar neverTypePack_{VariadicTypePack{&neverType_}, /*persistent*/ true}; +static TypePackVar uninhabitableTypePack_{TypePack{{&neverType_}, &neverTypePack_}, /*persistent*/ true}; SingletonTypes::SingletonTypes() : nilType(&nilType_) @@ -690,7 +732,11 @@ SingletonTypes::SingletonTypes() , trueType(&trueType_) , falseType(&falseType_) , anyType(&anyType_) + , unknownType(&unknownType_) + , neverType(&neverType_) , anyTypePack(&anyTypePack_) + , neverTypePack(&neverTypePack_) + , uninhabitableTypePack(&uninhabitableTypePack_) , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); @@ -738,19 +784,26 @@ TypeId SingletonTypes::makeStringMetatable() const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); const TypeId gmatchFunc = makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})}); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); + + const TypeId matchFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), + arena->addTypePack(TypePackVar{VariadicTypePack{FFlag::LuauDeduceFindMatchReturnTypes ? stringType : optionalString}})}); + attachMagicFunction(matchFunc, magicFunctionMatch); + + const TypeId findFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); + attachMagicFunction(findFunc, magicFunctionFind); TableTypeVar::Props stringLib = { {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, {"char", {arena->addType(FunctionTypeVar{numberVariadicList, arena->addTypePack({stringType})})}}, - {"find", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), - arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})})}}, + {"find", {findFunc}}, {"format", {formatFn}}, // FIXME {"gmatch", {gmatchFunc}}, {"gsub", {gsubFunc}}, {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, {"lower", {stringToStringType}}, - {"match", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), - arena->addTypePack(TypePackVar{VariadicTypePack{optionalString}})})}}, + {"match", {matchFunc}}, {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, {"reverse", {stringToStringType}}, {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, @@ -911,6 +964,8 @@ const TypeLevel* getLevel(TypeId ty) return &ttv->level; else if (auto ftv = get(ty)) return &ftv->level; + else if (auto ctv = get(ty)) + return &ctv->level; else return nullptr; } @@ -965,94 +1020,19 @@ bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent) return false; } -UnionTypeVarIterator::UnionTypeVarIterator(const UnionTypeVar* utv) +const std::vector& getTypes(const UnionTypeVar* utv) { - LUAU_ASSERT(utv); - - if (!utv->options.empty()) - stack.push_front({utv, 0}); - - seen.insert(utv); + return utv->options; } -UnionTypeVarIterator& UnionTypeVarIterator::operator++() +const std::vector& getTypes(const IntersectionTypeVar* itv) { - advance(); - descend(); - return *this; + return itv->parts; } -UnionTypeVarIterator UnionTypeVarIterator::operator++(int) +const std::vector& getTypes(const ConstrainedTypeVar* ctv) { - UnionTypeVarIterator copy = *this; - ++copy; - return copy; -} - -bool UnionTypeVarIterator::operator!=(const UnionTypeVarIterator& rhs) -{ - return !(*this == rhs); -} - -bool UnionTypeVarIterator::operator==(const UnionTypeVarIterator& rhs) -{ - if (!stack.empty() && !rhs.stack.empty()) - return stack.front() == rhs.stack.front(); - - return stack.empty() && rhs.stack.empty(); -} - -const TypeId& UnionTypeVarIterator::operator*() -{ - LUAU_ASSERT(!stack.empty()); - - descend(); - - auto [utv, currentIndex] = stack.front(); - LUAU_ASSERT(utv); - LUAU_ASSERT(currentIndex < utv->options.size()); - - const TypeId& ty = utv->options[currentIndex]; - LUAU_ASSERT(!get(follow(ty))); - return ty; -} - -void UnionTypeVarIterator::advance() -{ - while (!stack.empty()) - { - auto& [utv, currentIndex] = stack.front(); - ++currentIndex; - - if (currentIndex >= utv->options.size()) - stack.pop_front(); - else - break; - } -} - -void UnionTypeVarIterator::descend() -{ - while (!stack.empty()) - { - auto [utv, currentIndex] = stack.front(); - if (auto innerUnion = get(follow(utv->options[currentIndex]))) - { - // If we're about to descend into a cyclic UnionTypeVar, we should skip over this. - // Ideally this should never happen, but alas it does from time to time. :( - if (seen.find(innerUnion) != seen.end()) - advance(); - else - { - seen.insert(innerUnion); - stack.push_front({innerUnion, 0}); - } - - continue; - } - - break; - } + return ctv->parts; } UnionTypeVarIterator begin(const UnionTypeVar* utv) @@ -1065,9 +1045,30 @@ UnionTypeVarIterator end(const UnionTypeVar* utv) return UnionTypeVarIterator{}; } +IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv) +{ + return IntersectionTypeVarIterator{itv}; +} + +IntersectionTypeVarIterator end(const IntersectionTypeVar* itv) +{ + return IntersectionTypeVarIterator{}; +} + +ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv) +{ + return ConstrainedTypeVarIterator{ctv}; +} + +ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv) +{ + return ConstrainedTypeVarIterator{}; +} + + static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) { - const char* options = "cdiouxXeEfgGqs"; + const char* options = "cdiouxXeEfgGqs*"; std::vector result; @@ -1081,7 +1082,7 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha continue; // we just ignore all characters (including flags/precision) up until first alphabetic character - while (i < size && !(data[i] > 0 && isalpha(data[i]))) + while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*'))) i++; if (i == size) @@ -1089,6 +1090,8 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha if (data[i] == 'q' || data[i] == 's') result.push_back(typechecker.stringType); + else if (data[i] == '*') + result.push_back(typechecker.unknownType); else if (strchr(options, data[i])) result.push_back(typechecker.numberType); else @@ -1144,6 +1147,197 @@ std::optional> magicFunctionFormat( return WithPredicate{arena.addTypePack({typechecker.stringType})}; } +static std::vector parsePatternString(TypeChecker& typechecker, const char* data, size_t size) +{ + std::vector result; + int depth = 0; + bool parsingSet = false; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + ++i; + if (!parsingSet && i < size && data[i] == 'b') + i += 2; + } + else if (!parsingSet && data[i] == '[') + { + parsingSet = true; + if (i + 1 < size && data[i + 1] == ']') + i += 1; + } + else if (parsingSet && data[i] == ']') + { + parsingSet = false; + } + else if (data[i] == '(') + { + if (parsingSet) + continue; + + if (i + 1 < size && data[i + 1] == ')') + { + i++; + result.push_back(typechecker.numberType); + continue; + } + + ++depth; + result.push_back(typechecker.stringType); + } + else if (data[i] == ')') + { + if (parsingSet) + continue; + + --depth; + + if (depth < 0) + break; + } + } + + if (depth != 0 || parsingSet) + return std::vector(); + + if (result.empty()) + result.push_back(typechecker.stringType); + + return result; +} + +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + if (!FFlag::LuauDeduceGmatchReturnTypes) + return std::nullopt; + + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() != 2) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t index = expr.self ? 0 : 1; + if (expr.args.size > index) + pattern = expr.args.data[index]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location); + + const TypePackId emptyPack = arena.addTypePack({}); + const TypePackId returnList = arena.addTypePack(returnTypes); + const TypeId iteratorType = arena.addType(FunctionTypeVar{emptyPack, returnList}); + return WithPredicate{arena.addTypePack({iteratorType})}; +} + +static std::optional> magicFunctionMatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + if (!FFlag::LuauDeduceFindMatchReturnTypes) + return std::nullopt; + + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 3) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() == 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, expr.args.data[initIndex]->location); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + +static std::optional> magicFunctionFind( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + if (!FFlag::LuauDeduceFindMatchReturnTypes) + return std::nullopt; + + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 4) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + bool plain = false; + size_t plainIndex = expr.self ? 2 : 3; + if (expr.args.size > plainIndex) + { + AstExprConstantBool* p = expr.args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + } + + typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}}); + const TypeId optionalBoolean = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.booleanType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() >= 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, expr.args.data[initIndex]->location); + + if (params.size() == 4 && expr.args.size > plainIndex) + typechecker.unify(params[3], optionalBoolean, expr.args.data[plainIndex]->location); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + std::vector filterMap(TypeId type, TypeIdPredicate predicate) { type = follow(type); @@ -1228,4 +1422,19 @@ bool hasTag(const Property& prop, const std::string& tagName) return hasTag(prop.tags, tagName); } +bool TypeFun::operator==(const TypeFun& rhs) const +{ + return type == rhs.type && typeParams == rhs.typeParams && typePackParams == rhs.typePackParams; +} + +bool GenericTypeDefinition::operator==(const GenericTypeDefinition& rhs) const +{ + return ty == rhs.ty && defaultValue == rhs.defaultValue; +} + +bool GenericTypePackDefinition::operator==(const GenericTypePackDefinition& rhs) const +{ + return tp == rhs.tp && defaultValue == rhs.defaultValue; +} + } // namespace Luau diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index 8d23aa49..63d8647d 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -12,7 +12,7 @@ Free::Free(TypeLevel level) { } -Free::Free(Scope2* scope) +Free::Free(Scope* scope) : scope(scope) { } @@ -39,7 +39,7 @@ Generic::Generic(const Name& name) { } -Generic::Generic(Scope2* scope) +Generic::Generic(Scope* scope) : index(++nextIndex) , scope(scope) { @@ -53,7 +53,7 @@ Generic::Generic(TypeLevel level, const Name& name) { } -Generic::Generic(Scope2* scope, const Name& name) +Generic::Generic(Scope* scope, const Name& name) : index(++nextIndex) , scope(scope) , name(name) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 0792a350..e099817f 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -19,7 +19,9 @@ LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); +LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauQuantifyConstrained) +LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) namespace Luau { @@ -47,33 +49,6 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor } } - // TODO cycle and operator() need to be clipped when FFlagLuauUseVisitRecursionLimit is clipped - template - void cycle(TID) - { - } - template - bool operator()(TID ty, const T&) - { - return visit(ty); - } - bool operator()(TypeId ty, const FreeTypeVar& ftv) - { - return visit(ty, ftv); - } - bool operator()(TypeId ty, const FunctionTypeVar& ftv) - { - return visit(ty, ftv); - } - bool operator()(TypeId ty, const TableTypeVar& ttv) - { - return visit(ty, ttv); - } - bool operator()(TypePackId tp, const FreeTypePack& ftp) - { - return visit(tp, ftp); - } - bool visit(TypeId ty) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -103,6 +78,15 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor return true; } + bool visit(TypeId ty, const ConstrainedTypeVar&) override + { + if (!FFlag::LuauUnknownAndNeverType) + return visit(ty); + + promote(ty, log.getMutable(ty)); + return true; + } + bool visit(TypeId ty, const FunctionTypeVar&) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -445,6 +429,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } else if (subFree) { + if (FFlag::LuauUnknownAndNeverType) + { + // Normally, if the subtype is free, it should not be bound to any, unknown, or error types. + // But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors. + if (log.get(superTy)) + return; + } + TypeLevel subLevel = subFree->level; occursCheck(subTy, superTy); @@ -468,7 +460,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return; } - if (get(superTy) || get(superTy)) + if (get(superTy) || get(superTy) || get(superTy)) return tryUnifyWithAny(subTy, superTy); if (get(subTy)) @@ -482,7 +474,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return tryUnifyWithAny(superTy, subTy); } - if (get(subTy)) + if (log.get(subTy)) + return tryUnifyWithAny(superTy, subTy); + + if (log.get(subTy)) return tryUnifyWithAny(superTy, subTy); auto& cache = sharedState.cachedUnify; @@ -544,6 +539,16 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyTables(subTy, superTy, isIntersection); } + else if (FFlag::LuauScalarShapeSubtyping && log.get(superTy) && + (log.get(subTy) || log.get(subTy))) + { + tryUnifyScalarShape(subTy, superTy, /*reversed*/ false); + } + else if (FFlag::LuauScalarShapeSubtyping && log.get(subTy) && + (log.get(superTy) || log.get(superTy))) + { + tryUnifyScalarShape(subTy, superTy, /*reversed*/ true); + } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. else if (log.getMutable(superTy)) @@ -1606,6 +1611,60 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } } +void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) +{ + LUAU_ASSERT(FFlag::LuauScalarShapeSubtyping); + + TypeId osubTy = subTy; + TypeId osuperTy = superTy; + + if (reversed) + std::swap(subTy, superTy); + + if (auto ttv = log.get(superTy); !ttv || ttv->state != TableState::Free) + return reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); + + auto fail = [&](std::optional e) { + std::string reason = "The former's metatable does not satisfy the requirements."; + if (e) + reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason, *e}}); + else + reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason}}); + }; + + // Given t1 where t1 = { lower: (t1) -> (a, b...) } + // It should be the case that `string <: t1` iff `(subtype's metatable).__index <: t1` + if (auto metatable = getMetatable(subTy)) + { + auto mttv = log.get(*metatable); + if (!mttv) + fail(std::nullopt); + + if (auto it = mttv->props.find("__index"); it != mttv->props.end()) + { + TypeId ty = it->second.type; + Unifier child = makeChildUnifier(); + child.tryUnify_(ty, superTy); + + if (auto e = hasUnificationTooComplex(child.errors)) + reportError(*e); + else if (!child.errors.empty()) + fail(child.errors.front()); + + log.concat(std::move(child.log)); + + return; + } + else + { + return fail(std::nullopt); + } + } + + reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}}); + return; +} + TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) { ty = follow(ty); @@ -1862,6 +1921,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas if (state.log.getMutable(ty)) { + // TODO: Only bind if the anyType isn't any, unknown, or error (?) state.log.replace(ty, BoundTypeVar{anyType}); } else if (auto fun = state.log.getMutable(ty)) @@ -1901,22 +1961,28 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) { - LUAU_ASSERT(get(anyTy) || get(anyTy)); + LUAU_ASSERT(get(anyTy) || get(anyTy) || get(anyTy) || get(anyTy)); // These types are not visited in general loop below if (get(subTy) || get(subTy) || get(subTy)) return; - const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); - - const TypePackId anyTP = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + TypePackId anyTp; + if (FFlag::LuauUnknownAndNeverType) + anyTp = types->addTypePack(TypePackVar{VariadicTypePack{anyTy}}); + else + { + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); + anyTp = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + } std::vector queue = {subTy}; sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, getSingletonTypes().anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, + FFlag::LuauUnknownAndNeverType ? anyTy : getSingletonTypes().anyType, anyTp); } void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 6f39e3fd..1e164d04 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -474,16 +474,26 @@ public: bool value; }; +enum class ConstantNumberParseResult +{ + Ok, + Malformed, + BinOverflow, + HexOverflow, + DoublePrefix, +}; + class AstExprConstantNumber : public AstExpr { public: LUAU_RTTI(AstExprConstantNumber) - AstExprConstantNumber(const Location& location, double value); + AstExprConstantNumber(const Location& location, double value, ConstantNumberParseResult parseResult = ConstantNumberParseResult::Ok); void visit(AstVisitor* visitor) override; double value; + ConstantNumberParseResult parseResult; }; class AstExprConstantString : public AstExpr diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 24a280da..3066b756 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -50,9 +50,10 @@ void AstExprConstantBool::visit(AstVisitor* visitor) visitor->visit(this); } -AstExprConstantNumber::AstExprConstantNumber(const Location& location, double value) +AstExprConstantNumber::AstExprConstantNumber(const Location& location, double value, ConstantNumberParseResult parseResult) : AstExpr(ClassIndex(), location) , value(value) + , parseResult(parseResult) { } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 70c92555..1eb9565f 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -15,14 +15,14 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) -LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false) LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false) bool lua_telemetry_parsed_named_non_function_type = false; -LUAU_FASTFLAGVARIABLE(LuauErrorParseIntegerIssues, false) +LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false) +LUAU_FASTFLAGVARIABLE(LuauLintParseIntegerIssues, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) bool lua_telemetry_parsed_out_of_range_bin_integer = false; @@ -1134,10 +1134,9 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() { - if (options.allowTypeAnnotations && - (lexer.current().type == ':' || (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow))) + if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) { - if (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow) + if (lexer.current().type == Lexeme::SkinnyArrow) report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); nextLexeme(); @@ -1373,12 +1372,10 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) if (FFlag::LuauFixNamedFunctionParse && !names.empty()) forceFunctionType = true; - bool returnTypeIntroducer = - FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false; + bool returnTypeIntroducer = lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':'; // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element - if (params.size() == 1 && !varargAnnotation && !forceFunctionType && - (FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow)) + if (params.size() == 1 && !varargAnnotation && !forceFunctionType && !returnTypeIntroducer) { if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) lua_telemetry_parsed_named_non_function_type = true; @@ -1389,8 +1386,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) return {params[0], {}}; } - if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && !forceFunctionType && - allowPack) + if (!forceFunctionType && !returnTypeIntroducer && allowPack) { if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) lua_telemetry_parsed_named_non_function_type = true; @@ -1409,7 +1405,7 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray' instead of ':'"); lexer.next(); @@ -2037,8 +2033,10 @@ AstExpr* Parser::parseAssertionExpr() return expr; } -static const char* parseInteger(double& result, const char* data, int base) +static const char* parseInteger_DEPRECATED(double& result, const char* data, int base) { + LUAU_ASSERT(!FFlag::LuauLintParseIntegerIssues); + char* end = nullptr; unsigned long long value = strtoull(data, &end, base); @@ -2058,9 +2056,6 @@ static const char* parseInteger(double& result, const char* data, int base) else lua_telemetry_parsed_out_of_range_hex_integer = true; } - - if (FFlag::LuauErrorParseIntegerIssues) - return "Integer number value is out of range"; } } @@ -2068,11 +2063,13 @@ static const char* parseInteger(double& result, const char* data, int base) return *end == 0 ? nullptr : "Malformed number"; } -static const char* parseNumber(double& result, const char* data) +static const char* parseNumber_DEPRECATED2(double& result, const char* data) { + LUAU_ASSERT(!FFlag::LuauLintParseIntegerIssues); + // binary literal if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) - return parseInteger(result, data + 2, 2); + return parseInteger_DEPRECATED(result, data + 2, 2); // hexadecimal literal if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2]) @@ -2080,10 +2077,7 @@ static const char* parseNumber(double& result, const char* data) if (DFFlag::LuaReportParseIntegerIssues && data[2] == '0' && (data[3] == 'x' || data[3] == 'X')) lua_telemetry_parsed_double_prefix_hex_integer = true; - if (FFlag::LuauErrorParseIntegerIssues) - return parseInteger(result, data, 16); // keep prefix, it's handled by 'strtoull' - else - return parseInteger(result, data + 2, 16); + return parseInteger_DEPRECATED(result, data + 2, 16); } char* end = nullptr; @@ -2095,6 +2089,8 @@ static const char* parseNumber(double& result, const char* data) static bool parseNumber_DEPRECATED(double& result, const char* data) { + LUAU_ASSERT(!FFlag::LuauLintParseIntegerIssues); + // binary literal if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) { @@ -2123,6 +2119,73 @@ static bool parseNumber_DEPRECATED(double& result, const char* data) } } +static ConstantNumberParseResult parseInteger(double& result, const char* data, int base) +{ + LUAU_ASSERT(FFlag::LuauLintParseIntegerIssues); + LUAU_ASSERT(base == 2 || base == 16); + + char* end = nullptr; + unsigned long long value = strtoull(data, &end, base); + + if (*end != 0) + return ConstantNumberParseResult::Malformed; + + result = double(value); + + if (value == ULLONG_MAX && errno == ERANGE) + { + // 'errno' might have been set before we called 'strtoull', but we don't want the overhead of resetting a TLS variable on each call + // so we only reset it when we get a result that might be an out-of-range error and parse again to make sure + errno = 0; + value = strtoull(data, &end, base); + + if (errno == ERANGE) + { + if (DFFlag::LuaReportParseIntegerIssues) + { + if (base == 2) + lua_telemetry_parsed_out_of_range_bin_integer = true; + else + lua_telemetry_parsed_out_of_range_hex_integer = true; + } + + return base == 2 ? ConstantNumberParseResult::BinOverflow : ConstantNumberParseResult::HexOverflow; + } + } + + return ConstantNumberParseResult::Ok; +} + +static ConstantNumberParseResult parseNumber(double& result, const char* data) +{ + LUAU_ASSERT(FFlag::LuauLintParseIntegerIssues); + + // binary literal + if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) + return parseInteger(result, data + 2, 2); + + // hexadecimal literal + if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2]) + { + if (!FFlag::LuauErrorDoubleHexPrefix && data[2] == '0' && (data[3] == 'x' || data[3] == 'X')) + { + if (DFFlag::LuaReportParseIntegerIssues) + lua_telemetry_parsed_double_prefix_hex_integer = true; + + ConstantNumberParseResult parseResult = parseInteger(result, data + 2, 16); + return parseResult == ConstantNumberParseResult::Malformed ? parseResult : ConstantNumberParseResult::DoublePrefix; + } + + return parseInteger(result, data, 16); // pass in '0x' prefix, it's handled by 'strtoull' + } + + char* end = nullptr; + double value = strtod(data, &end); + + result = value; + return *end == 0 ? ConstantNumberParseResult::Ok : ConstantNumberParseResult::Malformed; +} + // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp AstExpr* Parser::parseSimpleExpr() { @@ -2163,10 +2226,21 @@ AstExpr* Parser::parseSimpleExpr() scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end()); } - if (DFFlag::LuaReportParseIntegerIssues || FFlag::LuauErrorParseIntegerIssues) + if (FFlag::LuauLintParseIntegerIssues) { double value = 0; - if (const char* error = parseNumber(value, scratchData.c_str())) + ConstantNumberParseResult result = parseNumber(value, scratchData.c_str()); + nextLexeme(); + + if (result == ConstantNumberParseResult::Malformed) + return reportExprError(start, {}, "Malformed number"); + + return allocator.alloc(start, value, result); + } + else if (DFFlag::LuaReportParseIntegerIssues) + { + double value = 0; + if (const char* error = parseNumber_DEPRECATED2(value, scratchData.c_str())) { nextLexeme(); @@ -2923,37 +2997,34 @@ AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const void Parser::nextLexeme() { - if (options.captureComments) - { - Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type; + Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type; - while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) - { - const Lexeme& lexeme = lexer.current(); + while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) + { + const Lexeme& lexeme = lexer.current(); + + if (options.captureComments) commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. - // The parser will turn this into a proper syntax error. - if (lexeme.type == Lexeme::BrokenComment) - return; + // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. + // The parser will turn this into a proper syntax error. + if (lexeme.type == Lexeme::BrokenComment) + return; - // Comments starting with ! are called "hot comments" and contain directives for type checking / linting - if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') - { - const char* text = lexeme.data; + // Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling + if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') + { + const char* text = lexeme.data; - unsigned int end = lexeme.length; - while (end > 0 && isSpace(text[end - 1])) - --end; + unsigned int end = lexeme.length; + while (end > 0 && isSpace(text[end - 1])) + --end; - hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); - } - - type = lexer.next(/* skipComments= */ false, /* updatePrevLocation= */ false).type; + hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); } + + type = lexer.next(/* skipComments= */ false, /* updatePrevLocation= */ false).type; } - else - lexer.next(); } } // namespace Luau diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 81db7c35..cd50ef00 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -7,6 +7,11 @@ #include "Luau/Transpiler.h" #include "FileUtils.h" +#include "Flags.h" + +#ifdef CALLGRIND +#include +#endif LUAU_FASTFLAG(DebugLuauTimeTracing) LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution) @@ -112,6 +117,7 @@ static void displayHelp(const char* argv0) printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); + printf(" --mode=strict: default to strict mode when typechecking\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); } @@ -178,9 +184,9 @@ struct CliConfigResolver : Luau::ConfigResolver mutable std::unordered_map configCache; mutable std::vector> configErrors; - CliConfigResolver() + CliConfigResolver(Luau::Mode mode) { - defaultConfig.mode = Luau::Mode::Nonstrict; + defaultConfig.mode = mode; } const Luau::Config& getConfig(const Luau::ModuleName& name) const override @@ -218,9 +224,7 @@ int main(int argc, char** argv) { Luau::assertHandler() = assertionHandler; - for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) - if (strncmp(flag->name, "Luau", 4) == 0) - flag->value = true; + setLuauFlagsDefault(); if (argc >= 2 && strcmp(argv[1], "--help") == 0) { @@ -229,6 +233,7 @@ int main(int argc, char** argv) } ReportFormat format = ReportFormat::Default; + Luau::Mode mode = Luau::Mode::Nonstrict; bool annotate = false; for (int i = 1; i < argc; ++i) @@ -240,16 +245,20 @@ int main(int argc, char** argv) format = ReportFormat::Luacheck; else if (strcmp(argv[i], "--formatter=gnu") == 0) format = ReportFormat::Gnu; + else if (strcmp(argv[i], "--mode=strict") == 0) + mode = Luau::Mode::Strict; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; else if (strcmp(argv[i], "--timetrace") == 0) FFlag::DebugLuauTimeTracing.value = true; + else if (strncmp(argv[i], "--fflags=", 9) == 0) + setLuauFlags(argv[i] + 9); } #if !defined(LUAU_ENABLE_TIME_TRACE) if (FFlag::DebugLuauTimeTracing) { - printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + fprintf(stderr, "To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); return 1; } #endif @@ -258,12 +267,16 @@ int main(int argc, char** argv) frontendOptions.retainFullTypeGraphs = annotate; CliFileResolver fileResolver; - CliConfigResolver configResolver; + CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); Luau::registerBuiltinTypes(frontend.typeChecker); Luau::freeze(frontend.typeChecker.globalTypes); +#ifdef CALLGRIND + CALLGRIND_ZERO_STATS; +#endif + std::vector files = getSourceFiles(argc, argv); int failed = 0; diff --git a/CLI/Ast.cpp b/CLI/Ast.cpp index 4ea46236..fd99d225 100644 --- a/CLI/Ast.cpp +++ b/CLI/Ast.cpp @@ -3,7 +3,7 @@ #include "Luau/Common.h" #include "Luau/Ast.h" -#include "Luau/JsonEncoder.h" +#include "Luau/AstJsonEncoder.h" #include "Luau/Parser.h" #include "Luau/ParseOptions.h" @@ -62,6 +62,7 @@ int main(int argc, char** argv) Luau::AstNameTable names(allocator); Luau::ParseOptions options; + options.captureComments = true; options.supportContinueStatement = true; options.allowTypeAnnotations = true; options.allowDeclarationSyntax = true; @@ -78,7 +79,7 @@ int main(int argc, char** argv) fprintf(stderr, "\n"); } - printf("%s", Luau::toJson(parseResult.root).c_str()); + printf("%s", Luau::toJson(parseResult.root, parseResult.commentLocations).c_str()); return parseResult.errors.size() > 0 ? 1 : 0; } diff --git a/CLI/Flags.cpp b/CLI/Flags.cpp new file mode 100644 index 00000000..4e261171 --- /dev/null +++ b/CLI/Flags.cpp @@ -0,0 +1,75 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" +#include "Luau/ExperimentalFlags.h" + +#include + +#include +#include + +static void setLuauFlag(std::string_view name, bool state) +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + { + if (name == flag->name) + { + flag->value = state; + return; + } + } + + fprintf(stderr, "Warning: unrecognized flag '%.*s'.\n", int(name.length()), name.data()); +} + +static void setLuauFlags(bool state) +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = state; +} + +void setLuauFlagsDefault() +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0 && !Luau::isFlagExperimental(flag->name)) + flag->value = true; +} + +void setLuauFlags(const char* list) +{ + std::string_view rest = list; + + while (!rest.empty()) + { + size_t ending = rest.find(","); + std::string_view element = rest.substr(0, ending); + + if (size_t separator = element.find('='); separator != std::string_view::npos) + { + std::string_view key = element.substr(0, separator); + std::string_view value = element.substr(separator + 1); + + if (value == "true" || value == "True") + setLuauFlag(key, true); + else if (value == "false" || value == "False") + setLuauFlag(key, false); + else + fprintf(stderr, "Warning: unrecognized value '%.*s' for flag '%.*s'.\n", int(value.length()), value.data(), int(key.length()), + key.data()); + } + else + { + if (element == "true" || element == "True") + setLuauFlags(true); + else if (element == "false" || element == "False") + setLuauFlags(false); + else + setLuauFlag(element, true); + } + + if (ending != std::string_view::npos) + rest.remove_prefix(ending + 1); + else + break; + } +} diff --git a/CLI/Flags.h b/CLI/Flags.h new file mode 100644 index 00000000..8dfb0a29 --- /dev/null +++ b/CLI/Flags.h @@ -0,0 +1,5 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +void setLuauFlagsDefault(); +void setLuauFlags(const char* list); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 5fe12bec..4d3beec9 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -8,9 +8,10 @@ #include "Luau/BytecodeBuilder.h" #include "Luau/Parser.h" -#include "FileUtils.h" -#include "Profiler.h" #include "Coverage.h" +#include "FileUtils.h" +#include "Flags.h" +#include "Profiler.h" #include "isocline.h" @@ -19,6 +20,9 @@ #ifdef _WIN32 #include #include + +#define WIN32_LEAN_AND_MEAN +#include #endif #ifdef CALLGRIND @@ -26,6 +30,7 @@ #endif #include +#include LUAU_FASTFLAG(DebugLuauTimeTracing) @@ -46,6 +51,35 @@ enum class CompileFormat constexpr int MaxTraversalLimit = 50; +// Ctrl-C handling +static void sigintCallback(lua_State* L, int gc) +{ + if (gc >= 0) + return; + + lua_callbacks(L)->interrupt = NULL; + + lua_rawcheckstack(L, 1); // reserve space for error string + luaL_error(L, "Execution interrupted"); +} + +static lua_State* replState = NULL; + +#ifdef _WIN32 +BOOL WINAPI sigintHandler(DWORD signal) +{ + if (signal == CTRL_C_EVENT && replState) + lua_callbacks(replState)->interrupt = &sigintCallback; + return TRUE; +} +#else +static void sigintHandler(int signum) +{ + if (signum == SIGINT && replState) + lua_callbacks(replState)->interrupt = &sigintCallback; +} +#endif + struct GlobalOptions { int optimizationLevel = 1; @@ -75,8 +109,8 @@ static int lua_loadstring(lua_State* L) return 1; lua_pushnil(L); - lua_insert(L, -2); /* put before error message */ - return 2; /* return nil plus error message */ + lua_insert(L, -2); // put before error message + return 2; // return nil plus error message } static int finishrequire(lua_State* L) @@ -97,7 +131,11 @@ static int lua_require(lua_State* L) // return the module from the cache lua_getfield(L, -1, name.c_str()); if (!lua_isnil(L, -1)) + { + // L stack: _MODULES result return finishrequire(L); + } + lua_pop(L, 1); std::optional source = readFile(name + ".luau"); @@ -109,6 +147,7 @@ static int lua_require(lua_State* L) } // module needs to run in a new thread, isolated from the rest + // note: we create ML on main thread so that it doesn't inherit environment of L lua_State* GL = lua_mainthread(L); lua_State* ML = lua_newthread(GL); lua_xmove(GL, L, 1); @@ -142,11 +181,12 @@ static int lua_require(lua_State* L) } } - // there's now a return value on top of ML; stack of L is MODULES thread + // there's now a return value on top of ML; L stack: _MODULES ML lua_xmove(ML, L, 1); lua_pushvalue(L, -1); lua_setfield(L, -4, name.c_str()); + // L stack: _MODULES ML result return finishrequire(L); } @@ -528,6 +568,15 @@ static void runRepl() lua_State* L = globalState.get(); setupState(L); + + // setup Ctrl+C handling + replState = L; +#ifdef _WIN32 + SetConsoleCtrlHandler(sigintHandler, TRUE); +#else + signal(SIGINT, sigintHandler); +#endif + luaL_sandboxthread(L); runReplImpl(L); } @@ -682,60 +731,11 @@ static int assertionHandler(const char* expr, const char* file, int line, const return 1; } -static void setLuauFlags(bool state) -{ - for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) - { - if (strncmp(flag->name, "Luau", 4) == 0) - flag->value = state; - } -} - -static void setFlag(std::string_view name, bool state) -{ - for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) - { - if (name == flag->name) - { - flag->value = state; - return; - } - } - - fprintf(stderr, "Warning: --fflag unrecognized flag '%.*s'.\n\n", int(name.length()), name.data()); -} - -static void applyFlagKeyValue(std::string_view element) -{ - if (size_t separator = element.find('='); separator != std::string_view::npos) - { - std::string_view key = element.substr(0, separator); - std::string_view value = element.substr(separator + 1); - - if (value == "true") - setFlag(key, true); - else if (value == "false") - setFlag(key, false); - else - fprintf(stderr, "Warning: --fflag unrecognized value '%.*s' for flag '%.*s'.\n\n", int(value.length()), value.data(), int(key.length()), - key.data()); - } - else - { - if (element == "true") - setLuauFlags(true); - else if (element == "false") - setLuauFlags(false); - else - setFlag(element, true); - } -} - int replMain(int argc, char** argv) { Luau::assertHandler() = assertionHandler; - setLuauFlags(true); + setLuauFlagsDefault(); CliMode mode = CliMode::Unknown; CompileFormat compileFormat{}; @@ -818,27 +818,10 @@ int replMain(int argc, char** argv) else if (strcmp(argv[i], "--timetrace") == 0) { FFlag::DebugLuauTimeTracing.value = true; - -#if !defined(LUAU_ENABLE_TIME_TRACE) - printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); - return 1; -#endif } else if (strncmp(argv[i], "--fflags=", 9) == 0) { - std::string_view list = argv[i] + 9; - - while (!list.empty()) - { - size_t ending = list.find(","); - - applyFlagKeyValue(list.substr(0, ending)); - - if (ending != std::string_view::npos) - list.remove_prefix(ending + 1); - else - break; - } + setLuauFlags(argv[i] + 9); } else if (argv[i][0] == '-') { @@ -848,6 +831,14 @@ int replMain(int argc, char** argv) } } +#if !defined(LUAU_ENABLE_TIME_TRACE) + if (FFlag::DebugLuauTimeTracing) + { + fprintf(stderr, "To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; + } +#endif + const std::vector files = getSourceFiles(argc, argv); if (mode == CliMode::Unknown) { diff --git a/CLI/ReplEntry.cpp b/CLI/ReplEntry.cpp index 75995e6a..8543e3f7 100644 --- a/CLI/ReplEntry.cpp +++ b/CLI/ReplEntry.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Repl.h" - - int main(int argc, char** argv) { return replMain(argc, argv); diff --git a/CMakeLists.txt b/CMakeLists.txt index e256e234..92006344 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -175,6 +175,7 @@ endif() if(LUAU_BUILD_TESTS) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) + target_compile_definitions(Luau.UnitTest PRIVATE DOCTEST_CONFIG_DOUBLE_STRINGIFY) target_include_directories(Luau.UnitTest PRIVATE extern) target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen) diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index c5979d3c..028b2d16 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -58,12 +58,25 @@ public: void jmp(Label& label); void jmp(OperandX64 op); + void call(Label& label); + void call(OperandX64 op); + + void int3(); + // AVX void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vaddsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vaddss(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vsubsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vmulsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vdivsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + + void vxorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + + void vcomisd(OperandX64 src1, OperandX64 src2); + void vsqrtpd(OperandX64 dst, OperandX64 src); void vsqrtps(OperandX64 dst, OperandX64 src); void vsqrtsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); diff --git a/CodeGen/include/Luau/Condition.h b/CodeGen/include/Luau/Condition.h index 36cbda95..78e4515e 100644 --- a/CodeGen/include/Luau/Condition.h +++ b/CodeGen/include/Luau/Condition.h @@ -37,8 +37,6 @@ enum class Condition Zero, NotZero, - // TODO: ordered and unordered floating-point conditions - Count }; diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 27e01781..f88063cf 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -231,6 +231,7 @@ void AssemblyBuilderX64::lea(OperandX64 lhs, OperandX64 rhs) if (logText) log("lea", lhs, rhs); + LUAU_ASSERT(rhs.cat == CategoryX64::mem); placeBinaryRegAndRegMem(lhs, rhs, 0x8d, 0x8d); } @@ -286,11 +287,42 @@ void AssemblyBuilderX64::jmp(OperandX64 op) if (logText) log("jmp", op); + placeRex(op); place(0xff); placeModRegMem(op, 4); commit(); } +void AssemblyBuilderX64::call(Label& label) +{ + place(0xe8); + placeLabel(label); + + if (logText) + log("call", label); + + commit(); +} + +void AssemblyBuilderX64::call(OperandX64 op) +{ + if (logText) + log("call", op); + + placeRex(op); + place(0xff); + placeModRegMem(op, 2); + commit(); +} + +void AssemblyBuilderX64::int3() +{ + if (logText) + log("int3"); + + place(0xcc); +} + void AssemblyBuilderX64::vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2) { placeAvx("vaddpd", dst, src1, src2, 0x58, false, AVX_0F, AVX_66); @@ -311,6 +343,31 @@ void AssemblyBuilderX64::vaddss(OperandX64 dst, OperandX64 src1, OperandX64 src2 placeAvx("vaddss", dst, src1, src2, 0x58, false, AVX_0F, AVX_F3); } +void AssemblyBuilderX64::vsubsd(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + placeAvx("vsubsd", dst, src1, src2, 0x5c, false, AVX_0F, AVX_F2); +} + +void AssemblyBuilderX64::vmulsd(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + placeAvx("vmulsd", dst, src1, src2, 0x59, false, AVX_0F, AVX_F2); +} + +void AssemblyBuilderX64::vdivsd(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + placeAvx("vdivsd", dst, src1, src2, 0x5e, false, AVX_0F, AVX_F2); +} + +void AssemblyBuilderX64::vxorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + placeAvx("vxorpd", dst, src1, src2, 0x57, false, AVX_0F, AVX_66); +} + +void AssemblyBuilderX64::vcomisd(OperandX64 src1, OperandX64 src2) +{ + placeAvx("vcomisd", src1, src2, 0x2f, false, AVX_0F, AVX_66); +} + void AssemblyBuilderX64::vsqrtpd(OperandX64 dst, OperandX64 src) { placeAvx("vsqrtpd", dst, src, 0x51, false, AVX_0F, AVX_66); @@ -471,9 +528,10 @@ void AssemblyBuilderX64::placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs, LUAU_ASSERT(lhs.cat == CategoryX64::reg || lhs.cat == CategoryX64::mem); LUAU_ASSERT(rhs.cat == CategoryX64::imm); - SizeX64 size = lhs.base.size; + SizeX64 size = lhs.cat == CategoryX64::reg ? lhs.base.size : lhs.memSize; + LUAU_ASSERT(size == SizeX64::byte || size == SizeX64::dword || size == SizeX64::qword); - placeRex(lhs.base); + placeRex(lhs); if (size == SizeX64::byte) { diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 0cb7e1d9..1d6b18e5 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -37,6 +37,14 @@ // Note that Luau runtime doesn't provide indefinite bytecode compatibility: support for older versions gets removed over time. As such, bytecode isn't a durable storage format and it's expected // that Luau users can recompile bytecode from source on Luau version upgrades if necessary. +// # Bytecode version history +// +// Note: due to limitations of the versioning scheme, some bytecode blobs that carry version 2 are using features from version 3. Starting from version 3, version should be sufficient to indicate bytecode compatibility. +// +// Version 1: Baseline version for the open-source release. Supported until 0.521. +// Version 2: Adds Proto::linedefined. Currently supported. +// Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported. + // Bytecode opcode, part of the instruction header enum LuauOpcode { @@ -367,6 +375,20 @@ enum LuauOpcode // D: jump offset (-32768..32767) LOP_FORGPREP, + // JUMPXEQKNIL, JUMPXEQKB: jumps to target offset if the comparison with constant is true (or false, see AUX) + // A: source register 1 + // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") + // AUX: constant value (for boolean) in low bit, NOT flag (that flips comparison result) in high bit + LOP_JUMPXEQKNIL, + LOP_JUMPXEQKB, + + // JUMPXEQKN, JUMPXEQKS: jumps to target offset if the comparison with constant is true (or false, see AUX) + // A: source register 1 + // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") + // AUX: constant table index in low 24 bits, NOT flag (that flips comparison result) in high bit + LOP_JUMPXEQKN, + LOP_JUMPXEQKS, + // Enum entry for number of opcodes, not a valid opcode by itself! LOP__COUNT }; @@ -391,7 +413,7 @@ enum LuauBytecodeTag { // Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled LBC_VERSION_MIN = 2, - LBC_VERSION_MAX = 2, + LBC_VERSION_MAX = 3, LBC_VERSION_TARGET = 2, // Types of constant table entries LBC_CONSTANT_NIL = 0, diff --git a/Common/include/Luau/Common.h b/Common/include/Luau/Common.h index fbb03a9e..f1846ac3 100644 --- a/Common/include/Luau/Common.h +++ b/Common/include/Luau/Common.h @@ -20,12 +20,6 @@ #define LUAU_DEBUGBREAK() __builtin_trap() #endif - - - - - - namespace Luau { @@ -67,16 +61,13 @@ struct FValue const char* name; FValue* next; - FValue(const char* name, T def, bool dynamic, void (*reg)(const char*, T*, bool) = nullptr) + FValue(const char* name, T def, bool dynamic) : value(def) , dynamic(dynamic) , name(name) , next(list) { list = this; - - if (reg) - reg(name, &value, dynamic); } operator T() const @@ -98,7 +89,7 @@ FValue* FValue::list = nullptr; #define LUAU_FASTFLAGVARIABLE(flag, def) \ namespace FFlag \ { \ - Luau::FValue flag(#flag, def, false, nullptr); \ + Luau::FValue flag(#flag, def, false); \ } #define LUAU_FASTINT(flag) \ namespace FInt \ @@ -108,7 +99,7 @@ FValue* FValue::list = nullptr; #define LUAU_FASTINTVARIABLE(flag, def) \ namespace FInt \ { \ - Luau::FValue flag(#flag, def, false, nullptr); \ + Luau::FValue flag(#flag, def, false); \ } #define LUAU_DYNAMIC_FASTFLAG(flag) \ @@ -119,7 +110,7 @@ FValue* FValue::list = nullptr; #define LUAU_DYNAMIC_FASTFLAGVARIABLE(flag, def) \ namespace DFFlag \ { \ - Luau::FValue flag(#flag, def, true, nullptr); \ + Luau::FValue flag(#flag, def, true); \ } #define LUAU_DYNAMIC_FASTINT(flag) \ namespace DFInt \ @@ -129,5 +120,5 @@ FValue* FValue::list = nullptr; #define LUAU_DYNAMIC_FASTINTVARIABLE(flag, def) \ namespace DFInt \ { \ - Luau::FValue flag(#flag, def, true, nullptr); \ + Luau::FValue flag(#flag, def, true); \ } diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h new file mode 100644 index 00000000..71e76ffe --- /dev/null +++ b/Common/include/Luau/ExperimentalFlags.h @@ -0,0 +1,25 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ + +inline bool isFlagExperimental(const char* flag) +{ + // Flags in this list are disabled by default in various command-line tools. They may have behavior that is not fully final, + // or critical bugs that are found after the code has been submitted. + static const char* kList[] = { + "LuauLowerBoundsCalculation", + nullptr, // makes sure we always have at least one entry + }; + + for (const char* item : kList) + if (item && strcmp(item, flag) == 0) + return true; + + return false; +} + +} // namespace Luau diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index 65e962da..eec70d7a 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -8,8 +8,8 @@ namespace Luau { -class AstStatBlock; class AstNameTable; +struct ParseResult; class BytecodeBuilder; class BytecodeEncoder; @@ -58,7 +58,7 @@ private: }; // compiles bytecode into bytecode builder using either a pre-parsed AST or parsing it from source; throws on errors -void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options = {}); +void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& options = {}); void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const CompileOptions& options = {}, const ParseOptions& parseOptions = {}); // compiles bytecode into a bytecode blob, that either contains the valid bytecode or an encoded error that luau_load can decode diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h index e235a2e7..5f69f69e 100644 --- a/Compiler/include/luacode.h +++ b/Compiler/include/luacode.h @@ -3,7 +3,7 @@ #include -/* Can be used to reconfigure visibility/exports for public APIs */ +// Can be used to reconfigure visibility/exports for public APIs #ifndef LUACODE_API #define LUACODE_API extern #endif @@ -35,5 +35,5 @@ struct lua_CompileOptions const char** mutableGlobals; }; -/* compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy */ +// compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize); diff --git a/Compiler/src/BuiltinFolding.cpp b/Compiler/src/BuiltinFolding.cpp new file mode 100644 index 00000000..e76da4e2 --- /dev/null +++ b/Compiler/src/BuiltinFolding.cpp @@ -0,0 +1,463 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "BuiltinFolding.h" + +#include "Luau/Bytecode.h" + +#include + +namespace Luau +{ +namespace Compile +{ + +const double kRadDeg = 3.14159265358979323846 / 180.0; + +static Constant cvar() +{ + return Constant(); +} + +static Constant cbool(bool v) +{ + Constant res = {Constant::Type_Boolean}; + res.valueBoolean = v; + return res; +} + +static Constant cnum(double v) +{ + Constant res = {Constant::Type_Number}; + res.valueNumber = v; + return res; +} + +static Constant cstring(const char* v) +{ + Constant res = {Constant::Type_String}; + res.stringLength = unsigned(strlen(v)); + res.valueString = v; + return res; +} + +static Constant ctype(const Constant& c) +{ + LUAU_ASSERT(c.type != Constant::Type_Unknown); + + switch (c.type) + { + case Constant::Type_Nil: + return cstring("nil"); + + case Constant::Type_Boolean: + return cstring("boolean"); + + case Constant::Type_Number: + return cstring("number"); + + case Constant::Type_String: + return cstring("string"); + + default: + LUAU_ASSERT(!"Unsupported constant type"); + return cvar(); + } +} + +static uint32_t bit32(double v) +{ + // convert through signed 64-bit integer to match runtime behavior and gracefully truncate negative integers + return uint32_t(int64_t(v)); +} + +Constant foldBuiltin(int bfid, const Constant* args, size_t count) +{ + switch (bfid) + { + case LBF_MATH_ABS: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(fabs(args[0].valueNumber)); + break; + + case LBF_MATH_ACOS: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(acos(args[0].valueNumber)); + break; + + case LBF_MATH_ASIN: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(asin(args[0].valueNumber)); + break; + + case LBF_MATH_ATAN2: + if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) + return cnum(atan2(args[0].valueNumber, args[1].valueNumber)); + break; + + case LBF_MATH_ATAN: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(atan(args[0].valueNumber)); + break; + + case LBF_MATH_CEIL: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(ceil(args[0].valueNumber)); + break; + + case LBF_MATH_COSH: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(cosh(args[0].valueNumber)); + break; + + case LBF_MATH_COS: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(cos(args[0].valueNumber)); + break; + + case LBF_MATH_DEG: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(args[0].valueNumber / kRadDeg); + break; + + case LBF_MATH_EXP: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(exp(args[0].valueNumber)); + break; + + case LBF_MATH_FLOOR: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(floor(args[0].valueNumber)); + break; + + case LBF_MATH_FMOD: + if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) + return cnum(fmod(args[0].valueNumber, args[1].valueNumber)); + break; + + // Note: FREXP isn't folded since it returns multiple values + + case LBF_MATH_LDEXP: + if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) + return cnum(ldexp(args[0].valueNumber, int(args[1].valueNumber))); + break; + + case LBF_MATH_LOG10: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(log10(args[0].valueNumber)); + break; + + case LBF_MATH_LOG: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(log(args[0].valueNumber)); + else if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) + { + if (args[1].valueNumber == 2.0) + return cnum(log2(args[0].valueNumber)); + else if (args[1].valueNumber == 10.0) + return cnum(log10(args[0].valueNumber)); + else + return cnum(log(args[0].valueNumber) / log(args[1].valueNumber)); + } + break; + + case LBF_MATH_MAX: + if (count >= 1 && args[0].type == Constant::Type_Number) + { + double r = args[0].valueNumber; + + for (size_t i = 1; i < count; ++i) + { + if (args[i].type != Constant::Type_Number) + return cvar(); + + double a = args[i].valueNumber; + + r = (a > r) ? a : r; + } + + return cnum(r); + } + break; + + case LBF_MATH_MIN: + if (count >= 1 && args[0].type == Constant::Type_Number) + { + double r = args[0].valueNumber; + + for (size_t i = 1; i < count; ++i) + { + if (args[i].type != Constant::Type_Number) + return cvar(); + + double a = args[i].valueNumber; + + r = (a < r) ? a : r; + } + + return cnum(r); + } + break; + + // Note: MODF isn't folded since it returns multiple values + + case LBF_MATH_POW: + if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) + return cnum(pow(args[0].valueNumber, args[1].valueNumber)); + break; + + case LBF_MATH_RAD: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(args[0].valueNumber * kRadDeg); + break; + + case LBF_MATH_SINH: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(sinh(args[0].valueNumber)); + break; + + case LBF_MATH_SIN: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(sin(args[0].valueNumber)); + break; + + case LBF_MATH_SQRT: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(sqrt(args[0].valueNumber)); + break; + + case LBF_MATH_TANH: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(tanh(args[0].valueNumber)); + break; + + case LBF_MATH_TAN: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(tan(args[0].valueNumber)); + break; + + case LBF_BIT32_ARSHIFT: + if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) + { + uint32_t u = bit32(args[0].valueNumber); + int s = int(args[1].valueNumber); + + if (unsigned(s) < 32) + return cnum(double(uint32_t(int32_t(u) >> s))); + } + break; + + case LBF_BIT32_BAND: + if (count >= 1 && args[0].type == Constant::Type_Number) + { + uint32_t r = bit32(args[0].valueNumber); + + for (size_t i = 1; i < count; ++i) + { + if (args[i].type != Constant::Type_Number) + return cvar(); + + r &= bit32(args[i].valueNumber); + } + + return cnum(double(r)); + } + break; + + case LBF_BIT32_BNOT: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(double(uint32_t(~bit32(args[0].valueNumber)))); + break; + + case LBF_BIT32_BOR: + if (count >= 1 && args[0].type == Constant::Type_Number) + { + uint32_t r = bit32(args[0].valueNumber); + + for (size_t i = 1; i < count; ++i) + { + if (args[i].type != Constant::Type_Number) + return cvar(); + + r |= bit32(args[i].valueNumber); + } + + return cnum(double(r)); + } + break; + + case LBF_BIT32_BXOR: + if (count >= 1 && args[0].type == Constant::Type_Number) + { + uint32_t r = bit32(args[0].valueNumber); + + for (size_t i = 1; i < count; ++i) + { + if (args[i].type != Constant::Type_Number) + return cvar(); + + r ^= bit32(args[i].valueNumber); + } + + return cnum(double(r)); + } + break; + + case LBF_BIT32_BTEST: + if (count >= 1 && args[0].type == Constant::Type_Number) + { + uint32_t r = bit32(args[0].valueNumber); + + for (size_t i = 1; i < count; ++i) + { + if (args[i].type != Constant::Type_Number) + return cvar(); + + r &= bit32(args[i].valueNumber); + } + + return cbool(r != 0); + } + break; + + case LBF_BIT32_EXTRACT: + if (count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number) + { + uint32_t u = bit32(args[0].valueNumber); + int f = int(args[1].valueNumber); + int w = int(args[2].valueNumber); + + if (f >= 0 && w > 0 && f + w <= 32) + { + uint32_t m = ~(0xfffffffeu << (w - 1)); + + return cnum(double((u >> f) & m)); + } + } + break; + + case LBF_BIT32_LROTATE: + if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) + { + uint32_t u = bit32(args[0].valueNumber); + int s = int(args[1].valueNumber); + + return cnum(double((u << (s & 31)) | (u >> ((32 - s) & 31)))); + } + break; + + case LBF_BIT32_LSHIFT: + if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) + { + uint32_t u = bit32(args[0].valueNumber); + int s = int(args[1].valueNumber); + + if (unsigned(s) < 32) + return cnum(double(u << s)); + } + break; + + case LBF_BIT32_REPLACE: + if (count == 4 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number && + args[3].type == Constant::Type_Number) + { + uint32_t n = bit32(args[0].valueNumber); + uint32_t v = bit32(args[1].valueNumber); + int f = int(args[2].valueNumber); + int w = int(args[3].valueNumber); + + if (f >= 0 && w > 0 && f + w <= 32) + { + uint32_t m = ~(0xfffffffeu << (w - 1)); + + return cnum(double((n & ~(m << f)) | ((v & m) << f))); + } + } + break; + + case LBF_BIT32_RROTATE: + if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) + { + uint32_t u = bit32(args[0].valueNumber); + int s = int(args[1].valueNumber); + + return cnum(double((u >> (s & 31)) | (u << ((32 - s) & 31)))); + } + break; + + case LBF_BIT32_RSHIFT: + if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) + { + uint32_t u = bit32(args[0].valueNumber); + int s = int(args[1].valueNumber); + + if (unsigned(s) < 32) + return cnum(double(u >> s)); + } + break; + + case LBF_TYPE: + if (count == 1 && args[0].type != Constant::Type_Unknown) + return ctype(args[0]); + break; + + case LBF_STRING_BYTE: + if (count == 1 && args[0].type == Constant::Type_String) + { + if (args[0].stringLength > 0) + return cnum(double(uint8_t(args[0].valueString[0]))); + } + else if (count == 2 && args[0].type == Constant::Type_String && args[1].type == Constant::Type_Number) + { + int i = int(args[1].valueNumber); + + if (i > 0 && unsigned(i) <= args[0].stringLength) + return cnum(double(uint8_t(args[0].valueString[i - 1]))); + } + break; + + case LBF_STRING_LEN: + if (count == 1 && args[0].type == Constant::Type_String) + return cnum(double(args[0].stringLength)); + break; + + case LBF_TYPEOF: + if (count == 1 && args[0].type != Constant::Type_Unknown) + return ctype(args[0]); + break; + + case LBF_MATH_CLAMP: + if (count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number) + { + double min = args[1].valueNumber; + double max = args[2].valueNumber; + + if (min <= max) + { + double v = args[0].valueNumber; + v = v < min ? min : v; + v = v > max ? max : v; + + return cnum(v); + } + } + break; + + case LBF_MATH_SIGN: + if (count == 1 && args[0].type == Constant::Type_Number) + { + double v = args[0].valueNumber; + + return cnum(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0); + } + break; + + case LBF_MATH_ROUND: + if (count == 1 && args[0].type == Constant::Type_Number) + return cnum(round(args[0].valueNumber)); + break; + } + + return cvar(); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/BuiltinFolding.h b/Compiler/src/BuiltinFolding.h new file mode 100644 index 00000000..1904e14f --- /dev/null +++ b/Compiler/src/BuiltinFolding.h @@ -0,0 +1,14 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "ConstantFolding.h" + +namespace Luau +{ +namespace Compile +{ + +Constant foldBuiltin(int bfid, const Constant* args, size_t count); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index 6bd24b6d..26933730 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,8 +4,6 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" -LUAU_FASTFLAGVARIABLE(LuauCompileRawlen, false) - namespace Luau { namespace Compile @@ -40,11 +38,8 @@ Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, } } -int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) +static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) { - if (builtin.empty()) - return -1; - if (builtin.isGlobal("assert")) return LBF_ASSERT; @@ -60,7 +55,7 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) return LBF_RAWGET; if (builtin.isGlobal("rawequal")) return LBF_RAWEQUAL; - if (FFlag::LuauCompileRawlen && builtin.isGlobal("rawlen")) + if (builtin.isGlobal("rawlen")) return LBF_RAWLEN; if (builtin.isGlobal("unpack")) @@ -200,5 +195,49 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) return -1; } +struct BuiltinVisitor : AstVisitor +{ + DenseHashMap& result; + + const DenseHashMap& globals; + const DenseHashMap& variables; + + const CompileOptions& options; + + BuiltinVisitor(DenseHashMap& result, const DenseHashMap& globals, + const DenseHashMap& variables, const CompileOptions& options) + : result(result) + , globals(globals) + , variables(variables) + , options(options) + { + } + + bool visit(AstExprCall* node) override + { + Builtin builtin = node->self ? Builtin() : getBuiltin(node->func, globals, variables); + if (builtin.empty()) + return true; + + int bfid = getBuiltinFunctionId(builtin, options); + + // getBuiltinFunctionId optimistically assumes all select() calls are builtin but actually the second argument must be a vararg + if (bfid == LBF_SELECT_VARARG && !(node->args.size == 2 && node->args.data[1]->is())) + bfid = -1; + + if (bfid >= 0) + result[node] = bfid; + + return true; // propagate to nested calls + } +}; + +void analyzeBuiltins(DenseHashMap& result, const DenseHashMap& globals, + const DenseHashMap& variables, const CompileOptions& options, AstNode* root) +{ + BuiltinVisitor visitor{result, globals, variables, options}; + root->visit(&visitor); +} + } // namespace Compile } // namespace Luau diff --git a/Compiler/src/Builtins.h b/Compiler/src/Builtins.h index 60df53a1..4399c532 100644 --- a/Compiler/src/Builtins.h +++ b/Compiler/src/Builtins.h @@ -35,7 +35,9 @@ struct Builtin }; Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, const DenseHashMap& variables); -int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options); + +void analyzeBuiltins(DenseHashMap& result, const DenseHashMap& globals, + const DenseHashMap& variables, const CompileOptions& options, AstNode* root); } // namespace Compile } // namespace Luau diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 5e2669b2..46ab2648 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauCompileBytecodeV3, false) + namespace Luau { @@ -77,6 +79,10 @@ static int getOpLength(LuauOpcode op) case LOP_JUMPIFNOTEQK: case LOP_FASTCALL2: case LOP_FASTCALL2K: + case LOP_JUMPXEQKNIL: + case LOP_JUMPXEQKB: + case LOP_JUMPXEQKN: + case LOP_JUMPXEQKS: return 2; default: @@ -108,6 +114,10 @@ inline bool isJumpD(LuauOpcode op) case LOP_JUMPBACK: case LOP_JUMPIFEQK: case LOP_JUMPIFNOTEQK: + case LOP_JUMPXEQKNIL: + case LOP_JUMPXEQKB: + case LOP_JUMPXEQKN: + case LOP_JUMPXEQKS: return true; default: @@ -120,6 +130,17 @@ inline bool isSkipC(LuauOpcode op) switch (op) { case LOP_LOADB: + return true; + + default: + return false; + } +} + +inline bool isFastCall(LuauOpcode op) +{ + switch (op) + { case LOP_FASTCALL: case LOP_FASTCALL1: case LOP_FASTCALL2: @@ -137,6 +158,8 @@ static int getJumpTarget(uint32_t insn, uint32_t pc) if (isJumpD(op)) return int(pc + LUAU_INSN_D(insn) + 1); + else if (isFastCall(op)) + return int(pc + LUAU_INSN_C(insn) + 2); else if (isSkipC(op) && LUAU_INSN_C(insn)) return int(pc + LUAU_INSN_C(insn) + 1); else if (op == LOP_JUMPX) @@ -479,7 +502,7 @@ bool BytecodeBuilder::patchSkipC(size_t jumpLabel, size_t targetLabel) unsigned int jumpInsn = insns[jumpLabel]; (void)jumpInsn; - LUAU_ASSERT(isSkipC(LuauOpcode(LUAU_INSN_OP(jumpInsn)))); + LUAU_ASSERT(isSkipC(LuauOpcode(LUAU_INSN_OP(jumpInsn))) || isFastCall(LuauOpcode(LUAU_INSN_OP(jumpInsn)))); LUAU_ASSERT(LUAU_INSN_C(jumpInsn) == 0); int offset = int(targetLabel) - int(jumpLabel) - 1; @@ -1056,6 +1079,9 @@ std::string BytecodeBuilder::getError(const std::string& message) uint8_t BytecodeBuilder::getVersion() { + if (FFlag::LuauCompileBytecodeV3) + return 3; + // This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags return LBC_VERSION_TARGET; } @@ -1233,6 +1259,24 @@ void BytecodeBuilder::validate() const VJUMP(LUAU_INSN_D(insn)); break; + case LOP_JUMPXEQKNIL: + case LOP_JUMPXEQKB: + VREG(LUAU_INSN_A(insn)); + VJUMP(LUAU_INSN_D(insn)); + break; + + case LOP_JUMPXEQKN: + VREG(LUAU_INSN_A(insn)); + VCONST(insns[i + 1] & 0xffffff, Number); + VJUMP(LUAU_INSN_D(insn)); + break; + + case LOP_JUMPXEQKS: + VREG(LUAU_INSN_A(insn)); + VCONST(insns[i + 1] & 0xffffff, String); + VJUMP(LUAU_INSN_D(insn)); + break; + case LOP_ADD: case LOP_SUB: case LOP_MUL: @@ -1766,6 +1810,26 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, formatAppend(result, "JUMPIFNOTEQK R%d K%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel); break; + case LOP_JUMPXEQKNIL: + formatAppend(result, "JUMPXEQKNIL R%d L%d%s\n", LUAU_INSN_A(insn), targetLabel, *code >> 31 ? " NOT" : ""); + code++; + break; + + case LOP_JUMPXEQKB: + formatAppend(result, "JUMPXEQKB R%d %d L%d%s\n", LUAU_INSN_A(insn), *code & 1, targetLabel, *code >> 31 ? " NOT" : ""); + code++; + break; + + case LOP_JUMPXEQKN: + formatAppend(result, "JUMPXEQKN R%d K%d L%d%s\n", LUAU_INSN_A(insn), *code & 0xffffff, targetLabel, *code >> 31 ? " NOT" : ""); + code++; + break; + + case LOP_JUMPXEQKS: + formatAppend(result, "JUMPXEQKS R%d K%d L%d%s\n", LUAU_INSN_A(insn), *code & 0xffffff, targetLabel, *code >> 31 ? " NOT" : ""); + code++; + break; + default: LUAU_ASSERT(!"Unsupported opcode"); } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index d7c8155c..2ee20cab 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -25,6 +25,9 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAGVARIABLE(LuauCompileNoIpairs, false) +LUAU_FASTFLAGVARIABLE(LuauCompileFreeReassign, false) +LUAU_FASTFLAGVARIABLE(LuauCompileXEQ, false) + namespace Luau { @@ -75,6 +78,12 @@ static BytecodeBuilder::StringRef sref(AstArray data) return {data.data, data.size}; } +static BytecodeBuilder::StringRef sref(AstArray data) +{ + LUAU_ASSERT(data.data); + return {data.data, data.size}; +} + struct Compiler { struct RegScope; @@ -89,6 +98,7 @@ struct Compiler , constants(nullptr) , locstants(nullptr) , tableShapes(nullptr) + , builtins(nullptr) { // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays localStack.reserve(16); @@ -245,7 +255,7 @@ struct Compiler { f.canInline = true; f.stackSize = stackSize; - f.costModel = modelCost(func->body, func->args.data, func->args.size); + f.costModel = modelCost(func->body, func->args.data, func->args.size, builtins); // track functions that only ever return a single value so that we can convert multret calls to fixedret calls if (allPathsEndWithReturn(func->body)) @@ -262,23 +272,44 @@ struct Compiler return fid; } + // returns true if node can return multiple values; may conservatively return true even if expr is known to return just a single value + bool isExprMultRet(AstExpr* node) + { + AstExprCall* expr = node->as(); + if (!expr) + return node->is(); + + // conservative version, optimized for compilation throughput + if (options.optimizationLevel <= 1) + return true; + + // handles builtin calls that can be constant-folded + // without this we may omit some optimizations eg compiling fast calls without use of FASTCALL2K + if (isConstant(expr)) + return false; + + // handles local function calls where we know only one argument is returned + AstExprFunction* func = getFunctionExpr(expr->func); + Function* fi = func ? functions.find(func) : nullptr; + + if (fi && fi->returnsOne) + return false; + + // unrecognized call, so we conservatively assume multret + return true; + } + // note: this doesn't just clobber target (assuming it's temp), but also clobbers *all* allocated registers >= target! // this is important to be able to support "multret" semantics due to Lua call frame structure bool compileExprTempMultRet(AstExpr* node, uint8_t target) { if (AstExprCall* expr = node->as()) { - // Optimization: convert multret calls to functions that always return one value to fixedret calls; this facilitates inlining - if (options.optimizationLevel >= 2) + // Optimization: convert multret calls that always return one value to fixedret calls; this facilitates inlining/constant folding + if (options.optimizationLevel >= 2 && !isExprMultRet(node)) { - AstExprFunction* func = getFunctionExpr(expr->func); - Function* fi = func ? functions.find(func) : nullptr; - - if (fi && fi->returnsOne) - { - compileExprTemp(node, target); - return false; - } + compileExprTemp(node, target); + return false; } // We temporarily swap out regTop to have targetTop work correctly... @@ -483,8 +514,7 @@ struct Compiler varc[i] = isConstant(expr->args.data[i]); // if the last argument only returns a single value, all following arguments are nil - if (expr->args.size != 0 && - !(expr->args.data[expr->args.size - 1]->is() || expr->args.data[expr->args.size - 1]->is())) + if (expr->args.size != 0 && !isExprMultRet(expr->args.data[expr->args.size - 1])) for (size_t i = expr->args.size; i < func->args.size && i < 8; ++i) varc[i] = true; @@ -523,7 +553,7 @@ struct Compiler AstLocal* var = func->args.data[i]; AstExpr* arg = i < expr->args.size ? expr->args.data[i] : nullptr; - if (i + 1 == expr->args.size && func->args.size > expr->args.size && (arg->is() || arg->is())) + if (i + 1 == expr->args.size && func->args.size > expr->args.size && isExprMultRet(arg)) { // if the last argument can return multiple values, we need to compute all of them into the remaining arguments unsigned int tail = unsigned(func->args.size - expr->args.size) + 1; @@ -566,7 +596,7 @@ struct Compiler } else { - AstExprLocal* le = arg->as(); + AstExprLocal* le = FFlag::LuauCompileFreeReassign ? getExprLocal(arg) : arg->as(); Variable* lv = le ? variables.find(le->local) : nullptr; // if the argument is a local that isn't mutated, we will simply reuse the existing register @@ -591,7 +621,7 @@ struct Compiler } // fold constant values updated above into expressions in the function body - foldConstants(constants, variables, locstants, func->body); + foldConstants(constants, variables, locstants, builtinsFold, func->body); bool usedFallthrough = false; @@ -632,7 +662,7 @@ struct Compiler if (Constant* var = locstants.find(func->args.data[i])) var->type = Constant::Type_Unknown; - foldConstants(constants, variables, locstants, func->body); + foldConstants(constants, variables, locstants, builtinsFold, func->body); } void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) @@ -675,29 +705,23 @@ struct Compiler int bfid = -1; - if (options.optimizationLevel >= 1) - { - Builtin builtin = getBuiltin(expr->func, globals, variables); - bfid = getBuiltinFunctionId(builtin, options); - } + if (options.optimizationLevel >= 1 && !expr->self) + if (const int* id = builtins.find(expr)) + bfid = *id; if (bfid == LBF_SELECT_VARARG) { // Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly // note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases - if (multRet == false && targetCount == 1 && expr->args.size == 2 && expr->args.data[1]->is()) + if (multRet == false && targetCount == 1) return compileExprSelectVararg(expr, target, targetCount, targetTop, multRet, regs); else bfid = -1; } // Optimization: for 1/2 argument fast calls use specialized opcodes - if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) - { - AstExpr* last = expr->args.data[expr->args.size - 1]; - if (!last->is() && !last->is()) - return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); - } + if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2 && !isExprMultRet(expr->args.data[expr->args.size - 1])) + return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); if (expr->self) { @@ -985,9 +1009,8 @@ struct Compiler size_t compileCompareJump(AstExprBinary* expr, bool not_ = false) { RegScope rs(this); - LuauOpcode opc = getJumpOpCompare(expr->op, not_); - bool isEq = (opc == LOP_JUMPIFEQ || opc == LOP_JUMPIFNOTEQ); + bool isEq = (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe); AstExpr* left = expr->left; AstExpr* right = expr->right; @@ -999,36 +1022,112 @@ struct Compiler std::swap(left, right); } - uint8_t rl = compileExprAuto(left, rs); - int32_t rr = -1; - - if (isEq && operandIsConstant) + if (FFlag::LuauCompileXEQ) { - if (opc == LOP_JUMPIFEQ) - opc = LOP_JUMPIFEQK; - else if (opc == LOP_JUMPIFNOTEQ) - opc = LOP_JUMPIFNOTEQK; + uint8_t rl = compileExprAuto(left, rs); - rr = getConstantIndex(right); - LUAU_ASSERT(rr >= 0); - } - else - rr = compileExprAuto(right, rs); + if (isEq && operandIsConstant) + { + const Constant* cv = constants.find(right); + LUAU_ASSERT(cv && cv->type != Constant::Type_Unknown); - size_t jumpLabel = bytecode.emitLabel(); + LuauOpcode opc = LOP_NOP; + int32_t cid = -1; + uint32_t flip = (expr->op == AstExprBinary::CompareEq) == not_ ? 0x80000000 : 0; - if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe) - { - bytecode.emitAD(opc, uint8_t(rr), 0); - bytecode.emitAux(rl); + switch (cv->type) + { + case Constant::Type_Nil: + opc = LOP_JUMPXEQKNIL; + cid = 0; + break; + + case Constant::Type_Boolean: + opc = LOP_JUMPXEQKB; + cid = cv->valueBoolean; + break; + + case Constant::Type_Number: + opc = LOP_JUMPXEQKN; + cid = getConstantIndex(right); + break; + + case Constant::Type_String: + opc = LOP_JUMPXEQKS; + cid = getConstantIndex(right); + break; + + default: + LUAU_ASSERT(!"Unexpected constant type"); + } + + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + size_t jumpLabel = bytecode.emitLabel(); + + bytecode.emitAD(opc, rl, 0); + bytecode.emitAux(cid | flip); + + return jumpLabel; + } + else + { + LuauOpcode opc = getJumpOpCompare(expr->op, not_); + + uint8_t rr = compileExprAuto(right, rs); + + size_t jumpLabel = bytecode.emitLabel(); + + if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe) + { + bytecode.emitAD(opc, rr, 0); + bytecode.emitAux(rl); + } + else + { + bytecode.emitAD(opc, rl, 0); + bytecode.emitAux(rr); + } + + return jumpLabel; + } } else { - bytecode.emitAD(opc, rl, 0); - bytecode.emitAux(rr); - } + LuauOpcode opc = getJumpOpCompare(expr->op, not_); - return jumpLabel; + uint8_t rl = compileExprAuto(left, rs); + int32_t rr = -1; + + if (isEq && operandIsConstant) + { + if (opc == LOP_JUMPIFEQ) + opc = LOP_JUMPIFEQK; + else if (opc == LOP_JUMPIFNOTEQ) + opc = LOP_JUMPIFNOTEQK; + + rr = getConstantIndex(right); + LUAU_ASSERT(rr >= 0); + } + else + rr = compileExprAuto(right, rs); + + size_t jumpLabel = bytecode.emitLabel(); + + if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe) + { + bytecode.emitAD(opc, uint8_t(rr), 0); + bytecode.emitAux(rl); + } + else + { + bytecode.emitAD(opc, rl, 0); + bytecode.emitAux(rr); + } + + return jumpLabel; + } } int32_t getConstantNumber(AstExpr* node) @@ -2156,19 +2255,27 @@ struct Compiler compileLValueUse(lv, source, /* set= */ true); } - int getExprLocalReg(AstExpr* node) + AstExprLocal* getExprLocal(AstExpr* node) { if (AstExprLocal* expr = node->as()) + return expr; + else if (AstExprGroup* expr = node->as()) + return getExprLocal(expr->expr); + else if (AstExprTypeAssertion* expr = node->as()) + return getExprLocal(expr->expr); + else + return nullptr; + } + + int getExprLocalReg(AstExpr* node) + { + if (AstExprLocal* expr = getExprLocal(node)) { // note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining Local* l = locals.find(expr->local); return l && l->allocated ? l->reg : -1; } - else if (AstExprGroup* expr = node->as()) - return getExprLocalReg(expr->expr); - else if (AstExprTypeAssertion* expr = node->as()) - return getExprLocalReg(expr->expr); else return -1; } @@ -2454,6 +2561,22 @@ struct Compiler if (options.optimizationLevel >= 1 && options.debugLevel <= 1 && areLocalsRedundant(stat)) return; + // Optimization: for 1-1 local assignments, we can reuse the register *if* neither local is mutated + if (FFlag::LuauCompileFreeReassign && options.optimizationLevel >= 1 && stat->vars.size == 1 && stat->values.size == 1) + { + if (AstExprLocal* re = getExprLocal(stat->values.data[0])) + { + Variable* lv = variables.find(stat->vars.data[0]); + Variable* rv = variables.find(re->local); + + if (int reg = getExprLocalReg(re); reg >= 0 && (!lv || !lv->written) && (!rv || !rv->written)) + { + pushLocal(stat->vars.data[0], uint8_t(reg)); + return; + } + } + } + // note: allocReg in this case allocates into parent block register - note that we don't have RegScope here uint8_t vars = allocReg(stat, unsigned(stat->vars.size)); @@ -2495,7 +2618,7 @@ struct Compiler } AstLocal* var = stat->var; - uint64_t costModel = modelCost(stat->body, &var, 1); + uint64_t costModel = modelCost(stat->body, &var, 1, builtins); // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to unrolling bool varc = true; @@ -2533,7 +2656,7 @@ struct Compiler locstants[var].type = Constant::Type_Number; locstants[var].valueNumber = from + iv * step; - foldConstants(constants, variables, locstants, stat); + foldConstants(constants, variables, locstants, builtinsFold, stat); size_t iterJumps = loopJumps.size(); @@ -2561,7 +2684,7 @@ struct Compiler // clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again locstants[var].type = Constant::Type_Unknown; - foldConstants(constants, variables, locstants, stat); + foldConstants(constants, variables, locstants, builtinsFold, stat); } void compileStatFor(AstStatFor* stat) @@ -3368,26 +3491,7 @@ struct Compiler bool visit(AstStatReturn* stat) override { - if (stat->list.size == 1) - { - AstExpr* value = stat->list.data[0]; - - if (AstExprCall* expr = value->as()) - { - AstExprFunction* func = self->getFunctionExpr(expr->func); - Function* fi = func ? self->functions.find(func) : nullptr; - - returnsOne &= fi && fi->returnsOne; - } - else if (value->is()) - { - returnsOne = false; - } - } - else - { - returnsOne = false; - } + returnsOne &= stat->list.size == 1 && !self->isExprMultRet(stat->list.data[0]); return false; } @@ -3487,6 +3591,8 @@ struct Compiler DenseHashMap constants; DenseHashMap locstants; DenseHashMap tableShapes; + DenseHashMap builtins; + const DenseHashMap* builtinsFold = nullptr; unsigned int regTop = 0; unsigned int stackSize = 0; @@ -3502,10 +3608,21 @@ struct Compiler std::vector captures; }; -void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) +void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& inputOptions) { LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler"); + LUAU_ASSERT(parseResult.root); + LUAU_ASSERT(parseResult.errors.empty()); + + CompileOptions options = inputOptions; + + for (const HotComment& hc : parseResult.hotcomments) + if (hc.header && hc.content.compare(0, 9, "optimize ") == 0) + options.optimizationLevel = std::max(0, std::min(2, atoi(hc.content.c_str() + 9))); + + AstStatBlock* root = parseResult.root; + Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables @@ -3514,10 +3631,17 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName // this pass analyzes mutability of locals/globals and associates locals with their initial values trackValues(compiler.globals, compiler.variables, root); + // builtin folding is enabled on optimization level 2 since we can't deoptimize folding at runtime + if (options.optimizationLevel >= 2) + compiler.builtinsFold = &compiler.builtins; + if (options.optimizationLevel >= 1) { + // this pass tracks which calls are builtins and can be compiled more efficiently + analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root); + // this pass analyzes constantness of expressions - foldConstants(compiler.constants, compiler.variables, compiler.locstants, root); + foldConstants(compiler.constants, compiler.variables, compiler.locstants, compiler.builtinsFold, root); // this pass analyzes table assignments to estimate table shapes for initially empty tables predictTableShapes(compiler.tableShapes, root); @@ -3559,9 +3683,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const if (!result.errors.empty()) throw ParseErrors(result.errors); - AstStatBlock* root = result.root; - - compileOrThrow(bytecode, root, names, options); + compileOrThrow(bytecode, result, names, options); } std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder) @@ -3584,7 +3706,7 @@ std::string compile(const std::string& source, const CompileOptions& options, co try { BytecodeBuilder bcb(encoder); - compileOrThrow(bcb, result.root, names, options); + compileOrThrow(bcb, result, names, options); return bcb.getBytecode(); } diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index a62beeb1..34f79544 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "ConstantFolding.h" +#include "BuiltinFolding.h" + #include namespace Luau @@ -193,13 +195,18 @@ struct ConstantVisitor : AstVisitor DenseHashMap& variables; DenseHashMap& locals; + const DenseHashMap* builtins; + bool wasEmpty = false; - ConstantVisitor( - DenseHashMap& constants, DenseHashMap& variables, DenseHashMap& locals) + std::vector builtinArgs; + + ConstantVisitor(DenseHashMap& constants, DenseHashMap& variables, + DenseHashMap& locals, const DenseHashMap* builtins) : constants(constants) , variables(variables) , locals(locals) + , builtins(builtins) { // since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries wasEmpty = constants.empty() && locals.empty(); @@ -253,8 +260,37 @@ struct ConstantVisitor : AstVisitor { analyze(expr->func); - for (size_t i = 0; i < expr->args.size; ++i) - analyze(expr->args.data[i]); + if (const int* bfid = builtins ? builtins->find(expr) : nullptr) + { + // since recursive calls to analyze() may reuse the vector we need to be careful and preserve existing contents + size_t offset = builtinArgs.size(); + bool canFold = true; + + builtinArgs.reserve(offset + expr->args.size); + + for (size_t i = 0; i < expr->args.size; ++i) + { + Constant ac = analyze(expr->args.data[i]); + + if (ac.type == Constant::Type_Unknown) + canFold = false; + else + builtinArgs.push_back(ac); + } + + if (canFold) + { + LUAU_ASSERT(builtinArgs.size() == offset + expr->args.size); + result = foldBuiltin(*bfid, builtinArgs.data() + offset, expr->args.size); + } + + builtinArgs.resize(offset); + } + else + { + for (size_t i = 0; i < expr->args.size; ++i) + analyze(expr->args.data[i]); + } } else if (AstExprIndexName* expr = node->as()) { @@ -395,9 +431,9 @@ struct ConstantVisitor : AstVisitor }; void foldConstants(DenseHashMap& constants, DenseHashMap& variables, - DenseHashMap& locals, AstNode* root) + DenseHashMap& locals, const DenseHashMap* builtins, AstNode* root) { - ConstantVisitor visitor{constants, variables, locals}; + ConstantVisitor visitor{constants, variables, locals, builtins}; root->visit(&visitor); } diff --git a/Compiler/src/ConstantFolding.h b/Compiler/src/ConstantFolding.h index 0a995d75..d67d9285 100644 --- a/Compiler/src/ConstantFolding.h +++ b/Compiler/src/ConstantFolding.h @@ -26,7 +26,7 @@ struct Constant { bool valueBoolean; double valueNumber; - char* valueString = nullptr; // length stored in stringLength + const char* valueString = nullptr; // length stored in stringLength }; bool isTruthful() const @@ -35,7 +35,7 @@ struct Constant return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false); } - AstArray getString() const + AstArray getString() const { LUAU_ASSERT(type == Type_String); return {valueString, stringLength}; @@ -43,7 +43,7 @@ struct Constant }; void foldConstants(DenseHashMap& constants, DenseHashMap& variables, - DenseHashMap& locals, AstNode* root); + DenseHashMap& locals, const DenseHashMap* builtins, AstNode* root); } // namespace Compile } // namespace Luau diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp index 5608cd86..81cbfd7a 100644 --- a/Compiler/src/CostModel.cpp +++ b/Compiler/src/CostModel.cpp @@ -113,11 +113,14 @@ struct Cost struct CostVisitor : AstVisitor { + const DenseHashMap& builtins; + DenseHashMap vars; Cost result; - CostVisitor() - : vars(nullptr) + CostVisitor(const DenseHashMap& builtins) + : builtins(builtins) + , vars(nullptr) { } @@ -148,14 +151,21 @@ struct CostVisitor : AstVisitor } else if (AstExprCall* expr = node->as()) { - Cost cost = 3; - cost += model(expr->func); + // builtin cost modeling is different from regular calls because we use FASTCALL to compile these + // thus we use a cheaper baseline, don't account for function, and assume constant/local copy is free + bool builtin = builtins.find(expr) != nullptr; + bool builtinShort = builtin && expr->args.size <= 2; // FASTCALL1/2 + + Cost cost = builtin ? 2 : 3; + + if (!builtin) + cost += model(expr->func); for (size_t i = 0; i < expr->args.size; ++i) { Cost ac = model(expr->args.data[i]); // for constants/locals we still need to copy them to the argument list - cost += ac.model == 0 ? Cost(1) : ac; + cost += ac.model == 0 && !builtinShort ? Cost(1) : ac; } return cost; @@ -327,9 +337,9 @@ struct CostVisitor : AstVisitor } }; -uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount) +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount, const DenseHashMap& builtins) { - CostVisitor visitor; + CostVisitor visitor{builtins}; for (size_t i = 0; i < varCount && i < 7; ++i) visitor.vars[vars[i]] = 0xffull << (i * 8 + 8); diff --git a/Compiler/src/CostModel.h b/Compiler/src/CostModel.h index 17defafb..e8f3e166 100644 --- a/Compiler/src/CostModel.h +++ b/Compiler/src/CostModel.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Ast.h" +#include "Luau/DenseHash.h" namespace Luau { @@ -9,7 +10,7 @@ namespace Compile { // cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant -uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount, const DenseHashMap& builtins); // cost is computed as B - sum(Di * Ci), where B is baseline cost, Di is the discount for each variable and Ci is 1 when variable #i is constant int computeCost(uint64_t model, const bool* varsConst, size_t varCount); diff --git a/Makefile b/Makefile index b8077897..dd32d237 100644 --- a/Makefile +++ b/Makefile @@ -31,15 +31,15 @@ ISOCLINE_SOURCES=extern/isocline/src/isocline.c ISOCLINE_OBJECTS=$(ISOCLINE_SOURCES:%=$(BUILD)/%.o) ISOCLINE_TARGET=$(BUILD)/libisocline.a -TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp +TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Flags.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests -REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/ReplEntry.cpp +REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/ReplEntry.cpp REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) REPL_CLI_TARGET=$(BUILD)/luau -ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Analyze.cpp +ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Analyze.cpp ANALYZE_CLI_OBJECTS=$(ANALYZE_CLI_SOURCES:%=$(BUILD)/%.o) ANALYZE_CLI_TARGET=$(BUILD)/luau-analyze @@ -50,8 +50,12 @@ TESTS_ARGS= ifneq ($(flags),) TESTS_ARGS+=--fflags=$(flags) endif +ifneq ($(opt),) + TESTS_ARGS+=-O$(opt) +endif OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) +EXECUTABLE_ALIASES = luau luau-analyze luau-tests # common flags CXXFLAGS=-g -Wall @@ -104,7 +108,7 @@ $(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnaly $(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include $(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include -$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern +$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY $(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -Iextern -Iextern/isocline/include $(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -Iextern $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include @@ -114,15 +118,18 @@ $(REPL_CLI_TARGET): LDFLAGS+=-lpthread fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a build/libprotobuf-mutator/external.protobuf/lib/libprotobuf.a # pseudo targets -.PHONY: all test clean coverage format luau-size +.PHONY: all test clean coverage format luau-size aliases -all: $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET) $(TESTS_TARGET) +all: $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET) $(TESTS_TARGET) aliases + +aliases: $(EXECUTABLE_ALIASES) test: $(TESTS_TARGET) $(TESTS_TARGET) $(TESTS_ARGS) clean: rm -rf $(BUILD) + rm -rf $(EXECUTABLE_ALIASES) coverage: $(TESTS_TARGET) $(TESTS_TARGET) --fflags=true @@ -148,6 +155,9 @@ luau: $(REPL_CLI_TARGET) luau-analyze: $(ANALYZE_CLI_TARGET) ln -fs $^ $@ +luau-tests: $(TESTS_TARGET) + ln -fs $^ $@ + # executable targets $(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) diff --git a/README.md b/README.md index 2ed7348d..2b44c89a 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -Luau ![CI](https://github.com/Roblox/luau/workflows/build/badge.svg) [![Coverage](https://coveralls.io/repos/github/Roblox/luau/badge.svg?branch=master&t=2PXMow)](https://coveralls.io/github/Roblox/luau?branch=master) +Luau ![CI](https://github.com/Roblox/luau/workflows/build/badge.svg) [![codecov](https://codecov.io/gh/Roblox/luau/branch/master/graph/badge.svg?token=S3U44WN416)](https://codecov.io/gh/Roblox/luau) ==== Luau (lowercase u, /ˈlu.aÊŠ/) is a fast, small, safe, gradually typed embeddable scripting language derived from [Lua](https://lua.org). diff --git a/Sources.cmake b/Sources.cmake index f261cba6..9a6019a9 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -4,6 +4,7 @@ if(NOT ${CMAKE_VERSION} VERSION_LESS "3.19") target_sources(Luau.Common PRIVATE Common/include/Luau/Common.h Common/include/Luau/Bytecode.h + Common/include/Luau/ExperimentalFlags.h ) endif() @@ -38,12 +39,14 @@ target_sources(Luau.Compiler PRIVATE Compiler/src/BytecodeBuilder.cpp Compiler/src/Compiler.cpp Compiler/src/Builtins.cpp + Compiler/src/BuiltinFolding.cpp Compiler/src/ConstantFolding.cpp Compiler/src/CostModel.cpp Compiler/src/TableShape.cpp Compiler/src/ValueTracking.cpp Compiler/src/lcode.cpp Compiler/src/Builtins.h + Compiler/src/BuiltinFolding.h Compiler/src/ConstantFolding.h Compiler/src/CostModel.h Compiler/src/TableShape.h @@ -63,6 +66,8 @@ target_sources(Luau.CodeGen PRIVATE # Luau.Analysis Sources target_sources(Luau.Analysis PRIVATE + Analysis/include/Luau/ApplyTypeFunction.h + Analysis/include/Luau/AstJsonEncoder.h Analysis/include/Luau/AstQuery.h Analysis/include/Luau/Autocomplete.h Analysis/include/Luau/BuiltinDefinitions.h @@ -78,7 +83,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Frontend.h Analysis/include/Luau/Instantiation.h Analysis/include/Luau/IostreamHelpers.h - Analysis/include/Luau/JsonEncoder.h + Analysis/include/Luau/JsonEmitter.h Analysis/include/Luau/Linter.h Analysis/include/Luau/LValue.h Analysis/include/Luau/Module.h @@ -110,6 +115,8 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Variant.h Analysis/include/Luau/VisitTypeVar.h + Analysis/src/ApplyTypeFunction.cpp + Analysis/src/AstJsonEncoder.cpp Analysis/src/AstQuery.cpp Analysis/src/Autocomplete.cpp Analysis/src/BuiltinDefinitions.cpp @@ -123,7 +130,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Frontend.cpp Analysis/src/Instantiation.cpp Analysis/src/IostreamHelpers.cpp - Analysis/src/JsonEncoder.cpp + Analysis/src/JsonEmitter.cpp Analysis/src/Linter.cpp Analysis/src/LValue.cpp Analysis/src/Module.cpp @@ -218,6 +225,8 @@ if(TARGET Luau.Repl.CLI) CLI/Coverage.cpp CLI/FileUtils.h CLI/FileUtils.cpp + CLI/Flags.h + CLI/Flags.cpp CLI/Profiler.h CLI/Profiler.cpp CLI/Repl.cpp @@ -229,6 +238,8 @@ if(TARGET Luau.Analyze.CLI) target_sources(Luau.Analyze.CLI PRIVATE CLI/FileUtils.h CLI/FileUtils.cpp + CLI/Flags.h + CLI/Flags.cpp CLI/Analyze.cpp) endif() @@ -247,23 +258,27 @@ if(TARGET Luau.UnitTest) tests/IostreamOptional.h tests/ScopedFlags.h tests/Fixture.cpp + tests/AssemblyBuilderX64.test.cpp + tests/AstJsonEncoder.test.cpp tests/AstQuery.test.cpp tests/AstVisitor.test.cpp tests/Autocomplete.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp tests/Config.test.cpp + tests/ConstraintGraphBuilder.test.cpp + tests/ConstraintSolver.test.cpp tests/CostModel.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp - tests/JsonEncoder.test.cpp + tests/JsonEmitter.test.cpp + tests/Lexer.test.cpp tests/Linter.test.cpp tests/LValue.test.cpp tests/Module.test.cpp tests/NonstrictMode.test.cpp tests/Normalize.test.cpp - tests/ConstraintGraphBuilder.test.cpp - tests/ConstraintSolver.test.cpp + tests/NotNull.test.cpp tests/Parser.test.cpp tests/RequireTracer.test.cpp tests/RuntimeLimits.test.cpp @@ -295,11 +310,11 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.tryUnify.test.cpp tests/TypeInfer.typePacks.cpp tests/TypeInfer.unionTypes.test.cpp + tests/TypeInfer.unknownnever.test.cpp tests/TypePack.test.cpp tests/TypeVar.test.cpp tests/Variant.test.cpp tests/VisitTypeVar.test.cpp - tests/AssemblyBuilderX64.test.cpp tests/main.cpp) endif() @@ -317,6 +332,8 @@ if(TARGET Luau.CLI.Test) CLI/Coverage.cpp CLI/FileUtils.h CLI/FileUtils.cpp + CLI/Flags.h + CLI/Flags.cpp CLI/Profiler.h CLI/Profiler.cpp CLI/Repl.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index 7f9647c8..f986c2b3 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -11,7 +11,7 @@ -/* option for multiple returns in `lua_pcall' and `lua_call' */ +// option for multiple returns in `lua_pcall' and `lua_call' #define LUA_MULTRET (-1) /* @@ -23,7 +23,7 @@ #define lua_upvalueindex(i) (LUA_GLOBALSINDEX - (i)) #define lua_ispseudo(i) ((i) <= LUA_REGISTRYINDEX) -/* thread status; 0 is OK */ +// thread status; 0 is OK enum lua_Status { LUA_OK = 0, @@ -32,7 +32,7 @@ enum lua_Status LUA_ERRSYNTAX, LUA_ERRMEM, LUA_ERRERR, - LUA_BREAK, /* yielded for a debug breakpoint */ + LUA_BREAK, // yielded for a debug breakpoint }; typedef struct lua_State lua_State; @@ -46,7 +46,7 @@ typedef int (*lua_Continuation)(lua_State* L, int status); typedef void* (*lua_Alloc)(void* ud, void* ptr, size_t osize, size_t nsize); -/* non-return type */ +// non-return type #define l_noret void LUA_NORETURN /* @@ -61,39 +61,39 @@ typedef void* (*lua_Alloc)(void* ud, void* ptr, size_t osize, size_t nsize); // clang-format off enum lua_Type { - LUA_TNIL = 0, /* must be 0 due to lua_isnoneornil */ - LUA_TBOOLEAN = 1, /* must be 1 due to l_isfalse */ + LUA_TNIL = 0, // must be 0 due to lua_isnoneornil + LUA_TBOOLEAN = 1, // must be 1 due to l_isfalse + - LUA_TLIGHTUSERDATA, LUA_TNUMBER, LUA_TVECTOR, - LUA_TSTRING, /* all types above this must be value types, all types below this must be GC types - see iscollectable */ + LUA_TSTRING, // all types above this must be value types, all types below this must be GC types - see iscollectable + - LUA_TTABLE, LUA_TFUNCTION, LUA_TUSERDATA, LUA_TTHREAD, - /* values below this line are used in GCObject tags but may never show up in TValue type tags */ + // values below this line are used in GCObject tags but may never show up in TValue type tags LUA_TPROTO, LUA_TUPVAL, LUA_TDEADKEY, - /* the count of TValue type tags */ + // the count of TValue type tags LUA_T_COUNT = LUA_TPROTO }; // clang-format on -/* type of numbers in Luau */ +// type of numbers in Luau typedef double lua_Number; -/* type for integer functions */ +// type for integer functions typedef int lua_Integer; -/* unsigned integer type */ +// unsigned integer type typedef unsigned lua_Unsigned; /* @@ -117,7 +117,7 @@ LUA_API void lua_remove(lua_State* L, int idx); LUA_API void lua_insert(lua_State* L, int idx); LUA_API void lua_replace(lua_State* L, int idx); LUA_API int lua_checkstack(lua_State* L, int sz); -LUA_API void lua_rawcheckstack(lua_State* L, int sz); /* allows for unlimited stack frames */ +LUA_API void lua_rawcheckstack(lua_State* L, int sz); // allows for unlimited stack frames LUA_API void lua_xmove(lua_State* from, lua_State* to, int n); LUA_API void lua_xpush(lua_State* from, lua_State* to, int idx); @@ -231,18 +231,18 @@ LUA_API void lua_setthreaddata(lua_State* L, void* data); enum lua_GCOp { - /* stop and resume incremental garbage collection */ + // stop and resume incremental garbage collection LUA_GCSTOP, LUA_GCRESTART, - /* run a full GC cycle; not recommended for latency sensitive applications */ + // run a full GC cycle; not recommended for latency sensitive applications LUA_GCCOLLECT, - /* return the heap size in KB and the remainder in bytes */ + // return the heap size in KB and the remainder in bytes LUA_GCCOUNT, LUA_GCCOUNTB, - /* return 1 if GC is active (not stopped); note that GC may not be actively collecting even if it's running */ + // return 1 if GC is active (not stopped); note that GC may not be actively collecting even if it's running LUA_GCISRUNNING, /* @@ -300,6 +300,7 @@ LUA_API uintptr_t lua_encodepointer(lua_State* L, uintptr_t p); LUA_API double lua_clock(); +LUA_API void lua_setuserdatatag(lua_State* L, int idx, int tag); LUA_API void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*)); LUA_API void lua_clonefunction(lua_State* L, int idx); @@ -358,9 +359,9 @@ LUA_API void lua_unref(lua_State* L, int ref); ** ======================================================================= */ -typedef struct lua_Debug lua_Debug; /* activation record */ +typedef struct lua_Debug lua_Debug; // activation record -/* Functions to be called by the debugger in specific events */ +// Functions to be called by the debugger in specific events typedef void (*lua_Hook)(lua_State* L, lua_Debug* ar); LUA_API int lua_stackdepth(lua_State* L); @@ -372,30 +373,30 @@ LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n); LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); LUA_API void lua_singlestep(lua_State* L, int enabled); -LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); +LUA_API int lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); typedef void (*lua_Coverage)(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size); LUA_API void lua_getcoverage(lua_State* L, int funcindex, void* context, lua_Coverage callback); -/* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ +// Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. LUA_API const char* lua_debugtrace(lua_State* L); struct lua_Debug { - const char* name; /* (n) */ - const char* what; /* (s) `Lua', `C', `main', `tail' */ - const char* source; /* (s) */ - int linedefined; /* (s) */ - int currentline; /* (l) */ - unsigned char nupvals; /* (u) number of upvalues */ - unsigned char nparams; /* (a) number of parameters */ - char isvararg; /* (a) */ - char short_src[LUA_IDSIZE]; /* (s) */ - void* userdata; /* only valid in luau_callhook */ + const char* name; // (n) + const char* what; // (s) `Lua', `C', `main', `tail' + const char* source; // (s) + int linedefined; // (s) + int currentline; // (l) + unsigned char nupvals; // (u) number of upvalues + unsigned char nparams; // (a) number of parameters + char isvararg; // (a) + char short_src[LUA_IDSIZE]; // (s) + void* userdata; // only valid in luau_callhook }; -/* }====================================================================== */ +// }====================================================================== /* Callbacks that can be used to reconfigure behavior of the VM dynamically. * These are shared between all coroutines. @@ -404,18 +405,18 @@ struct lua_Debug * can only be changed when the VM is not running any code */ struct lua_Callbacks { - void* userdata; /* arbitrary userdata pointer that is never overwritten by Luau */ + void* userdata; // arbitrary userdata pointer that is never overwritten by Luau - void (*interrupt)(lua_State* L, int gc); /* gets called at safepoints (loop back edges, call/ret, gc) if set */ - void (*panic)(lua_State* L, int errcode); /* gets called when an unprotected error is raised (if longjmp is used) */ + void (*interrupt)(lua_State* L, int gc); // gets called at safepoints (loop back edges, call/ret, gc) if set + void (*panic)(lua_State* L, int errcode); // gets called when an unprotected error is raised (if longjmp is used) - void (*userthread)(lua_State* LP, lua_State* L); /* gets called when L is created (LP == parent) or destroyed (LP == NULL) */ - int16_t (*useratom)(const char* s, size_t l); /* gets called when a string is created; returned atom can be retrieved via tostringatom */ + void (*userthread)(lua_State* LP, lua_State* L); // gets called when L is created (LP == parent) or destroyed (LP == NULL) + int16_t (*useratom)(const char* s, size_t l); // gets called when a string is created; returned atom can be retrieved via tostringatom - void (*debugbreak)(lua_State* L, lua_Debug* ar); /* gets called when BREAK instruction is encountered */ - void (*debugstep)(lua_State* L, lua_Debug* ar); /* gets called after each instruction in single step mode */ - void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */ - void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */ + void (*debugbreak)(lua_State* L, lua_Debug* ar); // gets called when BREAK instruction is encountered + void (*debugstep)(lua_State* L, lua_Debug* ar); // gets called after each instruction in single step mode + void (*debuginterrupt)(lua_State* L, lua_Debug* ar); // gets called when thread execution is interrupted by break in another thread + void (*debugprotectederror)(lua_State* L); // gets called when protected call results in an error }; typedef struct lua_Callbacks lua_Callbacks; diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index b93cbf7c..7b0f4c30 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -33,14 +33,14 @@ #define LUA_NORETURN __attribute__((__noreturn__)) #endif -/* Can be used to reconfigure visibility/exports for public APIs */ +// Can be used to reconfigure visibility/exports for public APIs #ifndef LUA_API #define LUA_API extern #endif #define LUALIB_API LUA_API -/* Can be used to reconfigure visibility for internal APIs */ +// Can be used to reconfigure visibility for internal APIs #if defined(__GNUC__) #define LUAI_FUNC __attribute__((visibility("hidden"))) extern #define LUAI_DATA LUAI_FUNC @@ -49,67 +49,67 @@ #define LUAI_DATA extern #endif -/* Can be used to reconfigure internal error handling to use longjmp instead of C++ EH */ +// Can be used to reconfigure internal error handling to use longjmp instead of C++ EH #ifndef LUA_USE_LONGJMP #define LUA_USE_LONGJMP 0 #endif -/* LUA_IDSIZE gives the maximum size for the description of the source */ +// LUA_IDSIZE gives the maximum size for the description of the source #ifndef LUA_IDSIZE #define LUA_IDSIZE 256 #endif -/* LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function */ +// LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function #ifndef LUA_MINSTACK #define LUA_MINSTACK 20 #endif -/* LUAI_MAXCSTACK limits the number of Lua stack slots that a C function can use */ +// LUAI_MAXCSTACK limits the number of Lua stack slots that a C function can use #ifndef LUAI_MAXCSTACK #define LUAI_MAXCSTACK 8000 #endif -/* LUAI_MAXCALLS limits the number of nested calls */ +// LUAI_MAXCALLS limits the number of nested calls #ifndef LUAI_MAXCALLS #define LUAI_MAXCALLS 20000 #endif -/* LUAI_MAXCCALLS is the maximum depth for nested C calls; this limit depends on native stack size */ +// LUAI_MAXCCALLS is the maximum depth for nested C calls; this limit depends on native stack size #ifndef LUAI_MAXCCALLS #define LUAI_MAXCCALLS 200 #endif -/* buffer size used for on-stack string operations; this limit depends on native stack size */ +// buffer size used for on-stack string operations; this limit depends on native stack size #ifndef LUA_BUFFERSIZE #define LUA_BUFFERSIZE 512 #endif -/* number of valid Lua userdata tags */ +// number of valid Lua userdata tags #ifndef LUA_UTAG_LIMIT #define LUA_UTAG_LIMIT 128 #endif -/* upper bound for number of size classes used by page allocator */ +// upper bound for number of size classes used by page allocator #ifndef LUA_SIZECLASSES #define LUA_SIZECLASSES 32 #endif -/* available number of separate memory categories */ +// available number of separate memory categories #ifndef LUA_MEMORY_CATEGORIES #define LUA_MEMORY_CATEGORIES 256 #endif -/* minimum size for the string table (must be power of 2) */ +// minimum size for the string table (must be power of 2) #ifndef LUA_MINSTRTABSIZE #define LUA_MINSTRTABSIZE 32 #endif -/* maximum number of captures supported by pattern matching */ +// maximum number of captures supported by pattern matching #ifndef LUA_MAXCAPTURES #define LUA_MAXCAPTURES 32 #endif -/* }================================================================== */ +// }================================================================== /* @@ LUAI_USER_ALIGNMENT_T is a type that requires maximum alignment. @@ -126,6 +126,6 @@ long l; \ } -#define LUA_VECTOR_SIZE 3 /* must be 3 or 4 */ +#define LUA_VECTOR_SIZE 3 // must be 3 or 4 #define LUA_EXTRA_SIZE LUA_VECTOR_SIZE - 2 diff --git a/VM/include/lualib.h b/VM/include/lualib.h index bebd0a0f..955604de 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -72,7 +72,7 @@ LUALIB_API const char* luaL_typename(lua_State* L, int idx); #define luaL_opt(L, f, n, d) (lua_isnoneornil(L, (n)) ? (d) : f(L, (n))) -/* generic buffer manipulation */ +// generic buffer manipulation struct luaL_Buffer { @@ -102,7 +102,7 @@ LUALIB_API void luaL_addvalue(luaL_Buffer* B); LUALIB_API void luaL_pushresult(luaL_Buffer* B); LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size); -/* builtin libraries */ +// builtin libraries LUALIB_API int luaopen_base(lua_State* L); #define LUA_COLIBNAME "coroutine" @@ -129,9 +129,9 @@ LUALIB_API int luaopen_math(lua_State* L); #define LUA_DBLIBNAME "debug" LUALIB_API int luaopen_debug(lua_State* L); -/* open all builtin libraries */ +// open all builtin libraries LUALIB_API void luaL_openlibs(lua_State* L); -/* sandbox libraries and globals */ +// sandbox libraries and globals LUALIB_API void luaL_sandbox(lua_State* L); LUALIB_API void luaL_sandboxthread(lua_State* L); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 3c3b7bd0..bb994fb4 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -51,10 +51,16 @@ const char* luau_ident = "$Luau: Copyright (C) 2019-2022 Roblox Corporation $\n" L->top++; \ } +#define updateatom(L, ts) \ + { \ + if (ts->atom == ATOM_UNDEF) \ + ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; \ + } + static Table* getcurrenv(lua_State* L) { - if (L->ci == L->base_ci) /* no enclosing function? */ - return L->gt; /* use global table as environment */ + if (L->ci == L->base_ci) // no enclosing function? + return L->gt; // use global table as environment else return curr_func(L)->env; } @@ -63,7 +69,7 @@ static LUAU_NOINLINE TValue* pseudo2addr(lua_State* L, int idx) { api_check(L, lua_ispseudo(idx)); switch (idx) - { /* pseudo-indices */ + { // pseudo-indices case LUA_REGISTRYINDEX: return registry(L); case LUA_ENVIRONINDEX: @@ -123,7 +129,7 @@ int lua_checkstack(lua_State* L, int size) { int res = 1; if (size > LUAI_MAXCSTACK || (L->top - L->base + size) > LUAI_MAXCSTACK) - res = 0; /* stack overflow */ + res = 0; // stack overflow else if (size > 0) { luaD_checkstack(L, size); @@ -213,7 +219,7 @@ void lua_settop(lua_State* L, int idx) else { api_check(L, -(idx + 1) <= (L->top - L->base)); - L->top += idx + 1; /* `subtract' index (index is negative) */ + L->top += idx + 1; // `subtract' index (index is negative) } return; } @@ -261,7 +267,7 @@ void lua_replace(lua_State* L, int idx) else { setobj(L, o, L->top - 1); - if (idx < LUA_GLOBALSINDEX) /* function upvalue? */ + if (idx < LUA_GLOBALSINDEX) // function upvalue? luaC_barrier(L, curr_func(L), L->top - 1); } L->top--; @@ -423,13 +429,13 @@ const char* lua_tolstring(lua_State* L, int idx, size_t* len) { luaC_checkthreadsleep(L); if (!luaV_tostring(L, o)) - { /* conversion failed? */ + { // conversion failed? if (len != NULL) *len = 0; return NULL; } luaC_checkGC(L); - o = index2addr(L, idx); /* previous call may reallocate the stack */ + o = index2addr(L, idx); // previous call may reallocate the stack } if (len != NULL) *len = tsvalue(o)->len; @@ -441,19 +447,25 @@ const char* lua_tostringatom(lua_State* L, int idx, int* atom) StkId o = index2addr(L, idx); if (!ttisstring(o)) return NULL; - const TString* s = tsvalue(o); + TString* s = tsvalue(o); if (atom) + { + updateatom(L, s); *atom = s->atom; + } return getstr(s); } const char* lua_namecallatom(lua_State* L, int* atom) { - const TString* s = L->namecall; + TString* s = L->namecall; if (!s) return NULL; if (atom) + { + updateatom(L, s); *atom = s->atom; + } return getstr(s); } @@ -648,7 +660,7 @@ void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, in void lua_pushboolean(lua_State* L, int b) { - setbvalue(L->top, (b != 0)); /* ensure that true is 1 */ + setbvalue(L->top, (b != 0)); // ensure that true is 1 api_incr_top(L); return; } @@ -817,7 +829,7 @@ void lua_settable(lua_State* L, int idx) StkId t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_settable(L, t, L->top - 2, L->top - 1); - L->top -= 2; /* pop index and value */ + L->top -= 2; // pop index and value return; } @@ -839,7 +851,7 @@ void lua_rawset(lua_State* L, int idx) StkId t = index2addr(L, idx); api_check(L, ttistable(t)); if (hvalue(t)->readonly) - luaG_runerror(L, "Attempt to modify a readonly table"); + luaG_readonlyerror(L); setobj2t(L, luaH_set(L, hvalue(t), L->top - 2), L->top - 1); luaC_barriert(L, hvalue(t), L->top - 1); L->top -= 2; @@ -852,7 +864,7 @@ void lua_rawseti(lua_State* L, int idx, int n) StkId o = index2addr(L, idx); api_check(L, ttistable(o)); if (hvalue(o)->readonly) - luaG_runerror(L, "Attempt to modify a readonly table"); + luaG_readonlyerror(L); setobj2t(L, luaH_setnum(L, hvalue(o), n), L->top - 1); luaC_barriert(L, hvalue(o), L->top - 1); L->top--; @@ -875,7 +887,7 @@ int lua_setmetatable(lua_State* L, int objindex) case LUA_TTABLE: { if (hvalue(obj)->readonly) - luaG_runerror(L, "Attempt to modify a readonly table"); + luaG_readonlyerror(L); hvalue(obj)->metatable = mt; if (mt) luaC_objbarrier(L, hvalue(obj), mt); @@ -955,7 +967,7 @@ void lua_call(lua_State* L, int nargs, int nresults) ** Execute a protected call. */ struct CallS -{ /* data to `f_call' */ +{ // data to `f_call' StkId func; int nresults; }; @@ -980,7 +992,7 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) func = savestack(L, o); } struct CallS c; - c.func = L->top - (nargs + 1); /* function to be called */ + c.func = L->top - (nargs + 1); // function to be called c.nresults = nresults; int status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); @@ -1032,7 +1044,7 @@ int lua_gc(lua_State* L, int what, int data) } case LUA_GCCOUNT: { - /* GC values are expressed in Kbytes: #bytes/2^10 */ + // GC values are expressed in Kbytes: #bytes/2^10 res = cast_int(g->totalbytes >> 10); break; } @@ -1072,8 +1084,8 @@ int lua_gc(lua_State* L, int what, int data) actualwork += stepsize; if (g->gcstate == GCSpause) - { /* end of cycle? */ - res = 1; /* signal it */ + { // end of cycle? + res = 1; // signal it break; } } @@ -1125,13 +1137,13 @@ int lua_gc(lua_State* L, int what, int data) } case LUA_GCSETSTEPSIZE: { - /* GC values are expressed in Kbytes: #bytes/2^10 */ + // GC values are expressed in Kbytes: #bytes/2^10 res = g->gcstepsize >> 10; g->gcstepsize = data << 10; break; } default: - res = -1; /* invalid option */ + res = -1; // invalid option } return res; } @@ -1157,8 +1169,8 @@ int lua_next(lua_State* L, int idx) { api_incr_top(L); } - else /* no more elements */ - L->top -= 1; /* remove key */ + else // no more elements + L->top -= 1; // remove key return more; } @@ -1173,12 +1185,12 @@ void lua_concat(lua_State* L, int n) L->top -= (n - 1); } else if (n == 0) - { /* push empty string */ + { // push empty string luaC_checkthreadsleep(L); setsvalue2s(L, L->top, luaS_newlstr(L, "", 0)); api_incr_top(L); } - /* else n == 1; nothing to do */ + // else n == 1; nothing to do return; } @@ -1265,7 +1277,7 @@ uintptr_t lua_encodepointer(lua_State* L, uintptr_t p) int lua_ref(lua_State* L, int idx) { - api_check(L, idx != LUA_REGISTRYINDEX); /* idx is a stack index for value */ + api_check(L, idx != LUA_REGISTRYINDEX); // idx is a stack index for value int ref = LUA_REFNIL; global_State* g = L->global; StkId p = index2addr(L, idx); @@ -1274,13 +1286,13 @@ int lua_ref(lua_State* L, int idx) Table* reg = hvalue(registry(L)); if (g->registryfree != 0) - { /* reuse existing slot */ + { // reuse existing slot ref = g->registryfree; } else - { /* no free elements */ + { // no free elements ref = luaH_getn(reg); - ref++; /* create new reference */ + ref++; // create new reference } TValue* slot = luaH_setnum(L, reg, ref); @@ -1300,11 +1312,19 @@ void lua_unref(lua_State* L, int ref) global_State* g = L->global; Table* reg = hvalue(registry(L)); TValue* slot = luaH_setnum(L, reg, ref); - setnvalue(slot, g->registryfree); /* NB: no barrier needed because value isn't collectable */ + setnvalue(slot, g->registryfree); // NB: no barrier needed because value isn't collectable g->registryfree = ref; return; } +void lua_setuserdatatag(lua_State* L, int idx, int tag) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + StkId o = index2addr(L, idx); + api_check(L, ttisuserdata(o)); + uvalue(o)->tag = uint8_t(tag); +} + void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*)) { api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 72169a86..c42e5ccc 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,7 +11,7 @@ #include -/* convert a stack index to positive */ +// convert a stack index to positive #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) /* @@ -75,7 +75,7 @@ void luaL_where(lua_State* L, int level) lua_pushfstring(L, "%s:%d: ", ar.short_src, ar.currentline); return; } - lua_pushliteral(L, ""); /* else, no information available... */ + lua_pushliteral(L, ""); // else, no information available... } l_noret luaL_errorL(lua_State* L, const char* fmt, ...) @@ -89,7 +89,7 @@ l_noret luaL_errorL(lua_State* L, const char* fmt, ...) lua_error(L); } -/* }====================================================== */ +// }====================================================== int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const lst[]) { @@ -104,13 +104,13 @@ int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const int luaL_newmetatable(lua_State* L, const char* tname) { - lua_getfield(L, LUA_REGISTRYINDEX, tname); /* get registry.name */ - if (!lua_isnil(L, -1)) /* name already in use? */ - return 0; /* leave previous value on top, but return 0 */ + lua_getfield(L, LUA_REGISTRYINDEX, tname); // get registry.name + if (!lua_isnil(L, -1)) // name already in use? + return 0; // leave previous value on top, but return 0 lua_pop(L, 1); - lua_newtable(L); /* create metatable */ + lua_newtable(L); // create metatable lua_pushvalue(L, -1); - lua_setfield(L, LUA_REGISTRYINDEX, tname); /* registry.name = metatable */ + lua_setfield(L, LUA_REGISTRYINDEX, tname); // registry.name = metatable return 1; } @@ -118,18 +118,18 @@ void* luaL_checkudata(lua_State* L, int ud, const char* tname) { void* p = lua_touserdata(L, ud); if (p != NULL) - { /* value is a userdata? */ + { // value is a userdata? if (lua_getmetatable(L, ud)) - { /* does it have a metatable? */ - lua_getfield(L, LUA_REGISTRYINDEX, tname); /* get correct metatable */ + { // does it have a metatable? + lua_getfield(L, LUA_REGISTRYINDEX, tname); // get correct metatable if (lua_rawequal(L, -1, -2)) - { /* does it have the correct mt? */ - lua_pop(L, 2); /* remove both metatables */ + { // does it have the correct mt? + lua_pop(L, 2); // remove both metatables return p; } } } - luaL_typeerrorL(L, ud, tname); /* else error */ + luaL_typeerrorL(L, ud, tname); // else error } void luaL_checkstack(lua_State* L, int space, const char* mes) @@ -243,18 +243,18 @@ const float* luaL_optvector(lua_State* L, int narg, const float* def) int luaL_getmetafield(lua_State* L, int obj, const char* event) { - if (!lua_getmetatable(L, obj)) /* no metatable? */ + if (!lua_getmetatable(L, obj)) // no metatable? return 0; lua_pushstring(L, event); lua_rawget(L, -2); if (lua_isnil(L, -1)) { - lua_pop(L, 2); /* remove metatable and metafield */ + lua_pop(L, 2); // remove metatable and metafield return 0; } else { - lua_remove(L, -2); /* remove only metatable */ + lua_remove(L, -2); // remove only metatable return 1; } } @@ -262,7 +262,7 @@ int luaL_getmetafield(lua_State* L, int obj, const char* event) int luaL_callmeta(lua_State* L, int obj, const char* event) { obj = abs_index(L, obj); - if (!luaL_getmetafield(L, obj, event)) /* no metafield? */ + if (!luaL_getmetafield(L, obj, event)) // no metafield? return 0; lua_pushvalue(L, obj); lua_call(L, 1, 1); @@ -282,19 +282,19 @@ void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l) if (libname) { int size = libsize(l); - /* check whether lib already exists */ + // check whether lib already exists luaL_findtable(L, LUA_REGISTRYINDEX, "_LOADED", 1); - lua_getfield(L, -1, libname); /* get _LOADED[libname] */ + lua_getfield(L, -1, libname); // get _LOADED[libname] if (!lua_istable(L, -1)) - { /* not found? */ - lua_pop(L, 1); /* remove previous result */ - /* try global variable (and create one if it does not exist) */ + { // not found? + lua_pop(L, 1); // remove previous result + // try global variable (and create one if it does not exist) if (luaL_findtable(L, LUA_GLOBALSINDEX, libname, size) != NULL) luaL_error(L, "name conflict for module '%s'", libname); lua_pushvalue(L, -1); - lua_setfield(L, -3, libname); /* _LOADED[libname] = new table */ + lua_setfield(L, -3, libname); // _LOADED[libname] = new table } - lua_remove(L, -2); /* remove _LOADED table */ + lua_remove(L, -2); // remove _LOADED table } for (; l->name; l++) { @@ -315,19 +315,19 @@ const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) lua_pushlstring(L, fname, e - fname); lua_rawget(L, -2); if (lua_isnil(L, -1)) - { /* no such field? */ - lua_pop(L, 1); /* remove this nil */ - lua_createtable(L, 0, (*e == '.' ? 1 : szhint)); /* new table for field */ + { // no such field? + lua_pop(L, 1); // remove this nil + lua_createtable(L, 0, (*e == '.' ? 1 : szhint)); // new table for field lua_pushlstring(L, fname, e - fname); lua_pushvalue(L, -2); - lua_settable(L, -4); /* set new table into field */ + lua_settable(L, -4); // set new table into field } else if (!lua_istable(L, -1)) - { /* field has a non-table value? */ - lua_pop(L, 2); /* remove table and value */ - return fname; /* return problematic part of the name */ + { // field has a non-table value? + lua_pop(L, 2); // remove table and value + return fname; // return problematic part of the name } - lua_remove(L, -2); /* remove previous table */ + lua_remove(L, -2); // remove previous table fname = e + 1; } while (*e == '.'); return NULL; @@ -470,11 +470,11 @@ void luaL_pushresultsize(luaL_Buffer* B, size_t size) luaL_pushresult(B); } -/* }====================================================== */ +// }====================================================== const char* luaL_tolstring(lua_State* L, int idx, size_t* len) { - if (luaL_callmeta(L, idx, "__tostring")) /* is there a metafield? */ + if (luaL_callmeta(L, idx, "__tostring")) // is there a metafield? { if (!lua_isstring(L, -1)) luaL_error(L, "'__tostring' must return a string"); diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 4fc5033e..f4dac61f 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -11,8 +11,6 @@ #include #include -LUAU_FASTFLAG(LuauLenTM) - static void writestring(const char* s, size_t l) { fwrite(s, 1, l, stdout); @@ -20,15 +18,15 @@ static void writestring(const char* s, size_t l) static int luaB_print(lua_State* L) { - int n = lua_gettop(L); /* number of arguments */ + int n = lua_gettop(L); // number of arguments for (int i = 1; i <= n; i++) { size_t l; - const char* s = luaL_tolstring(L, i, &l); /* convert to string using __tostring et al */ + const char* s = luaL_tolstring(L, i, &l); // convert to string using __tostring et al if (i > 1) writestring("\t", 1); writestring(s, l); - lua_pop(L, 1); /* pop result */ + lua_pop(L, 1); // pop result } writestring("\n", 1); return 0; @@ -38,7 +36,7 @@ static int luaB_tonumber(lua_State* L) { int base = luaL_optinteger(L, 2, 10); if (base == 10) - { /* standard conversion */ + { // standard conversion int isnum = 0; double n = lua_tonumberx(L, 1, &isnum); if (isnum) @@ -46,7 +44,7 @@ static int luaB_tonumber(lua_State* L) lua_pushnumber(L, n); return 1; } - luaL_checkany(L, 1); /* error if we don't have any argument */ + luaL_checkany(L, 1); // error if we don't have any argument } else { @@ -56,17 +54,17 @@ static int luaB_tonumber(lua_State* L) unsigned long long n; n = strtoull(s1, &s2, base); if (s1 != s2) - { /* at least one valid digit? */ + { // at least one valid digit? while (isspace((unsigned char)(*s2))) - s2++; /* skip trailing spaces */ + s2++; // skip trailing spaces if (*s2 == '\0') - { /* no invalid trailing characters? */ + { // no invalid trailing characters? lua_pushnumber(L, (double)n); return 1; } } } - lua_pushnil(L); /* else not a number */ + lua_pushnil(L); // else not a number return 1; } @@ -75,7 +73,7 @@ static int luaB_error(lua_State* L) int level = luaL_optinteger(L, 2, 1); lua_settop(L, 1); if (lua_isstring(L, 1) && level > 0) - { /* add extra information? */ + { // add extra information? luaL_where(L, level); lua_pushvalue(L, 1); lua_concat(L, 2); @@ -89,10 +87,10 @@ static int luaB_getmetatable(lua_State* L) if (!lua_getmetatable(L, 1)) { lua_pushnil(L); - return 1; /* no metatable */ + return 1; // no metatable } luaL_getmetafield(L, 1, "__metatable"); - return 1; /* returns either __metatable field (if present) or metatable */ + return 1; // returns either __metatable field (if present) or metatable } static int luaB_setmetatable(lua_State* L) @@ -126,8 +124,8 @@ static void getfunc(lua_State* L, int opt) static int luaB_getfenv(lua_State* L) { getfunc(L, 1); - if (lua_iscfunction(L, -1)) /* is a C function? */ - lua_pushvalue(L, LUA_GLOBALSINDEX); /* return the thread's global env. */ + if (lua_iscfunction(L, -1)) // is a C function? + lua_pushvalue(L, LUA_GLOBALSINDEX); // return the thread's global env. else lua_getfenv(L, -1); lua_setsafeenv(L, -1, false); @@ -142,7 +140,7 @@ static int luaB_setfenv(lua_State* L) lua_setsafeenv(L, -1, false); if (lua_isnumber(L, 1) && lua_tonumber(L, 1) == 0) { - /* change environment of current thread */ + // change environment of current thread lua_pushthread(L); lua_insert(L, -2); lua_setfenv(L, -2); @@ -182,9 +180,6 @@ static int luaB_rawset(lua_State* L) static int luaB_rawlen(lua_State* L) { - if (!FFlag::LuauLenTM) - luaL_error(L, "'rawlen' is not available"); - int tt = lua_type(L, 1); luaL_argcheck(L, tt == LUA_TTABLE || tt == LUA_TSTRING, 1, "table or string expected"); int len = lua_objlen(L, 1); @@ -201,7 +196,7 @@ static int luaB_gcinfo(lua_State* L) static int luaB_type(lua_State* L) { luaL_checkany(L, 1); - /* resulting name doesn't differentiate between userdata types */ + // resulting name doesn't differentiate between userdata types lua_pushstring(L, lua_typename(L, lua_type(L, 1))); return 1; } @@ -209,7 +204,7 @@ static int luaB_type(lua_State* L) static int luaB_typeof(lua_State* L) { luaL_checkany(L, 1); - /* resulting name returns __type if specified unless the input is a newproxy-created userdata */ + // resulting name returns __type if specified unless the input is a newproxy-created userdata lua_pushstring(L, luaL_typename(L, 1)); return 1; } @@ -217,7 +212,7 @@ static int luaB_typeof(lua_State* L) int luaB_next(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); - lua_settop(L, 2); /* create a 2nd argument if there isn't one */ + lua_settop(L, 2); // create a 2nd argument if there isn't one if (lua_next(L, 1)) return 2; else @@ -230,9 +225,9 @@ int luaB_next(lua_State* L) static int luaB_pairs(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); - lua_pushvalue(L, lua_upvalueindex(1)); /* return generator, */ - lua_pushvalue(L, 1); /* state, */ - lua_pushnil(L); /* and initial value */ + lua_pushvalue(L, lua_upvalueindex(1)); // return generator, + lua_pushvalue(L, 1); // state, + lua_pushnil(L); // and initial value return 3; } @@ -240,7 +235,7 @@ int luaB_inext(lua_State* L) { int i = luaL_checkinteger(L, 2); luaL_checktype(L, 1, LUA_TTABLE); - i++; /* next value */ + i++; // next value lua_pushinteger(L, i); lua_rawgeti(L, 1, i); return (lua_isnil(L, -1)) ? 0 : 2; @@ -249,9 +244,9 @@ int luaB_inext(lua_State* L) static int luaB_ipairs(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); - lua_pushvalue(L, lua_upvalueindex(1)); /* return generator, */ - lua_pushvalue(L, 1); /* state, */ - lua_pushinteger(L, 0); /* and initial value */ + lua_pushvalue(L, lua_upvalueindex(1)); // return generator, + lua_pushvalue(L, 1); // state, + lua_pushinteger(L, 0); // and initial value return 3; } @@ -340,12 +335,12 @@ static int luaB_xpcally(lua_State* L) { luaL_checktype(L, 2, LUA_TFUNCTION); - /* swap function & error function */ + // swap function & error function lua_pushvalue(L, 1); lua_pushvalue(L, 2); lua_replace(L, 1); lua_replace(L, 2); - /* at this point the stack looks like err, f, args */ + // at this point the stack looks like err, f, args // any errors from this point on are handled by continuation L->ci->flags |= LUA_CALLINFO_HANDLE; @@ -386,7 +381,7 @@ static int luaB_xpcallcont(lua_State* L, int status) lua_rawcheckstack(L, 1); lua_pushboolean(L, true); lua_replace(L, 1); // replace error function with status - return lua_gettop(L); /* return status + all results */ + return lua_gettop(L); // return status + all results } else { @@ -462,16 +457,16 @@ static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFuncti int luaopen_base(lua_State* L) { - /* set global _G */ + // set global _G lua_pushvalue(L, LUA_GLOBALSINDEX); lua_setglobal(L, "_G"); - /* open lib into global table */ + // open lib into global table luaL_register(L, "_G", base_funcs); lua_pushliteral(L, "Luau"); - lua_setglobal(L, "_VERSION"); /* set global _VERSION */ + lua_setglobal(L, "_VERSION"); // set global _VERSION - /* `ipairs' and `pairs' need auxiliary functions as upvalues */ + // `ipairs' and `pairs' need auxiliary functions as upvalues auxopen(L, "ipairs", luaB_ipairs, luaB_inext); auxopen(L, "pairs", luaB_pairs, luaB_next); diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index 093400f2..47445b80 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -8,10 +8,10 @@ #define ALLONES ~0u #define NBITS int(8 * sizeof(unsigned)) -/* macro to trim extra bits */ +// macro to trim extra bits #define trim(x) ((x)&ALLONES) -/* builds a number with 'n' ones (1 <= n <= NBITS) */ +// builds a number with 'n' ones (1 <= n <= NBITS) #define mask(n) (~((ALLONES << 1) << ((n)-1))) typedef unsigned b_uint; @@ -69,7 +69,7 @@ static int b_not(lua_State* L) static int b_shift(lua_State* L, b_uint r, int i) { if (i < 0) - { /* shift right? */ + { // shift right? i = -i; r = trim(r); if (i >= NBITS) @@ -78,7 +78,7 @@ static int b_shift(lua_State* L, b_uint r, int i) r >>= i; } else - { /* shift left */ + { // shift left if (i >= NBITS) r = 0; else @@ -106,11 +106,11 @@ static int b_arshift(lua_State* L) if (i < 0 || !(r & ((b_uint)1 << (NBITS - 1)))) return b_shift(L, r, -i); else - { /* arithmetic shift for 'negative' number */ + { // arithmetic shift for 'negative' number if (i >= NBITS) r = ALLONES; else - r = trim((r >> i) | ~(~(b_uint)0 >> i)); /* add signal bit */ + r = trim((r >> i) | ~(~(b_uint)0 >> i)); // add signal bit lua_pushunsigned(L, r); return 1; } @@ -119,9 +119,9 @@ static int b_arshift(lua_State* L) static int b_rot(lua_State* L, int i) { b_uint r = luaL_checkunsigned(L, 1); - i &= (NBITS - 1); /* i = i % NBITS */ + i &= (NBITS - 1); // i = i % NBITS r = trim(r); - if (i != 0) /* avoid undefined shift of NBITS when i == 0 */ + if (i != 0) // avoid undefined shift of NBITS when i == 0 r = (r << i) | (r >> (NBITS - i)); lua_pushunsigned(L, trim(r)); return 1; @@ -172,7 +172,7 @@ static int b_replace(lua_State* L) b_uint v = luaL_checkunsigned(L, 2); int f = fieldargs(L, 3, &w); int m = mask(w); - v &= m; /* erase bits outside given width */ + v &= m; // erase bits outside given width r = (r & ~(m << f)) | (v << f); lua_pushunsigned(L, r); return 1; diff --git a/VM/src/lcommon.h b/VM/src/lcommon.h index ac79cd97..c9d95c77 100644 --- a/VM/src/lcommon.h +++ b/VM/src/lcommon.h @@ -11,7 +11,7 @@ typedef LUAI_USER_ALIGNMENT_T L_Umaxalign; -/* internal assertions for in-house debugging */ +// internal assertions for in-house debugging #define check_exp(c, e) (LUAU_ASSERT(c), (e)) #define api_check(l, e) LUAU_ASSERT(e) diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 7592a14c..7b967e34 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,9 +5,9 @@ #include "lstate.h" #include "lvm.h" -#define CO_RUN 0 /* running */ -#define CO_SUS 1 /* suspended */ -#define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ +#define CO_RUN 0 // running +#define CO_SUS 1 // suspended +#define CO_NOR 2 // 'normal' (it resumed another coroutine) #define CO_DEAD 3 #define CO_STATUS_ERROR -1 @@ -23,13 +23,13 @@ static int auxstatus(lua_State* L, lua_State* co) return CO_SUS; if (co->status == LUA_BREAK) return CO_NOR; - if (co->status != 0) /* some error occurred */ + if (co->status != 0) // some error occurred return CO_DEAD; - if (co->ci != co->base_ci) /* does it have frames? */ + if (co->ci != co->base_ci) // does it have frames? return CO_NOR; if (co->top == co->base) return CO_DEAD; - return CO_SUS; /* initial state */ + return CO_SUS; // initial state } static int costatus(lua_State* L) @@ -68,10 +68,10 @@ static int auxresume(lua_State* L, lua_State* co, int narg) int nres = cast_int(co->top - co->base); if (nres) { - /* +1 accounts for true/false status in resumefinish */ + // +1 accounts for true/false status in resumefinish if (nres + 1 > LUA_MINSTACK && !lua_checkstack(L, nres + 1)) luaL_error(L, "too many results to resume"); - lua_xmove(co, L, nres); /* move yielded values */ + lua_xmove(co, L, nres); // move yielded values } return nres; } @@ -81,7 +81,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg) } else { - lua_xmove(co, L, 1); /* move error message */ + lua_xmove(co, L, 1); // move error message return CO_STATUS_ERROR; } } @@ -102,13 +102,13 @@ static int auxresumecont(lua_State* L, lua_State* co) int nres = cast_int(co->top - co->base); if (!lua_checkstack(L, nres + 1)) luaL_error(L, "too many results to resume"); - lua_xmove(co, L, nres); /* move yielded values */ + lua_xmove(co, L, nres); // move yielded values return nres; } else { lua_rawcheckstack(L, 2); - lua_xmove(co, L, 1); /* move error message */ + lua_xmove(co, L, 1); // move error message return CO_STATUS_ERROR; } } @@ -119,13 +119,13 @@ static int coresumefinish(lua_State* L, int r) { lua_pushboolean(L, 0); lua_insert(L, -2); - return 2; /* return false + error message */ + return 2; // return false + error message } else { lua_pushboolean(L, 1); lua_insert(L, -(r + 1)); - return r + 1; /* return true + `resume' returns */ + return r + 1; // return true + `resume' returns } } @@ -161,12 +161,12 @@ static int auxwrapfinish(lua_State* L, int r) if (r < 0) { if (lua_isstring(L, -1)) - { /* error object is a string? */ - luaL_where(L, 1); /* add extra info */ + { // error object is a string? + luaL_where(L, 1); // add extra info lua_insert(L, -2); lua_concat(L, 2); } - lua_error(L); /* propagate error */ + lua_error(L); // propagate error } return r; } @@ -221,7 +221,7 @@ static int coyield(lua_State* L) static int corunning(lua_State* L) { if (lua_pushthread(L)) - lua_pushnil(L); /* main thread is not a coroutine */ + lua_pushnil(L); // main thread is not a coroutine return 1; } @@ -250,7 +250,7 @@ static int coclose(lua_State* L) { lua_pushboolean(L, false); if (lua_gettop(co)) - lua_xmove(co, L, 1); /* move error message */ + lua_xmove(co, L, 1); // move error message lua_resetthread(co); return 2; } diff --git a/VM/src/ldblib.cpp b/VM/src/ldblib.cpp index 93d8703a..ece4f551 100644 --- a/VM/src/ldblib.cpp +++ b/VM/src/ldblib.cpp @@ -82,9 +82,9 @@ static int db_info(lua_State* L) case 'f': if (L1 == L) - lua_pushvalue(L, -1 - results); /* function is right before results */ + lua_pushvalue(L, -1 - results); // function is right before results else - lua_xmove(L1, L, 1); /* function is at top of L1 */ + lua_xmove(L1, L, 1); // function is at top of L1 results++; break; @@ -130,15 +130,14 @@ static int db_traceback(lua_State* L) if (ar.currentline > 0) { - char line[32]; -#ifdef _MSC_VER - _itoa(ar.currentline, line, 10); // 5x faster than sprintf -#else - sprintf(line, "%d", ar.currentline); -#endif + char line[32]; // manual conversion for performance + char* lineend = line + sizeof(line); + char* lineptr = lineend; + for (unsigned int r = ar.currentline; r > 0; r /= 10) + *--lineptr = '0' + (r % 10); luaL_addchar(&buf, ':'); - luaL_addstring(&buf, line); + luaL_addlstring(&buf, lineptr, lineend - lineptr); } if (ar.name) diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index e050050e..c44ccbed 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauDebuggerBreakpointHitOnNextBestLine, false); + static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -84,7 +86,7 @@ const char* lua_setlocal(lua_State* L, int level, int n) const LocVar* var = fp ? luaF_getlocal(fp, n, currentpc(L, ci)) : NULL; if (var) setobjs2s(L, ci->base + var->reg, L->top - 1); - L->top--; /* pop value */ + L->top--; // pop value const char* name = var ? getstr(var->varname) : NULL; return name; } @@ -267,12 +269,24 @@ l_noret luaG_indexerror(lua_State* L, const TValue* p1, const TValue* p2) luaG_runerror(L, "attempt to index %s with %s", t1, t2); } +l_noret luaG_methoderror(lua_State* L, const TValue* p1, const TString* p2) +{ + const char* t1 = luaT_objtypename(L, p1); + + luaG_runerror(L, "attempt to call missing method '%s' of %s", getstr(p2), t1); +} + +l_noret luaG_readonlyerror(lua_State* L) +{ + luaG_runerror(L, "attempt to modify a readonly table"); +} + static void pusherror(lua_State* L, const char* msg) { CallInfo* ci = L->ci; if (isLua(ci)) { - char buff[LUA_IDSIZE]; /* add file:line information */ + char buff[LUA_IDSIZE]; // add file:line information luaO_chunkid(buff, getstr(getluaproto(ci)->source), LUA_IDSIZE); int line = currentline(L, ci); luaO_pushfstring(L, "%s:%d: %s", buff, line, msg); @@ -367,14 +381,6 @@ void lua_singlestep(lua_State* L, int enabled) L->singlestep = bool(enabled); } -void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) -{ - const TValue* func = luaA_toobject(L, funcindex); - api_check(L, ttisfunction(func) && !clvalue(func)->isC); - - luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); -} - static int getmaxline(Proto* p) { int result = -1; @@ -394,6 +400,71 @@ static int getmaxline(Proto* p) return result; } +// Find the line number with instructions. If the provided line doesn't have any instruction, it should return the next line number with +// instructions. +static int getnextline(Proto* p, int line) +{ + int closest = -1; + if (p->lineinfo) + { + for (int i = 0; i < p->sizecode; ++i) + { + // note: we keep prologue as is, instead opting to break at the first meaningful instruction + if (LUAU_INSN_OP(p->code[i]) == LOP_PREPVARARGS) + continue; + + int current = luaG_getline(p, i); + if (current >= line) + { + closest = current; + break; + } + } + } + + for (int i = 0; i < p->sizep; ++i) + { + // Find the closest line number to the intended one. + int candidate = getnextline(p->p[i], line); + if (closest == -1 || (candidate >= line && candidate < closest)) + { + closest = candidate; + } + } + + return closest; +} + +int lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) +{ + int target = -1; + + if (FFlag::LuauDebuggerBreakpointHitOnNextBestLine) + { + const TValue* func = luaA_toobject(L, funcindex); + api_check(L, ttisfunction(func) && !clvalue(func)->isC); + + Proto* p = clvalue(func)->l.p; + // Find line number to add the breakpoint to. + target = getnextline(p, line); + + if (target != -1) + { + // Add breakpoint on the exact line + luaG_breakpoint(L, p, target, bool(enabled)); + } + } + else + { + const TValue* func = luaA_toobject(L, funcindex); + api_check(L, ttisfunction(func) && !clvalue(func)->isC); + + luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); + } + + return target; +} + static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* context, lua_Coverage callback) { memset(buffer, -1, size * sizeof(int)); @@ -465,7 +536,7 @@ const char* lua_debugtrace(lua_State* L) if (ar.currentline > 0) { char line[32]; - sprintf(line, ":%d", ar.currentline); + snprintf(line, sizeof(line), ":%d", ar.currentline); offset = append(buf, sizeof(buf), offset, line); } @@ -481,7 +552,7 @@ const char* lua_debugtrace(lua_State* L) if (depth > limit1 + limit2 && level == limit1 - 1) { char skip[32]; - sprintf(skip, "... (+%d frames)\n", int(depth - limit1 - limit2)); + snprintf(skip, sizeof(skip), "... (+%d frames)\n", int(depth - limit1 - limit2)); offset = append(buf, sizeof(buf), offset, skip); diff --git a/VM/src/ldebug.h b/VM/src/ldebug.h index 75bb8dcc..a93e412f 100644 --- a/VM/src/ldebug.h +++ b/VM/src/ldebug.h @@ -19,6 +19,8 @@ LUAI_FUNC l_noret luaG_concaterror(lua_State* L, StkId p1, StkId p2); LUAI_FUNC l_noret luaG_aritherror(lua_State* L, const TValue* p1, const TValue* p2, TMS op); LUAI_FUNC l_noret luaG_ordererror(lua_State* L, const TValue* p1, const TValue* p2, TMS op); LUAI_FUNC l_noret luaG_indexerror(lua_State* L, const TValue* p1, const TValue* p2); +LUAI_FUNC l_noret luaG_methoderror(lua_State* L, const TValue* p1, const TString* p2); +LUAI_FUNC l_noret luaG_readonlyerror(lua_State* L); LUAI_FUNC LUA_PRINTF_ATTR(2, 3) l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...); LUAI_FUNC void luaG_pusherror(lua_State* L, const char* error); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 0642cb6d..6016e41f 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -31,7 +31,7 @@ struct lua_jmpbuf jmp_buf buf; }; -/* use POSIX versions of setjmp/longjmp if possible: they don't save/restore signal mask and are therefore faster */ +// use POSIX versions of setjmp/longjmp if possible: they don't save/restore signal mask and are therefore faster #if defined(__linux__) || defined(__APPLE__) #define LUAU_SETJMP(buf) _setjmp(buf) #define LUAU_LONGJMP(buf, code) _longjmp(buf, code) @@ -153,7 +153,7 @@ l_noret luaD_throw(lua_State* L, int errcode) } #endif -/* }====================================================== */ +// }====================================================== static void correctstack(lua_State* L, TValue* oldstack) { @@ -177,7 +177,7 @@ void luaD_reallocstack(lua_State* L, int newsize) luaM_reallocarray(L, L->stack, L->stacksize, realsize, TValue, L->memcat); TValue* newstack = L->stack; for (int i = L->stacksize; i < realsize; i++) - setnilvalue(newstack + i); /* erase new segment */ + setnilvalue(newstack + i); // erase new segment L->stacksize = realsize; L->stack_last = newstack + newsize; correctstack(L, oldstack); @@ -194,7 +194,7 @@ void luaD_reallocCI(lua_State* L, int newsize) void luaD_growstack(lua_State* L, int n) { - if (n <= L->stacksize) /* double size is enough? */ + if (n <= L->stacksize) // double size is enough? luaD_reallocstack(L, 2 * L->stacksize); else luaD_reallocstack(L, L->stacksize + n); @@ -202,11 +202,11 @@ void luaD_growstack(lua_State* L, int n) CallInfo* luaD_growCI(lua_State* L) { - /* allow extra stack space to handle stack overflow in xpcall */ + // allow extra stack space to handle stack overflow in xpcall const int hardlimit = LUAI_MAXCALLS + (LUAI_MAXCALLS >> 3); if (L->size_ci >= hardlimit) - luaD_throw(L, LUA_ERRERR); /* error while handling stack error */ + luaD_throw(L, LUA_ERRERR); // error while handling stack error int request = L->size_ci * 2; luaD_reallocCI(L, L->size_ci >= LUAI_MAXCALLS ? hardlimit : request < LUAI_MAXCALLS ? request : LUAI_MAXCALLS); @@ -219,13 +219,13 @@ CallInfo* luaD_growCI(lua_State* L) void luaD_checkCstack(lua_State* L) { - /* allow extra stack space to handle stack overflow in xpcall */ + // allow extra stack space to handle stack overflow in xpcall const int hardlimit = LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3); if (L->nCcalls == LUAI_MAXCCALLS) luaG_runerror(L, "C stack overflow"); else if (L->nCcalls >= hardlimit) - luaD_throw(L, LUA_ERRERR); /* error while handling stack error */ + luaD_throw(L, LUA_ERRERR); // error while handling stack error } /* @@ -240,14 +240,14 @@ void luaD_call(lua_State* L, StkId func, int nResults) luaD_checkCstack(L); if (luau_precall(L, func, nResults) == PCRLUA) - { /* is a Lua function? */ - L->ci->flags |= LUA_CALLINFO_RETURN; /* luau_execute will stop after returning from the stack frame */ + { // is a Lua function? + L->ci->flags |= LUA_CALLINFO_RETURN; // luau_execute will stop after returning from the stack frame int oldactive = luaC_threadactive(L); l_setbit(L->stackstate, THREAD_ACTIVEBIT); luaC_checkthreadsleep(L); - luau_execute(L); /* call it */ + luau_execute(L); // call it if (!oldactive) resetbit(L->stackstate, THREAD_ACTIVEBIT); @@ -263,18 +263,18 @@ static void seterrorobj(lua_State* L, int errcode, StkId oldtop) { case LUA_ERRMEM: { - setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_MEMERRMSG)); /* can not fail because string is pinned in luaopen */ + setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_MEMERRMSG)); // can not fail because string is pinned in luaopen break; } case LUA_ERRERR: { - setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_ERRERRMSG)); /* can not fail because string is pinned in luaopen */ + setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_ERRERRMSG)); // can not fail because string is pinned in luaopen break; } case LUA_ERRSYNTAX: case LUA_ERRRUN: { - setobjs2s(L, oldtop, L->top - 1); /* error message on current top */ + setobjs2s(L, oldtop, L->top - 1); // error message on current top break; } } @@ -430,8 +430,8 @@ static void resume_finish(lua_State* L, int status) resetbit(L->stackstate, THREAD_ACTIVEBIT); if (status != 0) - { /* error? */ - L->status = cast_byte(status); /* mark thread as `dead' */ + { // error? + L->status = cast_byte(status); // mark thread as `dead' seterrorobj(L, status, L->top); L->ci->top = L->top; } @@ -503,7 +503,7 @@ int lua_yield(lua_State* L, int nresults) { if (L->nCcalls > L->baseCcalls) luaG_runerror(L, "attempt to yield across metamethod/C-call boundary"); - L->base = L->top - nresults; /* protect stack slots below */ + L->base = L->top - nresults; // protect stack slots below L->status = LUA_YIELD; return -1; } @@ -535,9 +535,9 @@ static void restore_stack_limit(lua_State* L) { LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK); if (L->size_ci > LUAI_MAXCALLS) - { /* there was an overflow? */ + { // there was an overflow? int inuse = cast_int(L->ci - L->base_ci); - if (inuse + 1 < LUAI_MAXCALLS) /* can `undo' overflow? */ + if (inuse + 1 < LUAI_MAXCALLS) // can `undo' overflow? luaD_reallocCI(L, LUAI_MAXCALLS); } } @@ -576,7 +576,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e } StkId oldtop = restorestack(L, old_top); - luaF_close(L, oldtop); /* close eventual pending closures */ + luaF_close(L, oldtop); // close eventual pending closures seterrorobj(L, status, oldtop); L->ci = restoreci(L, old_ci); L->base = L->ci->base; diff --git a/VM/src/ldo.h b/VM/src/ldo.h index 5e9472bf..eac9927c 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -34,12 +34,12 @@ #define saveci(L, p) ((char*)(p) - (char*)L->base_ci) #define restoreci(L, n) ((CallInfo*)((char*)L->base_ci + (n))) -/* results from luaD_precall */ -#define PCRLUA 0 /* initiated a call to a Lua function */ -#define PCRC 1 /* did a call to a C function */ -#define PCRYIELD 2 /* C function yielded */ +// results from luaD_precall +#define PCRLUA 0 // initiated a call to a Lua function +#define PCRC 1 // did a call to a C function +#define PCRYIELD 2 // C function yielded -/* type of protected functions, to be ran by `runprotected' */ +// type of protected functions, to be ran by `runprotected' typedef void (*Pfunc)(lua_State* L, void* ud); LUAI_FUNC CallInfo* luaD_growCI(lua_State* L); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 66447a95..dfde6dcb 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -73,20 +73,20 @@ UpVal* luaF_findupval(lua_State* L, StkId level) { LUAU_ASSERT(p->v != &p->u.value); if (p->v == level) - { /* found a corresponding upvalue? */ - if (isdead(g, obj2gco(p))) /* is it dead? */ - changewhite(obj2gco(p)); /* resurrect it */ + { // found a corresponding upvalue? + if (isdead(g, obj2gco(p))) // is it dead? + changewhite(obj2gco(p)); // resurrect it return p; } pp = &p->u.l.threadnext; } - UpVal* uv = luaM_newgco(L, UpVal, sizeof(UpVal), L->activememcat); /* not found: create a new one */ + UpVal* uv = luaM_newgco(L, UpVal, sizeof(UpVal), L->activememcat); // not found: create a new one uv->tt = LUA_TUPVAL; uv->marked = luaC_white(g); uv->memcat = L->activememcat; - uv->v = level; /* current value lives in the stack */ + uv->v = level; // current value lives in the stack // chain the upvalue in the threads open upvalue list at the proper position UpVal* next = *pp; @@ -121,9 +121,9 @@ void luaF_unlinkupval(UpVal* uv) void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) { - if (uv->v != &uv->u.value) /* is it open? */ - luaF_unlinkupval(uv); /* remove from open list */ - luaM_freegco(L, uv, sizeof(UpVal), uv->memcat, page); /* free upvalue */ + if (uv->v != &uv->u.value) // is it open? + luaF_unlinkupval(uv); // remove from open list + luaM_freegco(L, uv, sizeof(UpVal), uv->memcat, page); // free upvalue } void luaF_close(lua_State* L, StkId level) @@ -179,11 +179,11 @@ const LocVar* luaF_getlocal(const Proto* f, int local_number, int pc) for (i = 0; i < f->sizelocvars; i++) { if (pc >= f->locvars[i].startpc && pc < f->locvars[i].endpc) - { /* is variable active? */ + { // is variable active? local_number--; if (local_number == 0) return &f->locvars[i]; } } - return NULL; /* not found */ + return NULL; // not found } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 70b4dbf9..f7a851f4 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -125,7 +125,7 @@ static void removeentry(LuaNode* n) { LUAU_ASSERT(ttisnil(gval(n))); if (iscollectable(gkey(n))) - setttype(gkey(n), LUA_TDEADKEY); /* dead key; remove it */ + setttype(gkey(n), LUA_TDEADKEY); // dead key; remove it } static void reallymarkobject(global_State* g, GCObject* o) @@ -141,7 +141,7 @@ static void reallymarkobject(global_State* g, GCObject* o) case LUA_TUSERDATA: { Table* mt = gco2u(o)->metatable; - gray2black(o); /* udata are never gray */ + gray2black(o); // udata are never gray if (mt) markobject(g, mt); return; @@ -150,8 +150,8 @@ static void reallymarkobject(global_State* g, GCObject* o) { UpVal* uv = gco2uv(o); markvalue(g, uv->v); - if (uv->v == &uv->u.value) /* closed? */ - gray2black(o); /* open upvalues are never black */ + if (uv->v == &uv->u.value) // closed? + gray2black(o); // open upvalues are never black return; } case LUA_TFUNCTION: @@ -201,15 +201,15 @@ static int traversetable(global_State* g, Table* h) if (h->metatable) markobject(g, cast_to(Table*, h->metatable)); - /* is there a weak mode? */ + // is there a weak mode? if (const char* modev = gettablemode(g, h)) { weakkey = (strchr(modev, 'k') != NULL); weakvalue = (strchr(modev, 'v') != NULL); if (weakkey || weakvalue) - { /* is really weak? */ - h->gclist = g->weak; /* must be cleared after GC, ... */ - g->weak = obj2gco(h); /* ... so put in the appropriate list */ + { // is really weak? + h->gclist = g->weak; // must be cleared after GC, ... + g->weak = obj2gco(h); // ... so put in the appropriate list } } @@ -227,7 +227,7 @@ static int traversetable(global_State* g, Table* h) LuaNode* n = gnode(h, i); LUAU_ASSERT(ttype(gkey(n)) != LUA_TDEADKEY || ttisnil(gval(n))); if (ttisnil(gval(n))) - removeentry(n); /* remove empty entries */ + removeentry(n); // remove empty entries else { LUAU_ASSERT(!ttisnil(gkey(n))); @@ -251,20 +251,20 @@ static void traverseproto(global_State* g, Proto* f) stringmark(f->source); if (f->debugname) stringmark(f->debugname); - for (i = 0; i < f->sizek; i++) /* mark literals */ + for (i = 0; i < f->sizek; i++) // mark literals markvalue(g, &f->k[i]); for (i = 0; i < f->sizeupvalues; i++) - { /* mark upvalue names */ + { // mark upvalue names if (f->upvalues[i]) stringmark(f->upvalues[i]); } for (i = 0; i < f->sizep; i++) - { /* mark nested protos */ + { // mark nested protos if (f->p[i]) markobject(g, f->p[i]); } for (i = 0; i < f->sizelocvars; i++) - { /* mark local-variable names */ + { // mark local-variable names if (f->locvars[i].varname) stringmark(f->locvars[i].varname); } @@ -276,7 +276,7 @@ static void traverseclosure(global_State* g, Closure* cl) if (cl->isC) { int i; - for (i = 0; i < cl->nupvalues; i++) /* mark its upvalues */ + for (i = 0; i < cl->nupvalues; i++) // mark its upvalues markvalue(g, &cl->c.upvals[i]); } else @@ -284,7 +284,7 @@ static void traverseclosure(global_State* g, Closure* cl) int i; LUAU_ASSERT(cl->nupvalues == cl->l.p->nups); markobject(g, cast_to(Proto*, cl->l.p)); - for (i = 0; i < cl->nupvalues; i++) /* mark its upvalues */ + for (i = 0; i < cl->nupvalues; i++) // mark its upvalues markvalue(g, &cl->l.uprefs[i]); } } @@ -296,11 +296,11 @@ static void traversestack(global_State* g, lua_State* l, bool clearstack) stringmark(l->namecall); for (StkId o = l->stack; o < l->top; o++) markvalue(g, o); - /* final traversal? */ + // final traversal? if (g->gcstate == GCSatomic || clearstack) { StkId stack_end = l->stack + l->stacksize; - for (StkId o = l->top; o < stack_end; o++) /* clear not-marked stack slice */ + for (StkId o = l->top; o < stack_end; o++) // clear not-marked stack slice setnilvalue(o); } } @@ -320,8 +320,8 @@ static size_t propagatemark(global_State* g) { Table* h = gco2h(o); g->gray = h->gclist; - if (traversetable(g, h)) /* table is weak? */ - black2gray(o); /* keep it gray */ + if (traversetable(g, h)) // table is weak? + black2gray(o); // keep it gray return sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); } case LUA_TFUNCTION: @@ -393,7 +393,7 @@ static int isobjcleared(GCObject* o) { if (o->gch.tt == LUA_TSTRING) { - stringmark(&o->ts); /* strings are `values', so are never weak */ + stringmark(&o->ts); // strings are `values', so are never weak return 0; } @@ -417,8 +417,8 @@ static size_t cleartable(lua_State* L, GCObject* l) while (i--) { TValue* o = &h->array[i]; - if (iscleared(o)) /* value was collected? */ - setnilvalue(o); /* remove value */ + if (iscleared(o)) // value was collected? + setnilvalue(o); // remove value } i = sizenode(h); int activevalues = 0; @@ -432,8 +432,8 @@ static size_t cleartable(lua_State* L, GCObject* l) // can we clear key or value? if (iscleared(gkey(n)) || iscleared(gval(n))) { - setnilvalue(gval(n)); /* remove value ... */ - removeentry(n); /* remove entry from table */ + setnilvalue(gval(n)); // remove value ... + removeentry(n); // remove entry from table } else { @@ -460,7 +460,7 @@ static size_t cleartable(lua_State* L, GCObject* l) static void shrinkstack(lua_State* L) { - /* compute used stack - note that we can't use th->top if we're in the middle of vararg call */ + // compute used stack - note that we can't use th->top if we're in the middle of vararg call StkId lim = L->top; for (CallInfo* ci = L->base_ci; ci <= L->ci; ci++) { @@ -469,16 +469,16 @@ static void shrinkstack(lua_State* L) lim = ci->top; } - /* shrink stack and callinfo arrays if we aren't using most of the space */ - int ci_used = cast_int(L->ci - L->base_ci); /* number of `ci' in use */ - int s_used = cast_int(lim - L->stack); /* part of stack in use */ - if (L->size_ci > LUAI_MAXCALLS) /* handling overflow? */ - return; /* do not touch the stacks */ + // shrink stack and callinfo arrays if we aren't using most of the space + int ci_used = cast_int(L->ci - L->base_ci); // number of `ci' in use + int s_used = cast_int(lim - L->stack); // part of stack in use + if (L->size_ci > LUAI_MAXCALLS) // handling overflow? + return; // do not touch the stacks if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci) - luaD_reallocCI(L, L->size_ci / 2); /* still big enough... */ + luaD_reallocCI(L, L->size_ci / 2); // still big enough... condhardstacktests(luaD_reallocCI(L, ci_used + 1)); if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) - luaD_reallocstack(L, L->stacksize / 2); /* still big enough... */ + luaD_reallocstack(L, L->stacksize / 2); // still big enough... condhardstacktests(luaD_reallocstack(L, s_used)); } @@ -516,20 +516,20 @@ static void freeobj(lua_State* L, GCObject* o, lua_Page* page) static void shrinkbuffers(lua_State* L) { global_State* g = L->global; - /* check size of string hash */ + // check size of string hash if (g->strt.nuse < cast_to(uint32_t, g->strt.size / 4) && g->strt.size > LUA_MINSTRTABSIZE * 2) - luaS_resize(L, g->strt.size / 2); /* table is too big */ + luaS_resize(L, g->strt.size / 2); // table is too big } static void shrinkbuffersfull(lua_State* L) { global_State* g = L->global; - /* check size of string hash */ + // check size of string hash int hashsize = g->strt.size; while (g->strt.nuse < cast_to(uint32_t, hashsize / 4) && hashsize > LUA_MINSTRTABSIZE * 2) hashsize /= 2; if (hashsize != g->strt.size) - luaS_resize(L, hashsize); /* table is too big */ + luaS_resize(L, hashsize); // table is too big } static bool deletegco(void* context, lua_Page* page, GCObject* gco) @@ -562,7 +562,7 @@ void luaC_freeall(lua_State* L) luaM_visitgco(L, L, deletegco); - for (int i = 0; i < g->strt.size; i++) /* free all string lists */ + for (int i = 0; i < g->strt.size; i++) // free all string lists LUAU_ASSERT(g->strt.hash[i] == NULL); LUAU_ASSERT(L->global->strt.nuse == 0); @@ -577,7 +577,7 @@ static void markmt(global_State* g) markobject(g, g->mt[i]); } -/* mark root set */ +// mark root set static void markroot(lua_State* L) { global_State* g = L->global; @@ -585,7 +585,7 @@ static void markroot(lua_State* L) g->grayagain = NULL; g->weak = NULL; markobject(g, g->mainthread); - /* make global table be traversed before main stack */ + // make global table be traversed before main stack markobject(g, g->mainthread->gt); markvalue(g, registry(L)); markmt(g); @@ -616,28 +616,28 @@ static size_t atomic(lua_State* L) double currts = lua_clock(); #endif - /* remark occasional upvalues of (maybe) dead threads */ + // remark occasional upvalues of (maybe) dead threads work += remarkupvals(g); - /* traverse objects caught by write barrier and by 'remarkupvals' */ + // traverse objects caught by write barrier and by 'remarkupvals' work += propagateall(g); #ifdef LUAI_GCMETRICS g->gcmetrics.currcycle.atomictimeupval += recordGcDeltaTime(currts); #endif - /* remark weak tables */ + // remark weak tables g->gray = g->weak; g->weak = NULL; LUAU_ASSERT(!iswhite(obj2gco(g->mainthread))); - markobject(g, L); /* mark running thread */ - markmt(g); /* mark basic metatables (again) */ + markobject(g, L); // mark running thread + markmt(g); // mark basic metatables (again) work += propagateall(g); #ifdef LUAI_GCMETRICS g->gcmetrics.currcycle.atomictimeweak += recordGcDeltaTime(currts); #endif - /* remark gray again */ + // remark gray again g->gray = g->grayagain; g->grayagain = NULL; work += propagateall(g); @@ -646,7 +646,7 @@ static size_t atomic(lua_State* L) g->gcmetrics.currcycle.atomictimegray += recordGcDeltaTime(currts); #endif - /* remove collected objects from weak tables */ + // remove collected objects from weak tables work += cleartable(L, g->weak); g->weak = NULL; @@ -654,7 +654,7 @@ static size_t atomic(lua_State* L) g->gcmetrics.currcycle.atomictimeclear += recordGcDeltaTime(currts); #endif - /* flip current white */ + // flip current white g->currentwhite = cast_byte(otherwhite(g)); g->sweepgcopage = g->allgcopages; g->gcstate = GCSsweep; @@ -733,7 +733,7 @@ static size_t gcstep(lua_State* L, size_t limit) { case GCSpause: { - markroot(L); /* start a new collection */ + markroot(L); // start a new collection LUAU_ASSERT(g->gcstate == GCSpropagate); break; } @@ -765,7 +765,7 @@ static size_t gcstep(lua_State* L, size_t limit) cost += propagatemark(g); } - if (!g->gray) /* no more `gray' objects */ + if (!g->gray) // no more `gray' objects { #ifdef LUAI_GCMETRICS g->gcmetrics.currcycle.propagateagainwork = @@ -786,7 +786,7 @@ static size_t gcstep(lua_State* L, size_t limit) g->gcstats.atomicstarttimestamp = lua_clock(); g->gcstats.atomicstarttotalsizebytes = g->totalbytes; - cost = atomic(L); /* finish mark phase */ + cost = atomic(L); // finish mark phase LUAU_ASSERT(g->gcstate == GCSsweep); break; @@ -810,7 +810,7 @@ static size_t gcstep(lua_State* L, size_t limit) sweepgco(L, NULL, obj2gco(g->mainthread)); shrinkbuffers(L); - g->gcstate = GCSpause; /* end collection */ + g->gcstate = GCSpause; // end collection } break; } @@ -878,7 +878,7 @@ size_t luaC_step(lua_State* L, bool assist) { global_State* g = L->global; - int lim = g->gcstepsize * g->gcstepmul / 100; /* how much to work */ + int lim = g->gcstepsize * g->gcstepmul / 100; // how much to work LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -947,16 +947,16 @@ void luaC_fullgc(lua_State* L) if (g->gcstate <= GCSatomic) { - /* reset sweep marks to sweep all elements (returning them to white) */ + // reset sweep marks to sweep all elements (returning them to white) g->sweepgcopage = g->allgcopages; - /* reset other collector lists */ + // reset other collector lists g->gray = NULL; g->grayagain = NULL; g->weak = NULL; g->gcstate = GCSsweep; } LUAU_ASSERT(g->gcstate == GCSsweep); - /* finish any pending sweep phase */ + // finish any pending sweep phase while (g->gcstate != GCSpause) { LUAU_ASSERT(g->gcstate == GCSsweep); @@ -968,13 +968,13 @@ void luaC_fullgc(lua_State* L) startGcCycleMetrics(g); #endif - /* run a full collection cycle */ + // run a full collection cycle markroot(L); while (g->gcstate != GCSpause) { gcstep(L, SIZE_MAX); } - /* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */ + // reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) shrinkbuffersfull(L); size_t heapgoalsizebytes = (g->totalbytes / 100) * g->gcgoal; @@ -1011,11 +1011,11 @@ void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v) global_State* g = L->global; LUAU_ASSERT(isblack(o) && iswhite(v) && !isdead(g, v) && !isdead(g, o)); LUAU_ASSERT(g->gcstate != GCSpause); - /* must keep invariant? */ + // must keep invariant? if (keepinvariant(g)) - reallymarkobject(g, v); /* restore invariant */ - else /* don't mind */ - makewhite(g, o); /* mark as white just to avoid other barriers */ + reallymarkobject(g, v); // restore invariant + else // don't mind + makewhite(g, o); // mark as white just to avoid other barriers } void luaC_barriertable(lua_State* L, Table* t, GCObject* v) @@ -1033,7 +1033,7 @@ void luaC_barriertable(lua_State* L, Table* t, GCObject* v) LUAU_ASSERT(isblack(o) && !isdead(g, o)); LUAU_ASSERT(g->gcstate != GCSpause); - black2gray(o); /* make table gray (again) */ + black2gray(o); // make table gray (again) t->gclist = g->grayagain; g->grayagain = o; } @@ -1044,7 +1044,7 @@ void luaC_barrierback(lua_State* L, Table* t) GCObject* o = obj2gco(t); LUAU_ASSERT(isblack(o) && !isdead(g, o)); LUAU_ASSERT(g->gcstate != GCSpause); - black2gray(o); /* make table gray (again) */ + black2gray(o); // make table gray (again) t->gclist = g->grayagain; g->grayagain = o; } @@ -1066,11 +1066,11 @@ void luaC_initupval(lua_State* L, UpVal* uv) { if (keepinvariant(g)) { - gray2black(o); /* closed upvalues need barrier */ + gray2black(o); // closed upvalues need barrier luaC_barrier(L, uv, uv->v); } else - { /* sweep phase: sweep it (turning it into white) */ + { // sweep phase: sweep it (turning it into white) makewhite(g, o); LUAU_ASSERT(g->gcstate != GCSpause); } diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 797284a2..7b03a25d 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -9,9 +9,9 @@ /* ** Default settings for GC tunables (settable via lua_gc) */ -#define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ -#define LUAI_GCSTEPMUL 200 /* GC runs 'twice the speed' of memory allocation */ -#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ +#define LUAI_GCGOAL 200 // 200% (allow heap to double compared to live heap size) +#define LUAI_GCSTEPMUL 200 // GC runs 'twice the speed' of memory allocation +#define LUAI_GCSTEPSIZE 1 // GC runs every KB of memory allocation /* ** Possible states of the Garbage Collector diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index 2b38619b..bc997d44 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -19,7 +19,7 @@ static void validateobjref(global_State* g, GCObject* f, GCObject* t) if (keepinvariant(g)) { - /* basic incremental invariant: black can't point to white */ + // basic incremental invariant: black can't point to white LUAU_ASSERT(!(isblack(f) && iswhite(t))); } } @@ -135,7 +135,7 @@ static void validateproto(global_State* g, Proto* f) static void validateobj(global_State* g, GCObject* o) { - /* dead objects can only occur during sweep */ + // dead objects can only occur during sweep if (isdead(g, o)) { LUAU_ASSERT(g->gcstate == GCSsweep); diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index a6e7b494..0693b846 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -195,7 +195,7 @@ static int math_ldexp(lua_State* L) static int math_min(lua_State* L) { - int n = lua_gettop(L); /* number of arguments */ + int n = lua_gettop(L); // number of arguments double dmin = luaL_checknumber(L, 1); int i; for (i = 2; i <= n; i++) @@ -210,7 +210,7 @@ static int math_min(lua_State* L) static int math_max(lua_State* L) { - int n = lua_gettop(L); /* number of arguments */ + int n = lua_gettop(L); // number of arguments double dmax = luaL_checknumber(L, 1); int i; for (i = 2; i <= n; i++) @@ -227,29 +227,29 @@ static int math_random(lua_State* L) { global_State* g = L->global; switch (lua_gettop(L)) - { /* check number of arguments */ + { // check number of arguments case 0: - { /* no arguments */ + { // no arguments // Using ldexp instead of division for speed & clarity. // See http://mumble.net/~campbell/tmp/random_real.c for details on generating doubles from integer ranges. uint32_t rl = pcg32_random(&g->rngstate); uint32_t rh = pcg32_random(&g->rngstate); double rd = ldexp(double(rl | (uint64_t(rh) << 32)), -64); - lua_pushnumber(L, rd); /* number between 0 and 1 */ + lua_pushnumber(L, rd); // number between 0 and 1 break; } case 1: - { /* only upper limit */ + { // only upper limit int u = luaL_checkinteger(L, 1); luaL_argcheck(L, 1 <= u, 1, "interval is empty"); uint64_t x = uint64_t(u) * pcg32_random(&g->rngstate); int r = int(1 + (x >> 32)); - lua_pushinteger(L, r); /* int between 1 and `u' */ + lua_pushinteger(L, r); // int between 1 and `u' break; } case 2: - { /* lower and upper limits */ + { // lower and upper limits int l = luaL_checkinteger(L, 1); int u = luaL_checkinteger(L, 2); luaL_argcheck(L, l <= u, 2, "interval is empty"); @@ -258,7 +258,7 @@ static int math_random(lua_State* L) luaL_argcheck(L, ul < UINT_MAX, 2, "interval is too large"); // -INT_MIN..INT_MAX interval can result in integer overflow uint64_t x = uint64_t(ul + 1) * pcg32_random(&g->rngstate); int r = int(l + (x >> 32)); - lua_pushinteger(L, r); /* int between `l' and `u' */ + lua_pushinteger(L, r); // int between `l' and `u' break; } default: diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index 549b4630..5b27e2b8 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -42,7 +42,7 @@ LUAU_FASTMATH_END #define luai_num2int(i, d) ((i) = (int)(d)) -/* On MSVC in 32-bit, double to unsigned cast compiles into a call to __dtoui3, so we invoke x87->int64 conversion path manually */ +// On MSVC in 32-bit, double to unsigned cast compiles into a call to __dtoui3, so we invoke x87->int64 conversion path manually #if defined(_MSC_VER) && defined(_M_IX86) #define luai_num2unsigned(i, n) \ { \ diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index d5bd76a8..b6a40bb6 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -48,7 +48,7 @@ int luaO_rawequalObj(const TValue* t1, const TValue* t2) case LUA_TVECTOR: return luai_veceq(vvalue(t1), vvalue(t2)); case LUA_TBOOLEAN: - return bvalue(t1) == bvalue(t2); /* boolean true must be 1 !! */ + return bvalue(t1) == bvalue(t2); // boolean true must be 1 !! case LUA_TLIGHTUSERDATA: return pvalue(t1) == pvalue(t2); default: @@ -71,7 +71,7 @@ int luaO_rawequalKey(const TKey* t1, const TValue* t2) case LUA_TVECTOR: return luai_veceq(vvalue(t1), vvalue(t2)); case LUA_TBOOLEAN: - return bvalue(t1) == bvalue(t2); /* boolean true must be 1 !! */ + return bvalue(t1) == bvalue(t2); // boolean true must be 1 !! case LUA_TLIGHTUSERDATA: return pvalue(t1) == pvalue(t2); default: @@ -85,15 +85,15 @@ int luaO_str2d(const char* s, double* result) char* endptr; *result = luai_str2num(s, &endptr); if (endptr == s) - return 0; /* conversion failed */ - if (*endptr == 'x' || *endptr == 'X') /* maybe an hexadecimal constant? */ + return 0; // conversion failed + if (*endptr == 'x' || *endptr == 'X') // maybe an hexadecimal constant? *result = cast_num(strtoul(s, &endptr, 16)); if (*endptr == '\0') - return 1; /* most common case */ + return 1; // most common case while (isspace(cast_to(unsigned char, *endptr))) endptr++; if (*endptr != '\0') - return 0; /* invalid trailing characters? */ + return 0; // invalid trailing characters? return 1; } @@ -121,7 +121,7 @@ void luaO_chunkid(char* out, const char* source, size_t bufflen) { if (*source == '=') { - source++; /* skip the `=' */ + source++; // skip the `=' size_t srclen = strlen(source); size_t dstlen = srclen < bufflen ? srclen : bufflen - 1; memcpy(out, source, dstlen); @@ -130,26 +130,26 @@ void luaO_chunkid(char* out, const char* source, size_t bufflen) else if (*source == '@') { size_t l; - source++; /* skip the `@' */ + source++; // skip the `@' bufflen -= sizeof("..."); l = strlen(source); strcpy(out, ""); if (l > bufflen) { - source += (l - bufflen); /* get last part of file name */ + source += (l - bufflen); // get last part of file name strcat(out, "..."); } strcat(out, source); } else - { /* out = [string "string"] */ - size_t len = strcspn(source, "\n\r"); /* stop at first newline */ + { // out = [string "string"] + size_t len = strcspn(source, "\n\r"); // stop at first newline bufflen -= sizeof("[string \"...\"]"); if (len > bufflen) len = bufflen; strcpy(out, "[string \""); if (source[len] != '\0') - { /* must truncate? */ + { // must truncate? strncat(out, source, len); strcat(out, "..."); } diff --git a/VM/src/lobject.h b/VM/src/lobject.h index bdcb85cb..2097e335 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -49,7 +49,7 @@ typedef struct lua_TValue int tt; } TValue; -/* Macros to test type */ +// Macros to test type #define ttisnil(o) (ttype(o) == LUA_TNIL) #define ttisnumber(o) (ttype(o) == LUA_TNUMBER) #define ttisstring(o) (ttype(o) == LUA_TSTRING) @@ -62,7 +62,7 @@ typedef struct lua_TValue #define ttisvector(o) (ttype(o) == LUA_TVECTOR) #define ttisupval(o) (ttype(o) == LUA_TUPVAL) -/* Macros to access values */ +// Macros to access values #define ttype(o) ((o)->tt) #define gcvalue(o) check_exp(iscollectable(o), (o)->value.gc) #define pvalue(o) check_exp(ttislightuserdata(o), (o)->value.p) @@ -85,7 +85,7 @@ typedef struct lua_TValue #define checkliveness(g, obj) LUAU_ASSERT(!iscollectable(obj) || ((ttype(obj) == (obj)->value.gc->gch.tt) && !isdead(g, (obj)->value.gc))) -/* Macros to set values */ +// Macros to set values #define setnilvalue(obj) ((obj)->tt = LUA_TNIL) #define setnvalue(obj, x) \ @@ -200,18 +200,18 @@ typedef struct lua_TValue ** different types of sets, according to destination */ -/* from stack to (same) stack */ +// from stack to (same) stack #define setobjs2s setobj -/* to stack (not from same stack) */ +// to stack (not from same stack) #define setobj2s setobj #define setsvalue2s setsvalue #define sethvalue2s sethvalue #define setptvalue2s setptvalue -/* from table to same table */ +// from table to same table #define setobjt2t setobj -/* to table */ +// to table #define setobj2t setobj -/* to new object */ +// to new object #define setobj2n setobj #define setsvalue2n setsvalue @@ -219,7 +219,7 @@ typedef struct lua_TValue #define iscollectable(o) (ttype(o) >= LUA_TSTRING) -typedef TValue* StkId; /* index to stack elements */ +typedef TValue* StkId; // index to stack elements /* ** String headers for string table @@ -269,13 +269,13 @@ typedef struct Proto CommonHeader; - TValue* k; /* constants used by the function */ - Instruction* code; /* function bytecode */ - struct Proto** p; /* functions defined inside the function */ - uint8_t* lineinfo; /* for each instruction, line number as a delta from baseline */ - int* abslineinfo; /* baseline line info, one entry for each 1<global, i_o); \ } -/* copy a value from a key */ +// copy a value from a key #define getnodekey(L, obj, node) \ { \ TValue* i_o = (obj); \ @@ -418,22 +418,22 @@ typedef struct Table CommonHeader; - uint8_t tmcache; /* 1<

tm_sec); setfield(L, "min", stm->tm_min); setfield(L, "hour", stm->tm_hour); @@ -122,7 +122,7 @@ static int os_date(lua_State* L) luaL_buffinit(L, &b); for (; *s; s++) { - if (*s != '%' || *(s + 1) == '\0') /* no conversion specifier? */ + if (*s != '%' || *(s + 1) == '\0') // no conversion specifier? { luaL_addchar(&b, *s); } @@ -133,7 +133,7 @@ static int os_date(lua_State* L) else { size_t reslen; - char buff[200]; /* should be big enough for any conversion result */ + char buff[200]; // should be big enough for any conversion result cc[1] = *(++s); reslen = strftime(buff, sizeof(buff), cc, stm); luaL_addlstring(&b, buff, reslen); @@ -147,13 +147,13 @@ static int os_date(lua_State* L) static int os_time(lua_State* L) { time_t t; - if (lua_isnoneornil(L, 1)) /* called without args? */ - t = time(NULL); /* get current time */ + if (lua_isnoneornil(L, 1)) // called without args? + t = time(NULL); // get current time else { struct tm ts; luaL_checktype(L, 1, LUA_TTABLE); - lua_settop(L, 1); /* make sure table is at the top */ + lua_settop(L, 1); // make sure table is at the top ts.tm_sec = getfield(L, "sec", 0); ts.tm_min = getfield(L, "min", 0); ts.tm_hour = getfield(L, "hour", 12); diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index fbc6fb1e..4489f840 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -21,22 +21,22 @@ typedef struct LG static void stack_init(lua_State* L1, lua_State* L) { - /* initialize CallInfo array */ + // initialize CallInfo array L1->base_ci = luaM_newarray(L, BASIC_CI_SIZE, CallInfo, L1->memcat); L1->ci = L1->base_ci; L1->size_ci = BASIC_CI_SIZE; L1->end_ci = L1->base_ci + L1->size_ci - 1; - /* initialize stack array */ + // initialize stack array L1->stack = luaM_newarray(L, BASIC_STACK_SIZE + EXTRA_STACK, TValue, L1->memcat); L1->stacksize = BASIC_STACK_SIZE + EXTRA_STACK; TValue* stack = L1->stack; for (int i = 0; i < BASIC_STACK_SIZE + EXTRA_STACK; i++) - setnilvalue(stack + i); /* erase new stack */ + setnilvalue(stack + i); // erase new stack L1->top = stack; L1->stack_last = stack + (L1->stacksize - EXTRA_STACK); - /* initialize first ci */ + // initialize first ci L1->ci->func = L1->top; - setnilvalue(L1->top++); /* `function' entry for this `ci' */ + setnilvalue(L1->top++); // `function' entry for this `ci' L1->base = L1->ci->base = L1->top; L1->ci->top = L1->top + LUA_MINSTACK; } @@ -53,13 +53,13 @@ static void freestack(lua_State* L, lua_State* L1) static void f_luaopen(lua_State* L, void* ud) { global_State* g = L->global; - stack_init(L, L); /* init stack */ - L->gt = luaH_new(L, 0, 2); /* table of globals */ - sethvalue(L, registry(L), luaH_new(L, 0, 2)); /* registry */ - luaS_resize(L, LUA_MINSTRTABSIZE); /* initial size of string table */ + stack_init(L, L); // init stack + L->gt = luaH_new(L, 0, 2); // table of globals + sethvalue(L, registry(L), luaH_new(L, 0, 2)); // registry + luaS_resize(L, LUA_MINSTRTABSIZE); // initial size of string table luaT_init(L); - luaS_fix(luaS_newliteral(L, LUA_MEMERRMSG)); /* pin to make sure we can always throw this error */ - luaS_fix(luaS_newliteral(L, LUA_ERRERRMSG)); /* pin to make sure we can always throw this error */ + luaS_fix(luaS_newliteral(L, LUA_MEMERRMSG)); // pin to make sure we can always throw this error + luaS_fix(luaS_newliteral(L, LUA_ERRERRMSG)); // pin to make sure we can always throw this error g->GCthreshold = 4 * g->totalbytes; } @@ -85,8 +85,8 @@ static void preinit_state(lua_State* L, global_State* g) static void close_state(lua_State* L) { global_State* g = L->global; - luaF_close(L, L->stack); /* close all upvalues for this thread */ - luaC_freeall(L); /* collect all objects */ + luaF_close(L, L->stack); // close all upvalues for this thread + luaC_freeall(L); // collect all objects LUAU_ASSERT(g->strbufgc == NULL); LUAU_ASSERT(g->strt.nuse == 0); luaM_freearray(L, L->global->strt.hash, L->global->strt.size, TString*, 0); @@ -110,8 +110,8 @@ lua_State* luaE_newthread(lua_State* L) luaC_init(L, L1, LUA_TTHREAD); preinit_state(L1, L->global); L1->activememcat = L->activememcat; // inherit the active memory category - stack_init(L1, L); /* init stack */ - L1->gt = L->gt; /* share table of globals */ + stack_init(L1, L); // init stack + L1->gt = L->gt; // share table of globals L1->singlestep = L->singlestep; LUAU_ASSERT(iswhite(obj2gco(L1))); return L1; @@ -119,7 +119,7 @@ lua_State* luaE_newthread(lua_State* L) void luaE_freethread(lua_State* L, lua_State* L1, lua_Page* page) { - luaF_close(L1, L1->stack); /* close all upvalues for this thread */ + luaF_close(L1, L1->stack); // close all upvalues for this thread LUAU_ASSERT(L1->openupval == NULL); global_State* g = L->global; if (g->cb.userthread) @@ -130,9 +130,9 @@ void luaE_freethread(lua_State* L, lua_State* L1, lua_Page* page) void lua_resetthread(lua_State* L) { - /* close upvalues before clearing anything */ + // close upvalues before clearing anything luaF_close(L, L->stack); - /* clear call frames */ + // clear call frames CallInfo* ci = L->base_ci; ci->func = L->stack; ci->base = ci->func + 1; @@ -141,12 +141,12 @@ void lua_resetthread(lua_State* L) L->ci = ci; if (L->size_ci != BASIC_CI_SIZE) luaD_reallocCI(L, BASIC_CI_SIZE); - /* clear thread state */ + // clear thread state L->status = LUA_OK; L->base = L->ci->base; L->top = L->ci->base; L->nCcalls = L->baseCcalls = 0; - /* clear thread stack */ + // clear thread stack if (L->stacksize != BASIC_STACK_SIZE + EXTRA_STACK) luaD_reallocstack(L, BASIC_STACK_SIZE); for (int i = 0; i < L->stacksize; i++) @@ -177,7 +177,7 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->mainthread = L; g->uvhead.u.l.prev = &g->uvhead; g->uvhead.u.l.next = &g->uvhead; - g->GCthreshold = 0; /* mark it as unfinished state */ + g->GCthreshold = 0; // mark it as unfinished state g->registryfree = 0; g->errorjmp = NULL; g->rngstate = 0; @@ -224,7 +224,7 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) if (luaD_rawrunprotected(L, f_luaopen, NULL) != 0) { - /* memory allocation error: free partial state */ + // memory allocation error: free partial state close_state(L); L = NULL; } @@ -233,7 +233,7 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) void lua_close(lua_State* L) { - L = L->global->mainthread; /* only the main thread can be closed */ - luaF_close(L, L->stack); /* close all upvalues for this thread */ + L = L->global->mainthread; // only the main thread can be closed + luaF_close(L, L->stack); // close all upvalues for this thread close_state(L); } diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 423514a7..72a09713 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -5,10 +5,10 @@ #include "lobject.h" #include "ltm.h" -/* registry */ +// registry #define registry(L) (&L->global->registry) -/* extra stack space to handle TM calls and some other extras */ +// extra stack space to handle TM calls and some other extras #define EXTRA_STACK 5 #define BASIC_CI_SIZE 8 @@ -20,7 +20,7 @@ typedef struct stringtable { TString** hash; - uint32_t nuse; /* number of elements */ + uint32_t nuse; // number of elements int size; } stringtable; // clang-format on @@ -57,18 +57,18 @@ typedef struct stringtable typedef struct CallInfo { - StkId base; /* base for this function */ - StkId func; /* function index in the stack */ - StkId top; /* top for this function */ + StkId base; // base for this function + StkId func; // function index in the stack + StkId top; // top for this function const Instruction* savedpc; - int nresults; /* expected number of results from this function */ - unsigned int flags; /* call frame flags, see LUA_CALLINFO_* */ + int nresults; // expected number of results from this function + unsigned int flags; // call frame flags, see LUA_CALLINFO_* } CallInfo; // clang-format on -#define LUA_CALLINFO_RETURN (1 << 0) /* should the interpreter return after returning from this callinfo? first frame must have this set */ -#define LUA_CALLINFO_HANDLE (1 << 1) /* should the error thrown during execution get handled by continuation from this callinfo? func must be C */ +#define LUA_CALLINFO_RETURN (1 << 0) // should the interpreter return after returning from this callinfo? first frame must have this set +#define LUA_CALLINFO_HANDLE (1 << 1) // should the error thrown during execution get handled by continuation from this callinfo? func must be C #define curr_func(L) (clvalue(L->ci->func)) #define ci_func(ci) (clvalue((ci)->func)) @@ -152,55 +152,55 @@ struct GCMetrics // clang-format off typedef struct global_State { - stringtable strt; /* hash table for strings */ + stringtable strt; // hash table for strings - lua_Alloc frealloc; /* function to reallocate memory */ - void* ud; /* auxiliary data to `frealloc' */ + lua_Alloc frealloc; // function to reallocate memory + void* ud; // auxiliary data to `frealloc' uint8_t currentwhite; - uint8_t gcstate; /* state of garbage collector */ + uint8_t gcstate; // state of garbage collector - GCObject* gray; /* list of gray objects */ - GCObject* grayagain; /* list of objects to be traversed atomically */ - GCObject* weak; /* list of weak tables (to be cleared) */ + GCObject* gray; // list of gray objects + GCObject* grayagain; // list of objects to be traversed atomically + GCObject* weak; // list of weak tables (to be cleared) TString* strbufgc; // list of all string buffer objects - size_t GCthreshold; // when totalbytes > GCthreshold; run GC step + size_t GCthreshold; // when totalbytes > GCthreshold, run GC step size_t totalbytes; // number of bytes currently allocated int gcgoal; // see LUAI_GCGOAL int gcstepmul; // see LUAI_GCSTEPMUL int gcstepsize; // see LUAI_GCSTEPSIZE struct lua_Page* freepages[LUA_SIZECLASSES]; // free page linked list for each size class for non-collectable objects - struct lua_Page* freegcopages[LUA_SIZECLASSES]; // free page linked list for each size class for collectable objects + struct lua_Page* freegcopages[LUA_SIZECLASSES]; // free page linked list for each size class for collectable objects struct lua_Page* allgcopages; // page linked list with all pages for all classes struct lua_Page* sweepgcopage; // position of the sweep in `allgcopages' - size_t memcatbytes[LUA_MEMORY_CATEGORIES]; /* total amount of memory used by each memory category */ + size_t memcatbytes[LUA_MEMORY_CATEGORIES]; // total amount of memory used by each memory category struct lua_State* mainthread; - UpVal uvhead; /* head of double-linked list of all open upvalues */ - struct Table* mt[LUA_T_COUNT]; /* metatables for basic types */ - TString* ttname[LUA_T_COUNT]; /* names for basic types */ - TString* tmname[TM_N]; /* array with tag-method names */ + UpVal uvhead; // head of double-linked list of all open upvalues + struct Table* mt[LUA_T_COUNT]; // metatables for basic types + TString* ttname[LUA_T_COUNT]; // names for basic types + TString* tmname[TM_N]; // array with tag-method names - TValue pseudotemp; /* storage for temporary values used in pseudo2addr */ + TValue pseudotemp; // storage for temporary values used in pseudo2addr - TValue registry; /* registry table, used by lua_ref and LUA_REGISTRYINDEX */ - int registryfree; /* next free slot in registry */ + TValue registry; // registry table, used by lua_ref and LUA_REGISTRYINDEX + int registryfree; // next free slot in registry - struct lua_jmpbuf* errorjmp; /* jump buffer data for longjmp-style error handling */ + struct lua_jmpbuf* errorjmp; // jump buffer data for longjmp-style error handling - uint64_t rngstate; /* PCG random number generator state */ - uint64_t ptrenckey[4]; /* pointer encoding key for display */ + uint64_t rngstate; // PCG random number generator state + uint64_t ptrenckey[4]; // pointer encoding key for display - void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); /* for each userdata tag, a gc callback to be called immediately before freeing memory */ + void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory lua_Callbacks cb; @@ -221,39 +221,39 @@ struct lua_State CommonHeader; uint8_t status; - uint8_t activememcat; /* memory category that is used for new GC object allocations */ + uint8_t activememcat; // memory category that is used for new GC object allocations uint8_t stackstate; - bool singlestep; /* call debugstep hook after each instruction */ + bool singlestep; // call debugstep hook after each instruction - StkId top; /* first free slot in the stack */ - StkId base; /* base of current function */ + StkId top; // first free slot in the stack + StkId base; // base of current function global_State* global; - CallInfo* ci; /* call info for current function */ - StkId stack_last; /* last free slot in the stack */ - StkId stack; /* stack base */ + CallInfo* ci; // call info for current function + StkId stack_last; // last free slot in the stack + StkId stack; // stack base - CallInfo* end_ci; /* points after end of ci array*/ - CallInfo* base_ci; /* array of CallInfo's */ + CallInfo* end_ci; // points after end of ci array + CallInfo* base_ci; // array of CallInfo's int stacksize; - int size_ci; /* size of array `base_ci' */ + int size_ci; // size of array `base_ci' - unsigned short nCcalls; /* number of nested C calls */ - unsigned short baseCcalls; /* nested C calls when resuming coroutine */ + unsigned short nCcalls; // number of nested C calls + unsigned short baseCcalls; // nested C calls when resuming coroutine - int cachedslot; /* when table operations or INDEX/NEWINDEX is invoked from Luau, what is the expected slot for lookup? */ + int cachedslot; // when table operations or INDEX/NEWINDEX is invoked from Luau, what is the expected slot for lookup? - Table* gt; /* table of globals */ - UpVal* openupval; /* list of open upvalues in this stack */ + Table* gt; // table of globals + UpVal* openupval; // list of open upvalues in this stack GCObject* gclist; - TString* namecall; /* when invoked from Luau using NAMECALL, what method do we need to invoke? */ + TString* namecall; // when invoked from Luau using NAMECALL, what method do we need to invoke? void* userdata; }; @@ -271,10 +271,10 @@ union GCObject struct Table h; struct Proto p; struct UpVal uv; - struct lua_State th; /* thread */ + struct lua_State th; // thread }; -/* macros to convert a GCObject into a specific value */ +// macros to convert a GCObject into a specific value #define gco2ts(o) check_exp((o)->gch.tt == LUA_TSTRING, &((o)->ts)) #define gco2u(o) check_exp((o)->gch.tt == LUA_TUSERDATA, &((o)->u)) #define gco2cl(o) check_exp((o)->gch.tt == LUA_TFUNCTION, &((o)->cl)) @@ -283,7 +283,7 @@ union GCObject #define gco2uv(o) check_exp((o)->gch.tt == LUA_TUPVAL, &((o)->uv)) #define gco2th(o) check_exp((o)->gch.tt == LUA_TTHREAD, &((o)->th)) -/* macro to convert any Lua object into a GCObject */ +// macro to convert any Lua object into a GCObject #define obj2gco(v) check_exp(iscollectable(v), cast_to(GCObject*, (v) + 0)) LUAI_FUNC lua_State* luaE_newthread(lua_State* L); diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index c0cd3e26..9c266031 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -48,17 +48,17 @@ void luaS_resize(lua_State* L, int newsize) stringtable* tb = &L->global->strt; for (int i = 0; i < newsize; i++) newhash[i] = NULL; - /* rehash */ + // rehash for (int i = 0; i < tb->size; i++) { TString* p = tb->hash[i]; while (p) - { /* for each node in the list */ - TString* next = p->next; /* save next */ + { // for each node in the list + TString* next = p->next; // save next unsigned int h = p->hash; - int h1 = lmod(h, newsize); /* new position */ + int h1 = lmod(h, newsize); // new position LUAU_ASSERT(cast_int(h % newsize) == lmod(h, newsize)); - p->next = newhash[h1]; /* chain it */ + p->next = newhash[h1]; // chain it newhash[h1] = p; p = next; } @@ -81,15 +81,15 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) ts->tt = LUA_TSTRING; ts->memcat = L->activememcat; memcpy(ts->data, str, l); - ts->data[l] = '\0'; /* ending 0 */ - ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, l) : -1; + ts->data[l] = '\0'; // ending 0 + ts->atom = ATOM_UNDEF; tb = &L->global->strt; h = lmod(h, tb->size); - ts->next = tb->hash[h]; /* chain new entry */ + ts->next = tb->hash[h]; // chain new entry tb->hash[h] = ts; tb->nuse++; if (tb->nuse > cast_to(uint32_t, tb->size) && tb->size <= INT_MAX / 2) - luaS_resize(L, tb->size * 2); /* too crowded */ + luaS_resize(L, tb->size * 2); // too crowded return ts; } @@ -163,9 +163,7 @@ TString* luaS_buffinish(lua_State* L, TString* ts) ts->hash = h; ts->data[ts->len] = '\0'; // ending 0 - - // Complete string object - ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; + ts->atom = ATOM_UNDEF; ts->next = tb->hash[bucket]; // chain new entry tb->hash[bucket] = ts; @@ -183,13 +181,13 @@ TString* luaS_newlstr(lua_State* L, const char* str, size_t l) { if (el->len == l && (memcmp(str, getstr(el), l) == 0)) { - /* string may be dead */ + // string may be dead if (isdead(L->global, obj2gco(el))) changewhite(obj2gco(el)); return el; } } - return newlstr(L, str, l, h); /* not found */ + return newlstr(L, str, l, h); // not found } static bool unlinkstr(lua_State* L, TString* ts) diff --git a/VM/src/lstring.h b/VM/src/lstring.h index 290b64d8..41f9df9a 100644 --- a/VM/src/lstring.h +++ b/VM/src/lstring.h @@ -5,9 +5,12 @@ #include "lobject.h" #include "lstate.h" -/* string size limit */ +// string size limit #define MAXSSIZE (1 << 30) +// string atoms are not defined by default; the storage is 16-bit integer +#define ATOM_UNDEF -32768 + #define sizestring(len) (offsetof(TString, data) + len + 1) #define luaS_new(L, s) (luaS_newlstr(L, s, strlen(s))) diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 74a8aa8a..b3ea1094 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,7 +8,9 @@ #include #include -/* macro to `unsign' a character */ +LUAU_FASTFLAGVARIABLE(LuauTostringFormatSpecifier, false); + +// macro to `unsign' a character #define uchar(c) ((unsigned char)(c)) static int str_len(lua_State* L) @@ -21,7 +23,7 @@ static int str_len(lua_State* L) static int posrelat(int pos, size_t len) { - /* relative string position: negative means back from end */ + // relative string position: negative means back from end if (pos < 0) pos += (int)len + 1; return (pos >= 0) ? pos : 0; @@ -137,9 +139,9 @@ static int str_byte(lua_State* L) if ((size_t)pose > l) pose = (int)l; if (posi > pose) - return 0; /* empty interval; return no values */ + return 0; // empty interval; return no values n = (int)(pose - posi + 1); - if (posi + n <= pose) /* overflow? */ + if (posi + n <= pose) // overflow? luaL_error(L, "string slice too long"); luaL_checkstack(L, n, "string slice too long"); for (i = 0; i < n; i++) @@ -149,7 +151,7 @@ static int str_byte(lua_State* L) static int str_char(lua_State* L) { - int n = lua_gettop(L); /* number of arguments */ + int n = lua_gettop(L); // number of arguments luaL_Buffer b; char* ptr = luaL_buffinitsize(L, &b, n); @@ -176,12 +178,12 @@ static int str_char(lua_State* L) typedef struct MatchState { - int matchdepth; /* control for recursive depth (to avoid C stack overflow) */ - const char* src_init; /* init of source string */ - const char* src_end; /* end ('\0') of source string */ - const char* p_end; /* end ('\0') of pattern */ + int matchdepth; // control for recursive depth (to avoid C stack overflow) + const char* src_init; // init of source string + const char* src_end; // end ('\0') of source string + const char* p_end; // end ('\0') of pattern lua_State* L; - int level; /* total number of captures (finished or unfinished) */ + int level; // total number of captures (finished or unfinished) struct { const char* init; @@ -189,7 +191,7 @@ typedef struct MatchState } capture[LUA_MAXCAPTURES]; } MatchState; -/* recursive function */ +// recursive function static const char* match(MatchState* ms, const char* s, const char* p); #define L_ESC '%' @@ -227,11 +229,11 @@ static const char* classend(MatchState* ms, const char* p) if (*p == '^') p++; do - { /* look for a `]' */ + { // look for a `]' if (p == ms->p_end) luaL_error(ms->L, "malformed pattern (missing ']')"); if (*(p++) == L_ESC && p < ms->p_end) - p++; /* skip escapes (e.g. `%]') */ + p++; // skip escapes (e.g. `%]') } while (*p != ']'); return p + 1; } @@ -279,7 +281,7 @@ static int match_class(int c, int cl) break; case 'z': res = (c == 0); - break; /* deprecated option */ + break; // deprecated option default: return (cl == c); } @@ -292,7 +294,7 @@ static int matchbracketclass(int c, const char* p, const char* ec) if (*(p + 1) == '^') { sig = 0; - p++; /* skip the `^' */ + p++; // skip the `^' } while (++p < ec) { @@ -324,7 +326,7 @@ static int singlematch(MatchState* ms, const char* s, const char* p, const char* switch (*p) { case '.': - return 1; /* matches any char */ + return 1; // matches any char case L_ESC: return match_class(c, uchar(*(p + 1))); case '[': @@ -357,21 +359,21 @@ static const char* matchbalance(MatchState* ms, const char* s, const char* p) cont++; } } - return NULL; /* string ends out of balance */ + return NULL; // string ends out of balance } static const char* max_expand(MatchState* ms, const char* s, const char* p, const char* ep) { - ptrdiff_t i = 0; /* counts maximum expand for item */ + ptrdiff_t i = 0; // counts maximum expand for item while (singlematch(ms, s + i, p, ep)) i++; - /* keeps trying to match with the maximum repetitions */ + // keeps trying to match with the maximum repetitions while (i >= 0) { const char* res = match(ms, (s + i), ep + 1); if (res) return res; - i--; /* else didn't match; reduce 1 repetition to try again */ + i--; // else didn't match; reduce 1 repetition to try again } return NULL; } @@ -384,7 +386,7 @@ static const char* min_expand(MatchState* ms, const char* s, const char* p, cons if (res != NULL) return res; else if (singlematch(ms, s, p, ep)) - s++; /* try with one more repetition */ + s++; // try with one more repetition else return NULL; } @@ -399,8 +401,8 @@ static const char* start_capture(MatchState* ms, const char* s, const char* p, i ms->capture[level].init = s; ms->capture[level].len = what; ms->level = level + 1; - if ((res = match(ms, s, p)) == NULL) /* match failed? */ - ms->level--; /* undo capture */ + if ((res = match(ms, s, p)) == NULL) // match failed? + ms->level--; // undo capture return res; } @@ -408,9 +410,9 @@ static const char* end_capture(MatchState* ms, const char* s, const char* p) { int l = capture_to_close(ms); const char* res; - ms->capture[l].len = s - ms->capture[l].init; /* close capture */ - if ((res = match(ms, s, p)) == NULL) /* match failed? */ - ms->capture[l].len = CAP_UNFINISHED; /* undo capture */ + ms->capture[l].len = s - ms->capture[l].init; // close capture + if ((res = match(ms, s, p)) == NULL) // match failed? + ms->capture[l].len = CAP_UNFINISHED; // undo capture return res; } @@ -429,60 +431,60 @@ static const char* match(MatchState* ms, const char* s, const char* p) { if (ms->matchdepth-- == 0) luaL_error(ms->L, "pattern too complex"); -init: /* using goto's to optimize tail recursion */ +init: // using goto's to optimize tail recursion if (p != ms->p_end) - { /* end of pattern? */ + { // end of pattern? switch (*p) { case '(': - { /* start capture */ - if (*(p + 1) == ')') /* position capture? */ + { // start capture + if (*(p + 1) == ')') // position capture? s = start_capture(ms, s, p + 2, CAP_POSITION); else s = start_capture(ms, s, p + 1, CAP_UNFINISHED); break; } case ')': - { /* end capture */ + { // end capture s = end_capture(ms, s, p + 1); break; } case '$': { - if ((p + 1) != ms->p_end) /* is the `$' the last char in pattern? */ - goto dflt; /* no; go to default */ - s = (s == ms->src_end) ? s : NULL; /* check end of string */ + if ((p + 1) != ms->p_end) // is the `$' the last char in pattern? + goto dflt; // no; go to default + s = (s == ms->src_end) ? s : NULL; // check end of string break; } case L_ESC: - { /* escaped sequences not in the format class[*+?-]? */ + { // escaped sequences not in the format class[*+?-]? switch (*(p + 1)) { case 'b': - { /* balanced string? */ + { // balanced string? s = matchbalance(ms, s, p + 2); if (s != NULL) { p += 4; - goto init; /* return match(ms, s, p + 4); */ - } /* else fail (s == NULL) */ + goto init; // return match(ms, s, p + 4); + } // else fail (s == NULL) break; } case 'f': - { /* frontier? */ + { // frontier? const char* ep; char previous; p += 2; if (*p != '[') luaL_error(ms->L, "missing '[' after '%%f' in pattern"); - ep = classend(ms, p); /* points to what is next */ + ep = classend(ms, p); // points to what is next previous = (s == ms->src_init) ? '\0' : *(s - 1); if (!matchbracketclass(uchar(previous), p, ep - 1) && matchbracketclass(uchar(*s), p, ep - 1)) { p = ep; - goto init; /* return match(ms, s, ep); */ + goto init; // return match(ms, s, ep); } - s = NULL; /* match failed */ + s = NULL; // match failed break; } case '0': @@ -495,12 +497,12 @@ init: /* using goto's to optimize tail recursion */ case '7': case '8': case '9': - { /* capture results (%0-%9)? */ + { // capture results (%0-%9)? s = match_capture(ms, s, uchar(*(p + 1))); if (s != NULL) { p += 2; - goto init; /* return match(ms, s, p + 2) */ + goto init; // return match(ms, s, p + 2) } break; } @@ -511,48 +513,48 @@ init: /* using goto's to optimize tail recursion */ } default: dflt: - { /* pattern class plus optional suffix */ - const char* ep = classend(ms, p); /* points to optional suffix */ - /* does not match at least once? */ + { // pattern class plus optional suffix + const char* ep = classend(ms, p); // points to optional suffix + // does not match at least once? if (!singlematch(ms, s, p, ep)) { if (*ep == '*' || *ep == '?' || *ep == '-') - { /* accept empty? */ + { // accept empty? p = ep + 1; - goto init; /* return match(ms, s, ep + 1); */ + goto init; // return match(ms, s, ep + 1); } - else /* '+' or no suffix */ - s = NULL; /* fail */ + else // '+' or no suffix + s = NULL; // fail } else - { /* matched once */ + { // matched once switch (*ep) - { /* handle optional suffix */ + { // handle optional suffix case '?': - { /* optional */ + { // optional const char* res; if ((res = match(ms, s + 1, ep + 1)) != NULL) s = res; else { p = ep + 1; - goto init; /* else return match(ms, s, ep + 1); */ + goto init; // else return match(ms, s, ep + 1); } break; } - case '+': /* 1 or more repetitions */ - s++; /* 1 match already done */ - /* go through */ - case '*': /* 0 or more repetitions */ + case '+': // 1 or more repetitions + s++; // 1 match already done + // go through + case '*': // 0 or more repetitions s = max_expand(ms, s, p, ep); break; - case '-': /* 0 or more repetitions (minimum) */ + case '-': // 0 or more repetitions (minimum) s = min_expand(ms, s, p, ep); break; - default: /* no suffix */ + default: // no suffix s++; p = ep; - goto init; /* return match(ms, s + 1, ep); */ + goto init; // return match(ms, s + 1, ep); } } break; @@ -566,26 +568,26 @@ init: /* using goto's to optimize tail recursion */ static const char* lmemfind(const char* s1, size_t l1, const char* s2, size_t l2) { if (l2 == 0) - return s1; /* empty strings are everywhere */ + return s1; // empty strings are everywhere else if (l2 > l1) - return NULL; /* avoids a negative `l1' */ + return NULL; // avoids a negative `l1' else { - const char* init; /* to search for a `*s2' inside `s1' */ - l2--; /* 1st char will be checked by `memchr' */ - l1 = l1 - l2; /* `s2' cannot be found after that */ + const char* init; // to search for a `*s2' inside `s1' + l2--; // 1st char will be checked by `memchr' + l1 = l1 - l2; // `s2' cannot be found after that while (l1 > 0 && (init = (const char*)memchr(s1, *s2, l1)) != NULL) { - init++; /* 1st char is already checked */ + init++; // 1st char is already checked if (memcmp(init, s2 + 1, l2) == 0) return init - 1; else - { /* correct `l1' and `s1' to try again */ + { // correct `l1' and `s1' to try again l1 -= init - s1; s1 = init; } } - return NULL; /* not found */ + return NULL; // not found } } @@ -593,8 +595,8 @@ static void push_onecapture(MatchState* ms, int i, const char* s, const char* e) { if (i >= ms->level) { - if (i == 0) /* ms->level == 0, too */ - lua_pushlstring(ms->L, s, e - s); /* add whole match */ + if (i == 0) // ms->level == 0, too + lua_pushlstring(ms->L, s, e - s); // add whole match else luaL_error(ms->L, "invalid capture index"); } @@ -617,20 +619,20 @@ static int push_captures(MatchState* ms, const char* s, const char* e) luaL_checkstack(ms->L, nlevels, "too many captures"); for (i = 0; i < nlevels; i++) push_onecapture(ms, i, s, e); - return nlevels; /* number of strings pushed */ + return nlevels; // number of strings pushed } -/* check whether pattern has no special characters */ +// check whether pattern has no special characters static int nospecials(const char* p, size_t l) { size_t upto = 0; do { if (strpbrk(p + upto, SPECIALS)) - return 0; /* pattern has a special character */ - upto += strlen(p + upto) + 1; /* may have more after \0 */ + return 0; // pattern has a special character + upto += strlen(p + upto) + 1; // may have more after \0 } while (upto <= l); - return 1; /* no special chars found */ + return 1; // no special chars found } static void prepstate(MatchState* ms, lua_State* L, const char* s, size_t ls, const char* p, size_t lp) @@ -657,14 +659,14 @@ static int str_find_aux(lua_State* L, int find) if (init < 1) init = 1; else if (init > (int)ls + 1) - { /* start after string's end? */ - lua_pushnil(L); /* cannot find anything */ + { // start after string's end? + lua_pushnil(L); // cannot find anything return 1; } - /* explicit request or no special characters? */ + // explicit request or no special characters? if (find && (lua_toboolean(L, 4) || nospecials(p, lp))) { - /* do a plain search */ + // do a plain search const char* s2 = lmemfind(s + init - 1, ls - init + 1, p, lp); if (s2) { @@ -681,7 +683,7 @@ static int str_find_aux(lua_State* L, int find) if (anchor) { p++; - lp--; /* skip anchor character */ + lp--; // skip anchor character } prepstate(&ms, L, s, ls, p, lp); do @@ -692,8 +694,8 @@ static int str_find_aux(lua_State* L, int find) { if (find) { - lua_pushinteger(L, (int)(s1 - s + 1)); /* start */ - lua_pushinteger(L, (int)(res - s)); /* end */ + lua_pushinteger(L, (int)(s1 - s + 1)); // start + lua_pushinteger(L, (int)(res - s)); // end return push_captures(&ms, NULL, 0) + 2; } else @@ -701,7 +703,7 @@ static int str_find_aux(lua_State* L, int find) } } while (s1++ < ms.src_end && !anchor); } - lua_pushnil(L); /* not found */ + lua_pushnil(L); // not found return 1; } @@ -731,13 +733,13 @@ static int gmatch_aux(lua_State* L) { int newstart = (int)(e - s); if (e == src) - newstart++; /* empty match? go at least one position */ + newstart++; // empty match? go at least one position lua_pushinteger(L, newstart); lua_replace(L, lua_upvalueindex(3)); return push_captures(&ms, src, e); } } - return 0; /* not found */ + return 0; // not found } static int gmatch(lua_State* L) @@ -763,7 +765,7 @@ static void add_s(MatchState* ms, luaL_Buffer* b, const char* s, const char* e) luaL_addchar(b, news[i]); else { - i++; /* skip ESC */ + i++; // skip ESC if (!isdigit(uchar(news[i]))) { if (news[i] != L_ESC) @@ -775,7 +777,7 @@ static void add_s(MatchState* ms, luaL_Buffer* b, const char* s, const char* e) else { push_onecapture(ms, news[i] - '1', s, e); - luaL_addvalue(b); /* add capture to accumulated result */ + luaL_addvalue(b); // add capture to accumulated result } } } @@ -801,19 +803,19 @@ static void add_value(MatchState* ms, luaL_Buffer* b, const char* s, const char* break; } default: - { /* LUA_TNUMBER or LUA_TSTRING */ + { // LUA_TNUMBER or LUA_TSTRING add_s(ms, b, s, e); return; } } if (!lua_toboolean(L, -1)) - { /* nil or false? */ + { // nil or false? lua_pop(L, 1); - lua_pushlstring(L, s, e - s); /* keep original text */ + lua_pushlstring(L, s, e - s); // keep original text } else if (!lua_isstring(L, -1)) luaL_error(L, "invalid replacement value (a %s)", luaL_typename(L, -1)); - luaL_addvalue(b); /* add result to accumulator */ + luaL_addvalue(b); // add result to accumulator } static int str_gsub(lua_State* L) @@ -832,7 +834,7 @@ static int str_gsub(lua_State* L) if (anchor) { p++; - lp--; /* skip anchor character */ + lp--; // skip anchor character } prepstate(&ms, L, src, srcl, p, lp); while (n < max_s) @@ -845,8 +847,8 @@ static int str_gsub(lua_State* L) n++; add_value(&ms, &b, src, e, tr); } - if (e && e > src) /* non empty match? */ - src = e; /* skip it */ + if (e && e > src) // non empty match? + src = e; // skip it else if (src < ms.src_end) luaL_addchar(&b, *src++); else @@ -856,17 +858,17 @@ static int str_gsub(lua_State* L) } luaL_addlstring(&b, src, ms.src_end - src); luaL_pushresult(&b); - lua_pushinteger(L, n); /* number of substitutions */ + lua_pushinteger(L, n); // number of substitutions return 2; } -/* }====================================================== */ +// }====================================================== -/* valid flags in a format specification */ +// valid flags in a format specification #define FLAGS "-+ #0" -/* maximum size of each formatted item (> len(format('%99.99f', -1e308))) */ +// maximum size of each formatted item (> len(format('%99.99f', -1e308))) #define MAX_ITEM 512 -/* maximum size of each format specification (such as '%-099.99d') */ +// maximum size of each format specification (such as '%-099.99d') #define MAX_FORMAT 32 static void addquoted(lua_State* L, luaL_Buffer* b, int arg) @@ -914,20 +916,20 @@ static const char* scanformat(lua_State* L, const char* strfrmt, char* form, siz { const char* p = strfrmt; while (*p != '\0' && strchr(FLAGS, *p) != NULL) - p++; /* skip flags */ + p++; // skip flags if ((size_t)(p - strfrmt) >= sizeof(FLAGS)) luaL_error(L, "invalid format (repeated flags)"); if (isdigit(uchar(*p))) - p++; /* skip width */ + p++; // skip width if (isdigit(uchar(*p))) - p++; /* (2 digits at most) */ + p++; // (2 digits at most) if (*p == '.') { p++; if (isdigit(uchar(*p))) - p++; /* skip precision */ + p++; // skip precision if (isdigit(uchar(*p))) - p++; /* (2 digits at most) */ + p++; // (2 digits at most) } if (isdigit(uchar(*p))) luaL_error(L, "invalid format (width or precision too long)"); @@ -965,11 +967,11 @@ static int str_format(lua_State* L) if (*strfrmt != L_ESC) luaL_addchar(&b, *strfrmt++); else if (*++strfrmt == L_ESC) - luaL_addchar(&b, *strfrmt++); /* %% */ + luaL_addchar(&b, *strfrmt++); // %% else - { /* format item */ - char form[MAX_FORMAT]; /* to store the format (`%...') */ - char buff[MAX_ITEM]; /* to store the formatted item */ + { // format item + char form[MAX_FORMAT]; // to store the format (`%...') + char buff[MAX_ITEM]; // to store the formatted item if (++arg > top) luaL_error(L, "missing argument #%d", arg); size_t formatItemSize = 0; @@ -979,14 +981,14 @@ static int str_format(lua_State* L) { case 'c': { - sprintf(buff, form, (int)luaL_checknumber(L, arg)); + snprintf(buff, sizeof(buff), form, (int)luaL_checknumber(L, arg)); break; } case 'd': case 'i': { addInt64Format(form, formatIndicator, formatItemSize); - sprintf(buff, form, (long long)luaL_checknumber(L, arg)); + snprintf(buff, sizeof(buff), form, (long long)luaL_checknumber(L, arg)); break; } case 'o': @@ -997,7 +999,7 @@ static int str_format(lua_State* L) double argValue = luaL_checknumber(L, arg); addInt64Format(form, formatIndicator, formatItemSize); unsigned long long v = (argValue < 0) ? (unsigned long long)(long long)argValue : (unsigned long long)argValue; - sprintf(buff, form, v); + snprintf(buff, sizeof(buff), form, v); break; } case 'e': @@ -1006,13 +1008,13 @@ static int str_format(lua_State* L) case 'g': case 'G': { - sprintf(buff, form, (double)luaL_checknumber(L, arg)); + snprintf(buff, sizeof(buff), form, (double)luaL_checknumber(L, arg)); break; } case 'q': { addquoted(L, &b, arg); - continue; /* skip the 'addsize' at the end */ + continue; // skip the 'addsize' at the end } case 's': { @@ -1024,16 +1026,30 @@ static int str_format(lua_State* L) keep original string */ lua_pushvalue(L, arg); luaL_addvalue(&b); - continue; /* skip the `addsize' at the end */ + continue; // skip the `addsize' at the end } else { - sprintf(buff, form, s); + snprintf(buff, sizeof(buff), form, s); break; } } + case '*': + { + if (!FFlag::LuauTostringFormatSpecifier) + luaL_error(L, "invalid option '%%*' to 'format'"); + + if (formatItemSize != 1) + luaL_error(L, "'%%*' does not take a form"); + + size_t length; + const char* string = luaL_tolstring(L, arg, &length); + + luaL_addlstring(&b, string, length); + continue; // skip the `addsize' at the end + } default: - { /* also treat cases `pnLlh' */ + { // also treat cases `pnLlh' luaL_error(L, "invalid option '%%%c' to 'format'", *(strfrmt - 1)); } } @@ -1098,31 +1114,31 @@ static int str_split(lua_State* L) ** ======================================================= */ -/* value used for padding */ +// value used for padding #if !defined(LUAL_PACKPADBYTE) #define LUAL_PACKPADBYTE 0x00 #endif -/* maximum size for the binary representation of an integer */ +// maximum size for the binary representation of an integer #define MAXINTSIZE 16 -/* number of bits in a character */ +// number of bits in a character #define NB CHAR_BIT -/* mask for one character (NB 1's) */ +// mask for one character (NB 1's) #define MC ((1 << NB) - 1) -/* internal size of integers used for pack/unpack */ +// internal size of integers used for pack/unpack #define SZINT (int)sizeof(long long) -/* dummy union to get native endianness */ +// dummy union to get native endianness static const union { int dummy; - char little; /* true iff machine is little endian */ + char little; // true iff machine is little endian } nativeendian = {1}; -/* assume we need to align for double & pointers */ +// assume we need to align for double & pointers #define MAXALIGN 8 /* @@ -1133,7 +1149,7 @@ typedef union Ftypes float f; double d; double n; - char buff[5 * sizeof(double)]; /* enough for any float type */ + char buff[5 * sizeof(double)]; // enough for any float type } Ftypes; /* @@ -1151,15 +1167,15 @@ typedef struct Header */ typedef enum KOption { - Kint, /* signed integers */ - Kuint, /* unsigned integers */ - Kfloat, /* floating-point numbers */ - Kchar, /* fixed-length strings */ - Kstring, /* strings with prefixed length */ - Kzstr, /* zero-terminated strings */ - Kpadding, /* padding */ - Kpaddalign, /* padding for alignment */ - Knop /* no-op (configuration or spaces) */ + Kint, // signed integers + Kuint, // unsigned integers + Kfloat, // floating-point numbers + Kchar, // fixed-length strings + Kstring, // strings with prefixed length + Kzstr, // zero-terminated strings + Kpadding, // padding + Kpaddalign, // padding for alignment + Knop // no-op (configuration or spaces) } KOption; /* @@ -1173,8 +1189,8 @@ static int digit(int c) static int getnum(Header* h, const char** fmt, int df) { - if (!digit(**fmt)) /* no number? */ - return df; /* return default value */ + if (!digit(**fmt)) // no number? + return df; // return default value else { int a = 0; @@ -1216,7 +1232,7 @@ static void initheader(lua_State* L, Header* h) static KOption getoption(Header* h, const char** fmt, int* size) { int opt = *((*fmt)++); - *size = 0; /* default */ + *size = 0; // default switch (opt) { case 'b': @@ -1308,19 +1324,19 @@ static KOption getoption(Header* h, const char** fmt, int* size) static KOption getdetails(Header* h, size_t totalsize, const char** fmt, int* psize, int* ntoalign) { KOption opt = getoption(h, fmt, psize); - int align = *psize; /* usually, alignment follows size */ + int align = *psize; // usually, alignment follows size if (opt == Kpaddalign) - { /* 'X' gets alignment from following option */ + { // 'X' gets alignment from following option if (**fmt == '\0' || getoption(h, fmt, &align) == Kchar || align == 0) luaL_argerror(h->L, 1, "invalid next option for option 'X'"); } - if (align <= 1 || opt == Kchar) /* need no alignment? */ + if (align <= 1 || opt == Kchar) // need no alignment? *ntoalign = 0; else { - if (align > h->maxalign) /* enforce maximum alignment */ + if (align > h->maxalign) // enforce maximum alignment align = h->maxalign; - if ((align & (align - 1)) != 0) /* is 'align' not a power of 2? */ + if ((align & (align - 1)) != 0) // is 'align' not a power of 2? luaL_argerror(h->L, 1, "format asks for alignment not power of 2"); *ntoalign = (align - (int)(totalsize & (align - 1))) & (align - 1); } @@ -1338,18 +1354,18 @@ static void packint(luaL_Buffer* b, unsigned long long n, int islittle, int size LUAU_ASSERT(size <= MAXINTSIZE); char buff[MAXINTSIZE]; int i; - buff[islittle ? 0 : size - 1] = (char)(n & MC); /* first byte */ + buff[islittle ? 0 : size - 1] = (char)(n & MC); // first byte for (i = 1; i < size; i++) { n >>= NB; buff[islittle ? i : size - 1 - i] = (char)(n & MC); } if (neg && size > SZINT) - { /* negative number need sign extension? */ - for (i = SZINT; i < size; i++) /* correct extra bytes */ + { // negative number need sign extension? + for (i = SZINT; i < size; i++) // correct extra bytes buff[islittle ? i : size - 1 - i] = (char)MC; } - luaL_addlstring(b, buff, size); /* add result to buffer */ + luaL_addlstring(b, buff, size); // add result to buffer } /* @@ -1375,11 +1391,11 @@ static int str_pack(lua_State* L) { luaL_Buffer b; Header h; - const char* fmt = luaL_checkstring(L, 1); /* format string */ - int arg = 1; /* current argument to pack */ - size_t totalsize = 0; /* accumulate total size of result */ + const char* fmt = luaL_checkstring(L, 1); // format string + int arg = 1; // current argument to pack + size_t totalsize = 0; // accumulate total size of result initheader(L, &h); - lua_pushnil(L); /* mark to separate arguments from string buffer */ + lua_pushnil(L); // mark to separate arguments from string buffer luaL_buffinit(L, &b); while (*fmt != '\0') { @@ -1387,15 +1403,15 @@ static int str_pack(lua_State* L) KOption opt = getdetails(&h, totalsize, &fmt, &size, &ntoalign); totalsize += ntoalign + size; while (ntoalign-- > 0) - luaL_addchar(&b, LUAL_PACKPADBYTE); /* fill alignment */ + luaL_addchar(&b, LUAL_PACKPADBYTE); // fill alignment arg++; switch (opt) { case Kint: - { /* signed integers */ + { // signed integers long long n = (long long)luaL_checknumber(L, arg); if (size < SZINT) - { /* need overflow check? */ + { // need overflow check? long long lim = (long long)1 << ((size * NB) - 1); luaL_argcheck(L, -lim <= n && n < lim, arg, "integer overflow"); } @@ -1403,64 +1419,64 @@ static int str_pack(lua_State* L) break; } case Kuint: - { /* unsigned integers */ + { // unsigned integers long long n = (long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ + if (size < SZINT) // need overflow check? luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); packint(&b, (unsigned long long)n, h.islittle, size, 0); break; } case Kfloat: - { /* floating-point options */ + { // floating-point options volatile Ftypes u; char buff[MAXINTSIZE]; - double n = luaL_checknumber(L, arg); /* get argument */ + double n = luaL_checknumber(L, arg); // get argument if (size == sizeof(u.f)) - u.f = (float)n; /* copy it into 'u' */ + u.f = (float)n; // copy it into 'u' else if (size == sizeof(u.d)) u.d = (double)n; else u.n = n; - /* move 'u' to final result, correcting endianness if needed */ + // move 'u' to final result, correcting endianness if needed copywithendian(buff, u.buff, size, h.islittle); luaL_addlstring(&b, buff, size); break; } case Kchar: - { /* fixed-size string */ + { // fixed-size string size_t len; const char* s = luaL_checklstring(L, arg, &len); luaL_argcheck(L, len <= (size_t)size, arg, "string longer than given size"); - luaL_addlstring(&b, s, len); /* add string */ - while (len++ < (size_t)size) /* pad extra space */ + luaL_addlstring(&b, s, len); // add string + while (len++ < (size_t)size) // pad extra space luaL_addchar(&b, LUAL_PACKPADBYTE); break; } case Kstring: - { /* strings with length count */ + { // strings with length count size_t len; const char* s = luaL_checklstring(L, arg, &len); luaL_argcheck(L, size >= (int)sizeof(size_t) || len < ((size_t)1 << (size * NB)), arg, "string length does not fit in given size"); - packint(&b, len, h.islittle, size, 0); /* pack length */ + packint(&b, len, h.islittle, size, 0); // pack length luaL_addlstring(&b, s, len); totalsize += len; break; } case Kzstr: - { /* zero-terminated string */ + { // zero-terminated string size_t len; const char* s = luaL_checklstring(L, arg, &len); luaL_argcheck(L, strlen(s) == len, arg, "string contains zeros"); luaL_addlstring(&b, s, len); - luaL_addchar(&b, '\0'); /* add zero at the end */ + luaL_addchar(&b, '\0'); // add zero at the end totalsize += len + 1; break; } case Kpadding: - luaL_addchar(&b, LUAL_PACKPADBYTE); /* FALLTHROUGH */ + luaL_addchar(&b, LUAL_PACKPADBYTE); // FALLTHROUGH case Kpaddalign: case Knop: - arg--; /* undo increment */ + arg--; // undo increment break; } } @@ -1471,15 +1487,15 @@ static int str_pack(lua_State* L) static int str_packsize(lua_State* L) { Header h; - const char* fmt = luaL_checkstring(L, 1); /* format string */ - int totalsize = 0; /* accumulate total size of result */ + const char* fmt = luaL_checkstring(L, 1); // format string + int totalsize = 0; // accumulate total size of result initheader(L, &h); while (*fmt != '\0') { int size, ntoalign; KOption opt = getdetails(&h, totalsize, &fmt, &size, &ntoalign); luaL_argcheck(L, opt != Kstring && opt != Kzstr, 1, "variable-length format"); - size += ntoalign; /* total space used by option */ + size += ntoalign; // total space used by option luaL_argcheck(L, totalsize <= MAXSSIZE - size, 1, "format result too large"); totalsize += size; } @@ -1506,15 +1522,15 @@ static long long unpackint(lua_State* L, const char* str, int islittle, int size res |= (unsigned char)str[islittle ? i : size - 1 - i]; } if (size < SZINT) - { /* real size smaller than int? */ + { // real size smaller than int? if (issigned) - { /* needs sign extension? */ + { // needs sign extension? unsigned long long mask = (unsigned long long)1 << (size * NB - 1); - res = ((res ^ mask) - mask); /* do sign extension */ + res = ((res ^ mask) - mask); // do sign extension } } else if (size > SZINT) - { /* must check unread bytes */ + { // must check unread bytes int mask = (!issigned || (long long)res >= 0) ? 0 : MC; for (i = limit; i < size; i++) { @@ -1534,7 +1550,7 @@ static int str_unpack(lua_State* L) int pos = posrelat(luaL_optinteger(L, 3, 1), ld) - 1; if (pos < 0) pos = 0; - int n = 0; /* number of results */ + int n = 0; // number of results luaL_argcheck(L, size_t(pos) <= ld, 3, "initial position out of string"); initheader(L, &h); while (*fmt != '\0') @@ -1542,8 +1558,8 @@ static int str_unpack(lua_State* L) int size, ntoalign; KOption opt = getdetails(&h, pos, &fmt, &size, &ntoalign); luaL_argcheck(L, (size_t)ntoalign + size <= ld - pos, 2, "data string too short"); - pos += ntoalign; /* skip alignment */ - /* stack space for item + next position */ + pos += ntoalign; // skip alignment + // stack space for item + next position luaL_checkstack(L, 2, "too many results"); n++; switch (opt) @@ -1584,7 +1600,7 @@ static int str_unpack(lua_State* L) size_t len = (size_t)unpackint(L, data + pos, h.islittle, size, 0); luaL_argcheck(L, len <= ld - pos - size, 2, "data string too short"); lua_pushlstring(L, data + pos + size, len); - pos += (int)len; /* skip string */ + pos += (int)len; // skip string break; } case Kzstr: @@ -1592,22 +1608,22 @@ static int str_unpack(lua_State* L) size_t len = strlen(data + pos); luaL_argcheck(L, pos + len < ld, 2, "unfinished string for format 'z'"); lua_pushlstring(L, data + pos, len); - pos += (int)len + 1; /* skip string plus final '\0' */ + pos += (int)len + 1; // skip string plus final '\0' break; } case Kpaddalign: case Kpadding: case Knop: - n--; /* undo increment */ + n--; // undo increment break; } pos += size; } - lua_pushinteger(L, pos + 1); /* next position */ + lua_pushinteger(L, pos + 1); // next position return n + 1; } -/* }====================================================== */ +// }====================================================== static const luaL_Reg strlib[] = { {"byte", str_byte}, @@ -1632,14 +1648,14 @@ static const luaL_Reg strlib[] = { static void createmetatable(lua_State* L) { - lua_createtable(L, 0, 1); /* create metatable for strings */ - lua_pushliteral(L, ""); /* dummy string */ + lua_createtable(L, 0, 1); // create metatable for strings + lua_pushliteral(L, ""); // dummy string lua_pushvalue(L, -2); - lua_setmetatable(L, -2); /* set string metatable */ - lua_pop(L, 1); /* pop dummy string */ - lua_pushvalue(L, -2); /* string library... */ - lua_setfield(L, -2, "__index"); /* ...is the __index metamethod */ - lua_pop(L, 1); /* pop metatable */ + lua_setmetatable(L, -2); // set string metatable + lua_pop(L, 1); // pop dummy string + lua_pushvalue(L, -2); // string library... + lua_setfield(L, -2, "__index"); // ...is the __index metamethod + lua_pop(L, 1); // pop metatable } /* diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 2316cc3d..8d59ecbc 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -44,13 +44,10 @@ static_assert(TKey{{NULL}, {0}, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough static_assert(TKey{{NULL}, {0}, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next"); static_assert(TKey{{NULL}, {0}, LUA_TNIL, -(MAXSIZE - 1)}.next == -(MAXSIZE - 1), "not enough bits for next"); -// reset cache of absent metamethods, cache is updated in luaT_gettm -#define invalidateTMcache(t) t->tmcache = 0 - // empty hash data points to dummynode so that we can always dereference it const LuaNode luaH_dummynode = { - {{NULL}, {0}, LUA_TNIL}, /* value */ - {{NULL}, {0}, LUA_TNIL, 0} /* key */ + {{NULL}, {0}, LUA_TNIL}, // value + {{NULL}, {0}, LUA_TNIL, 0} // key }; #define dummynode (&luaH_dummynode) @@ -108,9 +105,9 @@ static LuaNode* hashvec(const Table* t, const float* v) memcpy(i, v, sizeof(i)); // convert -0 to 0 to make sure they hash to the same value - i[0] = (i[0] == 0x8000000) ? 0 : i[0]; - i[1] = (i[1] == 0x8000000) ? 0 : i[1]; - i[2] = (i[2] == 0x8000000) ? 0 : i[2]; + i[0] = (i[0] == 0x80000000) ? 0 : i[0]; + i[1] = (i[1] == 0x80000000) ? 0 : i[1]; + i[2] = (i[2] == 0x80000000) ? 0 : i[2]; // scramble bits to make sure that integer coordinates have entropy in lower bits i[0] ^= i[0] >> 17; @@ -121,7 +118,7 @@ static LuaNode* hashvec(const Table* t, const float* v) unsigned int h = (i[0] * 73856093) ^ (i[1] * 19349663) ^ (i[2] * 83492791); #if LUA_VECTOR_SIZE == 4 - i[3] = (i[3] == 0x8000000) ? 0 : i[3]; + i[3] = (i[3] == 0x80000000) ? 0 : i[3]; i[3] ^= i[3] >> 17; h ^= i[3] * 39916801; #endif @@ -173,52 +170,52 @@ static int findindex(lua_State* L, Table* t, StkId key) { int i; if (ttisnil(key)) - return -1; /* first iteration */ + return -1; // first iteration i = ttisnumber(key) ? arrayindex(nvalue(key)) : -1; - if (0 < i && i <= t->sizearray) /* is `key' inside array part? */ - return i - 1; /* yes; that's the index (corrected to C) */ + if (0 < i && i <= t->sizearray) // is `key' inside array part? + return i - 1; // yes; that's the index (corrected to C) else { LuaNode* n = mainposition(t, key); for (;;) - { /* check whether `key' is somewhere in the chain */ - /* key may be dead already, but it is ok to use it in `next' */ + { // check whether `key' is somewhere in the chain + // key may be dead already, but it is ok to use it in `next' if (luaO_rawequalKey(gkey(n), key) || (ttype(gkey(n)) == LUA_TDEADKEY && iscollectable(key) && gcvalue(gkey(n)) == gcvalue(key))) { - i = cast_int(n - gnode(t, 0)); /* key index in hash table */ - /* hash elements are numbered after array ones */ + i = cast_int(n - gnode(t, 0)); // key index in hash table + // hash elements are numbered after array ones return i + t->sizearray; } if (gnext(n) == 0) break; n += gnext(n); } - luaG_runerror(L, "invalid key to 'next'"); /* key not found */ + luaG_runerror(L, "invalid key to 'next'"); // key not found } } int luaH_next(lua_State* L, Table* t, StkId key) { - int i = findindex(L, t, key); /* find original element */ + int i = findindex(L, t, key); // find original element for (i++; i < t->sizearray; i++) - { /* try first array part */ + { // try first array part if (!ttisnil(&t->array[i])) - { /* a non-nil value? */ + { // a non-nil value? setnvalue(key, cast_num(i + 1)); setobj2s(L, key + 1, &t->array[i]); return 1; } } for (i -= t->sizearray; i < sizenode(t); i++) - { /* then hash part */ + { // then hash part if (!ttisnil(gval(gnode(t, i)))) - { /* a non-nil value? */ + { // a non-nil value? getnodekey(L, key, gnode(t, i)); setobj2s(L, key + 1, gval(gnode(t, i))); return 1; } } - return 0; /* no more elements */ + return 0; // no more elements } /* @@ -238,23 +235,23 @@ int luaH_next(lua_State* L, Table* t, StkId key) static int computesizes(int nums[], int* narray) { int i; - int twotoi; /* 2^i */ - int a = 0; /* number of elements smaller than 2^i */ - int na = 0; /* number of elements to go to array part */ - int n = 0; /* optimal size for array part */ + int twotoi; // 2^i + int a = 0; // number of elements smaller than 2^i + int na = 0; // number of elements to go to array part + int n = 0; // optimal size for array part for (i = 0, twotoi = 1; twotoi / 2 < *narray; i++, twotoi *= 2) { if (nums[i] > 0) { a += nums[i]; if (a > twotoi / 2) - { /* more than half elements present? */ - n = twotoi; /* optimal size (till now) */ - na = a; /* all elements smaller than n will go to array part */ + { // more than half elements present? + n = twotoi; // optimal size (till now) + na = a; // all elements smaller than n will go to array part } } if (a == *narray) - break; /* all elements already counted */ + break; // all elements already counted } *narray = n; LUAU_ASSERT(*narray / 2 <= na && na <= *narray); @@ -265,8 +262,8 @@ static int countint(double key, int* nums) { int k = arrayindex(key); if (0 < k && k <= MAXSIZE) - { /* is `key' an appropriate array index? */ - nums[ceillog2(k)]++; /* count as such */ + { // is `key' an appropriate array index? + nums[ceillog2(k)]++; // count as such return 1; } else @@ -276,20 +273,20 @@ static int countint(double key, int* nums) static int numusearray(const Table* t, int* nums) { int lg; - int ttlg; /* 2^lg */ - int ause = 0; /* summation of `nums' */ - int i = 1; /* count to traverse all array keys */ + int ttlg; // 2^lg + int ause = 0; // summation of `nums' + int i = 1; // count to traverse all array keys for (lg = 0, ttlg = 1; lg <= MAXBITS; lg++, ttlg *= 2) - { /* for each slice */ - int lc = 0; /* counter */ + { // for each slice + int lc = 0; // counter int lim = ttlg; if (lim > t->sizearray) { - lim = t->sizearray; /* adjust upper limit */ + lim = t->sizearray; // adjust upper limit if (i > lim) - break; /* no more elements to count */ + break; // no more elements to count } - /* count elements in range (2^(lg-1), 2^lg] */ + // count elements in range (2^(lg-1), 2^lg] for (; i <= lim; i++) { if (!ttisnil(&t->array[i - 1])) @@ -303,8 +300,8 @@ static int numusearray(const Table* t, int* nums) static int numusehash(const Table* t, int* nums, int* pnasize) { - int totaluse = 0; /* total number of elements */ - int ause = 0; /* summation of `nums' */ + int totaluse = 0; // total number of elements + int ause = 0; // summation of `nums' int i = sizenode(t); while (i--) { @@ -335,8 +332,8 @@ static void setnodevector(lua_State* L, Table* t, int size) { int lsize; if (size == 0) - { /* no elements to hash part? */ - t->node = cast_to(LuaNode*, dummynode); /* use common `dummynode' */ + { // no elements to hash part? + t->node = cast_to(LuaNode*, dummynode); // use common `dummynode' lsize = 0; } else @@ -357,7 +354,7 @@ static void setnodevector(lua_State* L, Table* t, int size) } t->lsizenode = cast_byte(lsize); t->nodemask8 = cast_byte((1 << lsize) - 1); - t->lastfree = size; /* all positions are free */ + t->lastfree = size; // all positions are free } static TValue* newkey(lua_State* L, Table* t, const TValue* key); @@ -382,17 +379,17 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) luaG_runerror(L, "table overflow"); int oldasize = t->sizearray; int oldhsize = t->lsizenode; - LuaNode* nold = t->node; /* save old hash ... */ - if (nasize > oldasize) /* array part must grow? */ + LuaNode* nold = t->node; // save old hash ... + if (nasize > oldasize) // array part must grow? setarrayvector(L, t, nasize); - /* create new hash part with appropriate size */ + // create new hash part with appropriate size setnodevector(L, t, nhsize); - /* used for the migration check at the end */ + // used for the migration check at the end LuaNode* nnew = t->node; if (nasize < oldasize) - { /* array part must shrink? */ + { // array part must shrink? t->sizearray = nasize; - /* re-insert elements from vanishing slice */ + // re-insert elements from vanishing slice for (int i = nasize; i < oldasize; i++) { if (!ttisnil(&t->array[i])) @@ -402,12 +399,12 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) setobjt2t(L, newkey(L, t, &ok), &t->array[i]); } } - /* shrink array */ + // shrink array luaM_reallocarray(L, t->array, oldasize, nasize, TValue, t->memcat); } - /* used for the migration check at the end */ + // used for the migration check at the end TValue* anew = t->array; - /* re-insert elements from hash part */ + // re-insert elements from hash part for (int i = twoto(oldhsize) - 1; i >= 0; i--) { LuaNode* old = nold + i; @@ -419,19 +416,19 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) } } - /* make sure we haven't recursively rehashed during element migration */ + // make sure we haven't recursively rehashed during element migration LUAU_ASSERT(nnew == t->node); LUAU_ASSERT(anew == t->array); if (nold != dummynode) - luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); /* free old array */ + luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); // free old array } static int adjustasize(Table* t, int size, const TValue* ek) { bool tbound = t->node != dummynode || size < t->sizearray; int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; - /* move the array size up until the boundary is guaranteed to be inside the array part */ + // move the array size up until the boundary is guaranteed to be inside the array part while (size + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, size + 1)))) size++; return size; @@ -451,22 +448,22 @@ void luaH_resizehash(lua_State* L, Table* t, int nhsize) static void rehash(lua_State* L, Table* t, const TValue* ek) { - int nums[MAXBITS + 1]; /* nums[i] = number of keys between 2^(i-1) and 2^i */ + int nums[MAXBITS + 1]; // nums[i] = number of keys between 2^(i-1) and 2^i for (int i = 0; i <= MAXBITS; i++) - nums[i] = 0; /* reset counts */ - int nasize = numusearray(t, nums); /* count keys in array part */ - int totaluse = nasize; /* all those keys are integer keys */ - totaluse += numusehash(t, nums, &nasize); /* count keys in hash part */ - /* count extra key */ + nums[i] = 0; // reset counts + int nasize = numusearray(t, nums); // count keys in array part + int totaluse = nasize; // all those keys are integer keys + totaluse += numusehash(t, nums, &nasize); // count keys in hash part + // count extra key if (ttisnumber(ek)) nasize += countint(nvalue(ek), nums); totaluse++; - /* compute new size for array part */ + // compute new size for array part int na = computesizes(nums, &nasize); int nh = totaluse - na; - /* enforce the boundary invariant; for performance, only do hash lookups if we must */ + // enforce the boundary invariant; for performance, only do hash lookups if we must nasize = adjustasize(t, nasize, ek); - /* resize the table to new computed sizes */ + // resize the table to new computed sizes resize(L, t, nasize, nh); } @@ -514,7 +511,7 @@ static LuaNode* getfreepos(Table* t) if (ttisnil(gkey(n))) return n; } - return NULL; /* could not find a free place */ + return NULL; // could not find a free place } /* @@ -526,24 +523,24 @@ static LuaNode* getfreepos(Table* t) */ static TValue* newkey(lua_State* L, Table* t, const TValue* key) { - /* enforce boundary invariant */ + // enforce boundary invariant if (ttisnumber(key) && nvalue(key) == t->sizearray + 1) { - rehash(L, t, key); /* grow table */ + rehash(L, t, key); // grow table - /* after rehash, numeric keys might be located in the new array part, but won't be found in the node part */ + // after rehash, numeric keys might be located in the new array part, but won't be found in the node part return arrayornewkey(L, t, key); } LuaNode* mp = mainposition(t, key); if (!ttisnil(gval(mp)) || mp == dummynode) { - LuaNode* n = getfreepos(t); /* get a free place */ + LuaNode* n = getfreepos(t); // get a free place if (n == NULL) - { /* cannot find a free place? */ - rehash(L, t, key); /* grow table */ + { // cannot find a free place? + rehash(L, t, key); // grow table - /* after rehash, numeric keys might be located in the new array part, but won't be found in the node part */ + // after rehash, numeric keys might be located in the new array part, but won't be found in the node part return arrayornewkey(L, t, key); } LUAU_ASSERT(n != dummynode); @@ -551,24 +548,24 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) getnodekey(L, &mk, mp); LuaNode* othern = mainposition(t, &mk); if (othern != mp) - { /* is colliding node out of its main position? */ - /* yes; move colliding node into free position */ + { // is colliding node out of its main position? + // yes; move colliding node into free position while (othern + gnext(othern) != mp) - othern += gnext(othern); /* find previous */ - gnext(othern) = cast_int(n - othern); /* redo the chain with `n' in place of `mp' */ - *n = *mp; /* copy colliding node into free pos. (mp->next also goes) */ + othern += gnext(othern); // find previous + gnext(othern) = cast_int(n - othern); // redo the chain with `n' in place of `mp' + *n = *mp; // copy colliding node into free pos. (mp->next also goes) if (gnext(mp) != 0) { - gnext(n) += cast_int(mp - n); /* correct 'next' */ - gnext(mp) = 0; /* now 'mp' is free */ + gnext(n) += cast_int(mp - n); // correct 'next' + gnext(mp) = 0; // now 'mp' is free } setnilvalue(gval(mp)); } else - { /* colliding node is in its own main position */ - /* new node will go into free position */ + { // colliding node is in its own main position + // new node will go into free position if (gnext(mp) != 0) - gnext(n) = cast_int((mp + gnext(mp)) - n); /* chain new position */ + gnext(n) = cast_int((mp + gnext(mp)) - n); // chain new position else LUAU_ASSERT(gnext(n) == 0); gnext(mp) = cast_int(n - mp); @@ -586,7 +583,7 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) */ const TValue* luaH_getnum(Table* t, int key) { - /* (1 <= key && key <= t->sizearray) */ + // (1 <= key && key <= t->sizearray) if (cast_to(unsigned int, key - 1) < cast_to(unsigned int, t->sizearray)) return &t->array[key - 1]; else if (t->node != dummynode) @@ -594,9 +591,9 @@ const TValue* luaH_getnum(Table* t, int key) double nk = cast_num(key); LuaNode* n = hashnum(t, nk); for (;;) - { /* check whether `key' is somewhere in the chain */ + { // check whether `key' is somewhere in the chain if (ttisnumber(gkey(n)) && luai_numeq(nvalue(gkey(n)), nk)) - return gval(n); /* that's it */ + return gval(n); // that's it if (gnext(n) == 0) break; n += gnext(n); @@ -614,9 +611,9 @@ const TValue* luaH_getstr(Table* t, TString* key) { LuaNode* n = hashstr(t, key); for (;;) - { /* check whether `key' is somewhere in the chain */ + { // check whether `key' is somewhere in the chain if (ttisstring(gkey(n)) && tsvalue(gkey(n)) == key) - return gval(n); /* that's it */ + return gval(n); // that's it if (gnext(n) == 0) break; n += gnext(n); @@ -640,17 +637,17 @@ const TValue* luaH_get(Table* t, const TValue* key) int k; double n = nvalue(key); luai_num2int(k, n); - if (luai_numeq(cast_num(k), nvalue(key))) /* index is int? */ - return luaH_getnum(t, k); /* use specialized version */ - /* else go through */ + if (luai_numeq(cast_num(k), nvalue(key))) // index is int? + return luaH_getnum(t, k); // use specialized version + // else go through } default: { LuaNode* n = mainposition(t, key); for (;;) - { /* check whether `key' is somewhere in the chain */ + { // check whether `key' is somewhere in the chain if (luaO_rawequalKey(gkey(n), key)) - return gval(n); /* that's it */ + return gval(n); // that's it if (gnext(n) == 0) break; n += gnext(n); @@ -667,23 +664,26 @@ TValue* luaH_set(lua_State* L, Table* t, const TValue* key) if (p != luaO_nilobject) return cast_to(TValue*, p); else - { - if (ttisnil(key)) - luaG_runerror(L, "table index is nil"); - else if (ttisnumber(key) && luai_numisnan(nvalue(key))) - luaG_runerror(L, "table index is NaN"); - else if (ttisvector(key) && luai_vecisnan(vvalue(key))) - luaG_runerror(L, "table index contains NaN"); - return newkey(L, t, key); - } + return luaH_newkey(L, t, key); +} + +TValue* luaH_newkey(lua_State* L, Table* t, const TValue* key) +{ + if (ttisnil(key)) + luaG_runerror(L, "table index is nil"); + else if (ttisnumber(key) && luai_numisnan(nvalue(key))) + luaG_runerror(L, "table index is NaN"); + else if (ttisvector(key) && luai_vecisnan(vvalue(key))) + luaG_runerror(L, "table index contains NaN"); + return newkey(L, t, key); } TValue* luaH_setnum(lua_State* L, Table* t, int key) { - /* (1 <= key && key <= t->sizearray) */ + // (1 <= key && key <= t->sizearray) if (cast_to(unsigned int, key - 1) < cast_to(unsigned int, t->sizearray)) return &t->array[key - 1]; - /* hash fallback */ + // hash fallback const TValue* p = luaH_getnum(t, key); if (p != luaO_nilobject) return cast_to(TValue*, p); @@ -739,9 +739,9 @@ int luaH_getn(Table* t) if (boundary > 0) { if (!ttisnil(&t->array[t->sizearray - 1]) && t->node == dummynode) - return t->sizearray; /* fast-path: the end of the array in `t' already refers to a boundary */ + return t->sizearray; // fast-path: the end of the array in `t' already refers to a boundary if (boundary < t->sizearray && !ttisnil(&t->array[boundary - 1]) && ttisnil(&t->array[boundary])) - return boundary; /* fast-path: boundary already refers to a boundary in `t' */ + return boundary; // fast-path: boundary already refers to a boundary in `t' int foundboundary = updateaboundary(t, boundary); if (foundboundary > 0) @@ -767,7 +767,7 @@ int luaH_getn(Table* t) } else { - /* validate boundary invariant */ + // validate boundary invariant LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1))); return j; } @@ -812,7 +812,7 @@ Table* luaH_clone(lua_State* L, Table* tt) void luaH_clear(Table* tt) { - /* clear array part */ + // clear array part for (int i = 0; i < tt->sizearray; ++i) { setnilvalue(&tt->array[i]); @@ -820,7 +820,7 @@ void luaH_clear(Table* tt) maybesetaboundary(tt, 0); - /* clear hash part */ + // clear hash part if (tt->node != dummynode) { int size = sizenode(tt); @@ -834,6 +834,6 @@ void luaH_clear(Table* tt) } } - /* back to empty -> no tag methods present */ + // back to empty -> no tag methods present tt->tmcache = cast_byte(~0); } diff --git a/VM/src/ltable.h b/VM/src/ltable.h index e8413c85..021f21bf 100644 --- a/VM/src/ltable.h +++ b/VM/src/ltable.h @@ -11,12 +11,16 @@ #define gval2slot(t, v) int(cast_to(LuaNode*, static_cast(v)) - t->node) +// reset cache of absent metamethods, cache is updated in luaT_gettm +#define invalidateTMcache(t) t->tmcache = 0 + LUAI_FUNC const TValue* luaH_getnum(Table* t, int key); LUAI_FUNC TValue* luaH_setnum(lua_State* L, Table* t, int key); LUAI_FUNC const TValue* luaH_getstr(Table* t, TString* key); LUAI_FUNC TValue* luaH_setstr(lua_State* L, Table* t, TString* key); LUAI_FUNC const TValue* luaH_get(Table* t, const TValue* key); LUAI_FUNC TValue* luaH_set(lua_State* L, Table* t, const TValue* key); +LUAI_FUNC TValue* luaH_newkey(lua_State* L, Table* t, const TValue* key); LUAI_FUNC Table* luaH_new(lua_State* L, int narray, int lnhash); LUAI_FUNC void luaH_resizearray(lua_State* L, Table* t, int nasize); LUAI_FUNC void luaH_resizehash(lua_State* L, Table* t, int nhsize); @@ -26,4 +30,6 @@ LUAI_FUNC int luaH_getn(Table* t); LUAI_FUNC Table* luaH_clone(lua_State* L, Table* tt); LUAI_FUNC void luaH_clear(Table* tt); +#define luaH_setslot(L, t, slot, key) (invalidateTMcache(t), (slot == luaO_nilobject ? luaH_newkey(L, t, key) : cast_to(TValue*, slot))) + extern const LuaNode luaH_dummynode; diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 27187c61..6dd94149 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -18,13 +18,13 @@ static int foreachi(lua_State* L) int n = lua_objlen(L, 1); for (i = 1; i <= n; i++) { - lua_pushvalue(L, 2); /* function */ - lua_pushinteger(L, i); /* 1st argument */ - lua_rawgeti(L, 1, i); /* 2nd argument */ + lua_pushvalue(L, 2); // function + lua_pushinteger(L, i); // 1st argument + lua_rawgeti(L, 1, i); // 2nd argument lua_call(L, 2, 1); if (!lua_isnil(L, -1)) return 1; - lua_pop(L, 1); /* remove nil result */ + lua_pop(L, 1); // remove nil result } return 0; } @@ -33,16 +33,16 @@ static int foreach (lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); luaL_checktype(L, 2, LUA_TFUNCTION); - lua_pushnil(L); /* first key */ + lua_pushnil(L); // first key while (lua_next(L, 1)) { - lua_pushvalue(L, 2); /* function */ - lua_pushvalue(L, -3); /* key */ - lua_pushvalue(L, -3); /* value */ + lua_pushvalue(L, 2); // function + lua_pushvalue(L, -3); // key + lua_pushvalue(L, -3); // value lua_call(L, 2, 1); if (!lua_isnil(L, -1)) return 1; - lua_pop(L, 2); /* remove value and result */ + lua_pop(L, 2); // remove value and result } return 0; } @@ -51,10 +51,10 @@ static int maxn(lua_State* L) { double max = 0; luaL_checktype(L, 1, LUA_TTABLE); - lua_pushnil(L); /* first key */ + lua_pushnil(L); // first key while (lua_next(L, 1)) { - lua_pop(L, 1); /* remove value */ + lua_pop(L, 1); // remove value if (lua_type(L, -1) == LUA_TNUMBER) { double v = lua_tonumber(L, -1); @@ -79,9 +79,9 @@ static void moveelements(lua_State* L, int srct, int dstt, int f, int e, int t) Table* dst = hvalue(L->base + (dstt - 1)); if (dst->readonly) - luaG_runerror(L, "Attempt to modify a readonly table"); + luaG_readonlyerror(L); - int n = e - f + 1; /* number of elements to move */ + int n = e - f + 1; // number of elements to move if (cast_to(unsigned int, f - 1) < cast_to(unsigned int, src->sizearray) && cast_to(unsigned int, t - 1) < cast_to(unsigned int, dst->sizearray) && @@ -137,19 +137,19 @@ static int tinsert(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); int n = lua_objlen(L, 1); - int pos; /* where to insert new element */ + int pos; // where to insert new element switch (lua_gettop(L)) { case 2: - { /* called with only 2 arguments */ - pos = n + 1; /* insert new element at the end */ + { // called with only 2 arguments + pos = n + 1; // insert new element at the end break; } case 3: { - pos = luaL_checkinteger(L, 2); /* 2nd argument is the position */ + pos = luaL_checkinteger(L, 2); // 2nd argument is the position - /* move up elements if necessary */ + // move up elements if necessary if (1 <= pos && pos <= n) moveelements(L, 1, 1, pos, n, pos + 1); break; @@ -159,7 +159,7 @@ static int tinsert(lua_State* L) luaL_error(L, "wrong number of arguments to 'insert'"); } } - lua_rawseti(L, 1, pos); /* t[pos] = v */ + lua_rawseti(L, 1, pos); // t[pos] = v return 0; } @@ -169,14 +169,14 @@ static int tremove(lua_State* L) int n = lua_objlen(L, 1); int pos = luaL_optinteger(L, 2, n); - if (!(1 <= pos && pos <= n)) /* position is outside bounds? */ - return 0; /* nothing to remove */ - lua_rawgeti(L, 1, pos); /* result = t[pos] */ + if (!(1 <= pos && pos <= n)) // position is outside bounds? + return 0; // nothing to remove + lua_rawgeti(L, 1, pos); // result = t[pos] moveelements(L, 1, 1, pos + 1, n, pos); lua_pushnil(L); - lua_rawseti(L, 1, n); /* t[n] = nil */ + lua_rawseti(L, 1, n); // t[n] = nil return 1; } @@ -192,28 +192,28 @@ static int tmove(lua_State* L) int f = luaL_checkinteger(L, 2); int e = luaL_checkinteger(L, 3); int t = luaL_checkinteger(L, 4); - int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ + int tt = !lua_isnoneornil(L, 5) ? 5 : 1; // destination table luaL_checktype(L, tt, LUA_TTABLE); if (e >= f) - { /* otherwise, nothing to move */ + { // otherwise, nothing to move luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move"); - int n = e - f + 1; /* number of elements to move */ + int n = e - f + 1; // number of elements to move luaL_argcheck(L, t <= INT_MAX - n + 1, 4, "destination wrap around"); Table* dst = hvalue(L->base + (tt - 1)); - if (dst->readonly) /* also checked in moveelements, but this blocks resizes of r/o tables */ - luaG_runerror(L, "Attempt to modify a readonly table"); + if (dst->readonly) // also checked in moveelements, but this blocks resizes of r/o tables + luaG_readonlyerror(L); if (t > 0 && (t - 1) <= dst->sizearray && (t - 1 + n) > dst->sizearray) - { /* grow the destination table array */ + { // grow the destination table array luaH_resizearray(L, dst, t - 1 + n); } moveelements(L, 1, tt, f, e, t); } - lua_pushvalue(L, tt); /* return destination table */ + lua_pushvalue(L, tt); // return destination table return 1; } @@ -240,7 +240,7 @@ static int tconcat(lua_State* L) addfield(L, &b, i); luaL_addlstring(&b, sep, lsep); } - if (i == last) /* add last value (if interval was not empty) */ + if (i == last) // add last value (if interval was not empty) addfield(L, &b, i); luaL_pushresult(&b); return 1; @@ -248,8 +248,8 @@ static int tconcat(lua_State* L) static int tpack(lua_State* L) { - int n = lua_gettop(L); /* number of elements to pack */ - lua_createtable(L, n, 1); /* create result table */ + int n = lua_gettop(L); // number of elements to pack + lua_createtable(L, n, 1); // create result table Table* t = hvalue(L->top - 1); @@ -259,11 +259,11 @@ static int tpack(lua_State* L) setobj2t(L, e, L->base + i); } - /* t.n = number of elements */ + // t.n = number of elements TValue* nv = luaH_setstr(L, t, luaS_newliteral(L, "n")); setnvalue(nv, n); - return 1; /* return table */ + return 1; // return table } static int tunpack(lua_State* L) @@ -274,8 +274,8 @@ static int tunpack(lua_State* L) int i = luaL_optinteger(L, 2, 1); int e = luaL_opt(L, luaL_checkinteger, 3, lua_objlen(L, 1)); if (i > e) - return 0; /* empty range */ - unsigned n = (unsigned)e - i; /* number of elements minus 1 (avoid overflows) */ + return 0; // empty range + unsigned n = (unsigned)e - i; // number of elements minus 1 (avoid overflows) if (n >= (unsigned int)INT_MAX || !lua_checkstack(L, (int)(++n))) luaL_error(L, "too many results to unpack"); @@ -288,10 +288,10 @@ static int tunpack(lua_State* L) } else { - /* push arg[i..e - 1] (to avoid overflows) */ + // push arg[i..e - 1] (to avoid overflows) for (; i < e; i++) lua_rawgeti(L, 1, i); - lua_rawgeti(L, 1, e); /* push last element */ + lua_rawgeti(L, 1, e); // push last element } return (int)n; } @@ -312,85 +312,85 @@ static void set2(lua_State* L, int i, int j) static int sort_comp(lua_State* L, int a, int b) { if (!lua_isnil(L, 2)) - { /* function? */ + { // function? int res; lua_pushvalue(L, 2); - lua_pushvalue(L, a - 1); /* -1 to compensate function */ - lua_pushvalue(L, b - 2); /* -2 to compensate function and `a' */ + lua_pushvalue(L, a - 1); // -1 to compensate function + lua_pushvalue(L, b - 2); // -2 to compensate function and `a' lua_call(L, 2, 1); res = lua_toboolean(L, -1); lua_pop(L, 1); return res; } - else /* a < b? */ + else // a < b? return lua_lessthan(L, a, b); } static void auxsort(lua_State* L, int l, int u) { while (l < u) - { /* for tail recursion */ + { // for tail recursion int i, j; - /* sort elements a[l], a[(l+u)/2] and a[u] */ + // sort elements a[l], a[(l+u)/2] and a[u] lua_rawgeti(L, 1, l); lua_rawgeti(L, 1, u); - if (sort_comp(L, -1, -2)) /* a[u] < a[l]? */ - set2(L, l, u); /* swap a[l] - a[u] */ + if (sort_comp(L, -1, -2)) // a[u] < a[l]? + set2(L, l, u); // swap a[l] - a[u] else lua_pop(L, 2); if (u - l == 1) - break; /* only 2 elements */ + break; // only 2 elements i = (l + u) / 2; lua_rawgeti(L, 1, i); lua_rawgeti(L, 1, l); - if (sort_comp(L, -2, -1)) /* a[i]= P */ + { // invariant: a[l..i] <= P <= a[j..u] + // repeat ++i until a[i] >= P while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) { if (i >= u) luaL_error(L, "invalid order function for sorting"); - lua_pop(L, 1); /* remove a[i] */ + lua_pop(L, 1); // remove a[i] } - /* repeat --j until a[j] <= P */ + // repeat --j until a[j] <= P while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) { if (j <= l) luaL_error(L, "invalid order function for sorting"); - lua_pop(L, 1); /* remove a[j] */ + lua_pop(L, 1); // remove a[j] } if (j < i) { - lua_pop(L, 3); /* pop pivot, a[i], a[j] */ + lua_pop(L, 3); // pop pivot, a[i], a[j] break; } set2(L, i, j); } lua_rawgeti(L, 1, u - 1); lua_rawgeti(L, 1, i); - set2(L, u - 1, i); /* swap pivot (a[u-1]) with a[i] */ - /* a[l..i-1] <= a[i] == P <= a[i+1..u] */ - /* adjust so that smaller half is in [j..i] and larger one in [l..u] */ + set2(L, u - 1, i); // swap pivot (a[u-1]) with a[i] + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // adjust so that smaller half is in [j..i] and larger one in [l..u] if (i - l < u - i) { j = l; @@ -403,23 +403,23 @@ static void auxsort(lua_State* L, int l, int u) i = u; u = j - 2; } - auxsort(L, j, i); /* call recursively the smaller one */ - } /* repeat the routine for the larger one */ + auxsort(L, j, i); // call recursively the smaller one + } // repeat the routine for the larger one } static int sort(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); int n = lua_objlen(L, 1); - luaL_checkstack(L, 40, ""); /* assume array is smaller than 2^40 */ - if (!lua_isnoneornil(L, 2)) /* is there a 2nd argument? */ + luaL_checkstack(L, 40, ""); // assume array is smaller than 2^40 + if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? luaL_checktype(L, 2, LUA_TFUNCTION); - lua_settop(L, 2); /* make sure there is two arguments */ + lua_settop(L, 2); // make sure there is two arguments auxsort(L, 1, n); return 0; } -/* }====================================================== */ +// }====================================================== static int tcreate(lua_State* L) { @@ -482,7 +482,7 @@ static int tclear(lua_State* L) Table* tt = hvalue(L->base); if (tt->readonly) - luaG_runerror(L, "Attempt to modify a readonly table"); + luaG_readonlyerror(L); luaH_clear(tt); return 0; diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 49982b28..cb7ba097 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -12,7 +12,7 @@ // clang-format off const char* const luaT_typenames[] = { - /* ORDER TYPE */ + // ORDER TYPE "nil", "boolean", @@ -31,7 +31,7 @@ const char* const luaT_typenames[] = { }; const char* const luaT_eventname[] = { - /* ORDER TM */ + // ORDER TM "__index", "__newindex", @@ -70,12 +70,12 @@ void luaT_init(lua_State* L) for (i = 0; i < LUA_T_COUNT; i++) { L->global->ttname[i] = luaS_new(L, luaT_typenames[i]); - luaS_fix(L->global->ttname[i]); /* never collect these names */ + luaS_fix(L->global->ttname[i]); // never collect these names } for (i = 0; i < TM_N; i++) { L->global->tmname[i] = luaS_new(L, luaT_eventname[i]); - luaS_fix(L->global->tmname[i]); /* never collect these names */ + luaS_fix(L->global->tmname[i]); // never collect these names } } @@ -88,8 +88,8 @@ const TValue* luaT_gettm(Table* events, TMS event, TString* ename) const TValue* tm = luaH_getstr(events, ename); LUAU_ASSERT(event <= TM_EQ); if (ttisnil(tm)) - { /* no tag method? */ - events->tmcache |= cast_byte(1u << event); /* cache this fact */ + { // no tag method? + events->tmcache |= cast_byte(1u << event); // cache this fact return NULL; } else diff --git a/VM/src/ltm.h b/VM/src/ltm.h index e11ddb3a..f20ce1b2 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -20,7 +20,7 @@ typedef enum TM_ITER, TM_LEN, - TM_EQ, /* last tag method with `fast' access */ + TM_EQ, // last tag method with `fast' access TM_ADD, @@ -37,7 +37,7 @@ typedef enum TM_CONCAT, TM_TYPE, - TM_N /* number of elements in the enum */ + TM_N // number of elements in the enum } TMS; // clang-format on diff --git a/VM/src/ludata.h b/VM/src/ludata.h index f24e4a32..9b7ba26a 100644 --- a/VM/src/ludata.h +++ b/VM/src/ludata.h @@ -4,10 +4,10 @@ #include "lobject.h" -/* special tag value is used for user data with inline dtors */ +// special tag value is used for user data with inline dtors #define UTAG_IDTOR LUA_UTAG_LIMIT -/* special tag value is used for newproxy-created user data (all other user data objects are host-exposed) */ +// special tag value is used for newproxy-created user data (all other user data objects are host-exposed) #define UTAG_PROXY (LUA_UTAG_LIMIT + 1) #define sizeudata(len) (offsetof(Udata, data) + len) diff --git a/VM/src/lutf8lib.cpp b/VM/src/lutf8lib.cpp index 8bc8200a..837d0e12 100644 --- a/VM/src/lutf8lib.cpp +++ b/VM/src/lutf8lib.cpp @@ -8,8 +8,8 @@ #define iscont(p) ((*(p)&0xC0) == 0x80) -/* from strlib */ -/* translate a relative string position: negative means back from end */ +// from strlib +// translate a relative string position: negative means back from end static int u_posrelat(int pos, size_t len) { if (pos >= 0) @@ -28,28 +28,28 @@ static const char* utf8_decode(const char* o, int* val) static const unsigned int limits[] = {0xFF, 0x7F, 0x7FF, 0xFFFF}; const unsigned char* s = (const unsigned char*)o; unsigned int c = s[0]; - unsigned int res = 0; /* final result */ - if (c < 0x80) /* ascii? */ + unsigned int res = 0; // final result + if (c < 0x80) // ascii? res = c; else { - int count = 0; /* to count number of continuation bytes */ + int count = 0; // to count number of continuation bytes while (c & 0x40) - { /* still have continuation bytes? */ - int cc = s[++count]; /* read next byte */ - if ((cc & 0xC0) != 0x80) /* not a continuation byte? */ - return NULL; /* invalid byte sequence */ - res = (res << 6) | (cc & 0x3F); /* add lower 6 bits from cont. byte */ - c <<= 1; /* to test next bit */ + { // still have continuation bytes? + int cc = s[++count]; // read next byte + if ((cc & 0xC0) != 0x80) // not a continuation byte? + return NULL; // invalid byte sequence + res = (res << 6) | (cc & 0x3F); // add lower 6 bits from cont. byte + c <<= 1; // to test next bit } - res |= ((c & 0x7F) << (count * 5)); /* add first byte */ + res |= ((c & 0x7F) << (count * 5)); // add first byte if (count > 3 || res > MAXUNICODE || res <= limits[count]) - return NULL; /* invalid byte sequence */ - s += count; /* skip continuation bytes read */ + return NULL; // invalid byte sequence + s += count; // skip continuation bytes read } if (val) *val = res; - return (const char*)s + 1; /* +1 to include first byte */ + return (const char*)s + 1; // +1 to include first byte } /* @@ -70,9 +70,9 @@ static int utflen(lua_State* L) { const char* s1 = utf8_decode(s + posi, NULL); if (s1 == NULL) - { /* conversion error? */ - lua_pushnil(L); /* return nil ... */ - lua_pushinteger(L, posi + 1); /* ... and current position */ + { // conversion error? + lua_pushnil(L); // return nil ... + lua_pushinteger(L, posi + 1); // ... and current position return 2; } posi = (int)(s1 - s); @@ -97,8 +97,8 @@ static int codepoint(lua_State* L) luaL_argcheck(L, posi >= 1, 2, "out of range"); luaL_argcheck(L, pose <= (int)len, 3, "out of range"); if (posi > pose) - return 0; /* empty interval; return no values */ - if (pose - posi >= INT_MAX) /* (int -> int) overflow? */ + return 0; // empty interval; return no values + if (pose - posi >= INT_MAX) // (int -> int) overflow? luaL_error(L, "string slice too long"); n = (int)(pose - posi) + 1; luaL_checkstack(L, n, "string slice too long"); @@ -122,20 +122,20 @@ static int codepoint(lua_State* L) // from Lua 5.3 lobject.c, copied verbatim + static static int luaO_utf8esc(char* buff, unsigned long x) { - int n = 1; /* number of bytes put in buffer (backwards) */ + int n = 1; // number of bytes put in buffer (backwards) LUAU_ASSERT(x <= 0x10FFFF); - if (x < 0x80) /* ascii? */ + if (x < 0x80) // ascii? buff[UTF8BUFFSZ - 1] = cast_to(char, x); else - { /* need continuation bytes */ - unsigned int mfb = 0x3f; /* maximum that fits in first byte */ + { // need continuation bytes + unsigned int mfb = 0x3f; // maximum that fits in first byte do - { /* add continuation bytes */ + { // add continuation bytes buff[UTF8BUFFSZ - (n++)] = cast_to(char, 0x80 | (x & 0x3f)); - x >>= 6; /* remove added bits */ - mfb >>= 1; /* now there is one less bit available in first byte */ - } while (x > mfb); /* still needs continuation byte? */ - buff[UTF8BUFFSZ - n] = cast_to(char, (~mfb << 1) | x); /* add first byte */ + x >>= 6; // remove added bits + mfb >>= 1; // now there is one less bit available in first byte + } while (x > mfb); // still needs continuation byte? + buff[UTF8BUFFSZ - n] = cast_to(char, (~mfb << 1) | x); // add first byte } return n; } @@ -162,9 +162,9 @@ static int utfchar(lua_State* L) char buff[UTF8BUFFSZ]; const char* charstr; - int n = lua_gettop(L); /* number of arguments */ + int n = lua_gettop(L); // number of arguments if (n == 1) - { /* optimize common case of single char */ + { // optimize common case of single char int l = buffutfchar(L, 1, buff, &charstr); lua_pushlstring(L, charstr, l); } @@ -196,7 +196,7 @@ static int byteoffset(lua_State* L) luaL_argcheck(L, 1 <= posi && --posi <= (int)len, 3, "position out of range"); if (n == 0) { - /* find beginning of current byte sequence */ + // find beginning of current byte sequence while (posi > 0 && iscont(s + posi)) posi--; } @@ -207,9 +207,9 @@ static int byteoffset(lua_State* L) if (n < 0) { while (n < 0 && posi > 0) - { /* move back */ + { // move back do - { /* find beginning of previous character */ + { // find beginning of previous character posi--; } while (posi > 0 && iscont(s + posi)); n++; @@ -217,20 +217,20 @@ static int byteoffset(lua_State* L) } else { - n--; /* do not move for 1st character */ + n--; // do not move for 1st character while (n > 0 && posi < (int)len) { do - { /* find beginning of next character */ + { // find beginning of next character posi++; - } while (iscont(s + posi)); /* (cannot pass final '\0') */ + } while (iscont(s + posi)); // (cannot pass final '\0') n--; } } } - if (n == 0) /* did it find given character? */ + if (n == 0) // did it find given character? lua_pushinteger(L, posi + 1); - else /* no such character */ + else // no such character lua_pushnil(L); return 1; } @@ -240,16 +240,16 @@ static int iter_aux(lua_State* L) size_t len; const char* s = luaL_checklstring(L, 1, &len); int n = lua_tointeger(L, 2) - 1; - if (n < 0) /* first iteration? */ - n = 0; /* start from here */ + if (n < 0) // first iteration? + n = 0; // start from here else if (n < (int)len) { - n++; /* skip current byte */ + n++; // skip current byte while (iscont(s + n)) - n++; /* and its continuations */ + n++; // and its continuations } if (n >= (int)len) - return 0; /* no more codepoints */ + return 0; // no more codepoints else { int code; @@ -271,7 +271,7 @@ static int iter_codes(lua_State* L) return 3; } -/* pattern to match a single UTF-8 character */ +// pattern to match a single UTF-8 character #define UTF8PATT "[\0-\x7F\xC2-\xF4][\x80-\xBF]*" static const luaL_Reg funcs[] = { diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 85829ca1..376dd400 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,7 +16,7 @@ #include -LUAU_FASTFLAGVARIABLE(LuauLenTM, false) +LUAU_FASTFLAGVARIABLE(LuauNicerMethodErrors, false) // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ @@ -33,7 +33,7 @@ LUAU_FASTFLAGVARIABLE(LuauLenTM, false) // 3. VM_PROTECT macro saves savedpc and restores base for you; most external calls need to be wrapped into that. However, it does NOT restore // ra/rb/rc! // 4. When copying an object to any existing object as a field, generally speaking you need to call luaC_barrier! Be careful with all setobj calls -// 5. To make 4 easier to follow, please use setobj2s for copies to stack and setobj for other copies. +// 5. To make 4 easier to follow, please use setobj2s for copies to stack, setobj2t for writes to tables, and setobj for other copies. // 6. You can define HARDSTACKTESTS in llimits.h which will aggressively realloc stack; with address sanitizer this should be effective at finding // stack corruption bugs // 7. Many external Lua functions can call GC! GC will *not* traverse pointers to new objects that aren't reachable from Lua root. Be careful when @@ -110,7 +110,8 @@ LUAU_FASTFLAGVARIABLE(LuauLenTM, false) VM_DISPATCH_OP(LOP_FORGLOOP_NEXT), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_JUMPIFEQK), VM_DISPATCH_OP(LOP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ - VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP), + VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP), VM_DISPATCH_OP(LOP_JUMPXEQKNIL), \ + VM_DISPATCH_OP(LOP_JUMPXEQKB), VM_DISPATCH_OP(LOP_JUMPXEQKN), VM_DISPATCH_OP(LOP_JUMPXEQKS), \ #if defined(__GNUC__) || defined(__clang__) #define VM_USE_CGOTO 1 @@ -158,7 +159,7 @@ LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c) setobjs2s(L, ra + 3 + 1, ra + 1); setobjs2s(L, ra + 3, ra); - L->top = ra + 3 + 3; /* func. + 2 args (state and index) */ + L->top = ra + 3 + 3; // func. + 2 args (state and index) LUAU_ASSERT(L->top <= L->stack_last); luaD_call(L, ra + 3, c); @@ -236,10 +237,10 @@ LUAU_NOINLINE static void luau_tryfuncTM(lua_State* L, StkId func) const TValue* tm = luaT_gettmbyobj(L, func, TM_CALL); if (!ttisfunction(tm)) luaG_typeerror(L, func, "call"); - for (StkId p = L->top; p > func; p--) /* open space for metamethod */ + for (StkId p = L->top; p > func; p--) // open space for metamethod setobjs2s(L, p, p - 1); - L->top++; /* stack space pre-allocated by the caller */ - setobj2s(L, func, tm); /* tag method is the new function to be called */ + L->top++; // stack space pre-allocated by the caller + setobj2s(L, func, tm); // tag method is the new function to be called } LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) @@ -256,7 +257,7 @@ LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) L->base = L->ci->base; } - luaD_checkstack(L, LUA_MINSTACK); /* ensure minimum stack size */ + luaD_checkstack(L, LUA_MINSTACK); // ensure minimum stack size L->ci->top = L->top + LUA_MINSTACK; LUAU_ASSERT(L->ci->top <= L->stack_last); @@ -458,7 +459,7 @@ static void luau_execute(lua_State* L) if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) { - setobj(L, gval(n), ra); + setobj2t(L, gval(n), ra); luaC_barriert(L, h, ra); VM_NEXT(); } @@ -640,20 +641,16 @@ static void luau_execute(lua_State* L) VM_PATCH_C(pc - 2, L->cachedslot); VM_NEXT(); } - else - { - // slow-path, may invoke Lua calls via __index metamethod - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - VM_NEXT(); - } - } - else - { - // slow-path, may invoke Lua calls via __index metamethod - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - VM_NEXT(); + + // fall through to slow path } + + // fall through to slow path } + + // slow-path, may invoke Lua calls via __index metamethod + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + VM_NEXT(); } VM_CASE(LOP_SETTABLEKS) @@ -676,7 +673,7 @@ static void luau_execute(lua_State* L) // fast-path: value is in expected slot if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) { - setobj(L, gval(n), ra); + setobj2t(L, gval(n), ra); luaC_barriert(L, h, ra); VM_NEXT(); } @@ -688,7 +685,7 @@ static void luau_execute(lua_State* L) int cachedslot = gval2slot(h, res); // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ VM_PATCH_C(pc - 2, cachedslot); - setobj(L, res, ra); + setobj2t(L, res, ra); luaC_barriert(L, h, ra); VM_NEXT(); } @@ -753,19 +750,13 @@ static void luau_execute(lua_State* L) setobj2s(L, ra, &h->array[unsigned(index - 1)]); VM_NEXT(); } - else - { - // slow-path: handles out of bounds array lookups and non-integer numeric keys - VM_PROTECT(luaV_gettable(L, rb, rc, ra)); - VM_NEXT(); - } - } - else - { - // slow-path: handles non-array table lookup as well as __index MT calls - VM_PROTECT(luaV_gettable(L, rb, rc, ra)); - VM_NEXT(); + + // fall through to slow path } + + // slow-path: handles out of bounds array lookups, non-integer numeric keys, non-array table lookup, __index MT calls + VM_PROTECT(luaV_gettable(L, rb, rc, ra)); + VM_NEXT(); } VM_CASE(LOP_SETTABLE) @@ -790,19 +781,13 @@ static void luau_execute(lua_State* L) luaC_barriert(L, h, ra); VM_NEXT(); } - else - { - // slow-path: handles out of bounds array assignments and non-integer numeric keys - VM_PROTECT(luaV_settable(L, rb, rc, ra)); - VM_NEXT(); - } - } - else - { - // slow-path: handles non-array table access as well as __newindex MT calls - VM_PROTECT(luaV_settable(L, rb, rc, ra)); - VM_NEXT(); + + // fall through to slow path } + + // slow-path: handles out of bounds array assignments, non-integer numeric keys, non-array table access, __newindex MT calls + VM_PROTECT(luaV_settable(L, rb, rc, ra)); + VM_NEXT(); } VM_CASE(LOP_GETTABLEN) @@ -822,6 +807,8 @@ static void luau_execute(lua_State* L) setobj2s(L, ra, &h->array[c]); VM_NEXT(); } + + // fall through to slow path } // slow-path: handles out of bounds array lookups @@ -849,6 +836,8 @@ static void luau_execute(lua_State* L) luaC_barriert(L, h, ra); VM_NEXT(); } + + // fall through to slow path } // slow-path: handles out of bounds array lookups @@ -941,6 +930,10 @@ static void luau_execute(lua_State* L) VM_PROTECT(luaV_gettable(L, rb, kv, ra)); // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ VM_PATCH_C(pc - 2, L->cachedslot); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (FFlag::LuauNicerMethodErrors && ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); } } else @@ -978,6 +971,10 @@ static void luau_execute(lua_State* L) VM_PROTECT(luaV_gettable(L, rb, kv, ra)); // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ VM_PATCH_C(pc - 2, L->cachedslot); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (FFlag::LuauNicerMethodErrors && ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); } } else @@ -985,6 +982,10 @@ static void luau_execute(lua_State* L) // slow-path: handles non-table __index setobj2s(L, ra + 1, rb); VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (FFlag::LuauNicerMethodErrors && ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); } } @@ -1040,7 +1041,7 @@ static void luau_execute(lua_State* L) StkId argi = L->top; StkId argend = L->base + p->numparams; while (argi < argend) - setnilvalue(argi++); /* complete missing arguments */ + setnilvalue(argi++); // complete missing arguments L->top = p->is_vararg ? argi : ci->top; // reentry @@ -2086,7 +2087,7 @@ static void luau_execute(lua_State* L) { Table* h = hvalue(rb); - if (!FFlag::LuauLenTM || fastnotm(h->metatable, TM_LEN)) + if (fastnotm(h->metatable, TM_LEN)) { setnvalue(ra, cast_num(luaH_getn(h))); VM_NEXT(); @@ -2176,8 +2177,10 @@ static void luau_execute(lua_State* L) if (!ttisnumber(ra + 0) || !ttisnumber(ra + 1) || !ttisnumber(ra + 2)) { // slow-path: can convert arguments to numbers and trigger Lua errors - // Note: this doesn't reallocate stack so we don't need to recompute ra - VM_PROTECT(luau_prepareFORN(L, ra + 0, ra + 1, ra + 2)); + // Note: this doesn't reallocate stack so we don't need to recompute ra/base + VM_PROTECT_PC(); + + luau_prepareFORN(L, ra + 0, ra + 1, ra + 2); } double limit = nvalue(ra + 0); @@ -2224,7 +2227,7 @@ static void luau_execute(lua_State* L) if (ttisfunction(ra)) { - /* will be called during FORGLOOP */ + // will be called during FORGLOOP } else { @@ -2235,16 +2238,16 @@ static void luau_execute(lua_State* L) setobj2s(L, ra + 1, ra); setobj2s(L, ra, fn); - L->top = ra + 2; /* func + self arg */ + L->top = ra + 2; // func + self arg LUAU_ASSERT(L->top <= L->stack_last); VM_PROTECT(luaD_call(L, ra, 3)); L->top = L->ci->top; - /* recompute ra since stack might have been reallocated */ + // recompute ra since stack might have been reallocated ra = VM_REG(LUAU_INSN_A(insn)); - /* protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP */ + // protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP if (ttisnil(ra)) { VM_PROTECT(luaG_typeerror(L, ra, "call")); @@ -2252,12 +2255,12 @@ static void luau_execute(lua_State* L) } else if (fasttm(L, mt, TM_CALL)) { - /* table or userdata with __call, will be called during FORGLOOP */ - /* TODO: we might be able to stop supporting this depending on whether it's used in practice */ + // table or userdata with __call, will be called during FORGLOOP + // TODO: we might be able to stop supporting this depending on whether it's used in practice } else if (ttistable(ra)) { - /* set up registers for builtin iteration */ + // set up registers for builtin iteration setobj2s(L, ra + 1, ra); setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); setnilvalue(ra); @@ -2354,7 +2357,7 @@ static void luau_execute(lua_State* L) setobjs2s(L, ra + 3 + 1, ra + 1); setobjs2s(L, ra + 3, ra); - L->top = ra + 3 + 3; /* func + 2 args (state and index) */ + L->top = ra + 3 + 3; // func + 2 args (state and index) LUAU_ASSERT(L->top <= L->stack_last); VM_PROTECT(luaD_call(L, ra + 3, uint8_t(aux))); @@ -2382,7 +2385,7 @@ static void luau_execute(lua_State* L) if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) { setnilvalue(ra); - /* ra+1 is already the table */ + // ra+1 is already the table setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } else if (!ttisfunction(ra)) @@ -2454,7 +2457,7 @@ static void luau_execute(lua_State* L) if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) { setnilvalue(ra); - /* ra+1 is already the table */ + // ra+1 is already the table setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } else if (!ttisfunction(ra)) @@ -2629,8 +2632,8 @@ static void luau_execute(lua_State* L) LUAU_ASSERT(cast_int(L->top - base) >= numparams); // move fixed parameters to final position - StkId fixed = base; /* first fixed argument */ - base = L->top; /* final position of first argument */ + StkId fixed = base; // first fixed argument + base = L->top; // final position of first argument for (int i = 0; i < numparams; ++i) { @@ -2993,6 +2996,56 @@ static void luau_execute(lua_State* L) VM_CONTINUE(op); } + VM_CASE(LOP_JUMPXEQKNIL) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + static_assert(LUA_TNIL == 0, "we expect type-1 to be negative iff type is nil"); + // condition is equivalent to: int(ttisnil(ra)) != (aux >> 31) + pc += int((ttype(ra) - 1) ^ aux) < 0 ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_JUMPXEQKB) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + pc += int(ttisboolean(ra) && bvalue(ra) == int(aux & 1)) != (aux >> 31) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_JUMPXEQKN) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(aux & 0xffffff); + LUAU_ASSERT(ttisnumber(kv)); + + pc += int(ttisnumber(ra) && nvalue(ra) == nvalue(kv)) != (aux >> 31) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_JUMPXEQKS) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(aux & 0xffffff); + LUAU_ASSERT(ttisstring(kv)); + + pc += int(ttisstring(ra) && gcvalue(ra) == gcvalue(kv)) != (aux >> 31) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + #if !VM_USE_CGOTO default: LUAU_ASSERT(!"Unknown opcode"); @@ -3042,7 +3095,7 @@ int luau_precall(lua_State* L, StkId func, int nresults) StkId argi = L->top; StkId argend = L->base + ccl->l.p->numparams; while (argi < argend) - setnilvalue(argi++); /* complete missing arguments */ + setnilvalue(argi++); // complete missing arguments L->top = ccl->l.p->is_vararg ? argi : ci->top; L->ci->savedpc = ccl->l.p->code; diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index b9e762eb..8be241e0 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -12,9 +12,9 @@ #include #include -LUAU_FASTFLAG(LuauLenTM) +LUAU_FASTFLAGVARIABLE(LuauBetterNewindex, false) -/* limit for table tag-method chains (to avoid loops) */ +// limit for table tag-method chains (to avoid loops) #define MAXTAGLOOP 100 const TValue* luaV_tonumber(const TValue* obj, TValue* n) @@ -65,9 +65,9 @@ static StkId callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p // * during stack reallocation all of the allocated stack is copied (even beyond stack_last) so these // values will be preserved even if they go past stack_last LUAU_ASSERT((L->top + 3) < (L->stack + L->stacksize)); - setobj2s(L, L->top, f); /* push function */ - setobj2s(L, L->top + 1, p1); /* 1st argument */ - setobj2s(L, L->top + 2, p2); /* 2nd argument */ + setobj2s(L, L->top, f); // push function + setobj2s(L, L->top + 1, p1); // 1st argument + setobj2s(L, L->top + 2, p2); // 2nd argument luaD_checkstack(L, 3); L->top += 3; luaD_call(L, L->top - 3, 1); @@ -87,10 +87,10 @@ static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue // * during stack reallocation all of the allocated stack is copied (even beyond stack_last) so these // values will be preserved even if they go past stack_last LUAU_ASSERT((L->top + 4) < (L->stack + L->stacksize)); - setobj2s(L, L->top, f); /* push function */ - setobj2s(L, L->top + 1, p1); /* 1st argument */ - setobj2s(L, L->top + 2, p2); /* 2nd argument */ - setobj2s(L, L->top + 3, p3); /* 3th argument */ + setobj2s(L, L->top, f); // push function + setobj2s(L, L->top + 1, p1); // 1st argument + setobj2s(L, L->top + 2, p2); // 2nd argument + setobj2s(L, L->top + 3, p3); // 3th argument luaD_checkstack(L, 4); L->top += 4; luaD_call(L, L->top - 4, 0); @@ -103,21 +103,21 @@ void luaV_gettable(lua_State* L, const TValue* t, TValue* key, StkId val) { const TValue* tm; if (ttistable(t)) - { /* `t' is a table? */ + { // `t' is a table? Table* h = hvalue(t); - const TValue* res = luaH_get(h, key); /* do a primitive get */ + const TValue* res = luaH_get(h, key); // do a primitive get if (res != luaO_nilobject) - L->cachedslot = gval2slot(h, res); /* remember slot to accelerate future lookups */ + L->cachedslot = gval2slot(h, res); // remember slot to accelerate future lookups - if (!ttisnil(res) /* result is no nil? */ + if (!ttisnil(res) // result is no nil? || (tm = fasttm(L, h->metatable, TM_INDEX)) == NULL) - { /* or no TM? */ + { // or no TM? setobj2s(L, val, res); return; } - /* t isn't a table, so see if it has an INDEX meta-method to look up the key with */ + // t isn't a table, so see if it has an INDEX meta-method to look up the key with } else if (ttisnil(tm = luaT_gettmbyobj(L, t, TM_INDEX))) luaG_indexerror(L, t, key); @@ -126,9 +126,9 @@ void luaV_gettable(lua_State* L, const TValue* t, TValue* key, StkId val) callTMres(L, val, tm, t, key); return; } - t = tm; /* else repeat with `tm' */ + t = tm; // else repeat with `tm' } - luaG_runerror(L, "loop in gettable"); + luaG_runerror(L, "'__index' chain too long; possible loop"); } void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val) @@ -139,44 +139,70 @@ void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val) { const TValue* tm; if (ttistable(t)) - { /* `t' is a table? */ + { // `t' is a table? Table* h = hvalue(t); - if (h->readonly) - luaG_runerror(L, "Attempt to modify a readonly table"); + if (FFlag::LuauBetterNewindex) + { + const TValue* oldval = luaH_get(h, key); - TValue* oldval = luaH_set(L, h, key); /* do a primitive set */ + // should we assign the key? (if key is valid or __newindex is not set) + if (!ttisnil(oldval) || (tm = fasttm(L, h->metatable, TM_NEWINDEX)) == NULL) + { + if (h->readonly) + luaG_readonlyerror(L); - L->cachedslot = gval2slot(h, oldval); /* remember slot to accelerate future lookups */ + // luaH_set would work but would repeat the lookup so we use luaH_setslot that can reuse oldval if it's safe + TValue* newval = luaH_setslot(L, h, oldval, key); - if (!ttisnil(oldval) || /* result is no nil? */ - (tm = fasttm(L, h->metatable, TM_NEWINDEX)) == NULL) - { /* or no TM? */ - setobj2t(L, oldval, val); - luaC_barriert(L, h, val); - return; + L->cachedslot = gval2slot(h, newval); // remember slot to accelerate future lookups + + setobj2t(L, newval, val); + luaC_barriert(L, h, val); + return; + } + + // fallthrough to metamethod + } + else + { + if (h->readonly) + luaG_readonlyerror(L); + + TValue* oldval = luaH_set(L, h, key); // do a primitive set + + L->cachedslot = gval2slot(h, oldval); // remember slot to accelerate future lookups + + if (!ttisnil(oldval) || // result is no nil? + (tm = fasttm(L, h->metatable, TM_NEWINDEX)) == NULL) + { // or no TM? + setobj2t(L, oldval, val); + luaC_barriert(L, h, val); + return; + } + // else will try the tag method } - /* else will try the tag method */ } else if (ttisnil(tm = luaT_gettmbyobj(L, t, TM_NEWINDEX))) luaG_indexerror(L, t, key); + if (ttisfunction(tm)) { callTM(L, tm, t, key, val); return; } - /* else repeat with `tm' */ - setobj(L, &temp, tm); /* avoid pointing inside table (may rehash) */ + // else repeat with `tm' + setobj(L, &temp, tm); // avoid pointing inside table (may rehash) t = &temp; } - luaG_runerror(L, "loop in settable"); + luaG_runerror(L, "'__newindex' chain too long; possible loop"); } static int call_binTM(lua_State* L, const TValue* p1, const TValue* p2, StkId res, TMS event) { - const TValue* tm = luaT_gettmbyobj(L, p1, event); /* try first operand */ + const TValue* tm = luaT_gettmbyobj(L, p1, event); // try first operand if (ttisnil(tm)) - tm = luaT_gettmbyobj(L, p2, event); /* try second operand */ + tm = luaT_gettmbyobj(L, p2, event); // try second operand if (ttisnil(tm)) return 0; callTMres(L, res, tm, p1, p2); @@ -188,13 +214,13 @@ static const TValue* get_compTM(lua_State* L, Table* mt1, Table* mt2, TMS event) const TValue* tm1 = fasttm(L, mt1, event); const TValue* tm2; if (tm1 == NULL) - return NULL; /* no metamethod */ + return NULL; // no metamethod if (mt1 == mt2) - return tm1; /* same metatables => same metamethods */ + return tm1; // same metatables => same metamethods tm2 = fasttm(L, mt2, event); if (tm2 == NULL) - return NULL; /* no metamethod */ - if (luaO_rawequalObj(tm1, tm2)) /* same metamethods? */ + return NULL; // no metamethod + if (luaO_rawequalObj(tm1, tm2)) // same metamethods? return tm1; return NULL; } @@ -204,9 +230,9 @@ static int call_orderTM(lua_State* L, const TValue* p1, const TValue* p2, TMS ev const TValue* tm1 = luaT_gettmbyobj(L, p1, event); const TValue* tm2; if (ttisnil(tm1)) - return -1; /* no metamethod? */ + return -1; // no metamethod? tm2 = luaT_gettmbyobj(L, p2, event); - if (!luaO_rawequalObj(tm1, tm2)) /* different metamethods? */ + if (!luaO_rawequalObj(tm1, tm2)) // different metamethods? return -1; callTMres(L, L->top, tm1, p1, p2); return !l_isfalse(L->top); @@ -253,9 +279,9 @@ int luaV_lessequal(lua_State* L, const TValue* l, const TValue* r) return luai_numle(nvalue(l), nvalue(r)); else if (ttisstring(l)) return luaV_strcmp(tsvalue(l), tsvalue(r)) <= 0; - else if ((res = call_orderTM(L, l, r, TM_LE)) != -1) /* first try `le' */ + else if ((res = call_orderTM(L, l, r, TM_LE)) != -1) // first try `le' return res; - else if ((res = call_orderTM(L, r, l, TM_LT)) == -1) /* error if not `lt' */ + else if ((res = call_orderTM(L, r, l, TM_LT)) == -1) // error if not `lt' luaG_ordererror(L, l, r, TM_LE); return !res; } @@ -273,7 +299,7 @@ int luaV_equalval(lua_State* L, const TValue* t1, const TValue* t2) case LUA_TVECTOR: return luai_veceq(vvalue(t1), vvalue(t2)); case LUA_TBOOLEAN: - return bvalue(t1) == bvalue(t2); /* true must be 1 !! */ + return bvalue(t1) == bvalue(t2); // true must be 1 !! case LUA_TLIGHTUSERDATA: return pvalue(t1) == pvalue(t2); case LUA_TUSERDATA: @@ -281,19 +307,19 @@ int luaV_equalval(lua_State* L, const TValue* t1, const TValue* t2) tm = get_compTM(L, uvalue(t1)->metatable, uvalue(t2)->metatable, TM_EQ); if (!tm) return uvalue(t1) == uvalue(t2); - break; /* will try TM */ + break; // will try TM } case LUA_TTABLE: { tm = get_compTM(L, hvalue(t1)->metatable, hvalue(t2)->metatable, TM_EQ); if (!tm) return hvalue(t1) == hvalue(t2); - break; /* will try TM */ + break; // will try TM } default: return gcvalue(t1) == gcvalue(t2); } - callTMres(L, L->top, tm, t1, t2); /* call TM */ + callTMres(L, L->top, tm, t1, t2); // call TM return !l_isfalse(L->top); } @@ -302,21 +328,21 @@ void luaV_concat(lua_State* L, int total, int last) do { StkId top = L->base + last + 1; - int n = 2; /* number of elements handled in this pass (at least 2) */ + int n = 2; // number of elements handled in this pass (at least 2) if (!(ttisstring(top - 2) || ttisnumber(top - 2)) || !tostring(L, top - 1)) { if (!call_binTM(L, top - 2, top - 1, top - 2, TM_CONCAT)) luaG_concaterror(L, top - 2, top - 1); } - else if (tsvalue(top - 1)->len == 0) /* second op is empty? */ - (void)tostring(L, top - 2); /* result is first op (as string) */ + else if (tsvalue(top - 1)->len == 0) // second op is empty? + (void)tostring(L, top - 2); // result is first op (as string) else { - /* at least two string values; get as many as possible */ + // at least two string values; get as many as possible size_t tl = tsvalue(top - 1)->len; char* buffer; int i; - /* collect total length */ + // collect total length for (n = 1; n < total && tostring(L, top - n - 1); n++) { size_t l = tsvalue(top - n - 1)->len; @@ -340,7 +366,7 @@ void luaV_concat(lua_State* L, int total, int last) tl = 0; for (i = n; i > 0; i--) - { /* concat all strings */ + { // concat all strings size_t l = tsvalue(top - i)->len; memcpy(buffer + tl, svalue(top - i), l); tl += l; @@ -355,9 +381,9 @@ void luaV_concat(lua_State* L, int total, int last) setsvalue2s(L, top - n, luaS_buffinish(L, ts)); } } - total -= n - 1; /* got `n' strings to create 1 new */ + total -= n - 1; // got `n' strings to create 1 new last -= n - 1; - } while (total > 1); /* repeat until only 1 result left */ + } while (total > 1); // repeat until only 1 result left } void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op) @@ -476,29 +502,6 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM void luaV_dolen(lua_State* L, StkId ra, const TValue* rb) { - if (!FFlag::LuauLenTM) - { - switch (ttype(rb)) - { - case LUA_TTABLE: - { - setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); - break; - } - case LUA_TSTRING: - { - setnvalue(ra, cast_num(tsvalue(rb)->len)); - break; - } - default: - { /* try metamethod */ - if (!call_binTM(L, rb, luaO_nilobject, ra, TM_LEN)) - luaG_typeerror(L, rb, "get length of"); - } - } - return; - } - const TValue* tm = NULL; switch (ttype(rb)) { @@ -527,5 +530,5 @@ void luaV_dolen(lua_State* L, StkId ra, const TValue* rb) StkId res = callTMres(L, ra, tm, rb, luaO_nilobject); if (!ttisnumber(res)) - luaG_runerror(L, "'__len' must return a number"); /* note, we can't access rb since stack may have been reallocated */ + luaG_runerror(L, "'__len' must return a number"); // note, we can't access rb since stack may have been reallocated } diff --git a/bench/bench.py b/bench/bench.py index e78e96a8..bb3ea5f7 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -101,8 +101,10 @@ def getVmOutput(cmd): elif arguments.callgrind: try: subprocess.check_call("valgrind --tool=callgrind --callgrind-out-file=callgrind.out --combine-dumps=yes --dump-line=no " + cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, cwd=scriptdir) - file = open(os.path.join(scriptdir, "callgrind.out"), "r") - lines = file.readlines() + path = os.path.join(scriptdir, "callgrind.out") + with open(path, "r") as file: + lines = file.readlines() + os.unlink(path) return getCallgrindOutput(lines) except: return "" @@ -402,12 +404,12 @@ def analyzeResult(subdir, main, comparisons): continue - pooledStdDev = math.sqrt((main.unbiasedEst + compare.unbiasedEst) / 2) + if main.count > 1 and stats: + pooledStdDev = math.sqrt((main.unbiasedEst + compare.unbiasedEst) / 2) - tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) - degreesOfFreedom = 2 * main.count - 2 + tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) + degreesOfFreedom = 2 * main.count - 2 - if stats: # Two-tailed distribution with 95% conf. tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) diff --git a/bench/gc/test_SunSpider_crypto-aes.lua b/bench/gc/test_SunSpider_crypto-aes.lua deleted file mode 100644 index 8537e3da..00000000 --- a/bench/gc/test_SunSpider_crypto-aes.lua +++ /dev/null @@ -1,436 +0,0 @@ ---[[ - * AES Cipher function: encrypt 'input' with Rijndael algorithm - * - * takes byte-array 'input' (16 bytes) - * 2D byte-array key schedule 'w' (Nr+1 x Nb bytes) - * - * applies Nr rounds (10/12/14) using key schedule w for 'add round key' stage - * - * returns byte-array encrypted value (16 bytes) - */]] - - local bench = script and require(script.Parent.bench_support) or require("bench_support") - -function test() - --- Sbox is pre-computed multiplicative inverse in GF(2^8) used in SubBytes and KeyExpansion [§5.1.1] -local Sbox = { 0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76, - 0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0, - 0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15, - 0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75, - 0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84, - 0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf, - 0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8, - 0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2, - 0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73, - 0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb, - 0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79, - 0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08, - 0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a, - 0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e, - 0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf, - 0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16 }; - --- Rcon is Round Constant used for the Key Expansion [1st col is 2^(r-1) in GF(2^8)] [§5.2] -local Rcon = { { 0x00, 0x00, 0x00, 0x00 }, - {0x01, 0x00, 0x00, 0x00}, - {0x02, 0x00, 0x00, 0x00}, - {0x04, 0x00, 0x00, 0x00}, - {0x08, 0x00, 0x00, 0x00}, - {0x10, 0x00, 0x00, 0x00}, - {0x20, 0x00, 0x00, 0x00}, - {0x40, 0x00, 0x00, 0x00}, - {0x80, 0x00, 0x00, 0x00}, - {0x1b, 0x00, 0x00, 0x00}, - {0x36, 0x00, 0x00, 0x00} }; - -function Cipher(input, w) -- main Cipher function [§5.1] - local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) - local Nr = #w / Nb - 1; -- no of rounds: 10/12/14 for 128/192/256-bit keys - - local state = {{},{},{},{}}; -- initialise 4xNb byte-array 'state' with input [§3.4] - for i = 0,4*Nb-1 do state[(i % 4) + 1][math.floor(i/4) + 1] = input[i + 1]; end - - state = AddRoundKey(state, w, 0, Nb); - - for round = 1,Nr-1 do - state = SubBytes(state, Nb); - state = ShiftRows(state, Nb); - state = MixColumns(state, Nb); - state = AddRoundKey(state, w, round, Nb); - end - - state = SubBytes(state, Nb); - state = ShiftRows(state, Nb); - state = AddRoundKey(state, w, Nr, Nb); - - local output = {} -- convert state to 1-d array before returning [§3.4] - for i = 0,4*Nb-1 do output[i + 1] = state[(i % 4) + 1][math.floor(i / 4) + 1]; end - - return output; -end - - -function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1] - for r = 0,3 do - for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end - end - return s; -end - - -function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2] - local t = {}; - for r = 1,3 do - for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy - for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back - end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES): - return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf -end - - -function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3] - for c = 0,3 do - local a = {}; -- 'a' is a copy of the current column from 's' - local b = {}; -- 'b' is a•{02} in GF(2^8) - for i = 0,3 do - a[i + 1] = s[i + 1][c + 1]; - - if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then - b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b); - else - b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1); - end - end - -- a[n] ^ b[n] is a•{03} in GF(2^8) - s[1][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(b[1], a[2]), bit32.bxor(b[2], a[3])), a[4]); -- 2*a0 + 3*a1 + a2 + a3 - s[2][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(a[1], b[2]), bit32.bxor(a[3], b[3])), a[4]); -- a0 * 2*a1 + 3*a2 + a3 - s[3][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(a[1], a[2]), bit32.bxor(b[3], a[4])), b[4]); -- a0 + a1 + 2*a2 + 3*a3 - s[4][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(a[1], b[1]), bit32.bxor(a[2], a[3])), b[4]); -- 3*a0 + a1 + a2 + 2*a3 -end - return s; -end - - -function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4] - for r = 0,3 do - for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end - end - return state; -end - - -function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2] - local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) - local Nk = #key / 4 -- key length (in words): 4/6/8 for 128/192/256-bit keys - local Nr = Nk + 6; -- no of rounds: 10/12/14 for 128/192/256-bit keys - - local w = {}; - local temp = {}; - - for i = 0,Nk do - local r = { key[4*i + 1], key[4*i + 2], key[4*i + 3], key[4*i + 4] }; - w[i + 1] = r; - end - - for i = Nk,(Nb*(Nr+1)) - 1 do - w[i + 1] = {}; - for t = 0,3 do temp[t + 1] = w[i-1 + 1][t + 1]; end - if (i % Nk == 0) then - temp = SubWord(RotWord(temp)); - for t = 0,3 do temp[t + 1] = bit32.bxor(temp[t + 1], Rcon[i/Nk + 1][t + 1]); end - elseif (Nk > 6 and i % Nk == 4) then - temp = SubWord(temp); - end - for t = 0,3 do w[i + 1][t + 1] = bit32.bxor(w[i - Nk + 1][t + 1], temp[t + 1]); end - end - - return w; -end - -function SubWord(w) -- apply SBox to 4-byte word w - for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end - return w; -end - -function RotWord(w) -- rotate 4-byte word w left by one byte - w[5] = w[1]; - for i = 0,3 do w[i + 1] = w[i + 2]; end - return w; -end - - ---[[ - * Use AES to encrypt 'plaintext' with 'password' using 'nBits' key, in 'Counter' mode of operation - * - see http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf - * for each block - * - outputblock = cipher(counter, key) - * - cipherblock = plaintext xor outputblock - ]] - -function AESEncryptCtr(plaintext, password, nBits) - if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys - - -- for this example script, generate the key by applying Cipher to 1st 16/24/32 chars of password; - -- for real-world applications, a higher security approach would be to hash the password e.g. with SHA-1 - local nBytes = nBits/8; -- no bytes in key - local pwBytes = {}; - for i = 0,nBytes-1 do pwBytes[i + 1] = bit32.band(string.byte(password, i + 1), 0xff); end - local key = Cipher(pwBytes, KeyExpansion(pwBytes)); - - -- key is now 16/24/32 bytes long - for i = 1,nBytes-16 do - table.insert(key, key[i]) - end - - -- initialise counter block (NIST SP800-38A §B.2): millisecond time-stamp for nonce in 1st 8 bytes, - -- block counter in 2nd 8 bytes - local blockSize = 16; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES - local counterBlock = {}; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES - local nonce = 12564231564 -- (new Date()).getTime(); -- milliseconds since 1-Jan-1970 - - -- encode nonce in two stages to cater for JavaScript 32-bit limit on bitwise ops - for i = 0,3 do counterBlock[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff); end - for i = 0,3 do counterBlock[i + 4 + 1] = bit32.band(bit32.rshift(math.floor(nonce / 0x100000000), i*8), 0xff); end - - -- generate key schedule - an expansion of the key into distinct Key Rounds for each round - local keySchedule = KeyExpansion(key); - - local blockCount = math.ceil(#plaintext / blockSize); - local ciphertext = {}; -- ciphertext as array of strings - - for b = 0,blockCount-1 do - -- set counter (block #) in last 8 bytes of counter block (leaving nonce in 1st 8 bytes) - -- again done in two stages for 32-bit ops - for c = 0,3 do counterBlock[15-c + 1] = bit32.band(bit32.rshift(b, c*8), 0xff); end - for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.rshift(math.floor(b/0x100000000), c*8) end - - local cipherCntr = Cipher(counterBlock, keySchedule); -- -- encrypt counter block -- - - -- calculate length of final block: - local blockLength = nil - - if b= 0, strings that can be automatically coered to numbers that are >= 0, false and nil + -- Do note that empty regex patterns (comment-only patterns included) are never cached regardless + -- The default is 256 + cacheSize = 256, + + -- A boolean that determines whether this use unicode data + -- If this value evalulates to false, you can remove _unicodechar_category, _scripts and _xuc safely and it'll now error if: + -- - You try to compile a RegEx with unicode flag + -- - You try to use the \p pattern + -- The default is true + unicodeData = false, +}; + +-- +local u_categories = options.unicodeData and require(script:WaitForChild("_unicodechar_category")); +local chr_scripts = options.unicodeData and require(script:WaitForChild("_scripts")); +local xuc_chr = options.unicodeData and require(script:WaitForChild("_xuc")); +local proxy = setmetatable({ }, { __mode = 'k' }); +local re, re_m, match_m = { }, { }, { }; +local lockmsg; + +--[[ Functions ]]-- +local function to_str_arr(self, init) + if init then + self = string.sub(self, utf8.offset(self, init)); + end; + local len = utf8.len(self); + if len <= 1999 then + return { n = len, s = self, utf8.codepoint(self, 1, #self) }; + end; + local clen = math.ceil(len / 1999); + local ret = table.create(len); + local p = 1; + for i = 1, clen do + local c = table.pack(utf8.codepoint(self, utf8.offset(self, i * 1999 - 1998), utf8.offset(self, i * 1999 - (i == clen and 1998 - ((len - 1) % 1999 + 1) or - 1)) - 1)); + table.move(c, 1, c.n, p, ret); + p += c.n; + end; + ret.s, ret.n = self, len; + return ret; +end; + +local function from_str_arr(self) + local len = self.n or #self; + if len <= 7997 then + return utf8.char(table.unpack(self)); + end; + local clen = math.ceil(len / 7997); + local r = table.create(clen); + for i = 1, clen do + r[i] = utf8.char(table.unpack(self, i * 7997 - 7996, i * 7997 - (i == clen and 7997 - ((len - 1) % 7997 + 1) or 0))); + end; + return table.concat(r); +end; + +local function utf8_sub(self, i, j) + j = utf8.offset(self, j); + return string.sub(self, utf8.offset(self, i), j and j - 1); +end; + +-- +local flag_map = { + a = 'anchored', i = 'caseless', m = 'multiline', s = 'dotall', u = 'unicode', U = 'ungreedy', x ='extended', +}; + +local posix_class_names = { + alnum = true, alpha = true, ascii = true, blank = true, cntrl = true, digit = true, graph = true, lower = true, print = true, punct = true, space = true, upper = true, word = true, xdigit = true, +}; + +local escape_chars = { + -- grouped + -- digit, spaces and words + [0x44] = { "class", "digit", true }, [0x53] = { "class", "space", true }, [0x57] = { "class", "word", true }, + [0x64] = { "class", "digit", false }, [0x73] = { "class", "space", false }, [0x77] = { "class", "word", false }, + -- horizontal/vertical whitespace and newline + [0x48] = { "class", "blank", true }, [0x56] = { "class", "vertical_tab", true }, + [0x68] = { "class", "blank", false }, [0x76] = { "class", "vertical_tab", false }, + [0x4E] = { 0x4E }, [0x52] = { 0x52 }, + + -- not grouped + [0x42] = 0x08, + [0x6E] = 0x0A, [0x72] = 0x0D, [0x74] = 0x09, +}; + +local b_escape_chars = { + -- word boundary and not word boundary + [0x62] = { 0x62, { "class", "word", false } }, [0x42] = { 0x42, { "class", "word", false } }, + + -- keep match out + [0x4B] = { 0x4B }, + + -- start & end of string + [0x47] = { 0x47 }, [0x4A] = { 0x4A }, [0x5A] = { 0x5A }, [0x7A] = { 0x7A }, +}; + +local valid_categories = { + C = true, Cc = true, Cf = true, Cn = true, Co = true, Cs = true, + L = true, Ll = true, Lm = true, Lo = true, Lt = true, Lu = true, + M = true, Mc = true, Me = true, Mn = true, + N = true, Nd = true, Nl = true, No = true, + P = true, Pc = true, Pd = true, Pe = true, Pf = true, Pi = true, Po = true, Ps = true, + S = true, Sc = true, Sk = true, Sm = true, So = true, + Z = true, Zl = true, Zp = true, Zs = true, + + Xan = true, Xps = true, Xsp = true, Xuc = true, Xwd = true, +}; + +local class_ascii_punct = { + [0x21] = true, [0x22] = true, [0x23] = true, [0x24] = true, [0x25] = true, [0x26] = true, [0x27] = true, [0x28] = true, [0x29] = true, [0x2A] = true, [0x2B] = true, [0x2C] = true, [0x2D] = true, [0x2E] = true, [0x2F] = true, + [0x3A] = true, [0x3B] = true, [0x3C] = true, [0x3D] = true, [0x3E] = true, [0x3F] = true, [0x40] = true, [0x5B] = true, [0x5C] = true, [0x5D] = true, [0x5E] = true, [0x5F] = true, [0x60] = true, [0x7B] = true, [0x7C] = true, + [0x7D] = true, [0x7E] = true, +}; + +local end_str = { 0x24 }; +local dot = { 0x2E }; +local beginning_str = { 0x5E }; +local alternation = { 0x7C }; + +local function check_re(re_type, name, func) + if re_type == "Match" then + return function(...) + local arg_n = select('#', ...); + if arg_n < 1 then + error("missing argument #1 (Match expected)", 2); + end; + local arg0, arg1 = ...; + if not (proxy[arg0] and proxy[arg0].name == "Match") then + error(string.format("invalid argument #1 to %q (Match expected, got %s)", name, typeof(arg0)), 2); + else + arg0 = proxy[arg0]; + end; + if name == "group" or name == "span" then + if arg1 == nil then + arg1 = 0; + end; + end; + return func(arg0, arg1); + end; + end; + return function(...) + local arg_n = select('#', ...); + if arg_n < 1 then + error("missing argument #1 (RegEx expected)", 2); + elseif arg_n < 2 then + error("missing argument #2 (string expected)", 2); + end; + local arg0, arg1, arg2, arg3, arg4, arg5 = ...; + if not (proxy[arg0] and proxy[arg0].name == "RegEx") then + if type(arg0) ~= "string" and type(arg0) ~= "number" then + error(string.format("invalid argument #1 to %q (RegEx expected, got %s)", name, typeof(arg0)), 2); + end; + arg0 = re.fromstring(arg0); + elseif name == "sub" then + if type(arg2) == "number" then + arg2 ..= ''; + elseif type(arg2) ~= "string" then + error(string.format("invalid argument #3 to 'sub' (string expected, got %s)", typeof(arg2)), 2); + end; + elseif type(arg1) == "number" then + arg1 ..= ''; + elseif type(arg1) ~= "string" then + error(string.format("invalid argument #2 to %q (string expected, got %s)", name, typeof(arg1)), 2); + end; + if name ~= "sub" and name ~= "split" then + local init_type = typeof(arg2); + if init_type ~= 'nil' then + arg2 = tonumber(arg2); + if not arg2 then + error(string.format("invalid argument #3 to %q (number expected, got %s)", name, init_type), 2); + elseif arg2 < 0 then + arg2 = #arg1 + math.floor(arg2 + 0.5) + 1; + else + arg2 = math.max(math.floor(arg2 + 0.5), 1); + end; + end; + end; + arg0 = proxy[arg0]; + if name == "match" or name == "matchiter" then + arg3 = ...; + elseif name == "sub" then + arg5 = ...; + end; + return func(arg0, arg1, arg2, arg3, arg4, arg5); + end; +end; + +--[[ Matches ]]-- +local function match_tostr(self) + local spans = proxy[self].spans; + local s_start, s_end = spans[0][1], spans[0][2]; + if s_end <= s_start then + return string.format("Match (%d..%d, empty)", s_start, s_end - 1); + end; + return string.format("Match (%d..%d): %s", s_start, s_end - 1, utf8_sub(spans.input, s_start, s_end)); +end; + +local function new_match(span_arr, group_id, re, str) + span_arr.source, span_arr.input = re, str; + local object = newproxy(true); + local object_mt = getmetatable(object); + object_mt.__metatable = lockmsg; + object_mt.__index = setmetatable(span_arr, match_m); + object_mt.__tostring = match_tostr; + + proxy[object] = { name = "Match", spans = span_arr, group_id = group_id }; + return object; +end; + +match_m.group = check_re('Match', 'group', function(self, group_id) + local span = self.spans[type(group_id) == "number" and group_id or self.group_id[group_id]]; + if not span then + return nil; + end; + return utf8_sub(self.spans.input, span[1], span[2]); +end); + +match_m.span = check_re('Match', 'span', function(self, group_id) + local span = self.spans[type(group_id) == "number" and group_id or self.group_id[group_id]]; + if not span then + return nil; + end; + return span[1], span[2] - 1; +end); + +match_m.groups = check_re('Match', 'groups', function(self) + local spans = self.spans; + if spans.n > 0 then + local ret = table.create(spans.n); + for i = 0, spans.n do + local v = spans[i]; + if v then + ret[i] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + return table.unpack(ret, 1, spans.n); + end; + return utf8_sub(spans.input, spans[0][1], spans[0][2]); +end); + +match_m.groupdict = check_re('Match', 'groupdict', function(self) + local spans = self.spans; + local ret = { }; + for k, v in pairs(self.group_id) do + v = spans[v]; + if v then + ret[k] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + return ret; +end); + +match_m.grouparr = check_re('Match', 'groupdict', function(self) + local spans = self.spans; + local ret = table.create(spans.n); + for i = 0, spans.n do + local v = spans[i]; + if v then + ret[i] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + ret.n = spans.n; + return ret; +end); + +-- +local line_verbs = { + CR = 0, LF = 1, CRLF = 2, ANYRLF = 3, ANY = 4, NUL = 5, +}; +local function is_newline(str_arr, i, verb_flags) + local line_verb_n = verb_flags.newline; + local chr = str_arr[i]; + if line_verb_n == 0 then + -- carriage return + return chr == 0x0D; + elseif line_verb_n == 2 then + -- carriage return followed by line feed + return chr == 0x0A and str_arr[i - 1] == 0x20; + elseif line_verb_n == 3 then + -- any of the above + return chr == 0x0A or chr == 0x0D; + elseif line_verb_n == 4 then + -- any of Unicode newlines + return chr == 0x0A or chr == 0x0B or chr == 0x0C or chr == 0x0D or chr == 0x85 or chr == 0x2028 or chr == 0x2029; + elseif line_verb_n == 5 then + -- null + return chr == 0; + end; + -- linefeed + return chr == 0x0A; +end; + + +local function tkn_char_match(tkn_part, str_arr, i, flags, verb_flags) + local chr = str_arr[i]; + if not chr then + return false; + elseif flags.ignoreCase and chr >= 0x61 and chr <= 0x7A then + chr -= 0x20; + end; + if type(tkn_part) == "number" then + return tkn_part == chr; + elseif tkn_part[1] == "charset" then + for _, v in ipairs(tkn_part[3]) do + if tkn_char_match(v, str_arr, i, flags, verb_flags) then + return not tkn_part[2]; + end; + end; + return tkn_part[2]; + elseif tkn_part[1] == "range" then + return chr >= tkn_part[2] and chr <= tkn_part[3] or flags.ignoreCase and chr >= 0x41 and chr <= 0x5A and (chr + 0x20) >= tkn_part[2] and (chr + 0x20) <= tkn_part[3]; + elseif tkn_part[1] == "class" then + local char_class = tkn_part[2]; + local negate = tkn_part[3]; + local match = false; + -- if and elseifs :( + -- Might make these into tables in the future + if char_class == "xdigit" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x46 or chr >= 0x61 and chr <= 0x66; + elseif char_class == "ascii" then + match = chr <= 0x7F; + -- cannot be accessed through POSIX classes + elseif char_class == "vertical_tab" then + match = chr >= 0x0A and chr <= 0x0D or chr == 0x2028 or chr == 0x2029; + -- + elseif flags.unicode then + local current_category = u_categories[chr] or 'Cn'; + local first_category = current_category:sub(1, 1); + if char_class == "alnum" then + match = first_category == 'L' or current_category == 'Nl' or current_category == 'Nd'; + elseif char_class == "alpha" then + match = first_category == 'L' or current_category == 'Nl'; + elseif char_class == "blank" then + match = current_category == 'Zs' or chr == 0x09; + elseif char_class == "cntrl" then + match = current_category == 'Cc'; + elseif char_class == "digit" then + match = current_category == 'Nd'; + elseif char_class == "graph" then + match = first_category ~= 'P' and first_category ~= 'C'; + elseif char_class == "lower" then + match = current_category == 'Ll'; + elseif char_class == "print" then + match = first_category ~= 'C'; + elseif char_class == "punct" then + match = first_category == 'P'; + elseif char_class == "space" then + match = first_category == 'Z' or chr >= 0x09 and chr <= 0x0D; + elseif char_class == "upper" then + match = current_category == 'Lu'; + elseif char_class == "word" then + match = first_category == 'L' or current_category == 'Nl' or current_category == 'Nd' or current_category == 'Pc'; + end; + elseif char_class == "alnum" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A; + elseif char_class == "alpha" then + match = chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A; + elseif char_class == "blank" then + match = chr == 0x09 or chr == 0x20; + elseif char_class == "cntrl" then + match = chr <= 0x1F or chr == 0x7F; + elseif char_class == "digit" then + match = chr >= 0x30 and chr <= 0x39; + elseif char_class == "graph" then + match = chr >= 0x21 and chr <= 0x7E; + elseif char_class == "lower" then + match = chr >= 0x61 and chr <= 0x7A; + elseif char_class == "print" then + match = chr >= 0x20 and chr <= 0x7E; + elseif char_class == "punct" then + match = class_ascii_punct[chr]; + elseif char_class == "space" then + match = chr >= 0x09 and chr <= 0x0D or chr == 0x20; + elseif char_class == "upper" then + match = chr >= 0x41 and chr <= 0x5A; + elseif char_class == "word" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A or chr == 0x5F; + end; + if negate then + return not match; + end; + return match; + elseif tkn_part[1] == "category" then + local chr_category = u_categories[chr] or 'Cn'; + local category_v = tkn_part[3]; + local category_len = #category_v; + if category_len == 3 then + local match = false; + if category_v == "Xan" or category_v == "Xwd" then + match = chr_category:find("^[LN]") or category_v == "Xwd" and chr == 0x5F; + elseif category_v == "Xps" or category_v == "Xsp" then + match = chr_category:sub(1, 1) == 'Z' or chr >= 0x09 and chr <= 0x0D; + elseif category_v == "Xuc" then + match = tkn_char_match(xuc_chr, str_arr, i, flags, verb_flags); + end; + if tkn_part[2] then + return not match; + end + return match; + elseif chr_category:sub(1, category_len) == category_v then + return not tkn_part[2]; + end; + return tkn_part[2]; + elseif tkn_part[1] == 0x2E then + return flags.dotAll or not is_newline(str_arr, i, verb_flags); + elseif tkn_part[1] == 0x4E then + return not is_newline(str_arr, i, verb_flags); + elseif tkn_part[1] == 0x52 then + if verb_flags.newline_seq == 0 then + -- CR, LF or CRLF + return chr == 0x0A or chr == 0x0D; + end; + -- any unicode newline + return chr == 0x0A or chr == 0x0B or chr == 0x0C or chr == 0x0D or chr == 0x85 or chr == 0x2028 or chr == 0x2029; + end; + return false; +end; + +local function find_alternation(token, i, count) + while true do + local v = token[i]; + local is_table = type(v) == "table"; + if v == alternation then + return i, count; + elseif is_table and v[1] == 0x28 then + if count then + count += v.count; + end; + i = v[3]; + elseif is_table and v[1] == "quantifier" and type(v[5]) == "table" and v[5][1] == 0x28 then + if count then + count += v[5].count; + end; + i = v[5][3]; + elseif not v or is_table and v[1] == 0x29 then + return nil, count; + elseif count then + if is_table and v[1] == "quantifier" then + count += v[3]; + else + count += 1; + end; + end; + i += 1; + end; +end; + +local function re_rawfind(token, str_arr, init, flags, verb_flags, as_bool) + local tkn_i, str_i, start_i = 0, init, init; + local states = { }; + while tkn_i do + if tkn_i == 0 then + tkn_i += 1; + local next_alt = find_alternation(token, tkn_i); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + continue; + end; + local ctkn = token[tkn_i]; + local tkn_type = type(ctkn) == "table" and ctkn[1]; + if not ctkn then + break; + elseif ctkn == "ACCEPT" then + local not_lookaround = true; + local close_i = tkn_i; + repeat + close_i += 1; + local is_table = type(token[close_i]) == "table"; + local close_i_tkn = token[close_i]; + if is_table and (close_i_tkn[1] == 0x28 or close_i_tkn[1] == "quantifier" and type(close_i_tkn[5]) == "table" and close_i_tkn[5][1] == 0x28) then + close_i = close_i_tkn[1] == "quantifier" and close_i_tkn[5][3] or close_i_tkn[3]; + elseif is_table and close_i_tkn[1] == 0x29 and (close_i_tkn[4] == 0x21 or close_i_tkn[4] == 0x3D) then + not_lookaround = false; + tkn_i = close_i; + break; + end; + until not close_i_tkn; + if not_lookaround then + break; + end; + elseif ctkn == "PRUNE" or ctkn == "SKIP" then + table.insert(states, 1, { ctkn, str_i }); + tkn_i += 1; + elseif tkn_type == 0x28 then + table.insert(states, 1, { "group", tkn_i, str_i, nil, ctkn[2], ctkn[3], ctkn[4] }); + tkn_i += 1; + local next_alt, count = find_alternation(token, tkn_i, (ctkn[4] == 0x21 or ctkn[4] == 0x3D) and ctkn[5] and 0); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + if count then + str_i -= count; + end; + elseif tkn_type == 0x29 and ctkn[4] ~= 0x21 then + if ctkn[4] == 0x21 or ctkn[4] == 0x3D then + while true do + local selected_match_start; + local selected_state = table.remove(states, 1); + if selected_state[1] == "group" and selected_state[2] == ctkn[3] then + if (ctkn[4] == 0x21 or ctkn[4] == 0x3D) and not ctkn[5] then + str_i = selected_state[3]; + end; + if selected_match_start then + table.insert(states, 1, selected_match_start); + end; + break; + elseif selected_state[1] == "matchStart" and not selected_match_start and ctkn[4] == 0x3D then + selected_match_start = selected_state; + end; + end; + elseif ctkn[4] == 0x3E then + repeat + local selected_state = table.remove(states, 1); + until not selected_state or selected_state[1] == "group" and selected_state[2] == ctkn[3]; + else + for i, v in ipairs(states) do + if v[1] == "group" and v[2] == ctkn[3] then + if v.jmp then + -- recursive match + tkn_i = v.jmp; + end; + v[4] = str_i; + if v[7] == "quantifier" and v[10] + 1 < v[9] then + if token[ctkn[3]][4] ~= "lazy" or v[10] + 1 < v[8] then + tkn_i = ctkn[3]; + end; + local ctkn1 = token[ctkn[3]]; + local new_group = { "group", v[2], str_i, nil, ctkn1[5][2], ctkn1[5][3], "quantifier", ctkn1[2], ctkn1[3], v[10] + 1, v[11], ctkn1[4] }; + table.insert(states, 1, new_group); + if v[11] then + table.insert(states, 1, { "alternation", v[11], str_i }); + end; + end; + break; + end; + end; + end; + tkn_i += 1; + elseif tkn_type == 0x4B then + table.insert(states, 1, { "matchStart", str_i }); + tkn_i += 1; + elseif tkn_type == 0x7C then + local close_i = tkn_i; + repeat + close_i += 1; + local is_table = type(token[close_i]) == "table"; + local close_i_tkn = token[close_i]; + if is_table and (close_i_tkn[1] == 0x28 or close_i_tkn[1] == "quantifier" and type(close_i_tkn[5]) == "table" and close_i_tkn[5][1] == 0x28) then + close_i = close_i_tkn[1] == "quantifier" and close_i_tkn[5][3] or close_i_tkn[3]; + end; + until is_table and close_i_tkn[1] == 0x29 or not close_i_tkn; + if token[close_i] then + for _, v in ipairs(states) do + if v[1] == "group" and v[6] == close_i then + tkn_i = v[6]; + break; + end; + end; + else + tkn_i = close_i; + end; + elseif tkn_type == "recurmatch" then + table.insert(states, 1, { "group", ctkn[3], str_i, nil, nil, token[ctkn[3]][3], nil, jmp = tkn_i }); + tkn_i = ctkn[3] + 1; + local next_alt, count = find_alternation(token, tkn_i); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + else + local match; + if ctkn == "FAIL" then + match = false; + elseif tkn_type == 0x29 then + repeat + local selected_state = table.remove(states, 1); + until selected_state[1] == "group" and selected_state[2] == ctkn[3]; + elseif tkn_type == "quantifier" then + if type(ctkn[5]) == "table" and ctkn[5][1] == 0x28 then + local next_alt = find_alternation(token, tkn_i + 1); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + table.insert(states, next_alt and 2 or 1, { "group", tkn_i, str_i, nil, ctkn[5][2], ctkn[5][3], "quantifier", ctkn[2], ctkn[3], 0, next_alt, ctkn[4] }); + if ctkn[4] == "lazy" and ctkn[2] == 0 then + tkn_i = ctkn[5][3]; + end; + match = true; + else + local start_i, end_i; + local pattern_count = 1; + local is_backref = type(ctkn[5]) == "table" and ctkn[5][1] == "backref"; + if is_backref then + pattern_count = 0; + local group_n = ctkn[5][2]; + for _, v in ipairs(states) do + if v[1] == "group" and v[5] == group_n then + start_i, end_i = v[3], v[4]; + pattern_count = end_i - start_i; + break; + end; + end; + end; + local min_max_i = str_i + ctkn[2] * pattern_count; + local mcount = 0; + while mcount < ctkn[3] do + if is_backref then + if start_i and end_i then + local org_i = str_i; + if utf8_sub(str_arr.s, start_i, end_i) ~= utf8_sub(str_arr.s, org_i, str_i + pattern_count) then + break; + end; + else + break; + end; + elseif not tkn_char_match(ctkn[5], str_arr, str_i, flags, verb_flags) then + break; + end; + str_i += pattern_count; + mcount += 1; + end; + match = mcount >= ctkn[2]; + if match and ctkn[4] ~= "possessive" then + if ctkn[4] == "lazy" then + min_max_i, str_i = str_i, min_max_i; + end; + table.insert(states, 1, { "quantifier", tkn_i, str_i, math.min(min_max_i, str_arr.n + 1), (ctkn[4] == "lazy" and 1 or -1) * pattern_count }); + end; + end; + elseif tkn_type == "backref" then + local start_i, end_i; + local group_n = ctkn[2]; + for _, v in ipairs(states) do + if v[1] == "group" and v[5] == group_n then + start_i, end_i = v[3], v[4]; + break; + end; + end; + if start_i and end_i then + local org_i = str_i; + str_i += end_i - start_i; + match = utf8_sub(str_arr.s, start_i, end_i) == utf8_sub(str_arr.s, org_i, str_i); + end; + else + local chr = str_arr[str_i]; + if tkn_type == 0x24 or tkn_type == 0x5A or tkn_type == 0x7A then + match = str_i == str_arr.n + 1 or tkn_type == 0x24 and flags.multiline and is_newline(str_arr, str_i + 1, verb_flags) or tkn_type == 0x5A and str_i == str_arr.n and is_newline(str_arr, str_i, verb_flags); + elseif tkn_type == 0x5E or tkn_type == 0x41 or tkn_type == 0x47 then + match = str_i == 1 or tkn_type == 0x5E and flags.multiline and is_newline(str_arr, str_i - 1, verb_flags) or tkn_type == 0x47 and str_i == init; + elseif tkn_type == 0x42 or tkn_type == 0x62 then + local start_m = str_i == 1 or flags.multiline and is_newline(str_arr, str_i - 1, verb_flags); + local end_m = str_i == str_arr.n + 1 or flags.multiline and is_newline(str_arr, str_i, verb_flags); + local w_m = tkn_char_match(ctkn[2], str_arr[str_i - 1], flags) and 0 or tkn_char_match(ctkn[2], chr, flags) and 1; + if w_m == 0 then + match = end_m or not tkn_char_match(ctkn[2], chr, flags); + elseif w_m then + match = start_m or not tkn_char_match(ctkn[2], str_arr[str_i - 1], flags); + end; + if tkn_type == 0x42 then + match = not match; + end; + else + match = tkn_char_match(ctkn, str_arr, str_i, flags, verb_flags); + str_i += 1; + end; + end; + if not match then + while true do + local prev_type, prev_state = states[1] and states[1][1], states[1]; + if not prev_type or prev_type == "PRUNE" or prev_type == "SKIP" then + if prev_type then + table.clear(states); + end; + if start_i > str_arr.n then + if as_bool then + return false; + end; + return nil; + end; + start_i = prev_type == "SKIP" and prev_state[2] or start_i + 1; + tkn_i, str_i = 0, start_i; + break; + elseif prev_type == "alternation" then + tkn_i, str_i = prev_state[2], prev_state[3]; + local next_alt, count = find_alternation(token, tkn_i + 1); + if next_alt then + prev_state[2] = next_alt; + else + table.remove(states, 1); + end; + if count then + str_i -= count; + end; + break; + elseif prev_type == "group" then + if prev_state[7] == "quantifier" then + if prev_state[12] == "greedy" and prev_state[10] >= prev_state[8] + or prev_state[12] == "lazy" and prev_state[10] < prev_state[9] and not prev_state[13] then + tkn_i, str_i = prev_state[12] == "greedy" and prev_state[6] or prev_state[2], prev_state[3]; + if prev_state[12] == "greedy" then + table.remove(states, 1); + break; + elseif prev_state[10] >= prev_state[8] then + prev_state[13] = true; + break; + end; + end; + elseif prev_state[7] == 0x21 then + table.remove(states, 1); + tkn_i, str_i = prev_state[6], prev_state[3]; + break; + end; + elseif prev_type == "quantifier" then + if math.sign(prev_state[4] - prev_state[3]) == math.sign(prev_state[5]) then + prev_state[3] += prev_state[5]; + tkn_i, str_i = prev_state[2], prev_state[3]; + break; + end; + end; + -- keep match out state and recursive state, can be safely removed + -- prevents infinite loop + table.remove(states, 1); + end; + end; + tkn_i += 1; + end; + end; + if as_bool then + return true; + end; + local match_start_ran = false; + local span = table.create(token.group_n); + span[0], span.n = { start_i, str_i }, token.group_n; + for _, v in ipairs(states) do + if v[1] == "matchStart" and not match_start_ran then + span[0][1], match_start_ran = v[2], true; + elseif v[1] == "group" and v[5] and not span[v[5]] then + span[v[5]] = { v[3], v[4] }; + end; + end; + return span; +end; + +--[[ Methods ]]-- +re_m.test = check_re('RegEx', 'test', function(self, str, init) + return re_rawfind(self.token, to_str_arr(str, init), 1, self.flags, self.verb_flags, true); +end); + +re_m.match = check_re('RegEx', 'match', function(self, str, init, source) + local span = re_rawfind(self.token, to_str_arr(str, init), 1, self.flags, self.verb_flags, false); + if not span then + return nil; + end; + return new_match(span, self.group_id, source, str); +end); + +re_m.matchall = check_re('RegEx', 'matchall', function(self, str, init, source) + str = to_str_arr(str, init); + local i = 1; + return function() + local span = i <= str.n + 1 and re_rawfind(self.token, str, i, self.flags, self.verb_flags, false); + if not span then + return nil; + end; + i = span[0][2] + (span[0][1] >= span[0][2] and 1 or 0); + return new_match(span, self.group_id, source, str.s); + end; +end); + +local function insert_tokenized_sub(repl_r, str, span, tkn) + for _, v in ipairs(tkn) do + if type(v) == "table" then + if v[1] == "condition" then + if span[v[2]] then + if v[3] then + insert_tokenized_sub(repl_r, str, span, v[3]); + else + table.move(str, span[v[2]][1], span[v[2]][2] - 1, #repl_r + 1, repl_r); + end; + elseif v[4] then + insert_tokenized_sub(repl_r, str, span, v[4]); + end; + else + table.move(v, 1, #v, #repl_r + 1, repl_r); + end; + elseif span[v] then + table.move(str, span[v][1], span[v][2] - 1, #repl_r + 1, repl_r); + end; + end; + repl_r.n = #repl_r; + return repl_r; +end; + +re_m.sub = check_re('RegEx', 'sub', function(self, repl, str, n, repl_flag_str, source) + if repl_flag_str ~= nil and type(repl_flag_str) ~= "number" and type(repl_flag_str) ~= "string" then + error(string.format("invalid argument #5 to 'sub' (string expected, got %s)", typeof(repl_flag_str)), 3); + end + local repl_flags = { + l = false, o = false, u = false, + }; + for f in string.gmatch(repl_flag_str or '', utf8.charpattern) do + if repl_flags[f] ~= false then + error("invalid regular expression substitution flag " .. f, 3); + end; + repl_flags[f] = true; + end; + local repl_type = type(repl); + if repl_type == "number" then + repl ..= ''; + elseif repl_type ~= "string" and repl_type ~= "function" and (not repl_flags.o or repl_type ~= "table") then + error(string.format("invalid argument #2 to 'sub' (string/function%s expected, got %s)", repl_flags.o and "/table" or '', typeof(repl)), 3); + end; + if tonumber(n) then + n = tonumber(n); + if n <= -1 or n ~= n then + n = math.huge; + end; + elseif n ~= nil then + error(string.format("invalid argument #4 to 'sub' (number expected, got %s)", typeof(n)), 3); + else + n = math.huge; + end; + if n < 1 then + return str, 0; + end; + local min_repl_n = 0; + if repl_type == "string" then + repl = to_str_arr(repl); + if not repl_flags.l then + local i1 = 0; + local repl_r = table.create(3); + local group_n = self.token.group_n; + local conditional_c = { }; + while i1 < repl.n do + local i2 = i1; + repeat + i2 += 1; + until not repl[i2] or repl[i2] == 0x24 or repl[i2] == 0x5C or (repl[i2] == 0x3A or repl[i2] == 0x7D) and conditional_c[1]; + min_repl_n += i2 - i1 - 1; + if i2 - i1 > 1 then + table.insert(repl_r, table.move(repl, i1 + 1, i2 - 1, 1, table.create(i2 - i1 - 1))); + end; + if repl[i2] == 0x3A then + local current_conditional_c = conditional_c[1]; + if current_conditional_c[2] then + error("malformed substitution pattern", 3); + end; + current_conditional_c[2] = table.move(repl_r, current_conditional_c[3], #repl_r, 1, table.create(#repl_r + 1 - current_conditional_c[3])); + for i3 = #repl_r, current_conditional_c[3], -1 do + repl_r[i3] = nil; + end; + elseif repl[i2] == 0x7D then + local current_conditional_c = table.remove(conditional_c, 1); + local second_c = table.move(repl_r, current_conditional_c[3], #repl_r, 1, table.create(#repl_r + 1 - current_conditional_c[3])); + for i3 = #repl_r, current_conditional_c[3], -1 do + repl_r[i3] = nil; + end; + table.insert(repl_r, { "condition", current_conditional_c[1], current_conditional_c[2] ~= true and (current_conditional_c[2] or second_c), current_conditional_c[2] and second_c }); + elseif repl[i2] then + i2 += 1; + local subst_c = repl[i2]; + if not subst_c then + if repl[i2 - 1] == 0x5C then + error("replacement string must not end with a trailing backslash", 3); + end; + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, repl[i2 - 1]); + else + table.insert(repl_r, { repl[i2 - 1] }); + end; + elseif subst_c == 0x5C and repl[i2 - 1] == 0x24 then + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, 0x24); + else + table.insert(repl_r, { 0x24 }); + end; + i2 -= 1; + min_repl_n += 1; + elseif subst_c == 0x30 then + table.insert(repl_r, 0); + elseif subst_c > 0x30 and subst_c <= 0x39 then + local start_i2 = i2; + local group_i = subst_c - 0x30; + while repl[i2 + 1] and repl[i2 + 1] >= 0x30 and repl[i2 + 1] <= 0x39 do + group_i ..= repl[i2 + 1] - 0x30; + i2 += 1; + end; + group_i = tonumber(group_i); + if not repl_flags.u and group_i > group_n then + error("reference to non-existent subpattern", 3); + end; + table.insert(repl_r, group_i); + elseif subst_c == 0x7B and repl[i2 - 1] == 0x24 then + i2 += 1; + local start_i2 = i2; + while repl[i2] and + (repl[i2] >= 0x30 and repl[i2] <= 0x39 + or repl[i2] >= 0x41 and repl[i2] <= 0x5A + or repl[i2] >= 0x61 and repl[i2] <= 0x7A + or repl[i2] == 0x5F) do + i2 += 1; + end; + if (repl[i2] == 0x7D or repl[i2] == 0x3A and (repl[i2 + 1] == 0x2B or repl[i2 + 1] == 0x2D)) and i2 ~= start_i2 then + local group_k = utf8_sub(repl.s, start_i2, i2); + if repl[start_i2] >= 0x30 and repl[start_i2] <= 0x39 then + group_k = tonumber(group_k); + if not repl_flags.u and group_k > group_n then + error("reference to non-existent subpattern", 3); + end; + else + group_k = self.group_id[group_k]; + if not repl_flags.u and (not group_k or group_k > group_n) then + error("reference to non-existent subpattern", 3); + end; + end; + if repl[i2] == 0x3A then + i2 += 1; + table.insert(conditional_c, { group_k, repl[i2] == 0x2D, #repl_r + 1 }); + else + table.insert(repl_r, group_k); + end; + else + error("malformed substitution pattern", 3); + end; + else + local c_escape_char; + if repl[i2 - 1] == 0x24 then + if subst_c ~= 0x24 then + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, 0x24); + else + table.insert(repl_r, { 0x24 }); + end; + end; + else + c_escape_char = escape_chars[repl[i2]]; + if type(c_escape_char) ~= "number" then + c_escape_char = nil; + end; + end; + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, c_escape_char or repl[i2]); + else + table.insert(repl_r, { c_escape_char or repl[i2] }); + end; + min_repl_n += 1; + end; + end; + i1 = i2; + end; + if conditional_c[1] then + error("malformed substitution pattern", 3); + end; + if not repl_r[2] and type(repl_r[1]) == "table" and repl_r[1][1] ~= "condition" then + repl, repl.n = repl_r[1], #repl_r[1]; + else + repl, repl_type = repl_r, "subst_string"; + end; + end; + end; + str = to_str_arr(str); + local incr, i0, count = 0, 1, 0; + while i0 <= str.n + incr + 1 do + local span = re_rawfind(self.token, str, i0, self.flags, self.verb_flags, false); + if not span then + break; + end; + local repl_r; + if repl_type == "string" then + repl_r = repl; + elseif repl_type == "subst_string" then + repl_r = insert_tokenized_sub(table.create(min_repl_n), str, span, repl); + else + local re_match; + local repl_c; + if repl_type == "table" then + re_match = utf8_sub(str.s, span[0][1], span[0][2]); + repl_c = repl[re_match]; + else + re_match = new_match(span, self.group_id, source, str.s); + repl_c = repl(re_match); + end; + if repl_c == re_match or repl_flags.o and not repl_c then + local repl_n = span[0][2] - span[0][1]; + repl_r = table.move(str, span[0][1], span[0][2] - 1, 1, table.create(repl_n)); + repl_r.n = repl_n; + elseif type(repl_c) == "string" then + repl_r = to_str_arr(repl_c); + elseif type(repl_c) == "number" then + repl_r = to_str_arr(repl_c .. ''); + elseif repl_flags.o then + error(string.format("invalid replacement value (a %s)", type(repl_c)), 3); + else + repl_r = { n = 0 }; + end; + end; + local match_len = span[0][2] - span[0][1]; + local repl_len = math.min(repl_r.n, match_len); + for i1 = 0, repl_len - 1 do + str[span[0][1] + i1] = repl_r[i1 + 1]; + end; + local i1 = span[0][1] + repl_len; + i0 = span[0][2]; + if match_len > repl_r.n then + for i2 = 1, match_len - repl_r.n do + table.remove(str, i1); + incr -= 1; + i0 -= 1; + end; + elseif repl_r.n > match_len then + for i2 = 1, repl_r.n - match_len do + table.insert(str, i1 + i2 - 1, repl_r[repl_len + i2]); + incr += 1; + i0 += 1; + end; + end; + if match_len <= 0 then + i0 += 1; + end; + count += 1; + if n < count + 1 then + break; + end; + end; + return from_str_arr(str), count; +end); + +re_m.split = check_re('RegEx', 'split', function(self, str, n) + if tonumber(n) then + n = tonumber(n); + if n <= -1 or n ~= n then + n = math.huge; + end; + elseif n ~= nil then + error(string.format("invalid argument #3 to 'split' (number expected, got %s)", typeof(n)), 3); + else + n = math.huge; + end; + str = to_str_arr(str); + local i, count = 1, 0; + local ret = { }; + local prev_empty = 0; + while i <= str.n + 1 do + count += 1; + local span = n >= count and re_rawfind(self.token, str, i, self.flags, self.verb_flags, false); + if not span then + break; + end; + table.insert(ret, utf8_sub(str.s, i - prev_empty, span[0][1])); + prev_empty = span[0][1] >= span[0][2] and 1 or 0; + i = span[0][2] + prev_empty; + end; + table.insert(ret, string.sub(str.s, utf8.offset(str.s, i - prev_empty))); + return ret; +end); + +-- +local function re_index(self, index) + return re_m[index] or proxy[self].flags[index]; +end; + +local function re_tostr(self) + return proxy[self].pattern_repr .. proxy[self].flag_repr; +end; +-- + +local other_valid_group_char = { + -- non-capturing group + [0x3A] = true, + -- lookarounds + [0x21] = true, [0x3D] = true, + -- atomic + [0x3E] = true, + -- branch reset + [0x7C] = true, +}; + +local function tokenize_ptn(codes, flags) + if flags.unicode and not options.unicodeData then + return "options.unicodeData cannot be turned off while having unicode flag"; + end; + local i, len = 1, codes.n; + local group_n = 0; + local outln, group_id, verb_flags = { }, { }, { + newline = 1, newline_seq = 1, not_empty = 0, + }; + while i <= len do + local c = codes[i]; + if c == 0x28 then + -- Match + local ret; + if codes[i + 1] == 0x2A then + i += 2; + local start_i = i; + while codes[i] + and (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F or codes[i] == 0x3A) do + i += 1; + end; + if codes[i] ~= 0x29 and codes[i - 1] ~= 0x3A then + -- fallback as normal and ( can't be repeated + return "quantifier doesn't follow a repeatable pattern"; + end; + local selected_verb = utf8_sub(codes.s, start_i, i); + if selected_verb == "positive_lookahead:" or selected_verb == "negative_lookhead:" + or selected_verb == "positive_lookbehind:" or selected_verb == "negative_lookbehind:" + or selected_verb:find("^[pn]l[ab]:$") then + ret = { 0x28, nil, nil, selected_verb:find('^n') and 0x21 or 0x3D, selected_verb:find('b', 3, true) and 1 }; + elseif selected_verb == "atomic:" then + ret = { 0x28, nil, nil, 0x3E, nil }; + elseif selected_verb == "ACCEPT" or selected_verb == "FAIL" or selected_verb == 'F' or selected_verb == "PRUNE" or selected_verb == "SKIP" then + ret = selected_verb == 'F' and "FAIL" or selected_verb; + else + if line_verbs[selected_verb] then + verb_flags.newline = selected_verb; + elseif selected_verb == "BSR_ANYCRLF" or selected_verb == "BSR_UNICODE" then + verb_flags.newline_seq = selected_verb == "BSR_UNICODE" and 1 or 0; + elseif selected_verb == "NOTEMPTY" or selected_verb == "NOTEMPTY_ATSTART" then + verb_flags.not_empty = selected_verb == "NOTEMPTY" and 1 or 2; + else + return "unknown or malformed verb"; + end; + if outln[1] then + return "this verb must be placed at the beginning of the regex"; + end; + end; + elseif codes[i + 1] == 0x3F then + -- ? syntax + i += 2; + if codes[i] == 0x23 then + -- comments + i = table.find(codes, 0x29, i); + if not i then + return "unterminated parenthetical"; + end; + i += 1; + continue; + elseif not codes[i] then + return "unterminated parenthetical"; + end; + ret = { 0x28, nil, nil, codes[i], nil }; + if codes[i] == 0x30 and codes[i + 1] == 0x29 then + -- recursive match entire pattern + ret[1], ret[2], ret[3], ret[5] = "recurmatch", 0, 0, nil; + elseif codes[i] > 0x30 and codes[i] <= 0x39 then + -- recursive match + local org_i = i; + i += 1; + while codes[i] >= 0x30 and codes[i] <= 0x30 do + i += 1; + end; + if codes[i] ~= 0x29 then + return "invalid group structure"; + end; + ret[1], ret[2], ret[4] = "recurmatch", tonumber(utf8_sub(codes.s, org_i, i)), nil; + elseif codes[i] == 0x3C and codes[i + 1] == 0x21 or codes[i + 1] == 0x3D then + -- lookbehinds + i += 1; + ret[4], ret[5] = codes[i], 1; + elseif codes[i] == 0x7C then + -- branch reset + ret[5] = group_n; + elseif codes[i] == 0x50 or codes[i] == 0x3C or codes[i] == 0x27 then + if codes[i] == 0x50 then + i += 1; + end; + if codes[i] == 0x3D then + -- backref + local start_i = i + 1; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if not codes[i] then + return "unterminated parenthetical"; + elseif codes[i] ~= 0x29 or i == start_i then + return "invalid group structure"; + end; + ret = { "backref", utf8_sub(codes.s, start_i, i) }; + elseif codes[i] == 0x3C or codes[i - 1] ~= 0x50 and codes[i] == 0x27 then + -- named capture + local delimiter = codes[i] == 0x27 and 0x27 or 0x3E; + local start_i = i + 1; + i += 1; + if codes[i] == 0x29 then + return "missing character in subpattern"; + elseif codes[i] >= 0x30 and codes[i] <= 0x39 then + return "subpattern name must not begin with a digit"; + elseif not (codes[i] >= 0x41 and codes[i] <= 0x5A or codes[i] >= 0x61 and codes[i] <= 0x7A or codes[i] == 0x5F) then + return "invalid character in subpattern"; + end; + i += 1; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if not codes[i] then + return "unterminated parenthetical"; + elseif codes[i] ~= delimiter then + return "invalid character in subpattern"; + end; + local name = utf8_sub(codes.s, start_i, i); + group_n += 1; + if (group_id[name] or group_n) ~= group_n then + return "subpattern name already exists"; + end; + for name1, group_n1 in pairs(group_id) do + if name ~= name1 and group_n == group_n1 then + return "different names for subpatterns of the same number aren't permitted"; + end; + end; + group_id[name] = group_n; + ret[2], ret[4] = group_n, nil; + else + return "invalid group structure"; + end; + elseif not other_valid_group_char[codes[i]] then + return "invalid group structure"; + end; + else + group_n += 1; + ret = { 0x28, group_n, nil, nil }; + end; + if ret then + table.insert(outln, ret); + end; + elseif c == 0x29 then + -- Close parenthesis + local i1 = #outln + 1; + local lookbehind_c = -1; + local current_lookbehind_c = 0; + local max_c, group_c = 0, 0; + repeat + i1 -= 1; + local v, is_table = outln[i1], type(outln[i1]) == "table"; + if is_table and v[1] == 0x28 then + group_c += 1; + if current_lookbehind_c and v.count then + current_lookbehind_c += v.count; + end; + if not v[3] then + if v[4] == 0x7C then + group_n = v[5] + math.max(max_c, group_c); + end; + if current_lookbehind_c ~= lookbehind_c and lookbehind_c ~= -1 then + lookbehind_c = nil; + else + lookbehind_c = current_lookbehind_c; + end; + break; + end; + elseif v == alternation then + if current_lookbehind_c ~= lookbehind_c and lookbehind_c ~= -1 then + lookbehind_c, current_lookbehind_c = nil, nil; + else + lookbehind_c, current_lookbehind_c = current_lookbehind_c, 0; + end; + max_c, group_c = math.max(max_c, group_c), 0; + elseif current_lookbehind_c then + if is_table and v[1] == "quantifier" then + if v[2] == v[3] then + current_lookbehind_c += v[2]; + else + current_lookbehind_c = nil; + end; + else + current_lookbehind_c += 1; + end; + end; + until i1 < 1; + if i1 < 1 then + return "unmatched ) in regular expression"; + end; + local v = outln[i1]; + local outln_len_p_1 = #outln + 1; + local ret = { 0x29, v[2], i1, v[4], v[5], count = lookbehind_c }; + if (v[4] == 0x21 or v[4] == 0x3D) and v[5] and not lookbehind_c then + return "lookbehind assertion is not fixed width"; + end; + v[3] = outln_len_p_1; + table.insert(outln, ret); + elseif c == 0x2E then + table.insert(outln, dot); + elseif c == 0x5B then + -- Character set + local negate, char_class = false, nil; + i += 1; + local start_i = i; + if codes[i] == 0x5E then + negate = true; + i += 1; + elseif codes[i] == 0x2E or codes[i] == 0x3A or codes[i] == 0x3D then + -- POSIX character classes + char_class = codes[i]; + end; + local ret; + if codes[i] == 0x5B or codes[i] == 0x5C then + ret = { }; + else + ret = { codes[i] }; + i += 1; + end; + while codes[i] ~= 0x5D do + if not codes[i] then + return "unterminated character class"; + elseif codes[i] == 0x2D and ret[1] and type(ret[1]) == "number" then + if codes[i + 1] == 0x5D then + table.insert(ret, 1, 0x2D); + else + i += 1; + local ret_c = codes[i]; + if ret_c == 0x5B then + if codes[i + 1] == 0x2E or codes[i + 1] == 0x3A or codes[i + 1] == 0x3D then + -- Check for POSIX character class, name does not matter + local i1 = i + 2; + repeat + i1 = table.find(codes, 0x5D, i1); + until not i1 or codes[i1 - 1] ~= 0x5C; + if not i1 then + return "unterminated character class"; + elseif codes[i1 - 1] == codes[i + 1] and i1 - 1 ~= i + 1 then + return "invalid range in character class"; + end; + end; + if ret[1] > 0x5B then + return "invalid range in character class"; + end; + elseif ret_c == 0x5C then + i += 1; + if codes[i] == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + ret_c = radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0; + elseif codes[i] >= 0x30 and codes[i] <= 0x37 then + local radix0, radix1, radix2 = codes[i] - 0x30, nil, nil; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + ret_c = radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0; + else + ret_c = escape_chars[codes[i]] or codes[i]; + if type(ret_c) ~= "number" then + return "invalid range in character class"; + end; + end; + elseif ret[1] > ret_c then + return "invalid range in character class"; + end; + ret[1] = { "range", ret[1], ret_c }; + end; + elseif codes[i] == 0x5B then + if codes[i + 1] == 0x2E or codes[i + 1] == 0x3A or codes[i + 1] == 0x3D then + local i1 = i + 2; + repeat + i1 = table.find(codes, 0x5D, i1); + until not i1 or codes[i1 - 1] ~= 0x5C; + if not i1 then + return "unterminated character class"; + elseif codes[i1 - 1] ~= codes[i + 1] or i1 - 1 == i + 1 then + table.insert(ret, 1, 0x5B); + elseif codes[i1 - 1] == 0x2E or codes[i1 - 1] == 0x3D then + return "POSIX collating elements aren't supported"; + elseif codes[i1 - 1] == 0x3A then + -- I have no plans to support escape codes (\) in character class names + local negate = codes[i + 3] == 0x5E; + local class_name = utf8_sub(codes.s, i + (negate and 3 or 2), i1 - 1); + -- If not valid then throw an error + if not posix_class_names[class_name] then + return "unknown POSIX class name"; + end; + table.insert(ret, 1, { "class", class_name, negate }); + i = i1; + end; + else + table.insert(ret, 1, 0x5B); + end; + elseif codes[i] == 0x5C then + i += 1; + if codes[i] == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] == 0x7B then + i += 1; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed hexadecimal character"; + elseif i - org_i > 4 then + return "character offset too large"; + end; + table.insert(ret, 1, tonumber(utf8_sub(codes.s, org_i, i), 16)); + else + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(ret, 1, radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0); + end; + elseif codes[i] >= 0x30 and codes[i] <= 0x37 then + local radix0, radix1, radix2 = codes[i] - 0x30, nil, nil; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(ret, 1, radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0); + elseif codes[i] == 0x45 then + -- intentionally left blank, \E that's not preceded \Q is ignored + elseif codes[i] == 0x51 then + local start_i = i + 1; + repeat + i = table.find(codes, 0x5C, i + 1); + until not i or codes[i + 1] == 0x45; + table.move(codes, start_i, i and i - 1 or #codes, #outln + 1, outln); + if not i then + break; + end; + i += 1; + elseif codes[i] == 0x4E then + if codes[i + 1] == 0x7B and codes[i + 2] == 0x55 and codes[i + 3] == 0x2B and flags.unicode then + i += 4; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == start_i then + return "malformed Unicode code point"; + end; + local code_point = tonumber(utf8_sub(codes.s, start_i, i)); + table.insert(ret, 1, code_point); + else + return "invalid escape sequence"; + end; + elseif codes[i] == 0x50 or codes[i] == 0x70 then + if not options.unicodeData then + return "options.unicodeData cannot be turned off when using \\p"; + end; + i += 1; + if codes[i] ~= 0x7B then + local c_name = utf8.char(codes[i] or 0); + if not valid_categories[c_name] then + return "unknown or malformed script name"; + end; + table.insert(ret, 1, { "category", false, c_name }); + else + local negate = codes[i] == 0x50; + i += 1; + if codes[i] == 0x5E then + i += 1; + negate = not negate; + end; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if codes[i] ~= 0x7D then + return "unknown or malformed script name"; + end; + local c_name = utf8_sub(codes.s, start_i, i); + local script_set = chr_scripts[c_name]; + if script_set then + table.insert(ret, 1, { "charset", negate, script_set }); + elseif not valid_categories[c_name] then + return "unknown or malformed script name"; + else + table.insert(ret, 1, { "category", negate, c_name }); + end; + end; + elseif codes[i] == 0x6F then + i += 1; + if codes[i] ~= 0x7B then + return "malformed octal code"; + end; + i += 1; + local org_i = i; + while codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed octal code"; + end; + local ret_chr = tonumber(utf8_sub(codes.s, org_i, i), 8); + if ret_chr > 0xFFFF then + return "character offset too large"; + end; + table.insert(ret, 1, ret_chr); + else + local esc_char = escape_chars[codes[i]]; + table.insert(ret, 1, type(esc_char) == "string" and { "class", esc_char, false } or esc_char or codes[i]); + end; + elseif flags.ignoreCase and codes[i] >= 0x61 and codes[i] <= 0x7A then + table.insert(ret, 1, codes[i] - 0x20); + else + table.insert(ret, 1, codes[i]); + end; + i += 1; + end; + if codes[i - 1] == char_class and i - 1 ~= start_i then + return char_class == 0x3A and "POSIX named classes are only support within a character set" or "POSIX collating elements aren't supported"; + end; + if not ret[2] and not negate then + table.insert(outln, ret[1]); + else + table.insert(outln, { "charset", negate, ret }); + end; + elseif c == 0x5C then + -- Escape char + i += 1; + local escape_c = codes[i]; + if not escape_c then + return "pattern may not end with a trailing backslash"; + elseif escape_c >= 0x30 and escape_c <= 0x39 then + local org_i = i; + while codes[i + 1] and codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39 do + i += 1; + end; + local escape_d = tonumber(utf8_sub(codes.s, org_i, i + 1)); + if escape_d > group_n and i ~= org_i then + i = org_i; + local radix0, radix1, radix2; + if codes[i] <= 0x37 then + radix0 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + end; + table.insert(outln, radix0 and (radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0) or codes[org_i]); + else + table.insert(outln, { "backref", escape_d }); + end; + elseif escape_c == 0x45 then + -- intentionally left blank, \E that's not preceded \Q is ignored + elseif escape_c == 0x51 then + local start_i = i + 1; + repeat + i = table.find(codes, 0x5C, i + 1); + until not i or codes[i + 1] == 0x45; + table.move(codes, start_i, i and i - 1 or #codes, #outln + 1, outln); + if not i then + break; + end; + i += 1; + elseif escape_c == 0x4E then + if codes[i + 1] == 0x7B and codes[i + 2] == 0x55 and codes[i + 3] == 0x2B and flags.unicode then + i += 4; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == start_i then + return "malformed Unicode code point"; + end; + local code_point = tonumber(utf8_sub(codes.s, start_i, i)); + table.insert(outln, code_point); + else + table.insert(outln, escape_chars[0x4E]); + end; + elseif escape_c == 0x50 or escape_c == 0x70 then + if not options.unicodeData then + return "options.unicodeData cannot be turned off when using \\p"; + end; + i += 1; + if codes[i] ~= 0x7B then + local c_name = utf8.char(codes[i] or 0); + if not valid_categories[c_name] then + return "unknown or malformed script name"; + end; + table.insert(outln, { "category", false, c_name }); + else + local negate = escape_c == 0x50; + i += 1; + if codes[i] == 0x5E then + i += 1; + negate = not negate; + end; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if codes[i] ~= 0x7D then + return "unknown or malformed script name"; + end; + local c_name = utf8_sub(codes.s, start_i, i); + local script_set = chr_scripts[c_name]; + if script_set then + table.insert(outln, { "charset", negate, script_set }); + elseif not valid_categories[c_name] then + return "unknown or malformed script name"; + else + table.insert(outln, { "category", negate, c_name }); + end; + end; + elseif escape_c == 0x67 and (codes[i + 1] == 0x7B or codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39) then + local is_grouped = false; + i += 1; + if codes[i] == 0x7B then + i += 1; + is_grouped = true; + elseif codes[i] < 0x30 or codes[i] > 0x39 then + return "malformed reference code"; + end; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if is_grouped and codes[i] ~= 0x7D then + return "malformed reference code"; + end; + local ref_name = tonumber(utf8_sub(codes.s, org_i, i + (is_grouped and 0 or 1))); + table.insert(outln, { "backref", ref_name }); + if not is_grouped then + i -= 1; + end; + elseif escape_c == 0x6F then + i += 1; + if codes[i + 1] ~= 0x7B then + return "malformed octal code"; + end + i += 1; + local org_i = i; + while codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed octal code"; + end; + local ret_chr = tonumber(utf8_sub(codes.s, org_i, i), 8); + if ret_chr > 0xFFFF then + return "character offset too large"; + end; + table.insert(outln, ret_chr); + elseif escape_c == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] == 0x7B then + i += 1; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed hexadecimal code"; + elseif i - org_i > 4 then + return "character offset too large"; + end; + table.insert(outln, tonumber(utf8_sub(codes.s, org_i, i), 16)); + else + if codes[i] and (codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66) then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and (codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66) then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(outln, radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0); + end; + else + local esc_char = b_escape_chars[escape_c] or escape_chars[escape_c]; + table.insert(outln, esc_char or escape_c); + end; + elseif c == 0x2A or c == 0x2B or c == 0x3F or c == 0x7B then + -- Quantifier + local start_q, end_q; + if c == 0x7B then + local org_i = i + 1; + local start_i; + while codes[i + 1] and (codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39 or codes[i + 1] == 0x2C and not start_i and i + 1 ~= org_i) do + i += 1; + if codes[i] == 0x2C then + start_i = i; + end; + end; + if codes[i + 1] == 0x7D then + i += 1; + if not start_i then + start_q = tonumber(utf8_sub(codes.s, org_i, i)); + end_q = start_q; + else + start_q, end_q = tonumber(utf8_sub(codes.s, org_i, start_i)), start_i + 1 == i and math.huge or tonumber(utf8_sub(codes.s, start_i + 1, i)); + if end_q < start_q then + return "numbers out of order in {} quantifier"; + end; + end; + else + table.move(codes, org_i - 1, i, #outln + 1, outln); + end; + else + start_q, end_q = c == 0x2B and 1 or 0, c == 0x3F and 1 or math.huge; + end; + if start_q then + local quantifier_type = flags.ungreedy and "lazy" or "greedy"; + if codes[i + 1] == 0x2B or codes[i + 1] == 0x3F then + i += 1; + quantifier_type = codes[i] == 0x2B and "possessive" or flags.ungreedy and "greedy" or "lazy"; + end; + local outln_len = #outln; + local last_outln_value = outln[outln_len]; + if not last_outln_value or type(last_outln_value) == "table" and (last_outln_value[1] == "quantifier" or last_outln_value[1] == 0x28 or b_escape_chars[last_outln_value[1]]) + or last_outln_value == alternation or type(last_outln_value) == "string" then + return "quantifier doesn't follow a repeatable pattern"; + end; + if end_q == 0 then + table.remove(outln); + elseif start_q ~= 1 or end_q ~= 1 then + if type(last_outln_value) == "table" and last_outln_value[1] == 0x29 then + outln_len = last_outln_value[3]; + end; + outln[outln_len] = { "quantifier", start_q, end_q, quantifier_type, outln[outln_len] }; + end; + end; + elseif c == 0x7C then + -- Alternation + table.insert(outln, alternation); + local i1 = #outln; + repeat + i1 -= 1; + local v1, is_table = outln[i1], type(outln[i1]) == "table"; + if is_table and v1[1] == 0x29 then + i1 = outln[i1][3]; + elseif is_table and v1[1] == 0x28 then + if v1[4] == 0x7C then + group_n = v1[5]; + end; + break; + end; + until not v1; + elseif c == 0x24 or c == 0x5E then + table.insert(outln, c == 0x5E and beginning_str or end_str); + elseif flags.ignoreCase and c >= 0x61 and c <= 0x7A then + table.insert(outln, c - 0x20); + elseif flags.extended and (c >= 0x09 and c <= 0x0D or c == 0x20 or c == 0x23) then + if c == 0x23 then + repeat + i += 1; + until not codes[i] or codes[i] == 0x0A or codes[i] == 0x0D; + end; + else + table.insert(outln, c); + end; + i += 1; + end; + local max_group_n = 0; + for i, v in ipairs(outln) do + if type(v) == "table" and (v[1] == 0x28 or v[1] == "quantifier" and type(v[5]) == "table" and v[5][1] == 0x28) then + if v[1] == "quantifier" then + v = v[5]; + end; + if not v[3] then + return "unterminated parenthetical"; + elseif v[2] then + max_group_n = math.max(max_group_n, v[2]); + end; + elseif type(v) == "table" and (v[1] == "backref" or v[1] == "recurmatch") then + if not group_id[v[2]] and (type(v[2]) ~= "number" or v[2] > group_n) then + return "reference to a non-existent or invalid subpattern"; + elseif v[1] == "recurmatch" and v[2] ~= 0 then + for i1, v1 in ipairs(outln) do + if type(v1) == "table" and v1[1] == 0x28 and v1[2] == v[2] then + v[3] = i1; + break; + end; + end; + elseif type(v[2]) == "string" then + v[2] = group_id[v[2]]; + end; + end; + end; + outln.group_n = max_group_n; + return outln, group_id, verb_flags; +end; + +if not tonumber(options.cacheSize) then + error(string.format("expected number for options.cacheSize, got %s", typeof(options.cacheSize)), 2); +end; +local cacheSize = math.floor(options.cacheSize or 0) ~= 0 and tonumber(options.cacheSize); +local cache_pattern, cache_pattern_names; +if not cacheSize then +elseif cacheSize < 0 or cacheSize ~= cacheSize then + error("cache size cannot be a negative number or a NaN", 2); +elseif cacheSize == math.huge then + cache_pattern, cache_pattern_names = { nil }, { nil }; +elseif cacheSize >= 2 ^ 32 then + error("cache size too large", 2); +else + cache_pattern, cache_pattern_names = table.create(options.cacheSize), table.create(options.cacheSize); +end; +if cacheSize then + function re.pruge() + table.clear(cache_pattern_names); + table.clear(cache_pattern); + end; +end; + +local function new_re(str_arr, flags, flag_repr, pattern_repr) + local tokenized_ptn, group_id, verb_flags; + local cache_format = cacheSize and string.format("%s|%s", str_arr.s, flag_repr); + local cached_token = cacheSize and cache_pattern[table.find(cache_pattern_names, cache_format)]; + if cached_token then + tokenized_ptn, group_id, verb_flags = table.unpack(cached_token, 1, 3); + else + tokenized_ptn, group_id, verb_flags = tokenize_ptn(str_arr, flags); + if type(tokenized_ptn) == "string" then + error(tokenized_ptn, 2); + end; + if cacheSize and tokenized_ptn[1] then + table.insert(cache_pattern_names, 1, cache_format); + table.insert(cache_pattern, 1, { tokenized_ptn, group_id, verb_flags }); + if cacheSize ~= math.huge then + table.remove(cache_pattern_names, cacheSize + 1); + table.remove(cache_pattern, cacheSize + 1); + end; + end; + end; + + local object = newproxy(true); + proxy[object] = { name = "RegEx", flags = flags, flag_repr = flag_repr, pattern_repr = pattern_repr, token = tokenized_ptn, group_id = group_id, verb_flags = verb_flags }; + local object_mt = getmetatable(object); + object_mt.__index = setmetatable(flags, re_m); + object_mt.__tostring = re_tostr; + object_mt.__metatable = lockmsg; + + return object; +end; + +local function escape_fslash(pre) + return (#pre % 2 == 0 and '\\' or '') .. pre .. '.'; +end; + +local function sort_flag_chr(a, b) + return a:lower() < b:lower(); +end; + +function re.new(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local ptn, flags_str = ...; + if type(ptn) == "number" then + ptn ..= ''; + elseif type(ptn) ~= "string" then + error(string.format("invalid argument #1 (string expected, got %s)", typeof(ptn)), 2); + end; + if type(flags_str) ~= "string" and type(flags_str) ~= "number" and flags_str ~= nil then + error(string.format("invalid argument #2 (string expected, got %s)", typeof(flags_str)), 2); + end; + + local flags = { + anchored = false, caseless = false, multiline = false, dotall = false, unicode = false, ungreedy = false, extended = false, + }; + local flag_repr = { }; + for f in string.gmatch(flags_str or '', utf8.charpattern) do + if flags[flag_map[f]] ~= false then + error("invalid regular expression flag " .. f, 3); + end; + flags[flag_map[f]] = true; + table.insert(flag_repr, f); + end; + table.sort(flag_repr, sort_flag_chr); + flag_repr = table.concat(flag_repr); + return new_re(to_str_arr(ptn), flags, flag_repr, string.format("/%s/", ptn:gsub("(\\*)/", escape_fslash))); +end; + +function re.fromstring(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local ptn = ...; + if type(ptn) == "number" then + ptn ..= ''; + elseif type(ptn) ~= "string" then + error(string.format("invalid argument #1 (string expected, got %s)", typeof(ptn), 2)); + end; + local str_arr = to_str_arr(ptn); + local delimiter = str_arr[1]; + if not delimiter then + error("empty regex", 2); + elseif delimiter == 0x5C or (delimiter >= 0x30 and delimiter <= 0x39) or (delimiter >= 0x41 and delimiter <= 0x5A) or (delimiter >= 0x61 and delimiter <= 0x7A) then + error("delimiter must not be alphanumeric or a backslash", 2); + end; + + local i0 = 1; + repeat + i0 = table.find(str_arr, delimiter, i0 + 1); + if not i0 then + error(string.format("no ending delimiter ('%s') found", utf8.char(delimiter)), 2); + end; + local escape_count = 1; + while str_arr[i0 - escape_count] == 0x5C do + escape_count += 1; + end; + until escape_count % 2 == 1; + + local flags = { + anchored = false, caseless = false, multiline = false, dotall = false, unicode = false, ungreedy = false, extended = false, + }; + local flag_repr = { }; + while str_arr.n > i0 do + local f = utf8.char(table.remove(str_arr)); + str_arr.n -= 1; + if flags[flag_map[f]] ~= false then + error("invalid regular expression flag " .. f, 3); + end; + flags[flag_map[f]] = true; + table.insert(flag_repr, f); + end; + table.sort(flag_repr, sort_flag_chr); + flag_repr = table.concat(flag_repr); + table.remove(str_arr, 1); + table.remove(str_arr); + str_arr.n -= 2; + str_arr.s = string.sub(str_arr.s, 2, 1 + str_arr.n); + return new_re(str_arr, flags, flag_repr, string.sub(ptn, 1, 2 + str_arr.n)); +end; + +local re_escape_line_chrs = { + ['\0'] = '\\x00', ['\n'] = '\\n', ['\t'] = '\\t', ['\r'] = '\\r', ['\f'] = '\\f', +}; + +function re.escape(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local str, extended, delimiter = ...; + if type(str) == "number" then + str ..= ''; + elseif type(str) ~= "string" then + error(string.format("invalid argument #1 to 'escape' (string expected, got %s)", typeof(str)), 2); + end; + if delimiter == nil then + delimiter = ''; + elseif type(delimiter) == "number" then + delimiter ..= ''; + elseif type(delimiter) ~= "string" then + error(string.format("invalid argument #3 to 'escape' (string expected, got %s)", typeof(delimiter)), 2); + end; + if utf8.len(delimiter) > 1 or delimiter:match("^[%a\\]$") then + error("delimiter have not be alphanumeric", 2); + end; + return (string.gsub(str, "[\0\f\n\r\t]", re_escape_line_chrs):gsub(string.format("[\\%s#()%%%%*+.?[%%]^{|%s]", extended and '%s' or '', (delimiter:find'^[%%%]]$' and '%' or '') .. delimiter), "\\%1")); +end; + +function re.type(...) + if select('#', ...) == 0 then + error("missing argument #1", 2); + end; + return proxy[...] and proxy[...].name; +end; + +for k, f in pairs(re_m) do + re[k] = f; +end; + +re_m = { __index = re_m }; + +lockmsg = re.fromstring([[/The\s*metatable\s*is\s*(?:locked|inaccessible)(?#Nice try :])/i]]); +getmetatable(lockmsg).__metatable = lockmsg; + +local function readonly_table() + error("Attempt to modify a readonly table", 2); +end; + +match_m = { + __index = match_m, + __metatable = lockmsg, + __newindex = readonly_table, +}; + +re.Match = setmetatable({ }, match_m); + +return setmetatable({ }, { + __index = re, + __metatable = lockmsg, + __newindex = readonly_table, +}); diff --git a/bench/tests/sunspider/crypto-aes.lua b/bench/tests/sunspider/crypto-aes.lua index 8dd0cec6..e3f54087 100644 --- a/bench/tests/sunspider/crypto-aes.lua +++ b/bench/tests/sunspider/crypto-aes.lua @@ -185,7 +185,7 @@ local function AESEncryptCtr(plaintext, password, nBits) -- for real-world applications, a higher security approach would be to hash the password e.g. with SHA-1 local nBytes = nBits/8; -- no bytes in key local pwBytes = {}; - for i = 0,nBytes-1 do pwBytes[i + 1] = bit32.band(string.byte(password, i + 1), 0xff); end + for i = 0,nBytes-1 do pwBytes[i + 1] = string.byte(password, i + 1); end local key = Cipher(pwBytes, KeyExpansion(pwBytes)); -- key is now 16/24/32 bytes long @@ -197,11 +197,11 @@ local function AESEncryptCtr(plaintext, password, nBits) -- block counter in 2nd 8 bytes local blockSize = 16; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES local counterBlock = {}; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES - local nonce = 12564231564 -- (new Date()).getTime(); -- milliseconds since 1-Jan-1970 + local nonce = os.clock() * 1000 -- (new Date()).getTime(); -- milliseconds since 1-Jan-1970 -- encode nonce in two stages to cater for JavaScript 32-bit limit on bitwise ops - for i = 0,3 do counterBlock[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff); end - for i = 0,3 do counterBlock[i + 4 + 1] = bit32.band(bit32.rshift(math.floor(nonce / 0x100000000), i*8), 0xff); end + for i = 0,3 do counterBlock[i + 1] = bit32.extract(nonce, i * 8, 8); end + for i = 0,3 do counterBlock[i + 4 + 1] = bit32.extract(math.floor(nonce / 0x100000000), i*8, 8); end -- generate key schedule - an expansion of the key into distinct Key Rounds for each round local keySchedule = KeyExpansion(key); @@ -212,8 +212,8 @@ local function AESEncryptCtr(plaintext, password, nBits) for b = 0,blockCount-1 do -- set counter (block #) in last 8 bytes of counter block (leaving nonce in 1st 8 bytes) -- again done in two stages for 32-bit ops - for c = 0,3 do counterBlock[15-c + 1] = bit32.band(bit32.rshift(b, c*8), 0xff); end - for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.rshift(math.floor(b/0x100000000), c*8) end + for c = 0,3 do counterBlock[15-c + 1] = bit32.extract(b, c*8, 8); end + for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.extract(math.floor(b/0x100000000), c*8, 8); end local cipherCntr = Cipher(counterBlock, keySchedule); -- -- encrypt counter block -- @@ -260,7 +260,7 @@ local function AESDecryptCtr(ciphertext, password, nBits) local nBytes = nBits/8; -- no bytes in key local pwBytes = {}; - for i = 0,nBytes-1 do pwBytes[i + 1] = bit32.band(string.byte(password, i + 1), 0xff); end + for i = 0,nBytes-1 do pwBytes[i + 1] = string.byte(password, i + 1); end local pwKeySchedule = KeyExpansion(pwBytes); local key = Cipher(pwBytes, pwKeySchedule); @@ -290,8 +290,8 @@ local function AESDecryptCtr(ciphertext, password, nBits) for b = 1,#ciphertext-1 do -- set counter (block #) in last 8 bytes of counter block (leaving nonce in 1st 8 bytes) - for c = 0,3 do counterBlock[15-c + 1] = bit32.band(bit32.rshift((b-1), c*8), 0xff); end - for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.band(bit32.rshift(math.floor((b-1)/0x100000000), c*8), 0xff); end + for c = 0,3 do counterBlock[15-c + 1] = bit32.extract(b-1, c*8, 8); end + for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.extract(math.floor((b-1)/0x100000000), c*8, 8); end local cipherCntr = Cipher(counterBlock, keySchedule); -- encrypt counter block diff --git a/bench/tests/tictactoe.lua b/bench/tests/tictactoe.lua index ae63f5f7..91d38f95 100644 --- a/bench/tests/tictactoe.lua +++ b/bench/tests/tictactoe.lua @@ -139,7 +139,7 @@ function test() for _, curr_qdr in pairs(negaMax.index_quadruplets) do -- iterate over all index quadruplets -- count the empty positions and positions occupied by the side whos move it is local player_plus_fields, player_minus_fields, empties = 0, 0, 0 - for _, index in pairs(curr_qdr) do -- iterate over all indices + for _, index in next, curr_qdr do -- iterate over all indices if board[index] == 0 then empties = empties + 1 elseif board[index] == 1 then @@ -225,4 +225,4 @@ function test() return t1-t0 end -bench.runCode(test, "tictactoe") \ No newline at end of file +bench.runCode(test, "tictactoe") diff --git a/docs/_pages/library.md b/docs/_pages/library.md index f419d2bf..d82b6f28 100644 --- a/docs/_pages/library.md +++ b/docs/_pages/library.md @@ -360,7 +360,7 @@ Returns `-1` if `n` is negative, `1` if `n` is positive, and `0` if `n` is zero function math.round(n: number): number ``` -Rounds `n` to the nearest integer boundary. +Rounds `n` to the nearest integer boundary. If `n` is exactly halfway between two integers, rounds `n` away from 0. ## table library @@ -683,7 +683,7 @@ Perform a bitwise `and` of all input numbers, and return `true` iff the result i function bit32.extract(n: number, f: number, w: number?): number ``` -Extracts bits at positions `[f..w]` and returns the resulting integer. `w` defaults to `f+1`, so a two-argument version of `extract` returns the bit value at position `f`. +Extracts bits of `n` at position `f` with a width of `w`, and returns the resulting integer. `w` defaults to `1`, so a two-argument version of `extract` returns the bit value at position `f`. Bits are indexed starting at 0. Errors if `f` and `f+w-1` are not between 0 and 31. ``` function bit32.lrotate(n: number, i: number): number @@ -701,7 +701,7 @@ Shifts `n` to the left by `i` bits (if `i` is negative, a right shift is perform function bit32.replace(n: number, r: number, f: number, w: number?): number ``` -Replaces bits at positions `[f..w]` of number `n` with `r` and returns the resulting integer. `w` defaults to `f+1`, so a three-argument version of `replace` changes one bit at position `f` to `r` (which should be 0 or 1) and returns the result. +Replaces bits of `n` at position `f` and width `w` with `r`, and returns the resulting integer. `w` defaults to `1`, so a three-argument version of `replace` changes one bit at position `f` to `r` (which should be 0 or 1) and returns the result. Bits are indexed starting at 0. Errors if `f` and `f+w-1` are not between 0 and 31. ``` function bit32.rrotate(n: number, i: number): number diff --git a/docs/_pages/performance.md b/docs/_pages/performance.md index 34b24b03..73c7f606 100644 --- a/docs/_pages/performance.md +++ b/docs/_pages/performance.md @@ -27,9 +27,9 @@ Of course the interpreter isn't typical C code - it uses many tricks to achieve Unlike Lua and LuaJIT, Luau uses a multi-pass compiler with a frontend that parses source into an AST and a backend that generates bytecode from it. This carries a small penalty in terms of compilation time, but results in more flexible code and, crucially, makes it easier to optimize the generated bytecode. -> Note: Compilation throughput isn't the main focus in Luau, but our compiler is reasonably fast; with all currently implemented optimizations enabled, it compiles 400K lines of Luau code in 0.5 seconds on a single core of a desktop Core i7 CPU, producing bytecode and debug information. +> Note: Compilation throughput isn't the main focus in Luau, but our compiler is reasonably fast; with all currently implemented optimizations enabled, it compiles 950K lines of Luau code in 1 second on a single core of a desktop Ryzen 5900X CPU, producing bytecode and debug information. -While bytecode optimizations are limited due to the flexibility of Luau code (e.g. `a * 1` may not be equivalent to `a` if `*` is overloaded through metatables), even in absence of type information Luau compiler can perform some optimizations such as "deep" constant folding across functions and local variables, perform upvalue optimizations for upvalues that aren't mutated, do analysis of builtin function usage, and some peephole optimizations on the resulting bytecode. In the future we plan to do bytecode-level inlining and possibly other code transformation. +While bytecode optimizations are limited due to the flexibility of Luau code (e.g. `a * 1` may not be equivalent to `a` if `*` is overloaded through metatables), even in absence of type information Luau compiler can perform some optimizations such as "deep" constant folding across functions and local variables, perform upvalue optimizations for upvalues that aren't mutated, do analysis of builtin function usage, and some peephole optimizations on the resulting bytecode. The compiler can also be instructed to use more aggressive optimizations by enabling optimization level 2 (`-O2` in CLI tools), some of which are documented further on this page. Luau compiler currently doesn't use type information to do further optimizations, however early experiments suggest that we can extract further wins. Because we control the entire stack (unlike e.g. TypeScript where the type information is discarded completely before reaching the VM), we have more flexibility there and can make some tradeoffs during codegen even if the type system isn't completely sound. For example, it might be reasonable to assume that in presence of known types, we can infer absence of side effects for arithmetic operations and builtins - if the runtime types mismatch due to intentional violation of the type safety through global injection, the code will still be safely sandboxed; this may unlock optimizations such as common subexpression elimination and allocation hoisting without a JIT. This is speculative pending further research. @@ -90,6 +90,8 @@ As a result, builtin calls are very fast in Luau - they are still slightly slowe > Note: The partial specialization mechanism is cute in that for `assert`, it only specializes on truthy conditions; hopefully performance of `assert(false)` isn't crucial for most code! +In addition to runtime optimizations for builtin calls, many builtin calls can also be constant-folded by the bytecode compiler when using aggressive optimizations (level 2); this currently applies to most builtin calls with constant arguments and a single return value. + ## Optimized table iteration Luau implements a fully generic iteration protocol; however, for iteration through tables in addition to generalized iteration (`for .. in t`) it recognizes three common idioms (`for .. in ipairs(t)`, `for .. in pairs(t)` and `for .. in next, t`) and emits specialized bytecode that is carefully optimized using custom internal iterators. diff --git a/docs/_pages/typecheck.md b/docs/_pages/typecheck.md index 363056c5..8e032da2 100644 --- a/docs/_pages/typecheck.md +++ b/docs/_pages/typecheck.md @@ -104,15 +104,15 @@ From the type checker perspective, each table can be in one of three states. The ### Unsealed tables -An unsealed table is a table whose properties could still be tacked on. This occurs when the table constructor literal had zero expressions. This is one way to accumulate knowledge of the shape of this table. +An unsealed table is a table which supports adding new properties, which updates the tables type. Unsealed tables are created using table literals. This is one way to accumulate knowledge of the shape of this table. ```lua -local t = {} -- {} -t.x = 1 -- {x: number} -t.y = 2 -- {x: number, y: number} +local t = {x = 1} -- {x: number} +t.y = 2 -- {x: number, y: number} +t.z = 3 -- {x: number, y: number, z: number} ``` -However, if this local were written as `local t: {} = {}`, it ends up sealing the table, so the two assignments henceforth will not be ok. +However, if this local were written as `local t: { x: number } = { x = 1 }`, it ends up sealing the table, so the two assignments henceforth will not be ok. Furthermore, once we exit the scope where this unsealed table was created in, we seal it. @@ -128,16 +128,25 @@ local v2 = vec2(1, 2) v2.z = 3 -- not ok ``` -### Sealed tables - -A sealed table is a table that is now locked down. This occurs when the table constructor literal had 1 or more expression, or when the table type is spelt out explicitly via a type annotation. +Unsealed tables are *exact* in that any property of the table must be named by the type. Since Luau treats missing properties as having value `nil`, this means that we can treat an unsealed table which does not mention a property as if it mentioned the property, as long as that property is optional. ```lua -local t = {x = 1} -- {x: number} -t.y = 2 -- not ok +local t = {x = 1} +local u : { x : number, y : number? } = t -- ok because y is optional +local v : { x : number, z : number } = t -- not ok because z is not optional ``` -Sealed tables support *width subtyping*, which allows a table with more properties to be used as a table with fewer +### Sealed tables + +A sealed table is a table that is now locked down. This occurs when the table type is spelled out explicitly via a type annotation, or if it is returned from a function. + +```lua +local t : { x: number } = {x = 1} +t.y = 2 -- not ok +``` + +Sealed tables are *inexact* in that the table may have properties which are not mentioned in the type. +As a result, sealed tables support *width subtyping*, which allows a table with more properties to be used as a table with fewer ```lua type Point1D = { x : number } diff --git a/docs/_posts/2022-07-07-luau-recap-june-2022.md b/docs/_posts/2022-07-07-luau-recap-june-2022.md new file mode 100644 index 00000000..1f58d892 --- /dev/null +++ b/docs/_posts/2022-07-07-luau-recap-june-2022.md @@ -0,0 +1,88 @@ +--- +layout: single +title: "Luau Recap: June 2022" +--- + +Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). + +[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-june-2022/).] + +# Lower bounds calculation + +A common problem that Luau has is that it primarily works by inspecting expressions in your program and narrowing the _upper bounds_ of the values that can inhabit particular variables. In other words, each time we see a variable used, we eliminate possible sets of values from that variable's domain. + +There are some important cases where this doesn't produce a helpful result. Take this function for instance: + +```lua +function find_first_if(vec, f) + for i, e in ipairs(vec) do + if f(e) then + return i + end + end + + return nil +end +``` + +Luau scans the function from top to bottom and first sees the line `return i`. It draws from this the inference that `find_first_if` must return the type of `i`, namely `number`. + +This is fine, but things go sour when we see the line `return nil`. Since we are always narrowing, we take from this line the judgement that the return type of the function is `nil`. Since we have already concluded that the function must return `number`, Luau reports an error. + +What we actually want to do in this case is to take these `return` statements as inferences about the _lower_ bound of the function's return type. Instead of saying "this function must return values of type `nil`," we should instead say "this function may _also_ return values of type `nil`." + +Lower bounds calculation does precisely this. Moving forward, Luau will instead infer the type `number?` for the above function. + +This does have one unfortunate consequence: If a function has no return type annotation, we will no longer ever report a type error on a `return` statement. We think this is the right balance but we'll be keeping an eye on things just to be sure. + +Lower-bounds calculation is larger and a little bit riskier than other things we've been working on so we've set up a beta feature in Roblox Studio to enable them. It is called "Experimental Luau language features." + +Please try it out and let us know what you think! + +## Known bug + +We have a known bug with certain kinds of cyclic types when lower-bounds calculation is enabled. The following, for instance, is known to be problematic. + +```lua +type T = {T?}? -- spuriously reduces to {nil}? +``` + +We hope to have this fixed soon. + +# All table literals now result in unsealed tables + +Previously, the only way to create a sealed table was by with a literal empty table. We have relaxed this somewhat: Any table created by a `{}` expression is considered to be unsealed within the scope where it was created: + +```lua +local T = {} +T.x = 5 -- OK + +local V = {x=5} +V.y = 2 -- previously disallowed. Now OK. + +function mkTable() + return {x = 5} +end + +local U = mkTable() +U.y = 2 -- Still disallowed: U is sealed +``` + +# Other fixes + +* Adjust indentation and whitespace when creating multiline string representations of types, resulting in types that are easier to read. +* Some small bugfixes to autocomplete +* Fix a case where accessing a nonexistent property of a table would not result in an error being reported. +* Improve parser recovery for the incorrect code `function foo() -> ReturnType` (the correct syntax is `function foo(): ReturnType`) +* Improve the parse error offered for code that improperly uses the `function` keyword to start a type eg `type T = function` +* Some small crash fixes and performance improvements + +# Thanks! + +A very special thanks to all of our open source contributors: + +* [Allan N Jeremy](https://github.com/AllanJeremy) +* [Daniel Nachun](https://github.com/danielnachun) +* [JohnnyMorganz](https://github.com/JohnnyMorganz/) +* [Petri Häkkinen](https://github.com/petrihakkinen) +* [Qualadore](https://github.com/Qualadore) diff --git a/extern/doctest.h b/extern/doctest.h index f9e9c5c4..aa2724c7 100644 --- a/extern/doctest.h +++ b/extern/doctest.h @@ -11,7 +11,7 @@ // https://opensource.org/licenses/MIT // // The documentation can be found at the library's page: -// https://github.com/onqtam/doctest/blob/master/doc/markdown/readme.md +// https://github.com/doctest/doctest/blob/master/doc/markdown/readme.md // // ================================================================================================= // ================================================================================================= @@ -48,8 +48,16 @@ #define DOCTEST_VERSION_MAJOR 2 #define DOCTEST_VERSION_MINOR 4 -#define DOCTEST_VERSION_PATCH 6 -#define DOCTEST_VERSION_STR "2.4.6" +#define DOCTEST_VERSION_PATCH 9 + +// util we need here +#define DOCTEST_TOSTR_IMPL(x) #x +#define DOCTEST_TOSTR(x) DOCTEST_TOSTR_IMPL(x) + +#define DOCTEST_VERSION_STR \ + DOCTEST_TOSTR(DOCTEST_VERSION_MAJOR) "." \ + DOCTEST_TOSTR(DOCTEST_VERSION_MINOR) "." \ + DOCTEST_TOSTR(DOCTEST_VERSION_PATCH) #define DOCTEST_VERSION \ (DOCTEST_VERSION_MAJOR * 10000 + DOCTEST_VERSION_MINOR * 100 + DOCTEST_VERSION_PATCH) @@ -60,6 +68,12 @@ // ideas for the version stuff are taken from here: https://github.com/cxxstuff/cxx_detect +#ifdef _MSC_VER +#define DOCTEST_CPLUSPLUS _MSVC_LANG +#else +#define DOCTEST_CPLUSPLUS __cplusplus +#endif + #define DOCTEST_COMPILER(MAJOR, MINOR, PATCH) ((MAJOR)*10000000 + (MINOR)*100000 + (PATCH)) // GCC/Clang and GCC/MSVC are mutually exclusive, but Clang/MSVC are not because of clang-cl... @@ -137,85 +151,92 @@ // == COMPILER WARNINGS ============================================================================ // ================================================================================================= +// both the header and the implementation suppress all of these, +// so it only makes sense to aggregrate them like so +#define DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH \ + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wunknown-pragmas") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wweak-vtables") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wpadded") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-prototypes") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic") \ + \ + DOCTEST_GCC_SUPPRESS_WARNING_PUSH \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wunknown-pragmas") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wpragmas") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Weffc++") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-overflow") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-aliasing") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-declarations") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wuseless-cast") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wnoexcept") \ + \ + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ + /* these 4 also disabled globally via cmake: */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4514) /* unreferenced inline function has been removed */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4571) /* SEH related */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4710) /* function not inlined */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4711) /* function selected for inline expansion*/ \ + /* */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4616) /* invalid compiler warning */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4619) /* invalid compiler warning */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4996) /* The compiler encountered a deprecated declaration */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4706) /* assignment within conditional expression */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4512) /* 'class' : assignment operator could not be generated */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4127) /* conditional expression is constant */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4820) /* padding */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4625) /* copy constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4626) /* assignment operator was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5027) /* move assignment operator implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5026) /* move constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4640) /* construction of local static object not thread-safe */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5045) /* Spectre mitigation for memory load */ \ + /* static analysis */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26439) /* Function may not throw. Declare it 'noexcept' */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26495) /* Always initialize a member variable */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26451) /* Arithmetic overflow ... */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26444) /* Avoid unnamed objects with custom ctor and dtor... */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26812) /* Prefer 'enum class' over 'enum' */ + +#define DOCTEST_SUPPRESS_COMMON_WARNINGS_POP \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP \ + DOCTEST_GCC_SUPPRESS_WARNING_POP \ + DOCTEST_MSVC_SUPPRESS_WARNING_POP + +DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH -DOCTEST_CLANG_SUPPRESS_WARNING("-Wunknown-pragmas") DOCTEST_CLANG_SUPPRESS_WARNING("-Wnon-virtual-dtor") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wweak-vtables") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wpadded") DOCTEST_CLANG_SUPPRESS_WARNING("-Wdeprecated") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-prototypes") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-local-typedef") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic") DOCTEST_GCC_SUPPRESS_WARNING_PUSH -DOCTEST_GCC_SUPPRESS_WARNING("-Wunknown-pragmas") -DOCTEST_GCC_SUPPRESS_WARNING("-Wpragmas") -DOCTEST_GCC_SUPPRESS_WARNING("-Weffc++") -DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-overflow") -DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-aliasing") DOCTEST_GCC_SUPPRESS_WARNING("-Wctor-dtor-privacy") -DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-declarations") DOCTEST_GCC_SUPPRESS_WARNING("-Wnon-virtual-dtor") -DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-local-typedefs") -DOCTEST_GCC_SUPPRESS_WARNING("-Wuseless-cast") -DOCTEST_GCC_SUPPRESS_WARNING("-Wnoexcept") DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-promo") DOCTEST_MSVC_SUPPRESS_WARNING_PUSH -DOCTEST_MSVC_SUPPRESS_WARNING(4616) // invalid compiler warning -DOCTEST_MSVC_SUPPRESS_WARNING(4619) // invalid compiler warning -DOCTEST_MSVC_SUPPRESS_WARNING(4996) // The compiler encountered a deprecated declaration -DOCTEST_MSVC_SUPPRESS_WARNING(4706) // assignment within conditional expression -DOCTEST_MSVC_SUPPRESS_WARNING(4512) // 'class' : assignment operator could not be generated -DOCTEST_MSVC_SUPPRESS_WARNING(4127) // conditional expression is constant -DOCTEST_MSVC_SUPPRESS_WARNING(4820) // padding -DOCTEST_MSVC_SUPPRESS_WARNING(4625) // copy constructor was implicitly defined as deleted -DOCTEST_MSVC_SUPPRESS_WARNING(4626) // assignment operator was implicitly defined as deleted -DOCTEST_MSVC_SUPPRESS_WARNING(5027) // move assignment operator was implicitly defined as deleted -DOCTEST_MSVC_SUPPRESS_WARNING(5026) // move constructor was implicitly defined as deleted DOCTEST_MSVC_SUPPRESS_WARNING(4623) // default constructor was implicitly defined as deleted -DOCTEST_MSVC_SUPPRESS_WARNING(4640) // construction of local static object is not thread-safe -// static analysis -DOCTEST_MSVC_SUPPRESS_WARNING(26439) // This kind of function may not throw. Declare it 'noexcept' -DOCTEST_MSVC_SUPPRESS_WARNING(26495) // Always initialize a member variable -DOCTEST_MSVC_SUPPRESS_WARNING(26451) // Arithmetic overflow ... -DOCTEST_MSVC_SUPPRESS_WARNING(26444) // Avoid unnamed objects with custom construction and dtr... -DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' - -// 4548 - expression before comma has no effect; expected expression with side - effect -// 4265 - class has virtual functions, but destructor is not virtual -// 4986 - exception specification does not match previous declaration -// 4350 - behavior change: 'member1' called instead of 'member2' -// 4668 - 'x' is not defined as a preprocessor macro, replacing with '0' for '#if/#elif' -// 4365 - conversion from 'int' to 'unsigned long', signed/unsigned mismatch -// 4774 - format string expected in argument 'x' is not a string literal -// 4820 - padding in structs - -// only 4 should be disabled globally: -// - 4514 # unreferenced inline function has been removed -// - 4571 # SEH related -// - 4710 # function not inlined -// - 4711 # function 'x' selected for automatic inline expansion #define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN \ DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ - DOCTEST_MSVC_SUPPRESS_WARNING(4548) \ - DOCTEST_MSVC_SUPPRESS_WARNING(4265) \ - DOCTEST_MSVC_SUPPRESS_WARNING(4986) \ - DOCTEST_MSVC_SUPPRESS_WARNING(4350) \ - DOCTEST_MSVC_SUPPRESS_WARNING(4668) \ - DOCTEST_MSVC_SUPPRESS_WARNING(4365) \ - DOCTEST_MSVC_SUPPRESS_WARNING(4774) \ - DOCTEST_MSVC_SUPPRESS_WARNING(4820) \ - DOCTEST_MSVC_SUPPRESS_WARNING(4625) \ - DOCTEST_MSVC_SUPPRESS_WARNING(4626) \ - DOCTEST_MSVC_SUPPRESS_WARNING(5027) \ - DOCTEST_MSVC_SUPPRESS_WARNING(5026) \ - DOCTEST_MSVC_SUPPRESS_WARNING(4623) \ - DOCTEST_MSVC_SUPPRESS_WARNING(5039) \ - DOCTEST_MSVC_SUPPRESS_WARNING(5045) \ - DOCTEST_MSVC_SUPPRESS_WARNING(5105) + DOCTEST_MSVC_SUPPRESS_WARNING(4548) /* before comma no effect; expected side - effect */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4265) /* virtual functions, but destructor is not virtual */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4986) /* exception specification does not match previous */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4350) /* 'member1' called instead of 'member2' */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4668) /* not defined as a preprocessor macro */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4365) /* signed/unsigned mismatch */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4774) /* format string not a string literal */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4820) /* padding */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4625) /* copy constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4626) /* assignment operator was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5027) /* move assignment operator implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5026) /* move constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4623) /* default constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5039) /* pointer to pot. throwing function passed to extern C */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5045) /* Spectre mitigation for memory load */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5105) /* macro producing 'defined' has undefined behavior */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4738) /* storing float result in memory, loss of performance */ #define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END DOCTEST_MSVC_SUPPRESS_WARNING_POP @@ -228,6 +249,7 @@ DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' // GCC C++11 feature support table: https://gcc.gnu.org/projects/cxx-status.html // MSVC version table: // https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B#Internal_version_numbering +// MSVC++ 14.3 (17) _MSC_VER == 1930 (Visual Studio 2022) // MSVC++ 14.2 (16) _MSC_VER == 1920 (Visual Studio 2019) // MSVC++ 14.1 (15) _MSC_VER == 1910 (Visual Studio 2017) // MSVC++ 14.0 _MSC_VER == 1900 (Visual Studio 2015) @@ -237,6 +259,10 @@ DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' // MSVC++ 9.0 _MSC_VER == 1500 (Visual Studio 2008) // MSVC++ 8.0 _MSC_VER == 1400 (Visual Studio 2005) +// Universal Windows Platform support +#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) +#define DOCTEST_CONFIG_NO_WINDOWS_SEH +#endif // WINAPI_FAMILY #if DOCTEST_MSVC && !defined(DOCTEST_CONFIG_WINDOWS_SEH) #define DOCTEST_CONFIG_WINDOWS_SEH #endif // MSVC @@ -245,7 +271,7 @@ DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' #endif // DOCTEST_CONFIG_NO_WINDOWS_SEH #if !defined(_WIN32) && !defined(__QNX__) && !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && \ - !defined(__EMSCRIPTEN__) + !defined(__EMSCRIPTEN__) && !defined(__wasi__) #define DOCTEST_CONFIG_POSIX_SIGNALS #endif // _WIN32 #if defined(DOCTEST_CONFIG_NO_POSIX_SIGNALS) && defined(DOCTEST_CONFIG_POSIX_SIGNALS) @@ -253,7 +279,8 @@ DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' #endif // DOCTEST_CONFIG_NO_POSIX_SIGNALS #ifndef DOCTEST_CONFIG_NO_EXCEPTIONS -#if !defined(__cpp_exceptions) && !defined(__EXCEPTIONS) && !defined(_CPPUNWIND) +#if !defined(__cpp_exceptions) && !defined(__EXCEPTIONS) && !defined(_CPPUNWIND) \ + || defined(__wasi__) #define DOCTEST_CONFIG_NO_EXCEPTIONS #endif // no exceptions #endif // DOCTEST_CONFIG_NO_EXCEPTIONS @@ -268,6 +295,10 @@ DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' #define DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS #endif // DOCTEST_CONFIG_NO_EXCEPTIONS && !DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#ifdef __wasi__ +#define DOCTEST_CONFIG_NO_MULTITHREADING +#endif + #if defined(DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN) && !defined(DOCTEST_CONFIG_IMPLEMENT) #define DOCTEST_CONFIG_IMPLEMENT #endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN @@ -295,6 +326,16 @@ DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' #define DOCTEST_INTERFACE #endif // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL +// needed for extern template instantiations +// see https://github.com/fmtlib/fmt/issues/2228 +#if DOCTEST_MSVC +#define DOCTEST_INTERFACE_DECL +#define DOCTEST_INTERFACE_DEF DOCTEST_INTERFACE +#else // DOCTEST_MSVC +#define DOCTEST_INTERFACE_DECL DOCTEST_INTERFACE +#define DOCTEST_INTERFACE_DEF +#endif // DOCTEST_MSVC + #define DOCTEST_EMPTY #if DOCTEST_MSVC @@ -312,17 +353,46 @@ DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' #endif #ifndef DOCTEST_NORETURN +#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_NORETURN +#else // DOCTEST_MSVC #define DOCTEST_NORETURN [[noreturn]] +#endif // DOCTEST_MSVC #endif // DOCTEST_NORETURN #ifndef DOCTEST_NOEXCEPT +#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_NOEXCEPT +#else // DOCTEST_MSVC #define DOCTEST_NOEXCEPT noexcept +#endif // DOCTEST_MSVC #endif // DOCTEST_NOEXCEPT +#ifndef DOCTEST_CONSTEXPR +#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_CONSTEXPR const +#define DOCTEST_CONSTEXPR_FUNC inline +#else // DOCTEST_MSVC +#define DOCTEST_CONSTEXPR constexpr +#define DOCTEST_CONSTEXPR_FUNC constexpr +#endif // DOCTEST_MSVC +#endif // DOCTEST_CONSTEXPR + // ================================================================================================= // == FEATURE DETECTION END ======================================================================== // ================================================================================================= +#define DOCTEST_DECLARE_INTERFACE(name) \ + virtual ~name(); \ + name() = default; \ + name(const name&) = delete; \ + name(name&&) = delete; \ + name& operator=(const name&) = delete; \ + name& operator=(name&&) = delete; + +#define DOCTEST_DEFINE_INTERFACE(name) \ + name::~name() = default; + // internal macros for string concatenation and anonymous variable name generation #define DOCTEST_CAT_IMPL(s1, s2) s1##s2 #define DOCTEST_CAT(s1, s2) DOCTEST_CAT_IMPL(s1, s2) @@ -332,8 +402,6 @@ DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' #define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __LINE__) #endif // __COUNTER__ -#define DOCTEST_TOSTR(x) #x - #ifndef DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE #define DOCTEST_REF_WRAP(x) x& #else // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE @@ -347,31 +415,39 @@ DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' #define DOCTEST_PLATFORM_IPHONE #elif defined(_WIN32) #define DOCTEST_PLATFORM_WINDOWS +#elif defined(__wasi__) +#define DOCTEST_PLATFORM_WASI #else // DOCTEST_PLATFORM #define DOCTEST_PLATFORM_LINUX #endif // DOCTEST_PLATFORM -#define DOCTEST_GLOBAL_NO_WARNINGS(var) \ - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wglobal-constructors") \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-variable") \ - static const int var DOCTEST_UNUSED // NOLINT(fuchsia-statically-constructed-objects,cert-err58-cpp) -#define DOCTEST_GLOBAL_NO_WARNINGS_END() DOCTEST_CLANG_SUPPRESS_WARNING_POP +namespace doctest { namespace detail { + static DOCTEST_CONSTEXPR int consume(const int*, int) noexcept { return 0; } +}} + +#define DOCTEST_GLOBAL_NO_WARNINGS(var, ...) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wglobal-constructors") \ + static const int var = doctest::detail::consume(&var, __VA_ARGS__); \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP #ifndef DOCTEST_BREAK_INTO_DEBUGGER // should probably take a look at https://github.com/scottt/debugbreak #ifdef DOCTEST_PLATFORM_LINUX #if defined(__GNUC__) && (defined(__i386) || defined(__x86_64)) // Break at the location of the failing check if possible -#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT (hicpp-no-assembler) +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT(hicpp-no-assembler) #else #include #define DOCTEST_BREAK_INTO_DEBUGGER() raise(SIGTRAP) #endif #elif defined(DOCTEST_PLATFORM_MAC) #if defined(__x86_64) || defined(__x86_64__) || defined(__amd64__) || defined(__i386) -#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT (hicpp-no-assembler) +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT(hicpp-no-assembler) +#elif defined(__ppc__) || defined(__ppc64__) +// https://www.cocoawithlove.com/2008/03/break-into-debugger.html +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("li r0, 20\nsc\nnop\nli r0, 37\nli r4, 2\nsc\nnop\n": : : "memory","r0","r3","r4") // NOLINT(hicpp-no-assembler) #else -#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("brk #0"); // NOLINT (hicpp-no-assembler) +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("brk #0"); // NOLINT(hicpp-no-assembler) #endif #elif DOCTEST_MSVC #define DOCTEST_BREAK_INTO_DEBUGGER() __debugbreak() @@ -387,54 +463,66 @@ DOCTEST_GCC_SUPPRESS_WARNING_POP // this is kept here for backwards compatibility since the config option was changed #ifdef DOCTEST_CONFIG_USE_IOSFWD +#ifndef DOCTEST_CONFIG_USE_STD_HEADERS #define DOCTEST_CONFIG_USE_STD_HEADERS +#endif #endif // DOCTEST_CONFIG_USE_IOSFWD +// for clang - always include ciso646 (which drags some std stuff) because +// we want to check if we are using libc++ with the _LIBCPP_VERSION macro in +// which case we don't want to forward declare stuff from std - for reference: +// https://github.com/doctest/doctest/issues/126 +// https://github.com/doctest/doctest/issues/356 +#if DOCTEST_CLANG +#include +#ifdef _LIBCPP_VERSION +#ifndef DOCTEST_CONFIG_USE_STD_HEADERS +#define DOCTEST_CONFIG_USE_STD_HEADERS +#endif +#endif // _LIBCPP_VERSION +#endif // clang + #ifdef DOCTEST_CONFIG_USE_STD_HEADERS #ifndef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS #define DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS #endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS -#include +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN #include #include +#include +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END #else // DOCTEST_CONFIG_USE_STD_HEADERS -#if DOCTEST_CLANG -// to detect if libc++ is being used with clang (the _LIBCPP_VERSION identifier) -#include -#endif // clang - -#ifdef _LIBCPP_VERSION -#define DOCTEST_STD_NAMESPACE_BEGIN _LIBCPP_BEGIN_NAMESPACE_STD -#define DOCTEST_STD_NAMESPACE_END _LIBCPP_END_NAMESPACE_STD -#else // _LIBCPP_VERSION -#define DOCTEST_STD_NAMESPACE_BEGIN namespace std { -#define DOCTEST_STD_NAMESPACE_END } -#endif // _LIBCPP_VERSION - // Forward declaring 'X' in namespace std is not permitted by the C++ Standard. DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4643) -DOCTEST_STD_NAMESPACE_BEGIN // NOLINT (cert-dcl58-cpp) -typedef decltype(nullptr) nullptr_t; +namespace std { // NOLINT(cert-dcl58-cpp) +typedef decltype(nullptr) nullptr_t; // NOLINT(modernize-use-using) +typedef decltype(sizeof(void*)) size_t; // NOLINT(modernize-use-using) template struct char_traits; template <> struct char_traits; template -class basic_ostream; -typedef basic_ostream> ostream; +class basic_ostream; // NOLINT(fuchsia-virtual-inheritance) +typedef basic_ostream> ostream; // NOLINT(modernize-use-using) +template +// NOLINTNEXTLINE +basic_ostream& operator<<(basic_ostream&, const char*); +template +class basic_istream; +typedef basic_istream> istream; // NOLINT(modernize-use-using) template class tuple; #if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) -// see this issue on why this is needed: https://github.com/onqtam/doctest/issues/183 -template +// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 +template class allocator; -template +template class basic_string; using string = basic_string, allocator>; #endif // VS 2019 -DOCTEST_STD_NAMESPACE_END +} // namespace std DOCTEST_MSVC_SUPPRESS_WARNING_POP @@ -446,8 +534,14 @@ DOCTEST_MSVC_SUPPRESS_WARNING_POP namespace doctest { +using std::size_t; + DOCTEST_INTERFACE extern bool is_running_in_test; +#ifndef DOCTEST_CONFIG_STRING_SIZE_TYPE +#define DOCTEST_CONFIG_STRING_SIZE_TYPE unsigned +#endif + // A 24 byte string class (can be as small as 17 for x64 and 13 for x86) that can hold strings with length // of up to 23 chars on the stack before going on the heap - the last byte of the buffer is used for: // - "is small" bit - the highest bit - if "0" then it is small - otherwise its "1" (128) @@ -460,7 +554,6 @@ DOCTEST_INTERFACE extern bool is_running_in_test; // TODO: // - optimizations - like not deleting memory unnecessarily in operator= and etc. // - resize/reserve/clear -// - substr // - replace // - back/front // - iterator stuff @@ -470,63 +563,84 @@ DOCTEST_INTERFACE extern bool is_running_in_test; // - relational operators as free functions - taking const char* as one of the params class DOCTEST_INTERFACE String { - static const unsigned len = 24; //!OCLINT avoid private static members - static const unsigned last = len - 1; //!OCLINT avoid private static members +public: + using size_type = DOCTEST_CONFIG_STRING_SIZE_TYPE; + +private: + static DOCTEST_CONSTEXPR size_type len = 24; //!OCLINT avoid private static members + static DOCTEST_CONSTEXPR size_type last = len - 1; //!OCLINT avoid private static members struct view // len should be more than sizeof(view) - because of the final byte for flags { char* ptr; - unsigned size; - unsigned capacity; + size_type size; + size_type capacity; }; union { - char buf[len]; + char buf[len]; // NOLINT(*-avoid-c-arrays) view data; }; - bool isOnStack() const { return (buf[last] & 128) == 0; } - void setOnHeap(); - void setLast(unsigned in = last); + char* allocate(size_type sz); + + bool isOnStack() const noexcept { return (buf[last] & 128) == 0; } + void setOnHeap() noexcept; + void setLast(size_type in = last) noexcept; + void setSize(size_type sz) noexcept; void copy(const String& other); public: - String(); + static DOCTEST_CONSTEXPR size_type npos = static_cast(-1); + + String() noexcept; ~String(); // cppcheck-suppress noExplicitConstructor String(const char* in); - String(const char* in, unsigned in_size); + String(const char* in, size_type in_size); + + String(std::istream& in, size_type in_size); String(const String& other); String& operator=(const String& other); String& operator+=(const String& other); - String operator+(const String& other) const; - String(String&& other); - String& operator=(String&& other); + String(String&& other) noexcept; + String& operator=(String&& other) noexcept; - char operator[](unsigned i) const; - char& operator[](unsigned i); + char operator[](size_type i) const; + char& operator[](size_type i); // the only functions I'm willing to leave in the interface - available for inlining const char* c_str() const { return const_cast(this)->c_str(); } // NOLINT char* c_str() { - if(isOnStack()) + if (isOnStack()) { return reinterpret_cast(buf); + } return data.ptr; } - unsigned size() const; - unsigned capacity() const; + size_type size() const; + size_type capacity() const; + + String substr(size_type pos, size_type cnt = npos) &&; + String substr(size_type pos, size_type cnt = npos) const &; + + size_type find(char ch, size_type pos = 0) const; + size_type rfind(char ch, size_type pos = npos) const; int compare(const char* other, bool no_case = false) const; int compare(const String& other, bool no_case = false) const; + +friend DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, const String& in); }; +DOCTEST_INTERFACE String operator+(const String& lhs, const String& rhs); + DOCTEST_INTERFACE bool operator==(const String& lhs, const String& rhs); DOCTEST_INTERFACE bool operator!=(const String& lhs, const String& rhs); DOCTEST_INTERFACE bool operator<(const String& lhs, const String& rhs); @@ -534,7 +648,21 @@ DOCTEST_INTERFACE bool operator>(const String& lhs, const String& rhs); DOCTEST_INTERFACE bool operator<=(const String& lhs, const String& rhs); DOCTEST_INTERFACE bool operator>=(const String& lhs, const String& rhs); -DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, const String& in); +class DOCTEST_INTERFACE Contains { +public: + explicit Contains(const String& string); + + bool checkWith(const String& other) const; + + String string; +}; + +DOCTEST_INTERFACE String toString(const Contains& in); + +DOCTEST_INTERFACE bool operator==(const String& lhs, const Contains& rhs); +DOCTEST_INTERFACE bool operator==(const Contains& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator!=(const String& lhs, const Contains& rhs); +DOCTEST_INTERFACE bool operator!=(const Contains& lhs, const String& rhs); namespace Color { enum Enum @@ -607,7 +735,7 @@ namespace assertType { DT_WARN_THROWS_WITH = is_throws_with | is_warn, DT_CHECK_THROWS_WITH = is_throws_with | is_check, DT_REQUIRE_THROWS_WITH = is_throws_with | is_require, - + DT_WARN_THROWS_WITH_AS = is_throws_with | is_throws_as | is_warn, DT_CHECK_THROWS_WITH_AS = is_throws_with | is_throws_as | is_check, DT_REQUIRE_THROWS_WITH_AS = is_throws_with | is_throws_as | is_require, @@ -688,9 +816,27 @@ struct DOCTEST_INTERFACE AssertData String m_decomp; // for specific exception-related asserts - bool m_threw_as; - const char* m_exception_type; - const char* m_exception_string; + bool m_threw_as; + const char* m_exception_type; + + class DOCTEST_INTERFACE StringContains { + private: + Contains content; + bool isContains; + + public: + StringContains(const String& str) : content(str), isContains(false) { } + StringContains(Contains cntn) : content(static_cast(cntn)), isContains(true) { } + + bool check(const String& str) { return isContains ? (content == str) : (content.string == str); } + + operator const String&() const { return content.string; } + + const char* c_str() const { return content.string.c_str(); } + } m_exception_string; + + AssertData(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const StringContains& exception_string); }; struct DOCTEST_INTERFACE MessageData @@ -707,13 +853,13 @@ struct DOCTEST_INTERFACE SubcaseSignature const char* m_file; int m_line; + bool operator==(const SubcaseSignature& other) const; bool operator<(const SubcaseSignature& other) const; }; struct DOCTEST_INTERFACE IContextScope { - IContextScope(); - virtual ~IContextScope(); + DOCTEST_DECLARE_INTERFACE(IContextScope) virtual void stringify(std::ostream*) const = 0; }; @@ -723,9 +869,8 @@ namespace detail { struct ContextOptions //!OCLINT too many fields { - std::ostream* cout; // stdout stream - std::cout by default - std::ostream* cerr; // stderr stream - std::cerr by default - String binary_name; // the test binary name + std::ostream* cout = nullptr; // stdout stream + String binary_name; // the test binary name const detail::TestCase* currentTest = nullptr; @@ -744,9 +889,12 @@ struct ContextOptions //!OCLINT too many fields bool case_sensitive; // if filtering should be case sensitive bool exit; // if the program should be exited after the tests are ran/whatever bool duration; // print the time duration of each test case + bool minimal; // minimal console output (only test failures) + bool quiet; // no console output bool no_throw; // to skip exceptions-related assertion macros bool no_exitcode; // if the framework should return 0 as the exitcode bool no_run; // to not run the tests at all (can be done with an "*" exclude) + bool no_intro; // to not print the intro of the framework bool no_version; // to not print the version of the framework bool no_colors; // if output to the console should be colorized bool force_colors; // forces the use of colors even when a tty cannot be detected @@ -768,150 +916,184 @@ struct ContextOptions //!OCLINT too many fields }; namespace detail { - template - struct enable_if - {}; - - template - struct enable_if - { typedef TYPE type; }; - - // clang-format off - template struct remove_reference { typedef T type; }; - template struct remove_reference { typedef T type; }; - template struct remove_reference { typedef T type; }; - - template U declval(int); - - template T declval(long); - - template auto declval() DOCTEST_NOEXCEPT -> decltype(declval(0)) ; - - template struct is_lvalue_reference { const static bool value=false; }; - template struct is_lvalue_reference { const static bool value=true; }; - - template - inline T&& forward(typename remove_reference::type& t) DOCTEST_NOEXCEPT - { - return static_cast(t); - } - - template - inline T&& forward(typename remove_reference::type&& t) DOCTEST_NOEXCEPT - { - static_assert(!is_lvalue_reference::value, - "Can not forward an rvalue as an lvalue."); - return static_cast(t); - } - - template struct remove_const { typedef T type; }; - template struct remove_const { typedef T type; }; + namespace types { #ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - template struct is_enum : public std::is_enum {}; - template struct underlying_type : public std::underlying_type {}; + using namespace std; #else - // Use compiler intrinsics - template struct is_enum { constexpr static bool value = __is_enum(T); }; - template struct underlying_type { typedef __underlying_type(T) type; }; + template + struct enable_if { }; + + template + struct enable_if { using type = T; }; + + struct true_type { static DOCTEST_CONSTEXPR bool value = true; }; + struct false_type { static DOCTEST_CONSTEXPR bool value = false; }; + + template struct remove_reference { using type = T; }; + template struct remove_reference { using type = T; }; + template struct remove_reference { using type = T; }; + + template struct is_rvalue_reference : false_type { }; + template struct is_rvalue_reference : true_type { }; + + template struct remove_const { using type = T; }; + template struct remove_const { using type = T; }; + + // Compiler intrinsics + template struct is_enum { static DOCTEST_CONSTEXPR bool value = __is_enum(T); }; + template struct underlying_type { using type = __underlying_type(T); }; + + template struct is_pointer : false_type { }; + template struct is_pointer : true_type { }; + + template struct is_array : false_type { }; + // NOLINTNEXTLINE(*-avoid-c-arrays) + template struct is_array : true_type { }; #endif - // clang-format on + } + + // + template + T&& declval(); + + template + DOCTEST_CONSTEXPR_FUNC T&& forward(typename types::remove_reference::type& t) DOCTEST_NOEXCEPT { + return static_cast(t); + } + + template + DOCTEST_CONSTEXPR_FUNC T&& forward(typename types::remove_reference::type&& t) DOCTEST_NOEXCEPT { + return static_cast(t); + } template - struct deferred_false - // cppcheck-suppress unusedStructMember - { static const bool value = false; }; + struct deferred_false : types::false_type { }; - namespace has_insertion_operator_impl { - std::ostream &os(); - template - DOCTEST_REF_WRAP(T) val(); +// MSVS 2015 :( +#if defined(_MSC_VER) && _MSC_VER <= 1900 + template + struct has_global_insertion_operator : types::false_type { }; - template - struct check { - static constexpr bool value = false; - }; + template + struct has_global_insertion_operator(), declval()), void())> : types::true_type { }; - template - struct check(), void())> { - static constexpr bool value = true; - }; - } // namespace has_insertion_operator_impl + template + struct has_insertion_operator { static DOCTEST_CONSTEXPR bool value = has_global_insertion_operator::value; }; - template - using has_insertion_operator = has_insertion_operator_impl::check; + template + struct insert_hack; - DOCTEST_INTERFACE void my_memcpy(void* dest, const void* src, unsigned num); + template + struct insert_hack { + static void insert(std::ostream& os, const T& t) { ::operator<<(os, t); } + }; - DOCTEST_INTERFACE std::ostream* getTlsOss(); // returns a thread-local ostringstream - DOCTEST_INTERFACE String getTlsOssResult(); + template + struct insert_hack { + static void insert(std::ostream& os, const T& t) { operator<<(os, t); } + }; + + template + using insert_hack_t = insert_hack::value>; +#else + template + struct has_insertion_operator : types::false_type { }; +#endif + +template +struct has_insertion_operator(), declval()), void())> : types::true_type { }; + + DOCTEST_INTERFACE std::ostream* tlssPush(); + DOCTEST_INTERFACE String tlssPop(); template - struct StringMakerBase - { + struct StringMakerBase { template static String convert(const DOCTEST_REF_WRAP(T)) { +#ifdef DOCTEST_CONFIG_REQUIRE_STRINGIFICATION_FOR_ALL_USED_TYPES + static_assert(deferred_false::value, "No stringification detected for type T. See string conversion manual"); +#endif return "{?}"; } }; + template + struct filldata; + + template + void filloss(std::ostream* stream, const T& in) { + filldata::fill(stream, in); + } + + template + void filloss(std::ostream* stream, const T (&in)[N]) { // NOLINT(*-avoid-c-arrays) + // T[N], T(&)[N], T(&&)[N] have same behaviour. + // Hence remove reference. + filloss::type>(stream, in); + } + + template + String toStream(const T& in) { + std::ostream* stream = tlssPush(); + filloss(stream, in); + return tlssPop(); + } + template <> - struct StringMakerBase - { + struct StringMakerBase { template static String convert(const DOCTEST_REF_WRAP(T) in) { - *getTlsOss() << in; - return getTlsOssResult(); + return toStream(in); } }; - - DOCTEST_INTERFACE String rawMemoryToString(const void* object, unsigned size); - - template - String rawMemoryToString(const DOCTEST_REF_WRAP(T) object) { - return rawMemoryToString(&object, sizeof(object)); - } - - template - const char* type_to_string() { - return "<>"; - } } // namespace detail template -struct StringMaker : public detail::StringMakerBase::value> +struct StringMaker : public detail::StringMakerBase< + detail::has_insertion_operator::value || detail::types::is_pointer::value || detail::types::is_array::value> {}; +#ifndef DOCTEST_STRINGIFY +#ifdef DOCTEST_CONFIG_DOUBLE_STRINGIFY +#define DOCTEST_STRINGIFY(...) toString(toString(__VA_ARGS__)) +#else +#define DOCTEST_STRINGIFY(...) toString(__VA_ARGS__) +#endif +#endif + template -struct StringMaker -{ - template - static String convert(U* p) { - if(p) - return detail::rawMemoryToString(p); - return "NULL"; - } -}; +String toString() { +#if DOCTEST_MSVC >= 0 && DOCTEST_CLANG == 0 && DOCTEST_GCC == 0 + String ret = __FUNCSIG__; // class doctest::String __cdecl doctest::toString(void) + String::size_type beginPos = ret.find('<'); + return ret.substr(beginPos + 1, ret.size() - beginPos - static_cast(sizeof(">(void)"))); +#else + String ret = __PRETTY_FUNCTION__; // doctest::String toString() [with T = TYPE] + String::size_type begin = ret.find('=') + 2; + return ret.substr(begin, ret.size() - begin - 1); +#endif +} -template -struct StringMaker -{ - static String convert(R C::*p) { - if(p) - return detail::rawMemoryToString(p); - return "NULL"; - } -}; - -template ::value, bool>::type = true> +template ::value, bool>::type = true> String toString(const DOCTEST_REF_WRAP(T) value) { return StringMaker::convert(value); } #ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -DOCTEST_INTERFACE String toString(char* in); DOCTEST_INTERFACE String toString(const char* in); #endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 +DOCTEST_INTERFACE String toString(const std::string& in); +#endif // VS 2019 + +DOCTEST_INTERFACE String toString(String in); + +DOCTEST_INTERFACE String toString(std::nullptr_t); + DOCTEST_INTERFACE String toString(bool in); + DOCTEST_INTERFACE String toString(float in); DOCTEST_INTERFACE String toString(double in); DOCTEST_INTERFACE String toString(double long in); @@ -919,40 +1101,85 @@ DOCTEST_INTERFACE String toString(double long in); DOCTEST_INTERFACE String toString(char in); DOCTEST_INTERFACE String toString(char signed in); DOCTEST_INTERFACE String toString(char unsigned in); -DOCTEST_INTERFACE String toString(int short in); -DOCTEST_INTERFACE String toString(int short unsigned in); -DOCTEST_INTERFACE String toString(int in); -DOCTEST_INTERFACE String toString(int unsigned in); -DOCTEST_INTERFACE String toString(int long in); -DOCTEST_INTERFACE String toString(int long unsigned in); -DOCTEST_INTERFACE String toString(int long long in); -DOCTEST_INTERFACE String toString(int long long unsigned in); -DOCTEST_INTERFACE String toString(std::nullptr_t in); +DOCTEST_INTERFACE String toString(short in); +DOCTEST_INTERFACE String toString(short unsigned in); +DOCTEST_INTERFACE String toString(signed in); +DOCTEST_INTERFACE String toString(unsigned in); +DOCTEST_INTERFACE String toString(long in); +DOCTEST_INTERFACE String toString(long unsigned in); +DOCTEST_INTERFACE String toString(long long in); +DOCTEST_INTERFACE String toString(long long unsigned in); -template ::value, bool>::type = true> +template ::value, bool>::type = true> String toString(const DOCTEST_REF_WRAP(T) value) { - typedef typename detail::underlying_type::type UT; - return toString(static_cast(value)); + using UT = typename detail::types::underlying_type::type; + return (DOCTEST_STRINGIFY(static_cast(value))); } -#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) -// see this issue on why this is needed: https://github.com/onqtam/doctest/issues/183 -DOCTEST_INTERFACE String toString(const std::string& in); -#endif // VS 2019 +namespace detail { + template + struct filldata + { + static void fill(std::ostream* stream, const T& in) { +#if defined(_MSC_VER) && _MSC_VER <= 1900 + insert_hack_t::insert(*stream, in); +#else + operator<<(*stream, in); +#endif + } + }; -class DOCTEST_INTERFACE Approx +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4866) +// NOLINTBEGIN(*-avoid-c-arrays) + template + struct filldata { + static void fill(std::ostream* stream, const T(&in)[N]) { + *stream << "["; + for (size_t i = 0; i < N; i++) { + if (i != 0) { *stream << ", "; } + *stream << (DOCTEST_STRINGIFY(in[i])); + } + *stream << "]"; + } + }; +// NOLINTEND(*-avoid-c-arrays) +DOCTEST_MSVC_SUPPRESS_WARNING_POP + + // Specialized since we don't want the terminating null byte! +// NOLINTBEGIN(*-avoid-c-arrays) + template + struct filldata { + static void fill(std::ostream* stream, const char (&in)[N]) { + *stream << String(in, in[N - 1] ? N : N - 1); + } // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) + }; +// NOLINTEND(*-avoid-c-arrays) + + template <> + struct filldata { + static void fill(std::ostream* stream, const void* in); + }; + + template + struct filldata { + static void fill(std::ostream* stream, const T* in) { + filldata::fill(stream, in); + } + }; +} + +struct DOCTEST_INTERFACE Approx { -public: - explicit Approx(double value); + Approx(double value); Approx operator()(double value) const; #ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS template explicit Approx(const T& value, - typename detail::enable_if::value>::type* = + typename detail::types::enable_if::value>::type* = static_cast(nullptr)) { - *this = Approx(static_cast(value)); + *this = static_cast(value); } #endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS @@ -960,7 +1187,7 @@ public: #ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS template - typename detail::enable_if::value, Approx&>::type epsilon( + typename std::enable_if::value, Approx&>::type epsilon( const T& newEpsilon) { m_epsilon = static_cast(newEpsilon); return *this; @@ -971,7 +1198,7 @@ public: #ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS template - typename detail::enable_if::value, Approx&>::type scale( + typename std::enable_if::value, Approx&>::type scale( const T& newScale) { m_scale = static_cast(newScale); return *this; @@ -992,30 +1219,27 @@ public: DOCTEST_INTERFACE friend bool operator> (double lhs, const Approx & rhs); DOCTEST_INTERFACE friend bool operator> (const Approx & lhs, double rhs); - DOCTEST_INTERFACE friend String toString(const Approx& in); - #ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS #define DOCTEST_APPROX_PREFIX \ - template friend typename detail::enable_if::value, bool>::type + template friend typename std::enable_if::value, bool>::type - DOCTEST_APPROX_PREFIX operator==(const T& lhs, const Approx& rhs) { return operator==(double(lhs), rhs); } + DOCTEST_APPROX_PREFIX operator==(const T& lhs, const Approx& rhs) { return operator==(static_cast(lhs), rhs); } DOCTEST_APPROX_PREFIX operator==(const Approx& lhs, const T& rhs) { return operator==(rhs, lhs); } DOCTEST_APPROX_PREFIX operator!=(const T& lhs, const Approx& rhs) { return !operator==(lhs, rhs); } DOCTEST_APPROX_PREFIX operator!=(const Approx& lhs, const T& rhs) { return !operator==(rhs, lhs); } - DOCTEST_APPROX_PREFIX operator<=(const T& lhs, const Approx& rhs) { return double(lhs) < rhs.m_value || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator<=(const Approx& lhs, const T& rhs) { return lhs.m_value < double(rhs) || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator>=(const T& lhs, const Approx& rhs) { return double(lhs) > rhs.m_value || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator>=(const Approx& lhs, const T& rhs) { return lhs.m_value > double(rhs) || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator< (const T& lhs, const Approx& rhs) { return double(lhs) < rhs.m_value && lhs != rhs; } - DOCTEST_APPROX_PREFIX operator< (const Approx& lhs, const T& rhs) { return lhs.m_value < double(rhs) && lhs != rhs; } - DOCTEST_APPROX_PREFIX operator> (const T& lhs, const Approx& rhs) { return double(lhs) > rhs.m_value && lhs != rhs; } - DOCTEST_APPROX_PREFIX operator> (const Approx& lhs, const T& rhs) { return lhs.m_value > double(rhs) && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator<=(const T& lhs, const Approx& rhs) { return static_cast(lhs) < rhs.m_value || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator<=(const Approx& lhs, const T& rhs) { return lhs.m_value < static_cast(rhs) || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator>=(const T& lhs, const Approx& rhs) { return static_cast(lhs) > rhs.m_value || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator>=(const Approx& lhs, const T& rhs) { return lhs.m_value > static_cast(rhs) || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator< (const T& lhs, const Approx& rhs) { return static_cast(lhs) < rhs.m_value && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator< (const Approx& lhs, const T& rhs) { return lhs.m_value < static_cast(rhs) && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator> (const T& lhs, const Approx& rhs) { return static_cast(lhs) > rhs.m_value && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator> (const Approx& lhs, const T& rhs) { return lhs.m_value > static_cast(rhs) && lhs != rhs; } #undef DOCTEST_APPROX_PREFIX #endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS // clang-format on -private: double m_epsilon; double m_scale; double m_value; @@ -1025,18 +1249,35 @@ DOCTEST_INTERFACE String toString(const Approx& in); DOCTEST_INTERFACE const ContextOptions* getContextOptions(); -#if !defined(DOCTEST_CONFIG_DISABLE) +template +struct DOCTEST_INTERFACE_DECL IsNaN +{ + F value; bool flipped; + IsNaN(F f, bool flip = false) : value(f), flipped(flip) { } + IsNaN operator!() const { return { value, !flipped }; } + operator bool() const; +}; +#ifndef __MINGW32__ +extern template struct DOCTEST_INTERFACE_DECL IsNaN; +extern template struct DOCTEST_INTERFACE_DECL IsNaN; +extern template struct DOCTEST_INTERFACE_DECL IsNaN; +#endif +DOCTEST_INTERFACE String toString(IsNaN in); +DOCTEST_INTERFACE String toString(IsNaN in); +DOCTEST_INTERFACE String toString(IsNaN in); + +#ifndef DOCTEST_CONFIG_DISABLE namespace detail { // clang-format off #ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - template struct decay_array { typedef T type; }; - template struct decay_array { typedef T* type; }; - template struct decay_array { typedef T* type; }; + template struct decay_array { using type = T; }; + template struct decay_array { using type = T*; }; + template struct decay_array { using type = T*; }; - template struct not_char_pointer { enum { value = 1 }; }; - template<> struct not_char_pointer { enum { value = 0 }; }; - template<> struct not_char_pointer { enum { value = 0 }; }; + template struct not_char_pointer { static DOCTEST_CONSTEXPR value = 1; }; + template<> struct not_char_pointer { static DOCTEST_CONSTEXPR value = 0; }; + template<> struct not_char_pointer { static DOCTEST_CONSTEXPR value = 0; }; template struct can_use_op : public not_char_pointer::type> {}; #endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING @@ -1059,16 +1300,22 @@ namespace detail { bool m_entered = false; Subcase(const String& name, const char* file, int line); + Subcase(const Subcase&) = delete; + Subcase(Subcase&&) = delete; + Subcase& operator=(const Subcase&) = delete; + Subcase& operator=(Subcase&&) = delete; ~Subcase(); operator bool() const; + + private: + bool checkFilters(); }; template String stringifyBinaryExpr(const DOCTEST_REF_WRAP(L) lhs, const char* op, const DOCTEST_REF_WRAP(R) rhs) { - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - return doctest::toString(lhs) + op + doctest::toString(rhs); + return (DOCTEST_STRINGIFY(lhs)) + op + (DOCTEST_STRINGIFY(rhs)); } #if DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 6, 0) @@ -1079,12 +1326,12 @@ DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-comparison") // If not it doesn't find the operator or if the operator at global scope is defined after // this template, the template won't be instantiated due to SFINAE. Once the template is not // instantiated it can look for global operator using normal conversions. -#define SFINAE_OP(ret,op) decltype((void)(doctest::detail::declval() op doctest::detail::declval()),static_cast(0)) +#define SFINAE_OP(ret,op) decltype((void)(doctest::detail::declval() op doctest::detail::declval()),ret{}) #define DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(op, op_str, op_macro) \ template \ - DOCTEST_NOINLINE SFINAE_OP(Result,op) operator op(R&& rhs) { \ - bool res = op_macro(doctest::detail::forward(lhs), doctest::detail::forward(rhs)); \ + DOCTEST_NOINLINE SFINAE_OP(Result,op) operator op(R&& rhs) { \ + bool res = op_macro(doctest::detail::forward(lhs), doctest::detail::forward(rhs)); \ if(m_at & assertType::is_false) \ res = !res; \ if(!res || doctest::getContextOptions()->success) \ @@ -1103,11 +1350,12 @@ DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-comparison") return *this; \ } - struct DOCTEST_INTERFACE Result + struct DOCTEST_INTERFACE Result // NOLINT(*-member-init) { bool m_passed; String m_decomp; + Result() = default; // TODO: Why do we need this? (To remove NOLINT) Result(bool passed, const String& decomposition = String()); // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence @@ -1164,8 +1412,7 @@ DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-comparison") #ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING #define DOCTEST_COMPARISON_RETURN_TYPE bool #else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -#define DOCTEST_COMPARISON_RETURN_TYPE typename enable_if::value || can_use_op::value, bool>::type - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) +#define DOCTEST_COMPARISON_RETURN_TYPE typename types::enable_if::value || can_use_op::value, bool>::type inline bool eq(const char* lhs, const char* rhs) { return String(lhs) == String(rhs); } inline bool ne(const char* lhs, const char* rhs) { return String(lhs) != String(rhs); } inline bool lt(const char* lhs, const char* rhs) { return String(lhs) < String(rhs); } @@ -1213,26 +1460,26 @@ DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-comparison") assertType::Enum m_at; explicit Expression_lhs(L&& in, assertType::Enum at) - : lhs(doctest::detail::forward(in)) + : lhs(static_cast(in)) , m_at(at) {} DOCTEST_NOINLINE operator Result() { -// this is needed only foc MSVC 2015: -// https://ci.appveyor.com/project/onqtam/doctest/builds/38181202 +// this is needed only for MSVC 2015 DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4800) // 'int': forcing value to bool bool res = static_cast(lhs); DOCTEST_MSVC_SUPPRESS_WARNING_POP - if(m_at & assertType::is_false) //!OCLINT bitwise operator in conditional + if(m_at & assertType::is_false) { //!OCLINT bitwise operator in conditional res = !res; + } - if(!res || getContextOptions()->success) - return Result(res, doctest::toString(lhs)); - return Result(res); + if(!res || getContextOptions()->success) { + return { res, (DOCTEST_STRINGIFY(lhs)) }; + } + return { res }; } - /* This is required for user-defined conversions from Expression_lhs to L */ - //operator L() const { return lhs; } - operator L() const { return lhs; } + /* This is required for user-defined conversions from Expression_lhs to L */ + operator L() const { return lhs; } // clang-format off DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(==, " == ", DOCTEST_CMP_EQ) //!OCLINT bitwise operator in conditional @@ -1289,22 +1536,27 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP // https://github.com/catchorg/Catch2/issues/870 // https://github.com/catchorg/Catch2/issues/565 template - Expression_lhs operator<<(L &&operand) { - return Expression_lhs(doctest::detail::forward(operand), m_at); + Expression_lhs operator<<(L&& operand) { + return Expression_lhs(static_cast(operand), m_at); + } + + template ::value,void >::type* = nullptr> + Expression_lhs operator<<(const L &operand) { + return Expression_lhs(operand, m_at); } }; struct DOCTEST_INTERFACE TestSuite { - const char* m_test_suite; - const char* m_description; - bool m_skip; - bool m_no_breaks; - bool m_no_output; - bool m_may_fail; - bool m_should_fail; - int m_expected_failures; - double m_timeout; + const char* m_test_suite = nullptr; + const char* m_description = nullptr; + bool m_skip = false; + bool m_no_breaks = false; + bool m_no_output = false; + bool m_may_fail = false; + bool m_should_fail = false; + int m_expected_failures = 0; + double m_timeout = 0; TestSuite& operator*(const char* in); @@ -1315,25 +1567,28 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP } }; - typedef void (*funcType)(); + using funcType = void (*)(); struct DOCTEST_INTERFACE TestCase : public TestCaseData { funcType m_test; // a function pointer to the test case - const char* m_type; // for templated test cases - gets appended to the real name + String m_type; // for templated test cases - gets appended to the real name int m_template_id; // an ID used to distinguish between the different versions of a templated test case String m_full_name; // contains the name (only for templated test cases!) + the template type TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, - const char* type = "", int template_id = -1); + const String& type = String(), int template_id = -1); TestCase(const TestCase& other); + TestCase(TestCase&&) = delete; DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function TestCase& operator=(const TestCase& other); DOCTEST_MSVC_SUPPRESS_WARNING_POP + TestCase& operator=(TestCase&&) = delete; + TestCase& operator*(const char* in); template @@ -1343,6 +1598,8 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP } bool operator<(const TestCase& other) const; + + ~TestCase() = default; }; // forward declarations of functions used by the macros @@ -1382,27 +1639,36 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP struct DOCTEST_INTERFACE ResultBuilder : public AssertData { ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type = "", const char* exception_string = ""); + const char* exception_type = "", const String& exception_string = ""); + + ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const Contains& exception_string); void setResult(const Result& res); template - DOCTEST_NOINLINE void binary_assert(const DOCTEST_REF_WRAP(L) lhs, + DOCTEST_NOINLINE bool binary_assert(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) { m_failed = !RelationalComparator()(lhs, rhs); - if(m_failed || getContextOptions()->success) + if (m_failed || getContextOptions()->success) { m_decomp = stringifyBinaryExpr(lhs, ", ", rhs); + } + return !m_failed; } template - DOCTEST_NOINLINE void unary_assert(const DOCTEST_REF_WRAP(L) val) { + DOCTEST_NOINLINE bool unary_assert(const DOCTEST_REF_WRAP(L) val) { m_failed = !val; - if(m_at & assertType::is_false) //!OCLINT bitwise operator in conditional + if (m_at & assertType::is_false) { //!OCLINT bitwise operator in conditional m_failed = !m_failed; + } - if(m_failed || getContextOptions()->success) - m_decomp = toString(val); + if (m_failed || getContextOptions()->success) { + m_decomp = (DOCTEST_STRINGIFY(val)); + } + + return !m_failed; } void translateException(); @@ -1422,8 +1688,8 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP DOCTEST_INTERFACE void failed_out_of_a_testing_context(const AssertData& ad); - DOCTEST_INTERFACE void decomp_assert(assertType::Enum at, const char* file, int line, - const char* expr, Result result); + DOCTEST_INTERFACE bool decomp_assert(assertType::Enum at, const char* file, int line, + const char* expr, const Result& result); #define DOCTEST_ASSERT_OUT_OF_TESTS(decomp) \ do { \ @@ -1438,7 +1704,7 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP if(checkIfShouldThrow(at)) \ throwException(); \ } \ - return; \ + return !failed; \ } \ } while(false) @@ -1453,7 +1719,7 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP throwException() template - DOCTEST_NOINLINE void binary_assert(assertType::Enum at, const char* file, int line, + DOCTEST_NOINLINE bool binary_assert(assertType::Enum at, const char* file, int line, const char* expr, const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) { bool failed = !RelationalComparator()(lhs, rhs); @@ -1464,10 +1730,11 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP // ################################################################################### DOCTEST_ASSERT_OUT_OF_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); DOCTEST_ASSERT_IN_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); + return !failed; } template - DOCTEST_NOINLINE void unary_assert(assertType::Enum at, const char* file, int line, + DOCTEST_NOINLINE bool unary_assert(assertType::Enum at, const char* file, int line, const char* expr, const DOCTEST_REF_WRAP(L) val) { bool failed = !val; @@ -1478,14 +1745,14 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED // ################################################################################### - DOCTEST_ASSERT_OUT_OF_TESTS(toString(val)); - DOCTEST_ASSERT_IN_TESTS(toString(val)); + DOCTEST_ASSERT_OUT_OF_TESTS((DOCTEST_STRINGIFY(val))); + DOCTEST_ASSERT_IN_TESTS((DOCTEST_STRINGIFY(val))); + return !failed; } struct DOCTEST_INTERFACE IExceptionTranslator { - IExceptionTranslator(); - virtual ~IExceptionTranslator(); + DOCTEST_DECLARE_INTERFACE(IExceptionTranslator) virtual bool translate(String&) const = 0; }; @@ -1501,7 +1768,7 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP try { throw; // lgtm [cpp/rethrow-no-exception] // cppcheck-suppress catchExceptionByValue - } catch(T ex) { // NOLINT + } catch(const T& ex) { res = m_translateFunction(ex); //!OCLINT parameter reassignment return true; } catch(...) {} //!OCLINT - empty catch statement @@ -1516,95 +1783,70 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP DOCTEST_INTERFACE void registerExceptionTranslatorImpl(const IExceptionTranslator* et); - template - struct StringStreamBase - { - template - static void convert(std::ostream* s, const T& in) { - *s << toString(in); - } - - // always treat char* as a string in this context - no matter - // if DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING is defined - static void convert(std::ostream* s, const char* in) { *s << String(in); } - }; - - template <> - struct StringStreamBase - { - template - static void convert(std::ostream* s, const T& in) { - *s << in; - } - }; - - template - struct StringStream : public StringStreamBase::value> - {}; - - template - void toStream(std::ostream* s, const T& value) { - StringStream::convert(s, value); - } - -#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - DOCTEST_INTERFACE void toStream(std::ostream* s, char* in); - DOCTEST_INTERFACE void toStream(std::ostream* s, const char* in); -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - DOCTEST_INTERFACE void toStream(std::ostream* s, bool in); - DOCTEST_INTERFACE void toStream(std::ostream* s, float in); - DOCTEST_INTERFACE void toStream(std::ostream* s, double in); - DOCTEST_INTERFACE void toStream(std::ostream* s, double long in); - - DOCTEST_INTERFACE void toStream(std::ostream* s, char in); - DOCTEST_INTERFACE void toStream(std::ostream* s, char signed in); - DOCTEST_INTERFACE void toStream(std::ostream* s, char unsigned in); - DOCTEST_INTERFACE void toStream(std::ostream* s, int short in); - DOCTEST_INTERFACE void toStream(std::ostream* s, int short unsigned in); - DOCTEST_INTERFACE void toStream(std::ostream* s, int in); - DOCTEST_INTERFACE void toStream(std::ostream* s, int unsigned in); - DOCTEST_INTERFACE void toStream(std::ostream* s, int long in); - DOCTEST_INTERFACE void toStream(std::ostream* s, int long unsigned in); - DOCTEST_INTERFACE void toStream(std::ostream* s, int long long in); - DOCTEST_INTERFACE void toStream(std::ostream* s, int long long unsigned in); - - // ContextScope base class used to allow implementing methods of ContextScope + // ContextScope base class used to allow implementing methods of ContextScope // that don't depend on the template parameter in doctest.cpp. - class DOCTEST_INTERFACE ContextScopeBase : public IContextScope { + struct DOCTEST_INTERFACE ContextScopeBase : public IContextScope { + ContextScopeBase(const ContextScopeBase&) = delete; + + ContextScopeBase& operator=(const ContextScopeBase&) = delete; + ContextScopeBase& operator=(ContextScopeBase&&) = delete; + + ~ContextScopeBase() override = default; + protected: ContextScopeBase(); + ContextScopeBase(ContextScopeBase&& other) noexcept; void destroy(); + bool need_to_destroy{true}; }; template class ContextScope : public ContextScopeBase { - const L lambda_; + L lambda_; public: explicit ContextScope(const L &lambda) : lambda_(lambda) {} + explicit ContextScope(L&& lambda) : lambda_(static_cast(lambda)) { } - ContextScope(ContextScope &&other) : lambda_(other.lambda_) {} + ContextScope(const ContextScope&) = delete; + ContextScope(ContextScope&&) noexcept = default; + + ContextScope& operator=(const ContextScope&) = delete; + ContextScope& operator=(ContextScope&&) = delete; void stringify(std::ostream* s) const override { lambda_(s); } - ~ContextScope() override { destroy(); } + ~ContextScope() override { + if (need_to_destroy) { + destroy(); + } + } }; struct DOCTEST_INTERFACE MessageBuilder : public MessageData { std::ostream* m_stream; + bool logged = false; MessageBuilder(const char* file, int line, assertType::Enum severity); - MessageBuilder() = delete; + + MessageBuilder(const MessageBuilder&) = delete; + MessageBuilder(MessageBuilder&&) = delete; + + MessageBuilder& operator=(const MessageBuilder&) = delete; + MessageBuilder& operator=(MessageBuilder&&) = delete; + ~MessageBuilder(); // the preferred way of chaining parameters for stringification +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4866) template MessageBuilder& operator,(const T& in) { - toStream(m_stream, in); + *m_stream << (DOCTEST_STRINGIFY(in)); return *this; } +DOCTEST_MSVC_SUPPRESS_WARNING_POP // kept here just for backwards-compatibility - the comma operator should be preferred now template @@ -1620,7 +1862,7 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP bool log(); void react(); }; - + template ContextScope MakeContextScope(const L &lambda) { return ContextScope(lambda); @@ -1673,7 +1915,7 @@ int registerExceptionTranslator(String (*)(T)) { #endif // DOCTEST_CONFIG_DISABLE namespace detail { - typedef void (*assert_handler)(const AssertData&); + using assert_handler = void (*)(const AssertData&); struct ContextState; } // namespace detail @@ -1686,12 +1928,19 @@ class DOCTEST_INTERFACE Context public: explicit Context(int argc = 0, const char* const* argv = nullptr); - ~Context(); + Context(const Context&) = delete; + Context(Context&&) = delete; + + Context& operator=(const Context&) = delete; + Context& operator=(Context&&) = delete; + + ~Context(); // NOLINT(performance-trivially-destructible) void applyCommandLine(int argc, const char* const* argv); void addFilter(const char* filter, const char* value); void clearFilters(); + void setOption(const char* option, bool value); void setOption(const char* option, int value); void setOption(const char* option, const char* value); @@ -1701,6 +1950,8 @@ public: void setAssertHandler(detail::assert_handler ah); + void setCout(std::ostream* out); + int run(); }; @@ -1727,6 +1978,7 @@ struct DOCTEST_INTERFACE CurrentTestCaseStats int numAssertsFailedCurrentTest; double seconds; int failure_flags; // use TestCaseFailureReason::Enum + bool testCaseSuccess; }; struct DOCTEST_INTERFACE TestCaseException @@ -1790,8 +2042,7 @@ struct DOCTEST_INTERFACE IReporter // or isn't in the execution range (between first and last) (safe to cache a pointer to the input) virtual void test_case_skipped(const TestCaseData&) = 0; - // doctest will not be managing the lifetimes of reporters given to it but this would still be nice to have - virtual ~IReporter(); + DOCTEST_DECLARE_INTERFACE(IReporter) // can obtain all currently active contexts and stringify them if one wishes to do so static int get_num_active_contexts(); @@ -1803,7 +2054,7 @@ struct DOCTEST_INTERFACE IReporter }; namespace detail { - typedef IReporter* (*reporterCreatorFunc)(const ContextOptions&); + using reporterCreatorFunc = IReporter* (*)(const ContextOptions&); DOCTEST_INTERFACE void registerReporterImpl(const char* name, int prio, reporterCreatorFunc c, bool isReporter); @@ -1820,14 +2071,30 @@ int registerReporter(const char* name, int priority, bool isReporter) { } } // namespace doctest +#ifdef DOCTEST_CONFIG_ASSERTS_RETURN_VALUES +#define DOCTEST_FUNC_EMPTY [] { return false; }() +#else +#define DOCTEST_FUNC_EMPTY (void)0 +#endif + // if registering is not disabled -#if !defined(DOCTEST_CONFIG_DISABLE) +#ifndef DOCTEST_CONFIG_DISABLE + +#ifdef DOCTEST_CONFIG_ASSERTS_RETURN_VALUES +#define DOCTEST_FUNC_SCOPE_BEGIN [&] +#define DOCTEST_FUNC_SCOPE_END () +#define DOCTEST_FUNC_SCOPE_RET(v) return v +#else +#define DOCTEST_FUNC_SCOPE_BEGIN do +#define DOCTEST_FUNC_SCOPE_END while(false) +#define DOCTEST_FUNC_SCOPE_RET(v) (void)0 +#endif // common code in asserts - for convenience -#define DOCTEST_ASSERT_LOG_AND_REACT(b) \ - if(b.log()) \ - DOCTEST_BREAK_INTO_DEBUGGER(); \ - b.react() +#define DOCTEST_ASSERT_LOG_REACT_RETURN(b) \ + if(b.log()) DOCTEST_BREAK_INTO_DEBUGGER(); \ + b.react(); \ + DOCTEST_FUNC_SCOPE_RET(!b.m_failed) #ifdef DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS #define DOCTEST_WRAP_IN_TRY(x) x; @@ -1835,7 +2102,7 @@ int registerReporter(const char* name, int priority, bool isReporter) { #define DOCTEST_WRAP_IN_TRY(x) \ try { \ x; \ - } catch(...) { _DOCTEST_RB.translateException(); } + } catch(...) { DOCTEST_RB.translateException(); } #endif // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS #ifdef DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS @@ -1849,27 +2116,26 @@ int registerReporter(const char* name, int priority, bool isReporter) { // registers the test by initializing a dummy var with a function #define DOCTEST_REGISTER_FUNCTION(global_prefix, f, decorators) \ - global_prefix DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_VAR_)) = \ + global_prefix DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT */ \ doctest::detail::regTest( \ doctest::detail::TestCase( \ f, __FILE__, __LINE__, \ doctest_detail_test_suite_ns::getCurrentTestSuite()) * \ - decorators); \ - DOCTEST_GLOBAL_NO_WARNINGS_END() + decorators)) #define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, decorators) \ - namespace { \ + namespace { /* NOLINT */ \ struct der : public base \ { \ void f(); \ }; \ - static void func() { \ + static inline DOCTEST_NOINLINE void func() { \ der v; \ v.f(); \ } \ DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, func, decorators) \ } \ - inline DOCTEST_NOINLINE void der::f() + inline DOCTEST_NOINLINE void der::f() // NOLINT(misc-definitions-in-headers) #define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, decorators) \ static void f(); \ @@ -1878,18 +2144,18 @@ int registerReporter(const char* name, int priority, bool isReporter) { #define DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(f, proxy, decorators) \ static doctest::detail::funcType proxy() { return f; } \ - DOCTEST_REGISTER_FUNCTION(inline const, proxy(), decorators) \ + DOCTEST_REGISTER_FUNCTION(inline, proxy(), decorators) \ static void f() // for registering tests #define DOCTEST_TEST_CASE(decorators) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), decorators) + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), decorators) // for registering tests in classes - requires C++17 for inline variables! -#if __cplusplus >= 201703L || (DOCTEST_MSVC >= DOCTEST_COMPILER(19, 12, 0) && _MSVC_LANG >= 201703L) +#if DOCTEST_CPLUSPLUS >= 201703L #define DOCTEST_TEST_CASE_CLASS(decorators) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), \ - DOCTEST_ANONYMOUS(_DOCTEST_ANON_PROXY_), \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), \ + DOCTEST_ANONYMOUS(DOCTEST_ANON_PROXY_), \ decorators) #else // DOCTEST_TEST_CASE_CLASS #define DOCTEST_TEST_CASE_CLASS(...) \ @@ -1898,26 +2164,25 @@ int registerReporter(const char* name, int priority, bool isReporter) { // for registering tests with a fixture #define DOCTEST_TEST_CASE_FIXTURE(c, decorators) \ - DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(_DOCTEST_ANON_CLASS_), c, \ - DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), decorators) + DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), c, \ + DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), decorators) // for converting types to strings without the header and demangling -#define DOCTEST_TYPE_TO_STRING_IMPL(...) \ - template <> \ - inline const char* type_to_string<__VA_ARGS__>() { \ - return "<" #__VA_ARGS__ ">"; \ - } -#define DOCTEST_TYPE_TO_STRING(...) \ - namespace doctest { namespace detail { \ - DOCTEST_TYPE_TO_STRING_IMPL(__VA_ARGS__) \ +#define DOCTEST_TYPE_TO_STRING_AS(str, ...) \ + namespace doctest { \ + template <> \ + inline String toString<__VA_ARGS__>() { \ + return str; \ } \ } \ - typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + static_assert(true, "") + +#define DOCTEST_TYPE_TO_STRING(...) DOCTEST_TYPE_TO_STRING_AS(#__VA_ARGS__, __VA_ARGS__) #define DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, iter, func) \ template \ static void func(); \ - namespace { \ + namespace { /* NOLINT */ \ template \ struct iter; \ template \ @@ -1926,7 +2191,7 @@ int registerReporter(const char* name, int priority, bool isReporter) { iter(const char* file, unsigned line, int index) { \ doctest::detail::regTest(doctest::detail::TestCase(func, file, line, \ doctest_detail_test_suite_ns::getCurrentTestSuite(), \ - doctest::detail::type_to_string(), \ + doctest::toString(), \ int(line) * 1000 + index) \ * dec); \ iter>(file, line, index + 1); \ @@ -1943,20 +2208,20 @@ int registerReporter(const char* name, int priority, bool isReporter) { #define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(dec, T, id) \ DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(id, ITERATOR), \ - DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_)) + DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)) #define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, anon, ...) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_CAT(anon, DUMMY)) = \ - doctest::detail::instantiationHelper(DOCTEST_CAT(id, ITERATOR)<__VA_ARGS__>(__FILE__, __LINE__, 0));\ - DOCTEST_GLOBAL_NO_WARNINGS_END() + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_CAT(anon, DUMMY), /* NOLINT(cert-err58-cpp, fuchsia-statically-constructed-objects) */ \ + doctest::detail::instantiationHelper( \ + DOCTEST_CAT(id, ITERATOR)<__VA_ARGS__>(__FILE__, __LINE__, 0))) #define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) \ - DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_), std::tuple<__VA_ARGS__>) \ - typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), std::tuple<__VA_ARGS__>) \ + static_assert(true, "") #define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) \ - DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_), __VA_ARGS__) \ - typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), __VA_ARGS__) \ + static_assert(true, "") #define DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, anon, ...) \ DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(anon, ITERATOR), anon); \ @@ -1965,17 +2230,17 @@ int registerReporter(const char* name, int priority, bool isReporter) { static void anon() #define DOCTEST_TEST_CASE_TEMPLATE(dec, T, ...) \ - DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_), __VA_ARGS__) + DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), __VA_ARGS__) // for subcases #define DOCTEST_SUBCASE(name) \ - if(const doctest::detail::Subcase & DOCTEST_ANONYMOUS(_DOCTEST_ANON_SUBCASE_) DOCTEST_UNUSED = \ + if(const doctest::detail::Subcase & DOCTEST_ANONYMOUS(DOCTEST_ANON_SUBCASE_) DOCTEST_UNUSED = \ doctest::detail::Subcase(name, __FILE__, __LINE__)) // for grouping tests in test suites by using code blocks #define DOCTEST_TEST_SUITE_IMPL(decorators, ns_name) \ namespace ns_name { namespace doctest_detail_test_suite_ns { \ - static DOCTEST_NOINLINE doctest::detail::TestSuite& getCurrentTestSuite() { \ + static DOCTEST_NOINLINE doctest::detail::TestSuite& getCurrentTestSuite() noexcept { \ DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4640) \ DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") \ DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmissing-field-initializers") \ @@ -1995,53 +2260,53 @@ int registerReporter(const char* name, int priority, bool isReporter) { namespace ns_name #define DOCTEST_TEST_SUITE(decorators) \ - DOCTEST_TEST_SUITE_IMPL(decorators, DOCTEST_ANONYMOUS(_DOCTEST_ANON_SUITE_)) + DOCTEST_TEST_SUITE_IMPL(decorators, DOCTEST_ANONYMOUS(DOCTEST_ANON_SUITE_)) // for starting a testsuite block #define DOCTEST_TEST_SUITE_BEGIN(decorators) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_VAR_)) = \ - doctest::detail::setTestSuite(doctest::detail::TestSuite() * decorators); \ - DOCTEST_GLOBAL_NO_WARNINGS_END() \ - typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT(cert-err58-cpp) */ \ + doctest::detail::setTestSuite(doctest::detail::TestSuite() * decorators)) \ + static_assert(true, "") // for ending a testsuite block #define DOCTEST_TEST_SUITE_END \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_VAR_)) = \ - doctest::detail::setTestSuite(doctest::detail::TestSuite() * ""); \ - DOCTEST_GLOBAL_NO_WARNINGS_END() \ - typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT(cert-err58-cpp) */ \ + doctest::detail::setTestSuite(doctest::detail::TestSuite() * "")) \ + using DOCTEST_ANONYMOUS(DOCTEST_ANON_FOR_SEMICOLON_) = int // for registering exception translators #define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(translatorName, signature) \ inline doctest::String translatorName(signature); \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_TRANSLATOR_)) = \ - doctest::registerExceptionTranslator(translatorName); \ - DOCTEST_GLOBAL_NO_WARNINGS_END() \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_), /* NOLINT(cert-err58-cpp) */ \ + doctest::registerExceptionTranslator(translatorName)) \ doctest::String translatorName(signature) #define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ - DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(DOCTEST_ANONYMOUS(_DOCTEST_ANON_TRANSLATOR_), \ + DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_), \ signature) // for registering reporters #define DOCTEST_REGISTER_REPORTER(name, priority, reporter) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_REPORTER_)) = \ - doctest::registerReporter(name, priority, true); \ - DOCTEST_GLOBAL_NO_WARNINGS_END() typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_REPORTER_), /* NOLINT(cert-err58-cpp) */ \ + doctest::registerReporter(name, priority, true)) \ + static_assert(true, "") // for registering listeners #define DOCTEST_REGISTER_LISTENER(name, priority, reporter) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_REPORTER_)) = \ - doctest::registerReporter(name, priority, false); \ - DOCTEST_GLOBAL_NO_WARNINGS_END() typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_REPORTER_), /* NOLINT(cert-err58-cpp) */ \ + doctest::registerReporter(name, priority, false)) \ + static_assert(true, "") -// for logging +// clang-format off +// for logging - disabling formatting because it's important to have these on 2 separate lines - see PR #557 #define DOCTEST_INFO(...) \ - DOCTEST_INFO_IMPL(DOCTEST_ANONYMOUS(_DOCTEST_CAPTURE_), DOCTEST_ANONYMOUS(_DOCTEST_CAPTURE_), \ + DOCTEST_INFO_IMPL(DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_), \ + DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_OTHER_), \ __VA_ARGS__) +// clang-format on #define DOCTEST_INFO_IMPL(mb_name, s_name, ...) \ - auto DOCTEST_ANONYMOUS(_DOCTEST_CAPTURE_) = doctest::detail::MakeContextScope( \ + auto DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_) = doctest::detail::MakeContextScope( \ [&](std::ostream* s_name) { \ doctest::detail::MessageBuilder mb_name(__FILE__, __LINE__, doctest::assertType::is_warn); \ mb_name.m_stream = s_name; \ @@ -2051,16 +2316,18 @@ int registerReporter(const char* name, int priority, bool isReporter) { #define DOCTEST_CAPTURE(x) DOCTEST_INFO(#x " := ", x) #define DOCTEST_ADD_AT_IMPL(type, file, line, mb, ...) \ - do { \ + DOCTEST_FUNC_SCOPE_BEGIN { \ doctest::detail::MessageBuilder mb(file, line, doctest::assertType::type); \ mb * __VA_ARGS__; \ - DOCTEST_ASSERT_LOG_AND_REACT(mb); \ - } while(false) + if(mb.log()) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + mb.react(); \ + } DOCTEST_FUNC_SCOPE_END // clang-format off -#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_warn, file, line, DOCTEST_ANONYMOUS(_DOCTEST_MESSAGE_), __VA_ARGS__) -#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_check, file, line, DOCTEST_ANONYMOUS(_DOCTEST_MESSAGE_), __VA_ARGS__) -#define DOCTEST_ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_require, file, line, DOCTEST_ANONYMOUS(_DOCTEST_MESSAGE_), __VA_ARGS__) +#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_warn, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) +#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_check, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) +#define DOCTEST_ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_require, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) // clang-format on #define DOCTEST_MESSAGE(...) DOCTEST_ADD_MESSAGE_AT(__FILE__, __LINE__, __VA_ARGS__) @@ -2073,18 +2340,37 @@ int registerReporter(const char* name, int priority, bool isReporter) { #define DOCTEST_ASSERT_IMPLEMENT_2(assert_type, ...) \ DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ - doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + /* NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) */ \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ __LINE__, #__VA_ARGS__); \ - DOCTEST_WRAP_IN_TRY(_DOCTEST_RB.setResult( \ + DOCTEST_WRAP_IN_TRY(DOCTEST_RB.setResult( \ doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ - << __VA_ARGS__)) \ - DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB) \ + << __VA_ARGS__)) /* NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) */ \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB) \ DOCTEST_CLANG_SUPPRESS_WARNING_POP #define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ - do { \ + DOCTEST_FUNC_SCOPE_BEGIN { \ DOCTEST_ASSERT_IMPLEMENT_2(assert_type, __VA_ARGS__); \ - } while(false) + } DOCTEST_FUNC_SCOPE_END // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) + +#define DOCTEST_BINARY_ASSERT(assert_type, comp, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY( \ + DOCTEST_RB.binary_assert( \ + __VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } DOCTEST_FUNC_SCOPE_END + +#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY(DOCTEST_RB.unary_assert(__VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } DOCTEST_FUNC_SCOPE_END #else // DOCTEST_CONFIG_SUPER_FAST_ASSERTS @@ -2098,6 +2384,14 @@ int registerReporter(const char* name, int priority, bool isReporter) { doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ << __VA_ARGS__) DOCTEST_CLANG_SUPPRESS_WARNING_POP +#define DOCTEST_BINARY_ASSERT(assert_type, comparison, ...) \ + doctest::detail::binary_assert( \ + doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, __VA_ARGS__) + +#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ + doctest::detail::unary_assert(doctest::assertType::assert_type, __FILE__, __LINE__, \ + #__VA_ARGS__, __VA_ARGS__) + #endif // DOCTEST_CONFIG_SUPER_FAST_ASSERTS #define DOCTEST_WARN(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN, __VA_ARGS__) @@ -2108,122 +2402,14 @@ int registerReporter(const char* name, int priority, bool isReporter) { #define DOCTEST_REQUIRE_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE_FALSE, __VA_ARGS__) // clang-format off -#define DOCTEST_WARN_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN, cond); } while(false) -#define DOCTEST_CHECK_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK, cond); } while(false) -#define DOCTEST_REQUIRE_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE, cond); } while(false) -#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN_FALSE, cond); } while(false) -#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK_FALSE, cond); } while(false) -#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE_FALSE, cond); } while(false) +#define DOCTEST_WARN_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN_FALSE, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK_FALSE, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE_FALSE, cond); } DOCTEST_FUNC_SCOPE_END // clang-format on -#define DOCTEST_ASSERT_THROWS_AS(expr, assert_type, message, ...) \ - do { \ - if(!doctest::getContextOptions()->no_throw) { \ - doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #expr, #__VA_ARGS__, message); \ - try { \ - DOCTEST_CAST_TO_VOID(expr) \ - } catch(const typename doctest::detail::remove_const< \ - typename doctest::detail::remove_reference<__VA_ARGS__>::type>::type&) { \ - _DOCTEST_RB.translateException(); \ - _DOCTEST_RB.m_threw_as = true; \ - } catch(...) { _DOCTEST_RB.translateException(); } \ - DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ - } \ - } while(false) - -#define DOCTEST_ASSERT_THROWS_WITH(expr, expr_str, assert_type, ...) \ - do { \ - if(!doctest::getContextOptions()->no_throw) { \ - doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, expr_str, "", __VA_ARGS__); \ - try { \ - DOCTEST_CAST_TO_VOID(expr) \ - } catch(...) { _DOCTEST_RB.translateException(); } \ - DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ - } \ - } while(false) - -#define DOCTEST_ASSERT_NOTHROW(assert_type, ...) \ - do { \ - doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #__VA_ARGS__); \ - try { \ - DOCTEST_CAST_TO_VOID(__VA_ARGS__) \ - } catch(...) { _DOCTEST_RB.translateException(); } \ - DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ - } while(false) - -// clang-format off -#define DOCTEST_WARN_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_WARN_THROWS, "") -#define DOCTEST_CHECK_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_CHECK_THROWS, "") -#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_REQUIRE_THROWS, "") - -#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_AS, "", __VA_ARGS__) -#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_AS, "", __VA_ARGS__) -#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_AS, "", __VA_ARGS__) - -#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_WARN_THROWS_WITH, __VA_ARGS__) -#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_CHECK_THROWS_WITH, __VA_ARGS__) -#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_REQUIRE_THROWS_WITH, __VA_ARGS__) - -#define DOCTEST_WARN_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_WITH_AS, message, __VA_ARGS__) -#define DOCTEST_CHECK_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_WITH_AS, message, __VA_ARGS__) -#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_WITH_AS, message, __VA_ARGS__) - -#define DOCTEST_WARN_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_WARN_NOTHROW, __VA_ARGS__) -#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_CHECK_NOTHROW, __VA_ARGS__) -#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_REQUIRE_NOTHROW, __VA_ARGS__) - -#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS(expr); } while(false) -#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS(expr); } while(false) -#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS(expr); } while(false) -#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_AS(expr, ex); } while(false) -#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_AS(expr, ex); } while(false) -#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_AS(expr, ex); } while(false) -#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH(expr, with); } while(false) -#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH(expr, with); } while(false) -#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH(expr, with); } while(false) -#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex); } while(false) -#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex); } while(false) -#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex); } while(false) -#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_NOTHROW(expr); } while(false) -#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_NOTHROW(expr); } while(false) -#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_NOTHROW(expr); } while(false) -// clang-format on - -#ifndef DOCTEST_CONFIG_SUPER_FAST_ASSERTS - -#define DOCTEST_BINARY_ASSERT(assert_type, comp, ...) \ - do { \ - doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #__VA_ARGS__); \ - DOCTEST_WRAP_IN_TRY( \ - _DOCTEST_RB.binary_assert( \ - __VA_ARGS__)) \ - DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ - } while(false) - -#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ - do { \ - doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #__VA_ARGS__); \ - DOCTEST_WRAP_IN_TRY(_DOCTEST_RB.unary_assert(__VA_ARGS__)) \ - DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ - } while(false) - -#else // DOCTEST_CONFIG_SUPER_FAST_ASSERTS - -#define DOCTEST_BINARY_ASSERT(assert_type, comparison, ...) \ - doctest::detail::binary_assert( \ - doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, __VA_ARGS__) - -#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ - doctest::detail::unary_assert(doctest::assertType::assert_type, __FILE__, __LINE__, \ - #__VA_ARGS__, __VA_ARGS__) - -#endif // DOCTEST_CONFIG_SUPER_FAST_ASSERTS - #define DOCTEST_WARN_EQ(...) DOCTEST_BINARY_ASSERT(DT_WARN_EQ, eq, __VA_ARGS__) #define DOCTEST_CHECK_EQ(...) DOCTEST_BINARY_ASSERT(DT_CHECK_EQ, eq, __VA_ARGS__) #define DOCTEST_REQUIRE_EQ(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_EQ, eq, __VA_ARGS__) @@ -2250,75 +2436,350 @@ int registerReporter(const char* name, int priority, bool isReporter) { #define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY_FALSE, __VA_ARGS__) #define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY_FALSE, __VA_ARGS__) +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + +#define DOCTEST_ASSERT_THROWS_AS(expr, assert_type, message, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + if(!doctest::getContextOptions()->no_throw) { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #expr, #__VA_ARGS__, message); \ + try { \ + DOCTEST_CAST_TO_VOID(expr) \ + } catch(const typename doctest::detail::types::remove_const< \ + typename doctest::detail::types::remove_reference<__VA_ARGS__>::type>::type&) {\ + DOCTEST_RB.translateException(); \ + DOCTEST_RB.m_threw_as = true; \ + } catch(...) { DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } else { /* NOLINT(*-else-after-return) */ \ + DOCTEST_FUNC_SCOPE_RET(false); \ + } \ + } DOCTEST_FUNC_SCOPE_END + +#define DOCTEST_ASSERT_THROWS_WITH(expr, expr_str, assert_type, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + if(!doctest::getContextOptions()->no_throw) { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, expr_str, "", __VA_ARGS__); \ + try { \ + DOCTEST_CAST_TO_VOID(expr) \ + } catch(...) { DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } else { /* NOLINT(*-else-after-return) */ \ + DOCTEST_FUNC_SCOPE_RET(false); \ + } \ + } DOCTEST_FUNC_SCOPE_END + +#define DOCTEST_ASSERT_NOTHROW(assert_type, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + try { \ + DOCTEST_CAST_TO_VOID(__VA_ARGS__) \ + } catch(...) { DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } DOCTEST_FUNC_SCOPE_END + +// clang-format off +#define DOCTEST_WARN_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_WARN_THROWS, "") +#define DOCTEST_CHECK_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_CHECK_THROWS, "") +#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_REQUIRE_THROWS, "") + +#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_AS, "", __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_AS, "", __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_AS, "", __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_WARN_THROWS_WITH, __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_CHECK_THROWS_WITH, __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_REQUIRE_THROWS_WITH, __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_WITH_AS, message, __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_WITH_AS, message, __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_WITH_AS, message, __VA_ARGS__) + +#define DOCTEST_WARN_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_WARN_NOTHROW, __VA_ARGS__) +#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_CHECK_NOTHROW, __VA_ARGS__) +#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_REQUIRE_NOTHROW, __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END +// clang-format on + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +// ================================================================================================= +// == WHAT FOLLOWS IS VERSIONS OF THE MACROS THAT DO NOT DO ANY REGISTERING! == +// == THIS CAN BE ENABLED BY DEFINING DOCTEST_CONFIG_DISABLE GLOBALLY! == +// ================================================================================================= +#else // DOCTEST_CONFIG_DISABLE + +#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, name) \ + namespace /* NOLINT */ { \ + template \ + struct der : public base \ + { void f(); }; \ + } \ + template \ + inline void der::f() + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, name) \ + template \ + static inline void f() + +// for registering tests +#define DOCTEST_TEST_CASE(name) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) + +// for registering tests in classes +#define DOCTEST_TEST_CASE_CLASS(name) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) + +// for registering tests with a fixture +#define DOCTEST_TEST_CASE_FIXTURE(x, name) \ + DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), x, \ + DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) + +// for converting types to strings without the header and demangling +#define DOCTEST_TYPE_TO_STRING_AS(str, ...) static_assert(true, "") +#define DOCTEST_TYPE_TO_STRING(...) static_assert(true, "") + +// for typed tests +#define DOCTEST_TEST_CASE_TEMPLATE(name, type, ...) \ + template \ + inline void DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)() + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, type, id) \ + template \ + inline void DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)() + +#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) static_assert(true, "") +#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) static_assert(true, "") + +// for subcases +#define DOCTEST_SUBCASE(name) + +// for a testsuite block +#define DOCTEST_TEST_SUITE(name) namespace // NOLINT + +// for starting a testsuite block +#define DOCTEST_TEST_SUITE_BEGIN(name) static_assert(true, "") + +// for ending a testsuite block +#define DOCTEST_TEST_SUITE_END using DOCTEST_ANONYMOUS(DOCTEST_ANON_FOR_SEMICOLON_) = int + +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ + template \ + static inline doctest::String DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_)(signature) + +#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) +#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) + +#define DOCTEST_INFO(...) (static_cast(0)) +#define DOCTEST_CAPTURE(x) (static_cast(0)) +#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_ADD_FAIL_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_MESSAGE(...) (static_cast(0)) +#define DOCTEST_FAIL_CHECK(...) (static_cast(0)) +#define DOCTEST_FAIL(...) (static_cast(0)) + +#if defined(DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED) \ + && defined(DOCTEST_CONFIG_ASSERTS_RETURN_VALUES) + +#define DOCTEST_WARN(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_CHECK(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_REQUIRE(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_WARN_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_CHECK_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_FALSE(...) [&] { return !(__VA_ARGS__); }() + +#define DOCTEST_WARN_MESSAGE(cond, ...) [&] { return cond; }() +#define DOCTEST_CHECK_MESSAGE(cond, ...) [&] { return cond; }() +#define DOCTEST_REQUIRE_MESSAGE(cond, ...) [&] { return cond; }() +#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() + +namespace doctest { +namespace detail { +#define DOCTEST_RELATIONAL_OP(name, op) \ + template \ + bool name(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) { return lhs op rhs; } + + DOCTEST_RELATIONAL_OP(eq, ==) + DOCTEST_RELATIONAL_OP(ne, !=) + DOCTEST_RELATIONAL_OP(lt, <) + DOCTEST_RELATIONAL_OP(gt, >) + DOCTEST_RELATIONAL_OP(le, <=) + DOCTEST_RELATIONAL_OP(ge, >=) +} // namespace detail +} // namespace doctest + +#define DOCTEST_WARN_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() +#define DOCTEST_CHECK_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() +#define DOCTEST_WARN_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() +#define DOCTEST_CHECK_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() +#define DOCTEST_WARN_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() +#define DOCTEST_CHECK_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() +#define DOCTEST_WARN_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() +#define DOCTEST_CHECK_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() +#define DOCTEST_WARN_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() +#define DOCTEST_CHECK_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() +#define DOCTEST_WARN_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() +#define DOCTEST_CHECK_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() +#define DOCTEST_WARN_UNARY(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_CHECK_UNARY(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_REQUIRE_UNARY(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_WARN_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_CHECK_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + +#define DOCTEST_WARN_THROWS_WITH(expr, with, ...) [] { static_assert(false, "Exception translation is not available when doctest is disabled."); return false; }() +#define DOCTEST_CHECK_THROWS_WITH(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) + +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) + +#define DOCTEST_WARN_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_CHECK_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_REQUIRE_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_WARN_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_CHECK_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_WARN_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_CHECK_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_REQUIRE_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +#else // DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED + +#define DOCTEST_WARN(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_FALSE(...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_EQ(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_EQ(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_EQ(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_NE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_NE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_NE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_GT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_GT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_GT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_LT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_LT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_LT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_GE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_GE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_GE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_LE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_LE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_LE(...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_UNARY(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_UNARY(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_UNARY(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + +#define DOCTEST_WARN_THROWS(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_NOTHROW(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +#endif // DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED + +#endif // DOCTEST_CONFIG_DISABLE + #ifdef DOCTEST_CONFIG_NO_EXCEPTIONS -#undef DOCTEST_WARN_THROWS -#undef DOCTEST_CHECK_THROWS -#undef DOCTEST_REQUIRE_THROWS -#undef DOCTEST_WARN_THROWS_AS -#undef DOCTEST_CHECK_THROWS_AS -#undef DOCTEST_REQUIRE_THROWS_AS -#undef DOCTEST_WARN_THROWS_WITH -#undef DOCTEST_CHECK_THROWS_WITH -#undef DOCTEST_REQUIRE_THROWS_WITH -#undef DOCTEST_WARN_THROWS_WITH_AS -#undef DOCTEST_CHECK_THROWS_WITH_AS -#undef DOCTEST_REQUIRE_THROWS_WITH_AS -#undef DOCTEST_WARN_NOTHROW -#undef DOCTEST_CHECK_NOTHROW -#undef DOCTEST_REQUIRE_NOTHROW - -#undef DOCTEST_WARN_THROWS_MESSAGE -#undef DOCTEST_CHECK_THROWS_MESSAGE -#undef DOCTEST_REQUIRE_THROWS_MESSAGE -#undef DOCTEST_WARN_THROWS_AS_MESSAGE -#undef DOCTEST_CHECK_THROWS_AS_MESSAGE -#undef DOCTEST_REQUIRE_THROWS_AS_MESSAGE -#undef DOCTEST_WARN_THROWS_WITH_MESSAGE -#undef DOCTEST_CHECK_THROWS_WITH_MESSAGE -#undef DOCTEST_REQUIRE_THROWS_WITH_MESSAGE -#undef DOCTEST_WARN_THROWS_WITH_AS_MESSAGE -#undef DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE -#undef DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE -#undef DOCTEST_WARN_NOTHROW_MESSAGE -#undef DOCTEST_CHECK_NOTHROW_MESSAGE -#undef DOCTEST_REQUIRE_NOTHROW_MESSAGE - #ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS - -#define DOCTEST_WARN_THROWS(...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS(...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS(...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_AS(expr, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_AS(expr, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_WITH(expr, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_WITH(expr, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) -#define DOCTEST_WARN_NOTHROW(...) (static_cast(0)) -#define DOCTEST_CHECK_NOTHROW(...) (static_cast(0)) -#define DOCTEST_REQUIRE_NOTHROW(...) (static_cast(0)) - -#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) -#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) -#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) - +#define DOCTEST_EXCEPTION_EMPTY_FUNC DOCTEST_FUNC_EMPTY #else // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#define DOCTEST_EXCEPTION_EMPTY_FUNC [] { static_assert(false, "Exceptions are disabled! " \ + "Use DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS if you want to compile with exceptions disabled."); return false; }() #undef DOCTEST_REQUIRE #undef DOCTEST_REQUIRE_FALSE @@ -2333,163 +2794,55 @@ int registerReporter(const char* name, int priority, bool isReporter) { #undef DOCTEST_REQUIRE_UNARY #undef DOCTEST_REQUIRE_UNARY_FALSE +#define DOCTEST_REQUIRE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_FALSE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_MESSAGE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_FALSE_MESSAGE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_EQ DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_NE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_GT DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_LT DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_GE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_LE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_UNARY DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_UNARY_FALSE DOCTEST_EXCEPTION_EMPTY_FUNC + #endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#define DOCTEST_WARN_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC + #endif // DOCTEST_CONFIG_NO_EXCEPTIONS -// ================================================================================================= -// == WHAT FOLLOWS IS VERSIONS OF THE MACROS THAT DO NOT DO ANY REGISTERING! == -// == THIS CAN BE ENABLED BY DEFINING DOCTEST_CONFIG_DISABLE GLOBALLY! == -// ================================================================================================= -#else // DOCTEST_CONFIG_DISABLE - -#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, name) \ - namespace { \ - template \ - struct der : public base \ - { void f(); }; \ - } \ - template \ - inline void der::f() - -#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, name) \ - template \ - static inline void f() - -// for registering tests -#define DOCTEST_TEST_CASE(name) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), name) - -// for registering tests in classes -#define DOCTEST_TEST_CASE_CLASS(name) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), name) - -// for registering tests with a fixture -#define DOCTEST_TEST_CASE_FIXTURE(x, name) \ - DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(_DOCTEST_ANON_CLASS_), x, \ - DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), name) - -// for converting types to strings without the header and demangling -#define DOCTEST_TYPE_TO_STRING(...) typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) -#define DOCTEST_TYPE_TO_STRING_IMPL(...) - -// for typed tests -#define DOCTEST_TEST_CASE_TEMPLATE(name, type, ...) \ - template \ - inline void DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_)() - -#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, type, id) \ - template \ - inline void DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_)() - -#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) \ - typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) - -#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) \ - typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) - -// for subcases -#define DOCTEST_SUBCASE(name) - -// for a testsuite block -#define DOCTEST_TEST_SUITE(name) namespace - -// for starting a testsuite block -#define DOCTEST_TEST_SUITE_BEGIN(name) typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) - -// for ending a testsuite block -#define DOCTEST_TEST_SUITE_END typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) - -#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ - template \ - static inline doctest::String DOCTEST_ANONYMOUS(_DOCTEST_ANON_TRANSLATOR_)(signature) - -#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) -#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) - -#define DOCTEST_INFO(...) (static_cast(0)) -#define DOCTEST_CAPTURE(x) (static_cast(0)) -#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) (static_cast(0)) -#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) (static_cast(0)) -#define DOCTEST_ADD_FAIL_AT(file, line, ...) (static_cast(0)) -#define DOCTEST_MESSAGE(...) (static_cast(0)) -#define DOCTEST_FAIL_CHECK(...) (static_cast(0)) -#define DOCTEST_FAIL(...) (static_cast(0)) - -#define DOCTEST_WARN(...) (static_cast(0)) -#define DOCTEST_CHECK(...) (static_cast(0)) -#define DOCTEST_REQUIRE(...) (static_cast(0)) -#define DOCTEST_WARN_FALSE(...) (static_cast(0)) -#define DOCTEST_CHECK_FALSE(...) (static_cast(0)) -#define DOCTEST_REQUIRE_FALSE(...) (static_cast(0)) - -#define DOCTEST_WARN_MESSAGE(cond, ...) (static_cast(0)) -#define DOCTEST_CHECK_MESSAGE(cond, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_MESSAGE(cond, ...) (static_cast(0)) -#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) (static_cast(0)) -#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) (static_cast(0)) - -#define DOCTEST_WARN_THROWS(...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS(...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS(...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_AS(expr, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_AS(expr, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_WITH(expr, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_WITH(expr, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) -#define DOCTEST_WARN_NOTHROW(...) (static_cast(0)) -#define DOCTEST_CHECK_NOTHROW(...) (static_cast(0)) -#define DOCTEST_REQUIRE_NOTHROW(...) (static_cast(0)) - -#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) -#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) -#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) -#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) -#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) -#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) - -#define DOCTEST_WARN_EQ(...) (static_cast(0)) -#define DOCTEST_CHECK_EQ(...) (static_cast(0)) -#define DOCTEST_REQUIRE_EQ(...) (static_cast(0)) -#define DOCTEST_WARN_NE(...) (static_cast(0)) -#define DOCTEST_CHECK_NE(...) (static_cast(0)) -#define DOCTEST_REQUIRE_NE(...) (static_cast(0)) -#define DOCTEST_WARN_GT(...) (static_cast(0)) -#define DOCTEST_CHECK_GT(...) (static_cast(0)) -#define DOCTEST_REQUIRE_GT(...) (static_cast(0)) -#define DOCTEST_WARN_LT(...) (static_cast(0)) -#define DOCTEST_CHECK_LT(...) (static_cast(0)) -#define DOCTEST_REQUIRE_LT(...) (static_cast(0)) -#define DOCTEST_WARN_GE(...) (static_cast(0)) -#define DOCTEST_CHECK_GE(...) (static_cast(0)) -#define DOCTEST_REQUIRE_GE(...) (static_cast(0)) -#define DOCTEST_WARN_LE(...) (static_cast(0)) -#define DOCTEST_CHECK_LE(...) (static_cast(0)) -#define DOCTEST_REQUIRE_LE(...) (static_cast(0)) - -#define DOCTEST_WARN_UNARY(...) (static_cast(0)) -#define DOCTEST_CHECK_UNARY(...) (static_cast(0)) -#define DOCTEST_REQUIRE_UNARY(...) (static_cast(0)) -#define DOCTEST_WARN_UNARY_FALSE(...) (static_cast(0)) -#define DOCTEST_CHECK_UNARY_FALSE(...) (static_cast(0)) -#define DOCTEST_REQUIRE_UNARY_FALSE(...) (static_cast(0)) - -#endif // DOCTEST_CONFIG_DISABLE - // clang-format off // KEPT FOR BACKWARDS COMPATIBILITY - FORWARDING TO THE RIGHT MACROS #define DOCTEST_FAST_WARN_EQ DOCTEST_WARN_EQ @@ -2536,11 +2889,12 @@ int registerReporter(const char* name, int priority, bool isReporter) { // clang-format on // == SHORT VERSIONS OF THE MACROS -#if !defined(DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES) +#ifndef DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES #define TEST_CASE(name) DOCTEST_TEST_CASE(name) #define TEST_CASE_CLASS(name) DOCTEST_TEST_CASE_CLASS(name) #define TEST_CASE_FIXTURE(x, name) DOCTEST_TEST_CASE_FIXTURE(x, name) +#define TYPE_TO_STRING_AS(str, ...) DOCTEST_TYPE_TO_STRING_AS(str, __VA_ARGS__) #define TYPE_TO_STRING(...) DOCTEST_TYPE_TO_STRING(__VA_ARGS__) #define TEST_CASE_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(name, T, __VA_ARGS__) #define TEST_CASE_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, T, id) @@ -2673,39 +3027,19 @@ int registerReporter(const char* name, int priority, bool isReporter) { #endif // DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES -#if !defined(DOCTEST_CONFIG_DISABLE) +#ifndef DOCTEST_CONFIG_DISABLE // this is here to clear the 'current test suite' for the current translation unit - at the top DOCTEST_TEST_SUITE_END(); -// add stringification for primitive/fundamental types -namespace doctest { namespace detail { - DOCTEST_TYPE_TO_STRING_IMPL(bool) - DOCTEST_TYPE_TO_STRING_IMPL(float) - DOCTEST_TYPE_TO_STRING_IMPL(double) - DOCTEST_TYPE_TO_STRING_IMPL(long double) - DOCTEST_TYPE_TO_STRING_IMPL(char) - DOCTEST_TYPE_TO_STRING_IMPL(signed char) - DOCTEST_TYPE_TO_STRING_IMPL(unsigned char) -#if !DOCTEST_MSVC || defined(_NATIVE_WCHAR_T_DEFINED) - DOCTEST_TYPE_TO_STRING_IMPL(wchar_t) -#endif // not MSVC or wchar_t support enabled - DOCTEST_TYPE_TO_STRING_IMPL(short int) - DOCTEST_TYPE_TO_STRING_IMPL(unsigned short int) - DOCTEST_TYPE_TO_STRING_IMPL(int) - DOCTEST_TYPE_TO_STRING_IMPL(unsigned int) - DOCTEST_TYPE_TO_STRING_IMPL(long int) - DOCTEST_TYPE_TO_STRING_IMPL(unsigned long int) - DOCTEST_TYPE_TO_STRING_IMPL(long long int) - DOCTEST_TYPE_TO_STRING_IMPL(unsigned long long int) -}} // namespace doctest::detail - #endif // DOCTEST_CONFIG_DISABLE DOCTEST_CLANG_SUPPRESS_WARNING_POP DOCTEST_MSVC_SUPPRESS_WARNING_POP DOCTEST_GCC_SUPPRESS_WARNING_POP +DOCTEST_SUPPRESS_COMMON_WARNINGS_POP + #endif // DOCTEST_LIBRARY_INCLUDED #ifndef DOCTEST_SINGLE_HEADER @@ -2725,13 +3059,11 @@ DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-macros") DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH -DOCTEST_CLANG_SUPPRESS_WARNING("-Wunknown-pragmas") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wpadded") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wweak-vtables") DOCTEST_CLANG_SUPPRESS_WARNING("-Wglobal-constructors") DOCTEST_CLANG_SUPPRESS_WARNING("-Wexit-time-destructors") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-prototypes") DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") DOCTEST_CLANG_SUPPRESS_WARNING("-Wshorten-64-to-32") DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-variable-declarations") @@ -2739,65 +3071,35 @@ DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch") DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch-enum") DOCTEST_CLANG_SUPPRESS_WARNING("-Wcovered-switch-default") DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-noreturn") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-local-typedef") DOCTEST_CLANG_SUPPRESS_WARNING("-Wdisabled-macro-expansion") DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-braces") DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-field-initializers") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic") DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-member-function") DOCTEST_CLANG_SUPPRESS_WARNING("-Wnonportable-system-include-path") DOCTEST_GCC_SUPPRESS_WARNING_PUSH -DOCTEST_GCC_SUPPRESS_WARNING("-Wunknown-pragmas") -DOCTEST_GCC_SUPPRESS_WARNING("-Wpragmas") DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") -DOCTEST_GCC_SUPPRESS_WARNING("-Weffc++") DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") -DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-overflow") -DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-aliasing") DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-field-initializers") DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-braces") -DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-declarations") DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch") DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-enum") DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-default") DOCTEST_GCC_SUPPRESS_WARNING("-Wunsafe-loop-optimizations") DOCTEST_GCC_SUPPRESS_WARNING("-Wold-style-cast") -DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-local-typedefs") -DOCTEST_GCC_SUPPRESS_WARNING("-Wuseless-cast") DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-function") DOCTEST_GCC_SUPPRESS_WARNING("-Wmultiple-inheritance") -DOCTEST_GCC_SUPPRESS_WARNING("-Wnoexcept") DOCTEST_GCC_SUPPRESS_WARNING("-Wsuggest-attribute") DOCTEST_MSVC_SUPPRESS_WARNING_PUSH -DOCTEST_MSVC_SUPPRESS_WARNING(4616) // invalid compiler warning -DOCTEST_MSVC_SUPPRESS_WARNING(4619) // invalid compiler warning -DOCTEST_MSVC_SUPPRESS_WARNING(4996) // The compiler encountered a deprecated declaration DOCTEST_MSVC_SUPPRESS_WARNING(4267) // 'var' : conversion from 'x' to 'y', possible loss of data -DOCTEST_MSVC_SUPPRESS_WARNING(4706) // assignment within conditional expression -DOCTEST_MSVC_SUPPRESS_WARNING(4512) // 'class' : assignment operator could not be generated -DOCTEST_MSVC_SUPPRESS_WARNING(4127) // conditional expression is constant DOCTEST_MSVC_SUPPRESS_WARNING(4530) // C++ exception handler used, but unwind semantics not enabled DOCTEST_MSVC_SUPPRESS_WARNING(4577) // 'noexcept' used with no exception handling mode specified DOCTEST_MSVC_SUPPRESS_WARNING(4774) // format string expected in argument is not a string literal DOCTEST_MSVC_SUPPRESS_WARNING(4365) // conversion from 'int' to 'unsigned', signed/unsigned mismatch -DOCTEST_MSVC_SUPPRESS_WARNING(4820) // padding in structs -DOCTEST_MSVC_SUPPRESS_WARNING(4640) // construction of local static object is not thread-safe DOCTEST_MSVC_SUPPRESS_WARNING(5039) // pointer to potentially throwing function passed to extern C -DOCTEST_MSVC_SUPPRESS_WARNING(5045) // Spectre mitigation stuff -DOCTEST_MSVC_SUPPRESS_WARNING(4626) // assignment operator was implicitly defined as deleted -DOCTEST_MSVC_SUPPRESS_WARNING(5027) // move assignment operator was implicitly defined as deleted -DOCTEST_MSVC_SUPPRESS_WARNING(5026) // move constructor was implicitly defined as deleted -DOCTEST_MSVC_SUPPRESS_WARNING(4625) // copy constructor was implicitly defined as deleted DOCTEST_MSVC_SUPPRESS_WARNING(4800) // forcing value to bool 'true' or 'false' (performance warning) -// static analysis -DOCTEST_MSVC_SUPPRESS_WARNING(26439) // This kind of function may not throw. Declare it 'noexcept' -DOCTEST_MSVC_SUPPRESS_WARNING(26495) // Always initialize a member variable -DOCTEST_MSVC_SUPPRESS_WARNING(26451) // Arithmetic overflow ... -DOCTEST_MSVC_SUPPRESS_WARNING(26444) // Avoid unnamed objects with custom construction and dtor... -DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' +DOCTEST_MSVC_SUPPRESS_WARNING(5245) // unreferenced function with internal linkage has been removed DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN @@ -2805,7 +3107,7 @@ DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN #include #include #include -// borland (Embarcadero) compiler requires math.h and not cmath - https://github.com/onqtam/doctest/pull/37 +// borland (Embarcadero) compiler requires math.h and not cmath - https://github.com/doctest/doctest/pull/37 #ifdef __BORLANDC__ #include #endif // __BORLANDC__ @@ -2821,16 +3123,27 @@ DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN #include #include #include +#ifndef DOCTEST_CONFIG_NO_MULTITHREADING #include #include +#define DOCTEST_DECLARE_MUTEX(name) std::mutex name; +#define DOCTEST_DECLARE_STATIC_MUTEX(name) static DOCTEST_DECLARE_MUTEX(name) +#define DOCTEST_LOCK_MUTEX(name) std::lock_guard DOCTEST_ANONYMOUS(DOCTEST_ANON_LOCK_)(name); +#else // DOCTEST_CONFIG_NO_MULTITHREADING +#define DOCTEST_DECLARE_MUTEX(name) +#define DOCTEST_DECLARE_STATIC_MUTEX(name) +#define DOCTEST_LOCK_MUTEX(name) +#endif // DOCTEST_CONFIG_NO_MULTITHREADING #include #include +#include #include #include #include #include #include #include +#include #ifdef DOCTEST_PLATFORM_MAC #include @@ -2863,7 +3176,7 @@ DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN #endif // DOCTEST_PLATFORM_WINDOWS -// this is a fix for https://github.com/onqtam/doctest/issues/348 +// this is a fix for https://github.com/doctest/doctest/issues/348 // https://mail.gnome.org/archives/xml/2012-January/msg00000.html #if !defined(HAVE_UNISTD_H) && !defined(STDOUT_FILENO) #define STDOUT_FILENO fileno(stdout) @@ -2885,8 +3198,12 @@ DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END #endif #ifndef DOCTEST_THREAD_LOCAL +#if defined(DOCTEST_CONFIG_NO_MULTITHREADING) || DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_THREAD_LOCAL +#else // DOCTEST_MSVC #define DOCTEST_THREAD_LOCAL thread_local -#endif +#endif // DOCTEST_MSVC +#endif // DOCTEST_THREAD_LOCAL #ifndef DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES #define DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES 32 @@ -2906,12 +3223,34 @@ DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END #define DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS #endif +#ifndef DOCTEST_CDECL +#define DOCTEST_CDECL __cdecl +#endif + namespace doctest { bool is_running_in_test = false; namespace { using namespace detail; + + template + DOCTEST_NORETURN void throw_exception(Ex const& e) { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + throw e; +#else // DOCTEST_CONFIG_NO_EXCEPTIONS + std::cerr << "doctest will terminate because it needed to throw an exception.\n" + << "The message was: " << e.what() << '\n'; + std::terminate(); +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + } + +#ifndef DOCTEST_INTERNAL_ERROR +#define DOCTEST_INTERNAL_ERROR(msg) \ + throw_exception(std::logic_error( \ + __FILE__ ":" DOCTEST_TOSTR(__LINE__) ": Internal doctest error: " msg)) +#endif // DOCTEST_INTERNAL_ERROR + // case insensitive strcmp int stricmp(const char* a, const char* b) { for(;; a++, b++) { @@ -2921,20 +3260,6 @@ namespace { } } - template - String fpToString(T value, int precision) { - std::ostringstream oss; - oss << std::setprecision(precision) << std::fixed << value; - std::string d = oss.str(); - size_t i = d.find_last_not_of('0'); - if(i != std::string::npos && i != d.size() - 1) { - if(d[i] == '.') - i++; - d = d.substr(0, i + 1); - } - return d.c_str(); - } - struct Endianness { enum Arch @@ -2955,36 +3280,35 @@ namespace { } // namespace namespace detail { - void my_memcpy(void* dest, const void* src, unsigned num) { memcpy(dest, src, num); } + DOCTEST_THREAD_LOCAL class + { + std::vector stack; + std::stringstream ss; - String rawMemoryToString(const void* object, unsigned size) { - // Reverse order for little endian architectures - int i = 0, end = static_cast(size), inc = 1; - if(Endianness::which() == Endianness::Little) { - i = end - 1; - end = inc = -1; + public: + std::ostream* push() { + stack.push_back(ss.tellp()); + return &ss; } - unsigned const char* bytes = static_cast(object); - std::ostringstream oss; - oss << "0x" << std::setfill('0') << std::hex; - for(; i != end; i += inc) - oss << std::setw(2) << static_cast(bytes[i]); - return oss.str().c_str(); + String pop() { + if (stack.empty()) + DOCTEST_INTERNAL_ERROR("TLSS was empty when trying to pop!"); + + std::streampos pos = stack.back(); + stack.pop_back(); + unsigned sz = static_cast(ss.tellp() - pos); + ss.rdbuf()->pubseekpos(pos, std::ios::in | std::ios::out); + return String(ss, sz); + } + } g_oss; + + std::ostream* tlssPush() { + return g_oss.push(); } - DOCTEST_THREAD_LOCAL std::ostringstream g_oss; // NOLINT(cert-err58-cpp) - - std::ostream* getTlsOss() { - g_oss.clear(); // there shouldn't be anything worth clearing in the flags - g_oss.str(""); // the slow way of resetting a string stream - //g_oss.seekp(0); // optimal reset - as seen here: https://stackoverflow.com/a/624291/3162383 - return &g_oss; - } - - String getTlsOssResult() { - //g_oss << std::ends; // needed - as shown here: https://stackoverflow.com/a/624291/3162383 - return g_oss.str().c_str(); + String tlssPop() { + return g_oss.pop(); } #ifndef DOCTEST_CONFIG_DISABLE @@ -2993,20 +3317,19 @@ namespace timer_large_integer { #if defined(DOCTEST_PLATFORM_WINDOWS) - typedef ULONGLONG type; + using type = ULONGLONG; #else // DOCTEST_PLATFORM_WINDOWS - using namespace std; - typedef uint64_t type; + using type = std::uint64_t; #endif // DOCTEST_PLATFORM_WINDOWS } -typedef timer_large_integer::type ticks_t; +using ticks_t = timer_large_integer::type; #ifdef DOCTEST_CONFIG_GETCURRENTTICKS ticks_t getCurrentTicks() { return DOCTEST_CONFIG_GETCURRENTTICKS(); } #elif defined(DOCTEST_PLATFORM_WINDOWS) ticks_t getCurrentTicks() { - static LARGE_INTEGER hz = {0}, hzo = {0}; + static LARGE_INTEGER hz = { {0} }, hzo = { {0} }; if(!hz.QuadPart) { QueryPerformanceFrequency(&hz); QueryPerformanceCounter(&hzo); @@ -3038,9 +3361,17 @@ typedef timer_large_integer::type ticks_t; ticks_t m_ticks = 0; }; -#ifdef DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS +#ifdef DOCTEST_CONFIG_NO_MULTITHREADING template - using AtomicOrMultiLaneAtomic = std::atomic; + using Atomic = T; +#else // DOCTEST_CONFIG_NO_MULTITHREADING + template + using Atomic = std::atomic; +#endif // DOCTEST_CONFIG_NO_MULTITHREADING + +#if defined(DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS) || defined(DOCTEST_CONFIG_NO_MULTITHREADING) + template + using MultiLaneAtomic = Atomic; #else // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS // Provides a multilane implementation of an atomic variable that supports add, sub, load, // store. Instead of using a single atomic variable, this splits up into multiple ones, @@ -3057,8 +3388,8 @@ typedef timer_large_integer::type ticks_t; { struct CacheLineAlignedAtomic { - std::atomic atomic{}; - char padding[DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE - sizeof(std::atomic)]; + Atomic atomic{}; + char padding[DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE - sizeof(Atomic)]; }; CacheLineAlignedAtomic m_atomics[DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES]; @@ -3088,7 +3419,7 @@ typedef timer_large_integer::type ticks_t; return result; } - T operator=(T desired) DOCTEST_NOEXCEPT { + T operator=(T desired) DOCTEST_NOEXCEPT { // lgtm [cpp/assignment-does-not-return-this] store(desired); return desired; } @@ -3103,7 +3434,7 @@ typedef timer_large_integer::type ticks_t; private: // Each thread has a different atomic that it operates on. If more than NumLanes threads - // use this, some will use the same atomic. So performance will degrate a bit, but still + // use this, some will use the same atomic. So performance will degrade a bit, but still // everything will work. // // The logic here is a bit tricky. The call should be as fast as possible, so that there @@ -3114,24 +3445,21 @@ typedef timer_large_integer::type ticks_t; // assigned in a round-robin fashion. // 3. This tlsLaneIdx is stored in the thread local data, so it is directly available with // little overhead. - std::atomic& myAtomic() DOCTEST_NOEXCEPT { - static std::atomic laneCounter; + Atomic& myAtomic() DOCTEST_NOEXCEPT { + static Atomic laneCounter; DOCTEST_THREAD_LOCAL size_t tlsLaneIdx = laneCounter++ % DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES; return m_atomics[tlsLaneIdx].atomic; } }; - - template - using AtomicOrMultiLaneAtomic = MultiLaneAtomic; #endif // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS // this holds both parameters from the command line and runtime data for tests struct ContextState : ContextOptions, TestRunStats, CurrentTestCaseStats { - AtomicOrMultiLaneAtomic numAssertsCurrentTest_atomic; - AtomicOrMultiLaneAtomic numAssertsFailedCurrentTest_atomic; + MultiLaneAtomic numAssertsCurrentTest_atomic; + MultiLaneAtomic numAssertsFailedCurrentTest_atomic; std::vector> filters = decltype(filters)(9); // 9 different filters @@ -3144,11 +3472,12 @@ typedef timer_large_integer::type ticks_t; std::vector stringifiedContexts; // logging from INFO() due to an exception // stuff for subcases - std::vector subcasesStack; - std::set subcasesPassed; - int subcasesCurrentMaxLevel; - bool should_reenter; - std::atomic shouldLogCurrentException; + bool reachedLeaf; + std::vector subcaseStack; + std::vector nextSubcaseStack; + std::unordered_set fullyTraversedSubcases; + size_t currentSubcaseDepth; + Atomic shouldLogCurrentException; void resetRunData() { numTestCases = 0; @@ -3198,7 +3527,8 @@ typedef timer_large_integer::type ticks_t; (TestCaseFailureReason::FailedExactlyNumTimes & failure_flags); // if any subcase has failed - the whole test case has failed - if(failure_flags && !ok_to_fail) + testCaseSuccess = !(failure_flags && !ok_to_fail); + if(!testCaseSuccess) numTestCasesFailed++; } }; @@ -3213,23 +3543,37 @@ typedef timer_large_integer::type ticks_t; #endif // DOCTEST_CONFIG_DISABLE } // namespace detail -void String::setOnHeap() { *reinterpret_cast(&buf[last]) = 128; } -void String::setLast(unsigned in) { buf[last] = char(in); } - -void String::copy(const String& other) { - using namespace std; - if(other.isOnStack()) { - memcpy(buf, other.buf, len); +char* String::allocate(size_type sz) { + if (sz <= last) { + buf[sz] = '\0'; + setLast(last - sz); + return buf; } else { setOnHeap(); - data.size = other.data.size; + data.size = sz; data.capacity = data.size + 1; - data.ptr = new char[data.capacity]; - memcpy(data.ptr, other.data.ptr, data.size + 1); + data.ptr = new char[data.capacity]; + data.ptr[sz] = '\0'; + return data.ptr; } } -String::String() { +void String::setOnHeap() noexcept { *reinterpret_cast(&buf[last]) = 128; } +void String::setLast(size_type in) noexcept { buf[last] = char(in); } +void String::setSize(size_type sz) noexcept { + if (isOnStack()) { buf[sz] = '\0'; setLast(last - sz); } + else { data.ptr[sz] = '\0'; data.size = sz; } +} + +void String::copy(const String& other) { + if(other.isOnStack()) { + memcpy(buf, other.buf, len); + } else { + memcpy(allocate(other.data.size), other.data.ptr, other.data.size); + } +} + +String::String() noexcept { buf[0] = '\0'; setLast(); } @@ -3237,26 +3581,17 @@ String::String() { String::~String() { if(!isOnStack()) delete[] data.ptr; - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) -} +} // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) String::String(const char* in) : String(in, strlen(in)) {} -String::String(const char* in, unsigned in_size) { - using namespace std; - if(in_size <= last) { - memcpy(buf, in, in_size); - buf[in_size] = '\0'; - setLast(last - in_size); - } else { - setOnHeap(); - data.size = in_size; - data.capacity = data.size + 1; - data.ptr = new char[data.capacity]; - memcpy(data.ptr, in, in_size); - data.ptr[in_size] = '\0'; - } +String::String(const char* in, size_type in_size) { + memcpy(allocate(in_size), in, in_size); +} + +String::String(std::istream& in, size_type in_size) { + in.read(allocate(in_size), in_size); } String::String(const String& other) { copy(other); } @@ -3273,10 +3608,9 @@ String& String::operator=(const String& other) { } String& String::operator+=(const String& other) { - const unsigned my_old_size = size(); - const unsigned other_size = other.size(); - const unsigned total_size = my_old_size + other_size; - using namespace std; + const size_type my_old_size = size(); + const size_type other_size = other.size(); + const size_type total_size = my_old_size + other_size; if(isOnStack()) { if(total_size < len) { // append to the current stack space @@ -3323,18 +3657,13 @@ String& String::operator+=(const String& other) { return *this; } -// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) -String String::operator+(const String& other) const { return String(*this) += other; } - -String::String(String&& other) { - using namespace std; +String::String(String&& other) noexcept { memcpy(buf, other.buf, len); other.buf[0] = '\0'; other.setLast(); } -String& String::operator=(String&& other) { - using namespace std; +String& String::operator=(String&& other) noexcept { if(this != &other) { if(!isOnStack()) delete[] data.ptr; @@ -3345,30 +3674,60 @@ String& String::operator=(String&& other) { return *this; } -char String::operator[](unsigned i) const { - return const_cast(this)->operator[](i); // NOLINT +char String::operator[](size_type i) const { + return const_cast(this)->operator[](i); } -char& String::operator[](unsigned i) { +char& String::operator[](size_type i) { if(isOnStack()) return reinterpret_cast(buf)[i]; return data.ptr[i]; } DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmaybe-uninitialized") -unsigned String::size() const { +String::size_type String::size() const { if(isOnStack()) - return last - (unsigned(buf[last]) & 31); // using "last" would work only if "len" is 32 + return last - (size_type(buf[last]) & 31); // using "last" would work only if "len" is 32 return data.size; } DOCTEST_GCC_SUPPRESS_WARNING_POP -unsigned String::capacity() const { +String::size_type String::capacity() const { if(isOnStack()) return len; return data.capacity; } +String String::substr(size_type pos, size_type cnt) && { + cnt = std::min(cnt, size() - 1 - pos); + char* cptr = c_str(); + memmove(cptr, cptr + pos, cnt); + setSize(cnt); + return std::move(*this); +} + +String String::substr(size_type pos, size_type cnt) const & { + cnt = std::min(cnt, size() - 1 - pos); + return String{ c_str() + pos, cnt }; +} + +String::size_type String::find(char ch, size_type pos) const { + const char* begin = c_str(); + const char* end = begin + size(); + const char* it = begin + pos; + for (; it < end && *it != ch; it++); + if (it < end) { return static_cast(it - begin); } + else { return npos; } +} + +String::size_type String::rfind(char ch, size_type pos) const { + const char* begin = c_str(); + const char* it = begin + std::min(pos, size() - 1); + for (; it >= begin && *it != ch; it--); + if (it >= begin) { return static_cast(it - begin); } + else { return npos; } +} + int String::compare(const char* other, bool no_case) const { if(no_case) return doctest::stricmp(c_str(), other); @@ -3379,17 +3738,32 @@ int String::compare(const String& other, bool no_case) const { return compare(other.c_str(), no_case); } -// clang-format off +String operator+(const String& lhs, const String& rhs) { return String(lhs) += rhs; } + bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } bool operator< (const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } bool operator> (const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } bool operator<=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) < 0 : true; } bool operator>=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) > 0 : true; } -// clang-format on std::ostream& operator<<(std::ostream& s, const String& in) { return s << in.c_str(); } +Contains::Contains(const String& str) : string(str) { } + +bool Contains::checkWith(const String& other) const { + return strstr(other.c_str(), string.c_str()) != nullptr; +} + +String toString(const Contains& in) { + return "Contains( " + in.string + " )"; +} + +bool operator==(const String& lhs, const Contains& rhs) { return rhs.checkWith(lhs); } +bool operator==(const Contains& lhs, const String& rhs) { return lhs.checkWith(rhs); } +bool operator!=(const String& lhs, const Contains& rhs) { return !rhs.checkWith(lhs); } +bool operator!=(const Contains& lhs, const String& rhs) { return !lhs.checkWith(rhs); } + namespace { void color_to_stream(std::ostream&, Color::Enum) DOCTEST_BRANCH_ON_DISABLED({}, ;) } // namespace @@ -3403,64 +3777,42 @@ namespace Color { // clang-format off const char* assertString(assertType::Enum at) { - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4062) // enum 'x' in switch of enum 'y' is not handled - switch(at) { //!OCLINT missing default in switch statements - case assertType::DT_WARN : return "WARN"; - case assertType::DT_CHECK : return "CHECK"; - case assertType::DT_REQUIRE : return "REQUIRE"; + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4061) // enum 'x' in switch of enum 'y' is not explicitely handled + #define DOCTEST_GENERATE_ASSERT_TYPE_CASE(assert_type) case assertType::DT_ ## assert_type: return #assert_type + #define DOCTEST_GENERATE_ASSERT_TYPE_CASES(assert_type) \ + DOCTEST_GENERATE_ASSERT_TYPE_CASE(WARN_ ## assert_type); \ + DOCTEST_GENERATE_ASSERT_TYPE_CASE(CHECK_ ## assert_type); \ + DOCTEST_GENERATE_ASSERT_TYPE_CASE(REQUIRE_ ## assert_type) + switch(at) { + DOCTEST_GENERATE_ASSERT_TYPE_CASE(WARN); + DOCTEST_GENERATE_ASSERT_TYPE_CASE(CHECK); + DOCTEST_GENERATE_ASSERT_TYPE_CASE(REQUIRE); - case assertType::DT_WARN_FALSE : return "WARN_FALSE"; - case assertType::DT_CHECK_FALSE : return "CHECK_FALSE"; - case assertType::DT_REQUIRE_FALSE : return "REQUIRE_FALSE"; + DOCTEST_GENERATE_ASSERT_TYPE_CASES(FALSE); - case assertType::DT_WARN_THROWS : return "WARN_THROWS"; - case assertType::DT_CHECK_THROWS : return "CHECK_THROWS"; - case assertType::DT_REQUIRE_THROWS : return "REQUIRE_THROWS"; + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS); - case assertType::DT_WARN_THROWS_AS : return "WARN_THROWS_AS"; - case assertType::DT_CHECK_THROWS_AS : return "CHECK_THROWS_AS"; - case assertType::DT_REQUIRE_THROWS_AS : return "REQUIRE_THROWS_AS"; + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_AS); - case assertType::DT_WARN_THROWS_WITH : return "WARN_THROWS_WITH"; - case assertType::DT_CHECK_THROWS_WITH : return "CHECK_THROWS_WITH"; - case assertType::DT_REQUIRE_THROWS_WITH : return "REQUIRE_THROWS_WITH"; + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_WITH); - case assertType::DT_WARN_THROWS_WITH_AS : return "WARN_THROWS_WITH_AS"; - case assertType::DT_CHECK_THROWS_WITH_AS : return "CHECK_THROWS_WITH_AS"; - case assertType::DT_REQUIRE_THROWS_WITH_AS : return "REQUIRE_THROWS_WITH_AS"; + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_WITH_AS); - case assertType::DT_WARN_NOTHROW : return "WARN_NOTHROW"; - case assertType::DT_CHECK_NOTHROW : return "CHECK_NOTHROW"; - case assertType::DT_REQUIRE_NOTHROW : return "REQUIRE_NOTHROW"; + DOCTEST_GENERATE_ASSERT_TYPE_CASES(NOTHROW); - case assertType::DT_WARN_EQ : return "WARN_EQ"; - case assertType::DT_CHECK_EQ : return "CHECK_EQ"; - case assertType::DT_REQUIRE_EQ : return "REQUIRE_EQ"; - case assertType::DT_WARN_NE : return "WARN_NE"; - case assertType::DT_CHECK_NE : return "CHECK_NE"; - case assertType::DT_REQUIRE_NE : return "REQUIRE_NE"; - case assertType::DT_WARN_GT : return "WARN_GT"; - case assertType::DT_CHECK_GT : return "CHECK_GT"; - case assertType::DT_REQUIRE_GT : return "REQUIRE_GT"; - case assertType::DT_WARN_LT : return "WARN_LT"; - case assertType::DT_CHECK_LT : return "CHECK_LT"; - case assertType::DT_REQUIRE_LT : return "REQUIRE_LT"; - case assertType::DT_WARN_GE : return "WARN_GE"; - case assertType::DT_CHECK_GE : return "CHECK_GE"; - case assertType::DT_REQUIRE_GE : return "REQUIRE_GE"; - case assertType::DT_WARN_LE : return "WARN_LE"; - case assertType::DT_CHECK_LE : return "CHECK_LE"; - case assertType::DT_REQUIRE_LE : return "REQUIRE_LE"; + DOCTEST_GENERATE_ASSERT_TYPE_CASES(EQ); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(NE); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(GT); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(LT); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(GE); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(LE); - case assertType::DT_WARN_UNARY : return "WARN_UNARY"; - case assertType::DT_CHECK_UNARY : return "CHECK_UNARY"; - case assertType::DT_REQUIRE_UNARY : return "REQUIRE_UNARY"; - case assertType::DT_WARN_UNARY_FALSE : return "WARN_UNARY_FALSE"; - case assertType::DT_CHECK_UNARY_FALSE : return "CHECK_UNARY_FALSE"; - case assertType::DT_REQUIRE_UNARY_FALSE : return "REQUIRE_UNARY_FALSE"; + DOCTEST_GENERATE_ASSERT_TYPE_CASES(UNARY); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(UNARY_FALSE); + + default: DOCTEST_INTERNAL_ERROR("Tried stringifying invalid assert type!"); } DOCTEST_MSVC_SUPPRESS_WARNING_POP - return ""; } // clang-format on @@ -3494,6 +3846,12 @@ const char* skipPathFromFilename(const char* file) { DOCTEST_CLANG_SUPPRESS_WARNING_POP DOCTEST_GCC_SUPPRESS_WARNING_POP +bool SubcaseSignature::operator==(const SubcaseSignature& other) const { + return m_line == other.m_line + && std::strcmp(m_file, other.m_file) == 0 + && m_name == other.m_name; +} + bool SubcaseSignature::operator<(const SubcaseSignature& other) const { if(m_line != other.m_line) return m_line < other.m_line; @@ -3502,45 +3860,53 @@ bool SubcaseSignature::operator<(const SubcaseSignature& other) const { return m_name.compare(other.m_name) < 0; } -IContextScope::IContextScope() = default; -IContextScope::~IContextScope() = default; +DOCTEST_DEFINE_INTERFACE(IContextScope) -#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -String toString(char* in) { return toString(static_cast(in)); } -// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) -String toString(const char* in) { return String("\"") + (in ? in : "{null string}") + "\""; } -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -String toString(bool in) { return in ? "true" : "false"; } -String toString(float in) { return fpToString(in, 5) + "f"; } -String toString(double in) { return fpToString(in, 10); } -String toString(double long in) { return fpToString(in, 15); } - -#define DOCTEST_TO_STRING_OVERLOAD(type, fmt) \ - String toString(type in) { \ - char buf[64]; \ - std::sprintf(buf, fmt, in); \ - return buf; \ +namespace detail { + void filldata::fill(std::ostream* stream, const void* in) { + if (in) { *stream << in; } + else { *stream << "nullptr"; } } -DOCTEST_TO_STRING_OVERLOAD(char, "%d") -DOCTEST_TO_STRING_OVERLOAD(char signed, "%d") -DOCTEST_TO_STRING_OVERLOAD(char unsigned, "%u") -DOCTEST_TO_STRING_OVERLOAD(int short, "%d") -DOCTEST_TO_STRING_OVERLOAD(int short unsigned, "%u") -DOCTEST_TO_STRING_OVERLOAD(int, "%d") -DOCTEST_TO_STRING_OVERLOAD(unsigned, "%u") -DOCTEST_TO_STRING_OVERLOAD(int long, "%ld") -DOCTEST_TO_STRING_OVERLOAD(int long unsigned, "%lu") -DOCTEST_TO_STRING_OVERLOAD(int long long, "%lld") -DOCTEST_TO_STRING_OVERLOAD(int long long unsigned, "%llu") + template + String toStreamLit(T t) { + std::ostream* os = tlssPush(); + os->operator<<(t); + return tlssPop(); + } +} -String toString(std::nullptr_t) { return "NULL"; } +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +String toString(const char* in) { return String("\"") + (in ? in : "{null string}") + "\""; } +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING #if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) -// see this issue on why this is needed: https://github.com/onqtam/doctest/issues/183 +// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 String toString(const std::string& in) { return in.c_str(); } #endif // VS 2019 +String toString(String in) { return in; } + +String toString(std::nullptr_t) { return "nullptr"; } + +String toString(bool in) { return in ? "true" : "false"; } + +String toString(float in) { return toStreamLit(in); } +String toString(double in) { return toStreamLit(in); } +String toString(double long in) { return toStreamLit(in); } + +String toString(char in) { return toStreamLit(static_cast(in)); } +String toString(char signed in) { return toStreamLit(static_cast(in)); } +String toString(char unsigned in) { return toStreamLit(static_cast(in)); } +String toString(short in) { return toStreamLit(in); } +String toString(short unsigned in) { return toStreamLit(in); } +String toString(signed in) { return toStreamLit(in); } +String toString(unsigned in) { return toStreamLit(in); } +String toString(long in) { return toStreamLit(in); } +String toString(long unsigned in) { return toStreamLit(in); } +String toString(long long in) { return toStreamLit(in); } +String toString(long long unsigned in) { return toStreamLit(in); } + Approx::Approx(double value) : m_epsilon(static_cast(std::numeric_limits::epsilon()) * 100) , m_scale(1.0) @@ -3580,11 +3946,25 @@ bool operator>(double lhs, const Approx& rhs) { return lhs > rhs.m_value && lhs bool operator>(const Approx& lhs, double rhs) { return lhs.m_value > rhs && lhs != rhs; } String toString(const Approx& in) { - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - return String("Approx( ") + doctest::toString(in.m_value) + " )"; + return "Approx( " + doctest::toString(in.m_value) + " )"; } const ContextOptions* getContextOptions() { return DOCTEST_BRANCH_ON_DISABLED(nullptr, g_cs); } +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4738) +template +IsNaN::operator bool() const { + return std::isnan(value) ^ flipped; +} +DOCTEST_MSVC_SUPPRESS_WARNING_POP +template struct DOCTEST_INTERFACE_DEF IsNaN; +template struct DOCTEST_INTERFACE_DEF IsNaN; +template struct DOCTEST_INTERFACE_DEF IsNaN; +template +String toString(IsNaN in) { return String(in.flipped ? "! " : "") + "IsNaN( " + doctest::toString(in.value) + " )"; } +String toString(IsNaN in) { return toString(in); } +String toString(IsNaN in) { return toString(in); } +String toString(IsNaN in) { return toString(in); } + } // namespace doctest #ifdef DOCTEST_CONFIG_DISABLE @@ -3594,15 +3974,15 @@ Context::~Context() = default; void Context::applyCommandLine(int, const char* const*) {} void Context::addFilter(const char*, const char*) {} void Context::clearFilters() {} +void Context::setOption(const char*, bool) {} void Context::setOption(const char*, int) {} void Context::setOption(const char*, const char*) {} bool Context::shouldExit() { return false; } void Context::setAsDefaultForAssertsOutOfTestCases() {} void Context::setAssertHandler(detail::assert_handler) {} +void Context::setCout(std::ostream*) {} int Context::run() { return 0; } -IReporter::~IReporter() = default; - int IReporter::get_num_active_contexts() { return 0; } const IContextScope* const* IReporter::get_active_contexts() { return nullptr; } int IReporter::get_num_stringified_contexts() { return 0; } @@ -3635,7 +4015,7 @@ namespace doctest { namespace { // the int (priority) is part of the key for automatic sorting - sadly one can register a // reporter with a duplicate name and a different priority but hopefully that won't happen often :| - typedef std::map, reporterCreatorFunc> reporterMap; + using reporterMap = std::map, reporterCreatorFunc>; reporterMap& getReporters() { static reporterMap data; @@ -3667,8 +4047,8 @@ namespace detail { #ifndef DOCTEST_CONFIG_NO_EXCEPTIONS DOCTEST_NORETURN void throwException() { g_cs->shouldLogCurrentException = false; - throw TestFailureException(); - } // NOLINT(cert-err60-cpp) + throw TestFailureException(); // NOLINT(hicpp-exception-baseclass) + } #else // DOCTEST_CONFIG_NO_EXCEPTIONS void throwException() {} #endif // DOCTEST_CONFIG_NO_EXCEPTIONS @@ -3714,91 +4094,132 @@ namespace { return !*wild; } - //// C string hash function (djb2) - taken from http://www.cse.yorku.ca/~oz/hash.html - //unsigned hashStr(unsigned const char* str) { - // unsigned long hash = 5381; - // char c; - // while((c = *str++)) - // hash = ((hash << 5) + hash) + c; // hash * 33 + c - // return hash; - //} - // checks if the name matches any of the filters (and can be configured what to do when empty) bool matchesAny(const char* name, const std::vector& filters, bool matchEmpty, - bool caseSensitive) { - if(filters.empty() && matchEmpty) + bool caseSensitive) { + if (filters.empty() && matchEmpty) return true; - for(auto& curr : filters) - if(wildcmp(name, curr.c_str(), caseSensitive)) + for (auto& curr : filters) + if (wildcmp(name, curr.c_str(), caseSensitive)) return true; return false; } + + unsigned long long hash(unsigned long long a, unsigned long long b) { + return (a << 5) + b; + } + + // C string hash function (djb2) - taken from http://www.cse.yorku.ca/~oz/hash.html + unsigned long long hash(const char* str) { + unsigned long long hash = 5381; + char c; + while ((c = *str++)) + hash = ((hash << 5) + hash) + c; // hash * 33 + c + return hash; + } + + unsigned long long hash(const SubcaseSignature& sig) { + return hash(hash(hash(sig.m_file), hash(sig.m_name.c_str())), sig.m_line); + } + + unsigned long long hash(const std::vector& sigs, size_t count) { + unsigned long long running = 0; + auto end = sigs.begin() + count; + for (auto it = sigs.begin(); it != end; it++) { + running = hash(running, hash(*it)); + } + return running; + } + + unsigned long long hash(const std::vector& sigs) { + unsigned long long running = 0; + for (const SubcaseSignature& sig : sigs) { + running = hash(running, hash(sig)); + } + return running; + } } // namespace namespace detail { + bool Subcase::checkFilters() { + if (g_cs->subcaseStack.size() < size_t(g_cs->subcase_filter_levels)) { + if (!matchesAny(m_signature.m_name.c_str(), g_cs->filters[6], true, g_cs->case_sensitive)) + return true; + if (matchesAny(m_signature.m_name.c_str(), g_cs->filters[7], false, g_cs->case_sensitive)) + return true; + } + return false; + } Subcase::Subcase(const String& name, const char* file, int line) : m_signature({name, file, line}) { - auto* s = g_cs; + if (!g_cs->reachedLeaf) { + if (g_cs->nextSubcaseStack.size() <= g_cs->subcaseStack.size() + || g_cs->nextSubcaseStack[g_cs->subcaseStack.size()] == m_signature) { + // Going down. + if (checkFilters()) { return; } - // check subcase filters - if(s->subcasesStack.size() < size_t(s->subcase_filter_levels)) { - if(!matchesAny(m_signature.m_name.c_str(), s->filters[6], true, s->case_sensitive)) - return; - if(matchesAny(m_signature.m_name.c_str(), s->filters[7], false, s->case_sensitive)) - return; + g_cs->subcaseStack.push_back(m_signature); + g_cs->currentSubcaseDepth++; + m_entered = true; + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); + } + } else { + if (g_cs->subcaseStack[g_cs->currentSubcaseDepth] == m_signature) { + // This subcase is reentered via control flow. + g_cs->currentSubcaseDepth++; + m_entered = true; + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); + } else if (g_cs->nextSubcaseStack.size() <= g_cs->currentSubcaseDepth + && g_cs->fullyTraversedSubcases.find(hash(hash(g_cs->subcaseStack, g_cs->currentSubcaseDepth), hash(m_signature))) + == g_cs->fullyTraversedSubcases.end()) { + if (checkFilters()) { return; } + // This subcase is part of the one to be executed next. + g_cs->nextSubcaseStack.clear(); + g_cs->nextSubcaseStack.insert(g_cs->nextSubcaseStack.end(), + g_cs->subcaseStack.begin(), g_cs->subcaseStack.begin() + g_cs->currentSubcaseDepth); + g_cs->nextSubcaseStack.push_back(m_signature); + } } - - // if a Subcase on the same level has already been entered - if(s->subcasesStack.size() < size_t(s->subcasesCurrentMaxLevel)) { - s->should_reenter = true; - return; - } - - // push the current signature to the stack so we can check if the - // current stack + the current new subcase have been traversed - s->subcasesStack.push_back(m_signature); - if(s->subcasesPassed.count(s->subcasesStack) != 0) { - // pop - revert to previous stack since we've already passed this - s->subcasesStack.pop_back(); - return; - } - - s->subcasesCurrentMaxLevel = s->subcasesStack.size(); - m_entered = true; - - DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); } - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") Subcase::~Subcase() { - if(m_entered) { - // only mark the subcase stack as passed if no subcases have been skipped - if(g_cs->should_reenter == false) - g_cs->subcasesPassed.insert(g_cs->subcasesStack); - g_cs->subcasesStack.pop_back(); + if (m_entered) { + g_cs->currentSubcaseDepth--; + + if (!g_cs->reachedLeaf) { + // Leaf. + g_cs->fullyTraversedSubcases.insert(hash(g_cs->subcaseStack)); + g_cs->nextSubcaseStack.clear(); + g_cs->reachedLeaf = true; + } else if (g_cs->nextSubcaseStack.empty()) { + // All children are finished. + g_cs->fullyTraversedSubcases.insert(hash(g_cs->subcaseStack)); + } #if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L && (!defined(__MAC_OS_X_VERSION_MIN_REQUIRED) || __MAC_OS_X_VERSION_MIN_REQUIRED >= 101200) if(std::uncaught_exceptions() > 0 #else if(std::uncaught_exception() #endif - && g_cs->shouldLogCurrentException) { + && g_cs->shouldLogCurrentException) { DOCTEST_ITERATE_THROUGH_REPORTERS( test_case_exception, {"exception thrown in subcase - will translate later " - "when the whole test case has been exited (cannot " - "translate while there is an active exception)", - false}); + "when the whole test case has been exited (cannot " + "translate while there is an active exception)", + false}); g_cs->shouldLogCurrentException = false; } + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); } } - DOCTEST_CLANG_SUPPRESS_WARNING_POP - DOCTEST_GCC_SUPPRESS_WARNING_POP + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP DOCTEST_MSVC_SUPPRESS_WARNING_POP Subcase::operator bool() const { return m_entered; } @@ -3812,20 +4233,11 @@ namespace detail { TestSuite& TestSuite::operator*(const char* in) { m_test_suite = in; - // clear state - m_description = nullptr; - m_skip = false; - m_no_breaks = false; - m_no_output = false; - m_may_fail = false; - m_should_fail = false; - m_expected_failures = 0; - m_timeout = 0; return *this; } TestCase::TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, - const char* type, int template_id) { + const String& type, int template_id) { m_file = file; m_line = line; m_name = nullptr; // will be later overridden in operator* @@ -3850,10 +4262,8 @@ namespace detail { } DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function - DOCTEST_MSVC_SUPPRESS_WARNING(26437) // Do not slice TestCase& TestCase::operator=(const TestCase& other) { - static_cast(*this) = static_cast(other); - + TestCaseData::operator=(other); m_test = other.m_test; m_type = other.m_type; m_template_id = other.m_template_id; @@ -3869,7 +4279,7 @@ namespace detail { m_name = in; // make a new name with an appended type for templated test case if(m_template_id != -1) { - m_full_name = String(m_name) + m_type; + m_full_name = String(m_name) + "<" + m_type + ">"; // redirect the name to point to the newly constructed full name m_name = m_full_name.c_str(); } @@ -3925,29 +4335,6 @@ namespace { return suiteOrderComparator(lhs, rhs); } -#ifdef DOCTEST_CONFIG_COLORS_WINDOWS - HANDLE g_stdoutHandle; - WORD g_origFgAttrs; - WORD g_origBgAttrs; - bool g_attrsInitted = false; - - int colors_init() { - if(!g_attrsInitted) { - g_stdoutHandle = GetStdHandle(STD_OUTPUT_HANDLE); - g_attrsInitted = true; - CONSOLE_SCREEN_BUFFER_INFO csbiInfo; - GetConsoleScreenBufferInfo(g_stdoutHandle, &csbiInfo); - g_origFgAttrs = csbiInfo.wAttributes & ~(BACKGROUND_GREEN | BACKGROUND_RED | - BACKGROUND_BLUE | BACKGROUND_INTENSITY); - g_origBgAttrs = csbiInfo.wAttributes & ~(FOREGROUND_GREEN | FOREGROUND_RED | - FOREGROUND_BLUE | FOREGROUND_INTENSITY); - } - return 0; - } - - int dumy_init_console_colors = colors_init(); -#endif // DOCTEST_CONFIG_COLORS_WINDOWS - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") void color_to_stream(std::ostream& s, Color::Enum code) { static_cast(s); // for DOCTEST_CONFIG_COLORS_NONE or DOCTEST_CONFIG_COLORS_WINDOWS @@ -3981,10 +4368,26 @@ namespace { #ifdef DOCTEST_CONFIG_COLORS_WINDOWS if(g_no_colors || - (isatty(fileno(stdout)) == false && getContextOptions()->force_colors == false)) + (_isatty(_fileno(stdout)) == false && getContextOptions()->force_colors == false)) return; -#define DOCTEST_SET_ATTR(x) SetConsoleTextAttribute(g_stdoutHandle, x | g_origBgAttrs) + static struct ConsoleHelper { + HANDLE stdoutHandle; + WORD origFgAttrs; + WORD origBgAttrs; + + ConsoleHelper() { + stdoutHandle = GetStdHandle(STD_OUTPUT_HANDLE); + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo(stdoutHandle, &csbiInfo); + origFgAttrs = csbiInfo.wAttributes & ~(BACKGROUND_GREEN | BACKGROUND_RED | + BACKGROUND_BLUE | BACKGROUND_INTENSITY); + origBgAttrs = csbiInfo.wAttributes & ~(FOREGROUND_GREEN | FOREGROUND_RED | + FOREGROUND_BLUE | FOREGROUND_INTENSITY); + } + } ch; + +#define DOCTEST_SET_ATTR(x) SetConsoleTextAttribute(ch.stdoutHandle, x | ch.origBgAttrs) // clang-format off switch (code) { @@ -4001,7 +4404,7 @@ namespace { case Color::BrightWhite: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; case Color::None: case Color::Bright: // invalid - default: DOCTEST_SET_ATTR(g_origFgAttrs); + default: DOCTEST_SET_ATTR(ch.origFgAttrs); } // clang-format on #endif // DOCTEST_CONFIG_COLORS_WINDOWS @@ -4118,35 +4521,22 @@ namespace detail { getExceptionTranslators().push_back(et); } -#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - void toStream(std::ostream* s, char* in) { *s << in; } - void toStream(std::ostream* s, const char* in) { *s << in; } -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - void toStream(std::ostream* s, bool in) { *s << std::boolalpha << in << std::noboolalpha; } - void toStream(std::ostream* s, float in) { *s << in; } - void toStream(std::ostream* s, double in) { *s << in; } - void toStream(std::ostream* s, double long in) { *s << in; } - - void toStream(std::ostream* s, char in) { *s << in; } - void toStream(std::ostream* s, char signed in) { *s << in; } - void toStream(std::ostream* s, char unsigned in) { *s << in; } - void toStream(std::ostream* s, int short in) { *s << in; } - void toStream(std::ostream* s, int short unsigned in) { *s << in; } - void toStream(std::ostream* s, int in) { *s << in; } - void toStream(std::ostream* s, int unsigned in) { *s << in; } - void toStream(std::ostream* s, int long in) { *s << in; } - void toStream(std::ostream* s, int long unsigned in) { *s << in; } - void toStream(std::ostream* s, int long long in) { *s << in; } - void toStream(std::ostream* s, int long long unsigned in) { *s << in; } - DOCTEST_THREAD_LOCAL std::vector g_infoContexts; // for logging with INFO() ContextScopeBase::ContextScopeBase() { g_infoContexts.push_back(this); } - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + ContextScopeBase::ContextScopeBase(ContextScopeBase&& other) noexcept { + if (other.need_to_destroy) { + other.destroy(); + } + other.need_to_destroy = false; + g_infoContexts.push_back(this); + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") // destroy cannot be inlined into the destructor because that would mean calling stringify after @@ -4165,8 +4555,8 @@ namespace detail { g_infoContexts.pop_back(); } - DOCTEST_CLANG_SUPPRESS_WARNING_POP - DOCTEST_GCC_SUPPRESS_WARNING_POP + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP DOCTEST_MSVC_SUPPRESS_WARNING_POP } // namespace detail namespace { @@ -4207,10 +4597,10 @@ namespace { static LONG CALLBACK handleException(PEXCEPTION_POINTERS ExceptionInfo) { // Multiple threads may enter this filter/handler at once. We want the error message to be printed on the // console just once no matter how many threads have crashed. - static std::mutex mutex; + DOCTEST_DECLARE_STATIC_MUTEX(mutex) static bool execute = true; { - std::lock_guard lock(mutex); + DOCTEST_LOCK_MUTEX(mutex) if(execute) { bool reported = false; for(size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { @@ -4313,7 +4703,7 @@ namespace { static unsigned int prev_abort_behavior; static int prev_report_mode; static _HFILE prev_report_file; - static void (*prev_sigabrt_handler)(int); + static void (DOCTEST_CDECL *prev_sigabrt_handler)(int); static std::terminate_handler original_terminate_handler; static bool isSet; static ULONG guaranteeSize; @@ -4325,7 +4715,7 @@ namespace { unsigned int FatalConditionHandler::prev_abort_behavior; int FatalConditionHandler::prev_report_mode; _HFILE FatalConditionHandler::prev_report_file; - void (*FatalConditionHandler::prev_sigabrt_handler)(int); + void (DOCTEST_CDECL *FatalConditionHandler::prev_sigabrt_handler)(int); std::terminate_handler FatalConditionHandler::original_terminate_handler; bool FatalConditionHandler::isSet = false; ULONG FatalConditionHandler::guaranteeSize = 0; @@ -4383,7 +4773,7 @@ namespace { sigStack.ss_flags = 0; sigaltstack(&sigStack, &oldSigStack); struct sigaction sa = {}; - sa.sa_handler = handleSignal; // NOLINT + sa.sa_handler = handleSignal; sa.sa_flags = SA_ONSTACK; for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { sigaction(signalDefs[i].id, &sa, &oldSigActions[i]); @@ -4422,7 +4812,7 @@ namespace { #define DOCTEST_OUTPUT_DEBUG_STRING(text) ::OutputDebugStringA(text) #else // TODO: integration with XCode and other IDEs -#define DOCTEST_OUTPUT_DEBUG_STRING(text) // NOLINT(clang-diagnostic-unused-macros) +#define DOCTEST_OUTPUT_DEBUG_STRING(text) #endif // Platform void addAssert(assertType::Enum at) { @@ -4441,8 +4831,8 @@ namespace { DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, {message.c_str(), true}); - while(g_cs->subcasesStack.size()) { - g_cs->subcasesStack.pop_back(); + while (g_cs->subcaseStack.size()) { + g_cs->subcaseStack.pop_back(); DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); } @@ -4454,25 +4844,26 @@ namespace { } #endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH } // namespace + +AssertData::AssertData(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const StringContains& exception_string) + : m_test_case(g_cs->currentTest), m_at(at), m_file(file), m_line(line), m_expr(expr), + m_failed(true), m_threw(false), m_threw_as(false), m_exception_type(exception_type), + m_exception_string(exception_string) { +#if DOCTEST_MSVC + if (m_expr[0] == ' ') // this happens when variadic macros are disabled under MSVC + ++m_expr; +#endif // MSVC +} + namespace detail { + ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const String& exception_string) + : AssertData(at, file, line, expr, exception_type, exception_string) { } ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type, const char* exception_string) { - m_test_case = g_cs->currentTest; - m_at = at; - m_file = file; - m_line = line; - m_expr = expr; - m_failed = true; - m_threw = false; - m_threw_as = false; - m_exception_type = exception_type; - m_exception_string = exception_string; -#if DOCTEST_MSVC - if(m_expr[0] == ' ') // this happens when variadic macros are disabled under MSVC - ++m_expr; -#endif // MSVC - } + const char* exception_type, const Contains& exception_string) + : AssertData(at, file, line, expr, exception_type, exception_string) { } void ResultBuilder::setResult(const Result& res) { m_decomp = res.m_decomp; @@ -4488,17 +4879,17 @@ namespace detail { if(m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional m_failed = !m_threw; } else if((m_at & assertType::is_throws_as) && (m_at & assertType::is_throws_with)) { //!OCLINT - m_failed = !m_threw_as || (m_exception != m_exception_string); + m_failed = !m_threw_as || !m_exception_string.check(m_exception); } else if(m_at & assertType::is_throws_as) { //!OCLINT bitwise operator in conditional m_failed = !m_threw_as; } else if(m_at & assertType::is_throws_with) { //!OCLINT bitwise operator in conditional - m_failed = m_exception != m_exception_string; + m_failed = !m_exception_string.check(m_exception); } else if(m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional m_failed = m_threw; } if(m_exception.size()) - m_exception = String("\"") + m_exception + "\""; + m_exception = "\"" + m_exception + "\""; if(is_running_in_test) { addAssert(m_at); @@ -4526,8 +4917,8 @@ namespace detail { std::abort(); } - void decomp_assert(assertType::Enum at, const char* file, int line, const char* expr, - Result result) { + bool decomp_assert(assertType::Enum at, const char* file, int line, const char* expr, + const Result& result) { bool failed = !result.m_passed; // ################################################################################### @@ -4536,21 +4927,29 @@ namespace detail { // ################################################################################### DOCTEST_ASSERT_OUT_OF_TESTS(result.m_decomp); DOCTEST_ASSERT_IN_TESTS(result.m_decomp); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + return !failed; } MessageBuilder::MessageBuilder(const char* file, int line, assertType::Enum severity) { - m_stream = getTlsOss(); + m_stream = tlssPush(); m_file = file; m_line = line; m_severity = severity; } - IExceptionTranslator::IExceptionTranslator() = default; - IExceptionTranslator::~IExceptionTranslator() = default; + MessageBuilder::~MessageBuilder() { + if (!logged) + tlssPop(); + } + + DOCTEST_DEFINE_INTERFACE(IExceptionTranslator) bool MessageBuilder::log() { - m_string = getTlsOssResult(); + if (!logged) { + m_string = tlssPop(); + logged = true; + } + DOCTEST_ITERATE_THROUGH_REPORTERS(log_message, *this); const bool isWarn = m_severity & assertType::is_warn; @@ -4569,29 +4968,10 @@ namespace detail { if(m_severity & assertType::is_require) //!OCLINT bitwise operator in conditional throwException(); } - - MessageBuilder::~MessageBuilder() = default; } // namespace detail namespace { using namespace detail; - template - DOCTEST_NORETURN void throw_exception(Ex const& e) { -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - throw e; -#else // DOCTEST_CONFIG_NO_EXCEPTIONS - std::cerr << "doctest will terminate because it needed to throw an exception.\n" - << "The message was: " << e.what() << '\n'; - std::terminate(); -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - } - -#ifndef DOCTEST_INTERNAL_ERROR -#define DOCTEST_INTERNAL_ERROR(msg) \ - throw_exception(std::logic_error( \ - __FILE__ ":" DOCTEST_TOSTR(__LINE__) ": Internal doctest error: " msg)) -#endif // DOCTEST_INTERNAL_ERROR - // clang-format off // ================================================================================================= @@ -4673,10 +5053,10 @@ namespace { void ensureTagClosed(); - private: - void writeDeclaration(); + private: + void newlineIfNecessary(); bool m_tagIsOpen = false; @@ -4865,7 +5245,7 @@ namespace { XmlWriter::XmlWriter( std::ostream& os ) : m_os( os ) { - writeDeclaration(); + // writeDeclaration(); // called explicitly by the reporters that use the writer class - see issue #627 } XmlWriter::~XmlWriter() { @@ -4976,8 +5356,8 @@ namespace { struct XmlReporter : public IReporter { - XmlWriter xml; - std::mutex mutex; + XmlWriter xml; + DOCTEST_DECLARE_MUTEX(mutex) // caching pointers/references to objects of these types - safe to do const ContextOptions& opt; @@ -5054,7 +5434,8 @@ namespace { xml.scopedElement("TestCase").writeAttribute("name", in.data[i]->m_name) .writeAttribute("testsuite", in.data[i]->m_test_suite) .writeAttribute("filename", skipPathFromFilename(in.data[i]->m_file.c_str())) - .writeAttribute("line", line(in.data[i]->m_line)); + .writeAttribute("line", line(in.data[i]->m_line)) + .writeAttribute("skipped", in.data[i]->m_skip); } xml.scopedElement("OverallResultsTestCases") .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); @@ -5070,6 +5451,8 @@ namespace { } void test_run_start() override { + xml.writeDeclaration(); + // remove .exe extension - mainly to have the same output on UNIX and Windows std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); #ifdef DOCTEST_PLATFORM_WINDOWS @@ -5124,7 +5507,8 @@ namespace { xml.startElement("OverallResultsAsserts") .writeAttribute("successes", st.numAssertsCurrentTest - st.numAssertsFailedCurrentTest) - .writeAttribute("failures", st.numAssertsFailedCurrentTest); + .writeAttribute("failures", st.numAssertsFailedCurrentTest) + .writeAttribute("test_case_success", st.testCaseSuccess); if(opt.duration) xml.writeAttribute("duration", st.seconds); if(tc->m_expected_failures) @@ -5135,7 +5519,7 @@ namespace { } void test_case_exception(const TestCaseException& e) override { - std::lock_guard lock(mutex); + DOCTEST_LOCK_MUTEX(mutex) xml.scopedElement("Exception") .writeAttribute("crash", e.is_crash) @@ -5143,8 +5527,6 @@ namespace { } void subcase_start(const SubcaseSignature& in) override { - std::lock_guard lock(mutex); - xml.startElement("SubCase") .writeAttribute("name", in.m_name) .writeAttribute("filename", skipPathFromFilename(in.m_file)) @@ -5158,7 +5540,7 @@ namespace { if(!rb.m_failed && !opt.success) return; - std::lock_guard lock(mutex); + DOCTEST_LOCK_MUTEX(mutex) xml.startElement("Expression") .writeAttribute("success", !rb.m_failed) @@ -5174,7 +5556,7 @@ namespace { if(rb.m_at & assertType::is_throws_as) xml.scopedElement("ExpectedException").writeText(rb.m_exception_type); if(rb.m_at & assertType::is_throws_with) - xml.scopedElement("ExpectedExceptionString").writeText(rb.m_exception_string); + xml.scopedElement("ExpectedExceptionString").writeText(rb.m_exception_string.c_str()); if((rb.m_at & assertType::is_normal) && !rb.m_threw) xml.scopedElement("Expanded").writeText(rb.m_decomp.c_str()); @@ -5184,7 +5566,7 @@ namespace { } void log_message(const MessageData& mb) override { - std::lock_guard lock(mutex); + DOCTEST_LOCK_MUTEX(mutex) xml.startElement("Message") .writeAttribute("type", failureString(mb.m_severity)) @@ -5220,7 +5602,8 @@ namespace { } else if((rb.m_at & assertType::is_throws_as) && (rb.m_at & assertType::is_throws_with)) { //!OCLINT s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" - << rb.m_exception_string << "\", " << rb.m_exception_type << " ) " << Color::None; + << rb.m_exception_string.c_str() + << "\", " << rb.m_exception_type << " ) " << Color::None; if(rb.m_threw) { if(!rb.m_failed) { s << "threw as expected!\n"; @@ -5241,7 +5624,8 @@ namespace { } else if(rb.m_at & assertType::is_throws_with) { //!OCLINT bitwise operator in conditional s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" - << rb.m_exception_string << "\" ) " << Color::None + << rb.m_exception_string.c_str() + << "\" ) " << Color::None << (rb.m_threw ? (!rb.m_failed ? "threw as expected!" : "threw a DIFFERENT exception: ") : "did NOT throw at all!") @@ -5266,8 +5650,8 @@ namespace { // - more attributes in tags struct JUnitReporter : public IReporter { - XmlWriter xml; - std::mutex mutex; + XmlWriter xml; + DOCTEST_DECLARE_MUTEX(mutex) Timer timer; std::vector deepestSubcaseStackNames; @@ -5363,9 +5747,13 @@ namespace { // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE // ========================================================================================= - void report_query(const QueryData&) override {} + void report_query(const QueryData&) override { + xml.writeDeclaration(); + } - void test_run_start() override {} + void test_run_start() override { + xml.writeDeclaration(); + } void test_run_end(const TestRunStats& p) override { // remove .exe extension - mainly to have the same output on UNIX and Windows @@ -5435,12 +5823,11 @@ namespace { } void test_case_exception(const TestCaseException& e) override { - std::lock_guard lock(mutex); + DOCTEST_LOCK_MUTEX(mutex) testCaseData.addError("exception", e.error_string.c_str()); } void subcase_start(const SubcaseSignature& in) override { - std::lock_guard lock(mutex); deepestSubcaseStackNames.push_back(in.m_name); } @@ -5450,7 +5837,7 @@ namespace { if(!rb.m_failed) // report only failures & ignore the `success` option return; - std::lock_guard lock(mutex); + DOCTEST_LOCK_MUTEX(mutex) std::ostringstream os; os << skipPathFromFilename(rb.m_file) << (opt.gnu_file_line ? ":" : "(") @@ -5501,7 +5888,7 @@ namespace { bool hasLoggedCurrentTestStart; std::vector subcasesStack; size_t currentSubcaseLevel; - std::mutex mutex; + DOCTEST_DECLARE_MUTEX(mutex) // caching pointers/references to objects of these types - safe to do const ContextOptions& opt; @@ -5606,9 +5993,11 @@ namespace { } void printIntro() { - printVersion(); - s << Color::Cyan << "[doctest] " << Color::None - << "run with \"--" DOCTEST_OPTIONS_PREFIX_DISPLAY "help\" for options\n"; + if(opt.no_intro == false) { + printVersion(); + s << Color::Cyan << "[doctest] " << Color::None + << "run with \"--" DOCTEST_OPTIONS_PREFIX_DISPLAY "help\" for options\n"; + } } void printHelp() { @@ -5693,12 +6082,18 @@ namespace { << Whitespace(sizePrefixDisplay*1) << "exits after the tests finish\n"; s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "d, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "duration= " << Whitespace(sizePrefixDisplay*1) << "prints the time duration of each test\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "m, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "minimal= " + << Whitespace(sizePrefixDisplay*1) << "minimal console output (only failures)\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "q, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "quiet= " + << Whitespace(sizePrefixDisplay*1) << "no console output\n"; s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nt, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-throw= " << Whitespace(sizePrefixDisplay*1) << "skips exceptions-related assert checks\n"; s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ne, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-exitcode= " << Whitespace(sizePrefixDisplay*1) << "returns (or exits) always with success\n"; s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-run= " << Whitespace(sizePrefixDisplay*1) << "skips all runtime doctest operations\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ni, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-intro= " + << Whitespace(sizePrefixDisplay*1) << "omit the framework intro in the output\n"; s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nv, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-version= " << Whitespace(sizePrefixDisplay*1) << "omit the framework version in the output\n"; s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-colors= " @@ -5736,22 +6131,6 @@ namespace { printReporters(getReporters(), "reporters"); } - void list_query_results() { - separator_to_stream(); - if(opt.count || opt.list_test_cases) { - s << Color::Cyan << "[doctest] " << Color::None - << "unskipped test cases passing the current filters: " - << g_cs->numTestCasesPassingFilters << "\n"; - } else if(opt.list_test_suites) { - s << Color::Cyan << "[doctest] " << Color::None - << "unskipped test cases passing the current filters: " - << g_cs->numTestCasesPassingFilters << "\n"; - s << Color::Cyan << "[doctest] " << Color::None - << "test suites with unskipped test cases passing the current filters: " - << g_cs->numTestSuitesPassingFilters << "\n"; - } - } - // ========================================================================================= // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE // ========================================================================================= @@ -5797,9 +6176,15 @@ namespace { } } - void test_run_start() override { printIntro(); } + void test_run_start() override { + if(!opt.minimal) + printIntro(); + } void test_run_end(const TestRunStats& p) override { + if(opt.minimal && p.numTestCasesFailed == 0) + return; + separator_to_stream(); s << std::dec; @@ -5849,7 +6234,7 @@ namespace { // log the preamble of the test case only if there is something // else to print - something other than that an assert has failed if(opt.duration || - (st.failure_flags && st.failure_flags != TestCaseFailureReason::AssertFailure)) + (st.failure_flags && st.failure_flags != static_cast(TestCaseFailureReason::AssertFailure))) logTestStart(); if(opt.duration) @@ -5880,6 +6265,7 @@ namespace { } void test_case_exception(const TestCaseException& e) override { + DOCTEST_LOCK_MUTEX(mutex) if(tc->m_no_output) return; @@ -5904,14 +6290,12 @@ namespace { } void subcase_start(const SubcaseSignature& subc) override { - std::lock_guard lock(mutex); subcasesStack.push_back(subc); ++currentSubcaseLevel; hasLoggedCurrentTestStart = false; } void subcase_end() override { - std::lock_guard lock(mutex); --currentSubcaseLevel; hasLoggedCurrentTestStart = false; } @@ -5920,7 +6304,7 @@ namespace { if((!rb.m_failed && !opt.success) || tc->m_no_output) return; - std::lock_guard lock(mutex); + DOCTEST_LOCK_MUTEX(mutex) logTestStart(); @@ -5936,7 +6320,7 @@ namespace { if(tc->m_no_output) return; - std::lock_guard lock(mutex); + DOCTEST_LOCK_MUTEX(mutex) logTestStart(); @@ -6047,18 +6431,42 @@ namespace { std::vector& res) { String filtersString; if(parseOption(argc, argv, pattern, &filtersString)) { - // tokenize with "," as a separator - // cppcheck-suppress strtokCalled - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") - auto pch = std::strtok(filtersString.c_str(), ","); // modifies the string - while(pch != nullptr) { - if(strlen(pch)) - res.push_back(pch); - // uses the strtok() internal state to go to the next token - // cppcheck-suppress strtokCalled - pch = std::strtok(nullptr, ","); + // tokenize with "," as a separator, unless escaped with backslash + std::ostringstream s; + auto flush = [&s, &res]() { + auto string = s.str(); + if(string.size() > 0) { + res.push_back(string.c_str()); + } + s.str(""); + }; + + bool seenBackslash = false; + const char* current = filtersString.c_str(); + const char* end = current + strlen(current); + while(current != end) { + char character = *current++; + if(seenBackslash) { + seenBackslash = false; + if(character == ',' || character == '\\') { + s.put(character); + continue; + } + s.put('\\'); + } + if(character == '\\') { + seenBackslash = true; + } else if(character == ',') { + flush(); + } else { + s.put(character); + } } - DOCTEST_CLANG_SUPPRESS_WARNING_POP + + if(seenBackslash) { + s.put('\\'); + } + flush(); return true; } return false; @@ -6077,30 +6485,30 @@ namespace { if(!parseOption(argc, argv, pattern, &parsedValue)) return false; - if(type == 0) { + if(type) { + // integer + // TODO: change this to use std::stoi or something else! currently it uses undefined behavior - assumes '0' on failed parse... + int theInt = std::atoi(parsedValue.c_str()); + if (theInt != 0) { + res = theInt; //!OCLINT parameter reassignment + return true; + } + } else { // boolean - const char positive[][5] = {"1", "true", "on", "yes"}; // 5 - strlen("true") + 1 - const char negative[][6] = {"0", "false", "off", "no"}; // 6 - strlen("false") + 1 + const char positive[][5] = { "1", "true", "on", "yes" }; // 5 - strlen("true") + 1 + const char negative[][6] = { "0", "false", "off", "no" }; // 6 - strlen("false") + 1 // if the value matches any of the positive/negative possibilities - for(unsigned i = 0; i < 4; i++) { - if(parsedValue.compare(positive[i], true) == 0) { + for (unsigned i = 0; i < 4; i++) { + if (parsedValue.compare(positive[i], true) == 0) { res = 1; //!OCLINT parameter reassignment return true; } - if(parsedValue.compare(negative[i], true) == 0) { + if (parsedValue.compare(negative[i], true) == 0) { res = 0; //!OCLINT parameter reassignment return true; } } - } else { - // integer - // TODO: change this to use std::stoi or something else! currently it uses undefined behavior - assumes '0' on failed parse... - int theInt = std::atoi(parsedValue.c_str()); // NOLINT - if(theInt != 0) { - res = theInt; //!OCLINT parameter reassignment - return true; - } } return false; } @@ -6191,9 +6599,12 @@ void Context::parseArgs(int argc, const char* const* argv, bool withDefaults) { DOCTEST_PARSE_AS_BOOL_OR_FLAG("case-sensitive", "cs", case_sensitive, false); DOCTEST_PARSE_AS_BOOL_OR_FLAG("exit", "e", exit, false); DOCTEST_PARSE_AS_BOOL_OR_FLAG("duration", "d", duration, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("minimal", "m", minimal, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("quiet", "q", quiet, false); DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-throw", "nt", no_throw, false); DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-exitcode", "ne", no_exitcode, false); DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-run", "nr", no_run, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-intro", "ni", no_intro, false); DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-version", "nv", no_version, false); DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-colors", "nc", no_colors, false); DOCTEST_PARSE_AS_BOOL_OR_FLAG("force-colors", "fc", force_colors, false); @@ -6257,10 +6668,14 @@ void Context::clearFilters() { curr.clear(); } -// allows the user to override procedurally the int/bool options from the command line +// allows the user to override procedurally the bool options from the command line +void Context::setOption(const char* option, bool value) { + setOption(option, value ? "true" : "false"); +} + +// allows the user to override procedurally the int options from the command line void Context::setOption(const char* option, int value) { setOption(option, toString(value).c_str()); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) } // allows the user to override procedurally the string options from the command line @@ -6277,6 +6692,31 @@ void Context::setAsDefaultForAssertsOutOfTestCases() { g_cs = p; } void Context::setAssertHandler(detail::assert_handler ah) { p->ah = ah; } +void Context::setCout(std::ostream* out) { p->cout = out; } + +static class DiscardOStream : public std::ostream +{ +private: + class : public std::streambuf + { + private: + // allowing some buffering decreases the amount of calls to overflow + char buf[1024]; + + protected: + std::streamsize xsputn(const char_type*, std::streamsize count) override { return count; } + + int_type overflow(int_type ch) override { + setp(std::begin(buf), std::end(buf)); + return traits_type::not_eof(ch); + } + } discardBuf; + +public: + DiscardOStream() + : std::ostream(&discardBuf) {} +} discardOut; + // the main function that does all the filtering and test running int Context::run() { using namespace detail; @@ -6290,15 +6730,18 @@ int Context::run() { g_no_colors = p->no_colors; p->resetRunData(); - // stdout by default - p->cout = &std::cout; - p->cerr = &std::cerr; - - // or to a file if specified std::fstream fstr; - if(p->out.size()) { - fstr.open(p->out.c_str(), std::fstream::out); - p->cout = &fstr; + if(p->cout == nullptr) { + if(p->quiet) { + p->cout = &discardOut; + } else if(p->out.size()) { + // to a file if specified + fstr.open(p->out.c_str(), std::fstream::out); + p->cout = &fstr; + } else { + // stdout by default + p->cout = &std::cout; + } } FatalConditionHandler::allocateAltStackMem(); @@ -6370,7 +6813,7 @@ int Context::run() { // random_shuffle implementation const auto first = &testArray[0]; for(size_t i = testArray.size() - 1; i > 0; --i) { - int idxToSwap = std::rand() % (i + 1); // NOLINT + int idxToSwap = std::rand() % (i + 1); const auto temp = first[i]; @@ -6457,7 +6900,7 @@ int Context::run() { p->numAssertsFailedCurrentTest_atomic = 0; p->numAssertsCurrentTest_atomic = 0; - p->subcasesPassed.clear(); + p->fullyTraversedSubcases.clear(); DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_start, tc); @@ -6467,9 +6910,10 @@ int Context::run() { do { // reset some of the fields for subcases (except for the set of fully passed ones) - p->should_reenter = false; - p->subcasesCurrentMaxLevel = 0; - p->subcasesStack.clear(); + p->reachedLeaf = false; + // May not be empty if previous subcase exited via exception. + p->subcaseStack.clear(); + p->currentSubcaseDepth = 0; p->shouldLogCurrentException = true; @@ -6503,9 +6947,9 @@ DOCTEST_MSVC_SUPPRESS_WARNING_POP p->failure_flags |= TestCaseFailureReason::TooManyFailedAsserts; } - if(p->should_reenter && run_test) + if(!p->nextSubcaseStack.empty() && run_test) DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_reenter, tc); - if(!p->should_reenter) + if(p->nextSubcaseStack.empty()) run_test = false; } while(run_test); @@ -6531,17 +6975,10 @@ DOCTEST_MSVC_SUPPRESS_WARNING_POP DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, qdata); } - // see these issues on the reasoning for this: - // - https://github.com/onqtam/doctest/issues/143#issuecomment-414418903 - // - https://github.com/onqtam/doctest/issues/126 - auto DOCTEST_FIX_FOR_MACOS_LIBCPP_IOSFWD_STRING_LINK_ERRORS = []() DOCTEST_NOINLINE - { std::cout << std::string(); }; - DOCTEST_FIX_FOR_MACOS_LIBCPP_IOSFWD_STRING_LINK_ERRORS(); - return cleanup_and_return(); } -IReporter::~IReporter() = default; +DOCTEST_DEFINE_INTERFACE(IReporter) int IReporter::get_num_active_contexts() { return detail::g_infoContexts.size(); } const IContextScope* const* IReporter::get_active_contexts() { @@ -6576,5 +7013,7 @@ DOCTEST_CLANG_SUPPRESS_WARNING_POP DOCTEST_MSVC_SUPPRESS_WARNING_POP DOCTEST_GCC_SUPPRESS_WARNING_POP +DOCTEST_SUPPRESS_COMMON_WARNINGS_POP + #endif // DOCTEST_LIBRARY_IMPLEMENTATION #endif // DOCTEST_CONFIG_IMPLEMENT diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 22483f9e..f64b6157 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -333,7 +333,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) try { Luau::BytecodeBuilder bcb; - Luau::compileOrThrow(bcb, parseResult.root, parseNameTable, compileOptions); + Luau::compileOrThrow(bcb, parseResult, parseNameTable, compileOptions); bytecode = bcb.getBytecode(); } catch (const Luau::CompileError&) diff --git a/rfcs/README.md b/rfcs/README.md index f6c4c145..4b5e7b04 100644 --- a/rfcs/README.md +++ b/rfcs/README.md @@ -18,6 +18,13 @@ For changes in semantics, we should be asking: - Can it be sandboxed assuming malicious usage? - Is it compatible with type checking and other forms of static analysis? +For new standard library functions, we should be asking: + +- Is the new functionality used/useful often enough in existing code? +- Does the standard library implementation carry important performance benefits that can't be achieved in user code? +- Is the behavior general and unambiguous, as opposed to solving a problem / providing an interface that's too specific? +- Is the function interface amenable to type checking / linting? + In addition to these questions, we also need to consider that every addition carries a cost, and too many features will result in a language that is harder to learn, harder to implement and ensure consistent implementation quality throughout, slower, etc. In addition, any language is greater than the sum of its parts and features often have non-intuitive interactions with each other. Since reversing these decisions is incredibly costly and can be impossible due to backwards compatibility implications, all user facing changes to Luau language and core libraries must go through an RFC process. diff --git a/rfcs/STATUS.md b/rfcs/STATUS.md index 23a1be83..d2fe86f0 100644 --- a/rfcs/STATUS.md +++ b/rfcs/STATUS.md @@ -26,15 +26,3 @@ This document tracks unimplemented RFCs. [RFC: Lower bounds calculation](https://github.com/Roblox/luau/blob/master/rfcs/lower-bounds-calculation.md) **Status**: Implemented but not fully rolled out yet. - -## never and unknown types - -[RFC: never and unknown types](https://github.com/Roblox/luau/blob/master/rfcs/never-and-unknown-types.md) - -**Status**: Needs implementation - -## __len metamethod for tables and rawlen function - -[RFC: Support __len metamethod for tables and rawlen function](https://github.com/Roblox/luau/blob/master/rfcs/len-metamethod-rawlen.md) - -**Status**: Needs implementation diff --git a/rfcs/disallow-proposals-leading-to-ambiguity-in-grammar.md b/rfcs/disallow-proposals-leading-to-ambiguity-in-grammar.md new file mode 100644 index 00000000..d9c5c7d7 --- /dev/null +++ b/rfcs/disallow-proposals-leading-to-ambiguity-in-grammar.md @@ -0,0 +1,129 @@ +# Disallow `name T` and `name(T)` in future syntactic extensions for type annotations + +## Summary + +We propose to disallow the syntax `` `('`` as well as ` ` in future syntax extensions for type annotations to ensure that all existing programs continue to parse correctly. This still keeps the door open for future syntax extensions of different forms such as `` `<' `>'``. + +## Motivation + +Lua and by extension Luau's syntax is very free form, which means that when the parser finishes parsing a node, it doesn't try to look for a semi-colon or any termination token e.g. a `{` to start a block, or `;` to end a statement, or a newline, etc. It just immediately invokes the next parser to figure out how to parse the next node based on the remainder's starting token. + +That feature is sometimes quite troublesome when we want to add new syntax. + +We have had cases where we talked about using syntax like `setmetatable(T, MT)` and `keyof T`. They all look innocent, but when you look beyond that, and try to apply it onto Luau's grammar, things break down really fast. + +### `F(T)`? + +An example that _will_ cause a change in semantics: + +``` +local t: F +(u):m() +``` + +where today, `local t: F` is one statement, and `(u):m()` is another. If we had the syntax for `F(T)` here, it becomes invalid input because it gets parsed as + +``` +local t: F(u) +:m() +``` + +This is important because of the `setmetatable(T, MT)` case: + +``` +type Foo = setmetatable({ x: number }, { ... }) +``` + +For `setmetatable`, the parser isn't sure whether `{}` is actually a type or an expression, because _today_ `setmetatable` is parsed as a type reference, and `({}, {})` is the remainder that we'll attempt to parse as a statement. This means `{ x: number }` is invalid table _literal_. Recovery by backtracking is technically possible here, but this means performance loss on invalid input + may introduce false positives wrt how things are parsed. We'd much rather take a very strict stance about how things get parsed. + +### `F T`? + +An example that _will_ cause a change in semantics: + +``` +local function f(t): F T + (t or u):m() +end +``` + +where today, the return type annotation `F T` is simply parsed as just `F`, followed by a ambiguous parse error from the statement `T(t or u)` because its `(` is on the next line. If at some point in the future we were to allow `T` followed by `(` on the next line, then there's yet another semantic change. `F T` could be parsed as a type annotation and the first statement is `(t or u):m()` instead of `F` followed by `T(t or u):m()`. + +For `keyof`, here's a practical example of the above issue: + +``` +type Vec2 = {x: number, y: number} + +local function f(t, u): keyof Vec2 + (t or u):m() +end +``` + +There's three possible outcomes: + 1. Return type of `f` is `keyof`, statement throws a parse error because `(` is on the next line after `Vec2`, + 2. Return type of `f` is `keyof Vec2` and next statement is `(t or u):m()`, or + 3. Return type of `f` is `keyof` and next statement is `Vec2(t or u):m()` (if we allow `(` on the next line to be part of previous line). + +This particular case is even worse when we keep going: + +``` +local function f(t): F + T(t or u):m() +end +``` + +``` +local function f(t): F T + {1, 2, 3} +end +``` + +where today, `F` is the return type annotation of `f`, and `T(t or u):m()`/`T{1, 2, 3}` is the first statement, respectively. + +Adding some syntax for `F T` **will** cause the parser to change the semantics of the above three examples. + +### But what about `typeof(...)`? + +This syntax is grandfathered in because the parser supported `typeof(...)` before we stabilized our syntax, and especially before type annotations were released to the public, so we didn't need to worry about compatibility here. We are very glad that we used parentheses in this case, because it's natural for expressions to belong within parentheses `()`, and types to belong within angles `<>`. + +## The One Exception with a caveat + +This is a strict requirement! + +`function() -> ()` has been talked about in the past, and this one is different despite falling under the same category as `` `('``. The token `function` is in actual fact a "hard keyword," meaning that it cannot be parsed as a type annotation because it is not an identifier, just a keyword. + +Likewise, we also have talked about adding standalone `function` as a type annotation (semantics of it is irrelevant for this RFC) + +It's possible that we may end up adding both, but the requirements are as such: + 1. `function() -> ()` must be added first before standalone `function`, OR + 2. `function` can be added first, but with a future-proofing parse error if `<` or `(` follows after it + +If #1 is what ends up happening, there's not much to worry about because the type annotation parser will parse greedily already, so any new valid input will remain valid and have same semantics, except it also allows omitting of `(` and `<`. + +If #2 is what ends up happening, there could be a problem if we didn't future-proof against `<` and `(` to follow `function`: + +``` + return f :: function(T) -> U +``` + +which would be a parse error because at the point of `(` we expect one of `until`, `end`, or `EOF`, and + +``` + return f :: function(a) -> a +``` + +which would also be a parse error by the time we reach `->`, that is the production of the above is semantically equivalent to `(f < a) > (a)` which would compare whether the value of `f` is less than the value of `a`, then whether the result of that value is greater than `a`. + +## Alternatives + +Only allow these syntax when used inside parentheses e.g. `(F T)` or `(F(T))`. This makes it inconsistent with the existing `typeof(...)` type annotation, and changing that over is also breaking change. + +Support backtracking in the parser, so if `: MyType(t or u):m()` is invalid syntax, revert and parse `MyType` as a type, and `(t or u):m()` as an expression statement. Even so, this option is terrible for: + 1. parsing performance (backtracking means losing progress on invalid input), + 2. user experience (why was this annotation parsed as `X(...)` instead of `X` followed by a statement `(...)`), + 3. has false positives (`foo(bar)(baz)` may be parsed as `foo(bar)` as the type annotation and `(baz)` is the remainder to parse) + +## Drawbacks + +To be able to expose some kind of type-level operations using `F` syntax, means one of the following must be chosen: + 1. introduce the concept of "magic type functions" into type inference, or + 2. introduce them into the prelude as `export type F = ...` (where `...` is to be read as "we haven't decided") diff --git a/rfcs/len-metamethod-rawlen.md b/rfcs/len-metamethod-rawlen.md index 45284b71..60278dda 100644 --- a/rfcs/len-metamethod-rawlen.md +++ b/rfcs/len-metamethod-rawlen.md @@ -1,5 +1,7 @@ # Support `__len` metamethod for tables and `rawlen` function +**Status**: Implemented + ## Summary `__len` metamethod will be called by `#` operator on tables, matching Lua 5.2 diff --git a/rfcs/never-and-unknown-types.md b/rfcs/never-and-unknown-types.md index d996afc6..5ad216ef 100644 --- a/rfcs/never-and-unknown-types.md +++ b/rfcs/never-and-unknown-types.md @@ -1,5 +1,7 @@ # never and unknown types +**Status**: Implemented + ## Summary Add `unknown` and `never` types that are inhabited by everything and nothing respectively. diff --git a/rfcs/syntax-string-interpolation.md b/rfcs/syntax-string-interpolation.md index 208143a0..ad182620 100644 --- a/rfcs/syntax-string-interpolation.md +++ b/rfcs/syntax-string-interpolation.md @@ -31,9 +31,9 @@ Because we care about backward compatibility, we need some new syntax in order t 1. A string chunk (`` `...{ ``, `}...{`, and `` }...` ``) where `...` is a range of 0 to many characters. * `\` escapes `` ` ``, `{`, and itself `\`. - * Restriction: the string interpolation literal must have at least one value to interpolate. We do not need 3 ways to express a single line string literal. * The pairs must be on the same line (unless a `\` escapes the newline) but expressions needn't be on the same line. 2. An expression between the braces. This is the value that will be interpolated into the string. + * Restriction: we explicitly reject `{{` as it is considered an attempt to escape and get a single `{` character at runtime. 3. Formatting specification may follow after the expression, delimited by an unambiguous character. * Restriction: the formatting specification must be constant at parse time. * In the absence of an explicit formatting specification, the `%*` token will be used. @@ -61,7 +61,6 @@ local set2 = Set.new({0, 5, 4}) print(`{set1} ∪ {set2} = {Set.union(set1, set2)}`) --> {0, 1, 3} ∪ {0, 5, 4} = {0, 1, 3, 4, 5} --- For illustrative purposes. These are illegal specifically because they don't interpolate anything. print(`Some example escaping the braces \{like so}`) print(`backslash \ that escapes the space is not a part of the string...`) print(`backslash \\ will escape the second backslash...`) @@ -88,13 +87,25 @@ print(`Welcome to \ -- Luau! ``` -This expression will not be allowed to come after a `prefixexp`. I believe this is fully additive, so a future RFC may allow this. So for now, we explicitly reject the following: +This expression can also come after a `prefixexp`: ``` local name = "world" print`Hello {name}` ``` +The restriction on `{{` exists solely for the people coming from languages e.g. C#, Rust, or Python which uses `{{` to escape and get the character `{` at runtime. We're also rejecting this at parse time too, since the proper way to escape it is `\{`, so: + +```lua +print(`{{1, 2, 3}} = {myCoolSet}`) -- parse error +``` + +If we did not apply this as a parse error, then the above would wind up printing as the following, which is obviously a gotcha we can and should avoid. + +``` +--> table: 0xSOMEADDRESS = {1, 2, 3} +``` + Since the string interpolation expression is going to be lowered into a `string.format` call, we'll also need to extend `string.format`. The bare minimum to support the lowering is to add a new token whose definition is to perform a `tostring` call. `%*` is currently an invalid token, so this is a backward compatible extension. This RFC shall define `%*` to have the same behavior as if `tostring` was called. ```lua @@ -121,6 +132,13 @@ print(string.format("%* %* %*", return_two_nils())) --> error: value #3 is missing, got 2 ``` +It must be said that we are not allowing this style of string literals in type annotations at this time, regardless of zero or many interpolating expressions, so the following two type annotations below are illegal syntax: + +```lua +local foo: `foo` +local bar: `bar{baz}` +``` + ## Drawbacks If we want to use backticks for other purposes, it may introduce some potential ambiguity. One option to solve that is to only ever produce string interpolation tokens from the context of an expression. This is messy but doable because the parser and the lexer are already implemented to work in tandem. The other option is to pick a different delimiter syntax to keep backticks available for use in the future. diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 7f863c6f..28ce6a82 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -155,6 +155,13 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseBinaryInstructionForms") SINGLE_COMPARE(add(qword[rax + r13 * 2 + 0x1b], rsi), 0x4a, 0x01, 0x74, 0x68, 0x1b); SINGLE_COMPARE(add(qword[rbp + rbx * 2], rsi), 0x48, 0x01, 0x74, 0x5d, 0x00); SINGLE_COMPARE(add(qword[rsp + r10 * 2 + 0x1b], r10), 0x4e, 0x01, 0x54, 0x54, 0x1b); + + // [addr], imm + SINGLE_COMPARE(add(byte[rax], 2), 0x80, 0x00, 0x02); + SINGLE_COMPARE(add(dword[rax], 2), 0x83, 0x00, 0x02); + SINGLE_COMPARE(add(dword[rax], 0xabcd), 0x81, 0x00, 0xcd, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(qword[rax], 2), 0x48, 0x83, 0x00, 0x02); + SINGLE_COMPARE(add(qword[rax], 0xabcd), 0x48, 0x81, 0x00, 0xcd, 0xab, 0x00, 0x00); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseUnaryInstructionForms") @@ -213,6 +220,16 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") SINGLE_COMPARE(lea(rax, qword[r13 + r12 * 4 + 4]), 0x4b, 0x8d, 0x44, 0xa5, 0x04); } +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfAbsoluteJumps") +{ + SINGLE_COMPARE(jmp(rax), 0x48, 0xff, 0xe0); + SINGLE_COMPARE(jmp(r14), 0x49, 0xff, 0xe6); + SINGLE_COMPARE(jmp(qword[r14 + rdx * 4]), 0x49, 0xff, 0x24, 0x96); + SINGLE_COMPARE(call(rax), 0x48, 0xff, 0xd0); + SINGLE_COMPARE(call(r14), 0x49, 0xff, 0xd6); + SINGLE_COMPARE(call(qword[r14 + rdx * 4]), 0x49, 0xff, 0x14, 0x96); +} + TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") { // Jump back @@ -260,6 +277,23 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") {0xe9, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xe7, 0x3e}); } +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall") +{ + check( + [](AssemblyBuilderX64& build) { + Label fnB; + + build.and_(rcx, 0x3e); + build.call(fnB); + build.ret(); + + build.setLabel(fnB); + build.lea(rax, qword[rcx + 0x1f]); + build.ret(); + }, + {0x48, 0x83, 0xe1, 0x3e, 0xe8, 0x01, 0x00, 0x00, 0x00, 0xc3, 0x48, 0x8d, 0x41, 0x1f, 0xc3}); +} + TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") { SINGLE_COMPARE(vaddpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xa9, 0x58, 0xc6); @@ -277,6 +311,13 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") SINGLE_COMPARE(vaddps(xmm9, xmm12, xmmword[r9 + r14 * 2 + 0x1c]), 0xc4, 0x01, 0x98, 0x58, 0x4c, 0x71, 0x1c); SINGLE_COMPARE(vaddps(ymm1, ymm2, ymm3), 0xc4, 0xe1, 0xec, 0x58, 0xcb); SINGLE_COMPARE(vaddps(ymm9, ymm12, ymmword[r9 + r14 * 2 + 0x1c]), 0xc4, 0x01, 0x9c, 0x58, 0x4c, 0x71, 0x1c); + + // Coverage for other instructions that follow the same pattern + SINGLE_COMPARE(vsubsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x5c, 0xc6); + SINGLE_COMPARE(vmulsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x59, 0xc6); + SINGLE_COMPARE(vdivsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x5e, 0xc6); + + SINGLE_COMPARE(vxorpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xa9, 0x57, 0xc6); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXUnaryMergeInstructionForms") @@ -291,6 +332,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXUnaryMergeInstructionForms") SINGLE_COMPARE(vsqrtsd(xmm8, xmm10, qword[r9]), 0xc4, 0x41, 0xab, 0x51, 0x01); SINGLE_COMPARE(vsqrtss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xaa, 0x51, 0xc6); SINGLE_COMPARE(vsqrtss(xmm8, xmm10, dword[r9]), 0xc4, 0x41, 0xaa, 0x51, 0x01); + + // Coverage for other instructions that follow the same pattern + SINGLE_COMPARE(vcomisd(xmm8, xmm10), 0xc4, 0x41, 0xf9, 0x2f, 0xc2); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXMoveInstructionForms") @@ -315,6 +359,11 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXMoveInstructionForms") SINGLE_COMPARE(vmovups(ymm8, ymmword[r9]), 0xc4, 0x41, 0xfc, 0x10, 0x01); } +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions") +{ + SINGLE_COMPARE(int3(), 0xcc); +} + TEST_CASE("LogTest") { AssemblyBuilderX64 build(/* logText= */ true); @@ -339,6 +388,7 @@ TEST_CASE("LogTest") build.vmovapd(xmmword[rax], xmm11); build.pop(r12); build.ret(); + build.int3(); build.finalize(); @@ -361,6 +411,7 @@ TEST_CASE("LogTest") vmovapd xmmword ptr [rax],xmm11 pop r12 ret + int3 )"; CHECK(same); } diff --git a/tests/JsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp similarity index 60% rename from tests/JsonEncoder.test.cpp rename to tests/AstJsonEncoder.test.cpp index 8a263bd2..3ff36741 100644 --- a/tests/JsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Ast.h" -#include "Luau/JsonEncoder.h" +#include "Luau/AstJsonEncoder.h" #include "Luau/Parser.h" #include "doctest.h" @@ -56,10 +56,19 @@ TEST_CASE("encode_constants") AstExprConstantNil nil{Location()}; AstExprConstantBool b{Location(), true}; AstExprConstantNumber n{Location(), 8.2}; + AstExprConstantNumber bigNum{Location(), 0.1677721600000003}; + + AstArray charString; + charString.data = const_cast("a\x1d\0\\\"b"); + charString.size = 6; + + AstExprConstantString needsEscaping{Location(), charString}; CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil)); CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b)); - CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":8.2})", toJson(&n)); + CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":8.1999999999999993})", toJson(&n)); + CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":0.16777216000000031})", toJson(&bigNum)); + CHECK_EQ("{\"type\":\"AstExprConstantString\",\"location\":\"0,0 - 0,0\",\"value\":\"a\\u001d\\u0000\\\\\\\"b\"}", toJson(&needsEscaping)); } TEST_CASE("basic_escaping") @@ -87,7 +96,7 @@ TEST_CASE("encode_AstStatBlock") AstStatBlock block{Location(), bodyArray}; CHECK_EQ( - (R"({"type":"AstStatBlock","location":"0,0 - 0,0","body":[{"type":"AstStatLocal","location":"0,0 - 0,0","vars":[{"type":null,"name":"a_local","location":"0,0 - 0,0"}],"values":[]}]})"), + (R"({"type":"AstStatBlock","location":"0,0 - 0,0","body":[{"type":"AstStatLocal","location":"0,0 - 0,0","vars":[{"luauType":null,"name":"a_local","type":"AstLocal","location":"0,0 - 0,0"}],"values":[]}]})"), toJson(&block)); } @@ -106,7 +115,31 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_tables") CHECK( json == - R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); + R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"luauType":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","type":"AstTableProp","location":"2,12 - 2,15","propType":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":null},"name":"x","type":"AstLocal","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"type":"AstExprTableItem","kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_table_array") +{ + std::string src = R"(type X = {string})"; + + AstStatBlock* root = expectParse(src); + std::string json = toJson(root); + + CHECK( + json == + R"({"type":"AstStatBlock","location":"0,0 - 0,17","body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"type":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","parameters":[]}}},"exported":false}]})"); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_table_indexer") +{ + std::string src = R"(type X = {string})"; + + AstStatBlock* root = expectParse(src); + std::string json = toJson(root); + + CHECK( + json == + R"({"type":"AstStatBlock","location":"0,0 - 0,17","body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"type":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","parameters":[]}}},"exported":false}]})"); } TEST_CASE("encode_AstExprGroup") @@ -132,12 +165,24 @@ TEST_CASE("encode_AstExprGlobal") CHECK(json == expected); } +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprIfThen") +{ + AstStat* statement = expectParseStatement("local a = if x then y else z"); + + std::string_view expected = + R"({"type":"AstStatLocal","location":"0,0 - 0,28","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,6 - 0,7"}],"values":[{"type":"AstExprIfElse","location":"0,10 - 0,28","condition":{"type":"AstExprGlobal","location":"0,13 - 0,14","global":"x"},"hasThen":true,"trueExpr":{"type":"AstExprGlobal","location":"0,20 - 0,21","global":"y"},"hasElse":true,"falseExpr":{"type":"AstExprGlobal","location":"0,27 - 0,28","global":"z"}}]})"; + + CHECK(toJson(statement) == expected); +} + + TEST_CASE("encode_AstExprLocal") { AstLocal local{AstName{"foo"}, Location{}, nullptr, 0, 0, nullptr}; AstExprLocal exprLocal{Location{}, &local, false}; - CHECK(toJson(&exprLocal) == R"({"type":"AstExprLocal","location":"0,0 - 0,0","local":{"type":null,"name":"foo","location":"0,0 - 0,0"}})"); + CHECK(toJson(&exprLocal) == + R"({"type":"AstExprLocal","location":"0,0 - 0,0","local":{"luauType":null,"name":"foo","type":"AstLocal","location":"0,0 - 0,0"}})"); } TEST_CASE("encode_AstExprVarargs") @@ -181,7 +226,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprFunction") AstExpr* expr = expectParseExpr("function (a) return a end"); std::string_view expected = - R"({"type":"AstExprFunction","location":"0,4 - 0,29","generics":[],"genericPacks":[],"args":[{"type":null,"name":"a","location":"0,14 - 0,15"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,16 - 0,26","body":[{"type":"AstStatReturn","location":"0,17 - 0,25","list":[{"type":"AstExprLocal","location":"0,24 - 0,25","local":{"type":null,"name":"a","location":"0,14 - 0,15"}}]}]},"functionDepth":1,"debugname":"","hasEnd":true})"; + R"({"type":"AstExprFunction","location":"0,4 - 0,29","generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,16 - 0,26","body":[{"type":"AstStatReturn","location":"0,17 - 0,25","list":[{"type":"AstExprLocal","location":"0,24 - 0,25","local":{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}}]}]},"functionDepth":1,"debugname":"","hasEnd":true})"; CHECK(toJson(expr) == expected); } @@ -191,7 +236,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprTable") AstExpr* expr = expectParseExpr("{true, key=true, [key2]=true}"); std::string_view expected = - R"({"type":"AstExprTable","location":"0,4 - 0,33","items":[{"kind":"item","value":{"type":"AstExprConstantBool","location":"0,5 - 0,9","value":true}},{"kind":"record","key":{"type":"AstExprConstantString","location":"0,11 - 0,14","value":"key"},"value":{"type":"AstExprConstantBool","location":"0,15 - 0,19","value":true}},{"kind":"general","key":{"type":"AstExprGlobal","location":"0,22 - 0,26","global":"key2"},"value":{"type":"AstExprConstantBool","location":"0,28 - 0,32","value":true}}]})"; + R"({"type":"AstExprTable","location":"0,4 - 0,33","items":[{"type":"AstExprTableItem","kind":"item","value":{"type":"AstExprConstantBool","location":"0,5 - 0,9","value":true}},{"type":"AstExprTableItem","kind":"record","key":{"type":"AstExprConstantString","location":"0,11 - 0,14","value":"key"},"value":{"type":"AstExprConstantBool","location":"0,15 - 0,19","value":true}},{"type":"AstExprTableItem","kind":"general","key":{"type":"AstExprGlobal","location":"0,22 - 0,26","global":"key2"},"value":{"type":"AstExprConstantBool","location":"0,28 - 0,32","value":true}}]})"; CHECK(toJson(expr) == expected); } @@ -201,7 +246,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprUnary") AstExpr* expr = expectParseExpr("-b"); std::string_view expected = - R"({"type":"AstExprUnary","location":"0,4 - 0,6","op":"minus","expr":{"type":"AstExprGlobal","location":"0,5 - 0,6","global":"b"}})"; + R"({"type":"AstExprUnary","location":"0,4 - 0,6","op":"Minus","expr":{"type":"AstExprGlobal","location":"0,5 - 0,6","global":"b"}})"; CHECK(toJson(expr) == expected); } @@ -259,7 +304,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatWhile") AstStat* statement = expectParseStatement("while true do end"); std::string_view expected = - R"({"type":"AtStatWhile","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasDo":true,"hasEnd":true})"; + R"({"type":"AstStatWhile","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasDo":true,"hasEnd":true})"; CHECK(toJson(statement) == expected); } @@ -279,7 +324,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatBreak") AstStat* statement = expectParseStatement("while true do break end"); std::string_view expected = - R"({"type":"AtStatWhile","location":"0,0 - 0,23","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,20","body":[{"type":"AstStatBreak","location":"0,14 - 0,19"}]},"hasDo":true,"hasEnd":true})"; + R"({"type":"AstStatWhile","location":"0,0 - 0,23","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,20","body":[{"type":"AstStatBreak","location":"0,14 - 0,19"}]},"hasDo":true,"hasEnd":true})"; CHECK(toJson(statement) == expected); } @@ -289,7 +334,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatContinue") AstStat* statement = expectParseStatement("while true do continue end"); std::string_view expected = - R"({"type":"AtStatWhile","location":"0,0 - 0,26","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,23","body":[{"type":"AstStatContinue","location":"0,14 - 0,22"}]},"hasDo":true,"hasEnd":true})"; + R"({"type":"AstStatWhile","location":"0,0 - 0,26","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,23","body":[{"type":"AstStatContinue","location":"0,14 - 0,22"}]},"hasDo":true,"hasEnd":true})"; CHECK(toJson(statement) == expected); } @@ -299,7 +344,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatFor") AstStat* statement = expectParseStatement("for a=0,1 do end"); std::string_view expected = - R"({"type":"AstStatFor","location":"0,0 - 0,16","var":{"type":null,"name":"a","location":"0,4 - 0,5"},"from":{"type":"AstExprConstantNumber","location":"0,6 - 0,7","value":0},"to":{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},"body":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"hasDo":true,"hasEnd":true})"; + R"({"type":"AstStatFor","location":"0,0 - 0,16","var":{"luauType":null,"name":"a","type":"AstLocal","location":"0,4 - 0,5"},"from":{"type":"AstExprConstantNumber","location":"0,6 - 0,7","value":0},"to":{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},"body":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"hasDo":true,"hasEnd":true})"; CHECK(toJson(statement) == expected); } @@ -309,7 +354,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatForIn") AstStat* statement = expectParseStatement("for a in b do end"); std::string_view expected = - R"({"type":"AstStatForIn","location":"0,0 - 0,17","vars":[{"type":null,"name":"a","location":"0,4 - 0,5"}],"values":[{"type":"AstExprGlobal","location":"0,9 - 0,10","global":"b"}],"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasIn":true,"hasDo":true,"hasEnd":true})"; + R"({"type":"AstStatForIn","location":"0,0 - 0,17","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,4 - 0,5"}],"values":[{"type":"AstExprGlobal","location":"0,9 - 0,10","global":"b"}],"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasIn":true,"hasDo":true,"hasEnd":true})"; CHECK(toJson(statement) == expected); } @@ -329,7 +374,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatLocalFunction") AstStat* statement = expectParseStatement("local function a(b) return end"); std::string_view expected = - R"({"type":"AstStatLocalFunction","location":"0,0 - 0,30","name":{"type":null,"name":"a","location":"0,15 - 0,16"},"func":{"type":"AstExprFunction","location":"0,0 - 0,30","generics":[],"genericPacks":[],"args":[{"type":null,"name":"b","location":"0,17 - 0,18"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,19 - 0,27","body":[{"type":"AstStatReturn","location":"0,20 - 0,26","list":[]}]},"functionDepth":1,"debugname":"a","hasEnd":true}})"; + R"({"type":"AstStatLocalFunction","location":"0,0 - 0,30","name":{"luauType":null,"name":"a","type":"AstLocal","location":"0,15 - 0,16"},"func":{"type":"AstExprFunction","location":"0,0 - 0,30","generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"b","type":"AstLocal","location":"0,17 - 0,18"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,19 - 0,27","body":[{"type":"AstStatReturn","location":"0,20 - 0,26","list":[]}]},"functionDepth":1,"debugname":"a","hasEnd":true}})"; CHECK(toJson(statement) == expected); } @@ -349,7 +394,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction") AstStat* statement = expectParseStatement("declare function foo(x: number): string"); std::string_view expected = - R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","params":{"types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","parameters":[]}]},"retTypes":{"types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","parameters":[]}]},"generics":[],"genericPacks":[]})"; + R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","parameters":[]}]},"retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","parameters":[]}]},"generics":[],"genericPacks":[]})"; CHECK(toJson(statement) == expected); } @@ -370,11 +415,11 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") REQUIRE(2 == root->body.size); std::string_view expected1 = - R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","parameters":[]}},{"name":"method","type":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","parameters":[]}]},"returnTypes":{"types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","parameters":[]}]}}}]})"; + R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","parameters":[]}},{"name":"method","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","parameters":[]}]}}}]})"; CHECK(toJson(root->body.data[0]) == expected1); std::string_view expected2 = - R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","parameters":[]}}]})"; + R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","parameters":[]}}]})"; CHECK(toJson(root->body.data[1]) == expected2); } @@ -383,7 +428,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_annotation") AstStat* statement = expectParseStatement("type T = ((number) -> (string | nil)) & ((string) -> ())"); std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,35","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","parameters":[]}]},"returnTypes":{"types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","parameters":[]}]},"returnTypes":{"types":[]}}]},"exported":false})"; + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,35","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; CHECK(toJson(statement) == expected); } @@ -411,7 +456,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypePackExplicit") CHECK(2 == root->body.size); std::string_view expected = - R"({"type":"AstStatLocal","location":"2,8 - 2,36","vars":[{"type":{"type":"AstTypeReference","location":"2,17 - 2,36","name":"A","parameters":[{"type":"AstTypePackExplicit","location":"2,19 - 2,20","typeList":{"types":[{"type":"AstTypeReference","location":"2,20 - 2,26","name":"number","parameters":[]},{"type":"AstTypeReference","location":"2,28 - 2,34","name":"string","parameters":[]}]}}]},"name":"a","location":"2,14 - 2,15"}],"values":[]})"; + R"({"type":"AstStatLocal","location":"2,8 - 2,36","vars":[{"luauType":{"type":"AstTypeReference","location":"2,17 - 2,36","name":"A","parameters":[{"type":"AstTypePackExplicit","location":"2,19 - 2,20","typeList":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"2,20 - 2,26","name":"number","parameters":[]},{"type":"AstTypeReference","location":"2,28 - 2,34","name":"string","parameters":[]}]}}]},"name":"a","type":"AstLocal","location":"2,14 - 2,15"}],"values":[]})"; CHECK(toJson(root->body.data[1]) == expected); } diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index f0017509..6ec1426c 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -105,4 +105,37 @@ if true then REQUIRE(parentStat->is()); } +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_at_number_const") +{ + check(R"( +print(3.) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 8)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_in_workspace_dot") +{ + check(R"( +print(workspace.) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 16)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_in_workspace_colon") +{ + check(R"( +print(workspace:) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 16)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index f3b0bcad..75c5a606 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1974,7 +1974,7 @@ TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses check(R"( local function foo() return 1 end local function bar(a: number) return -a end -local abc = bar(@1) +local abc = bar(@1) )"); auto ac = autocomplete('1'); @@ -2240,43 +2240,18 @@ local a: aaa.do CHECK(ac.entryMap.count("other")); } -TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteSource") + +TEST_CASE_FIXTURE(ACFixture, "comments") { - std::string_view source = R"( - local a = table. -- Line 1 - -- | Column 23 - )"; + fileResolver.source["Comments"] = "--!str"; - auto ac = autocompleteSource(frontend, source, Position{1, 24}, nullCallback).result; - - CHECK_EQ(17, ac.entryMap.size()); - CHECK(ac.entryMap.count("find")); - CHECK(ac.entryMap.count("pack")); - CHECK(!ac.entryMap.count("math")); -} - -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_require") -{ - std::string_view source = R"( - local a = require(w -- Line 1 - -- | Column 27 - )"; - - // CLI-43699 require shouldn't crash inside autocompleteSource - auto ac = autocompleteSource(frontend, source, Position{1, 27}, nullCallback).result; -} - -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_comments") -{ - std::string_view source = "--!str"; - - auto ac = autocompleteSource(frontend, source, Position{0, 6}, nullCallback).result; + auto ac = Luau::autocomplete(frontend, "Comments", Position{0, 6}, nullCallback); CHECK_EQ(0, ac.entryMap.size()); } TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteProp_index_function_metamethod_is_variadic") { - std::string_view source = R"( + fileResolver.source["Module/A"] = R"( type Foo = {x: number} local t = {} setmetatable(t, { @@ -2289,7 +2264,7 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteProp_index_function_metamethod -- | Column 20 )"; - auto ac = autocompleteSource(frontend, source, Position{9, 20}, nullCallback).result; + auto ac = Luau::autocomplete(frontend, "Module/A", Position{9, 20}, nullCallback); REQUIRE_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); } @@ -2378,35 +2353,36 @@ end CHECK(ac.entryMap.count("elsewhere")); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_not_the_var_we_are_defining") +TEST_CASE_FIXTURE(ACFixture, "not_the_var_we_are_defining") { - std::string_view source = "abc,de"; + fileResolver.source["Module/A"] = "abc,de"; - auto ac = autocompleteSource(frontend, source, Position{0, 6}, nullCallback).result; + auto ac = Luau::autocomplete(frontend, "Module/A", Position{0, 6}, nullCallback); CHECK(!ac.entryMap.count("de")); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_recursive_function") +TEST_CASE_FIXTURE(ACFixture, "recursive_function_global") { - { - std::string_view global = R"(function abc() + fileResolver.source["global"] = R"(function abc() end )"; - auto ac = autocompleteSource(frontend, global, Position{1, 0}, nullCallback).result; - CHECK(ac.entryMap.count("abc")); - } + auto ac = Luau::autocomplete(frontend, "global", Position{1, 0}, nullCallback); + CHECK(ac.entryMap.count("abc")); +} - { - std::string_view local = R"(local function abc() + + +TEST_CASE_FIXTURE(ACFixture, "recursive_function_local") +{ + fileResolver.source["local"] = R"(local function abc() end )"; - auto ac = autocompleteSource(frontend, local, Position{1, 0}, nullCallback).result; - CHECK(ac.entryMap.count("abc")); - } + auto ac = Luau::autocomplete(frontend, "local", Position{1, 0}, nullCallback); + CHECK(ac.entryMap.count("abc")); } TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys") @@ -2869,7 +2845,7 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_on_class") { - ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; + ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; loadDefinition(R"( declare class Foo @@ -2907,9 +2883,25 @@ t.@1 } } +TEST_CASE_FIXTURE(ACFixture, "do_compatible_self_calls") +{ + ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; + + check(R"( +local t = {} +function t:m() end +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + CHECK(!ac.entryMap["m"].wrongIndexType); +} + TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") { - ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; + ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; check(R"( local t = {} @@ -2925,7 +2917,7 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") { - ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; + ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; check(R"( local f: (() -> number) & ((number) -> number) = function(x: number?) return 2 end @@ -2940,7 +2932,7 @@ t:@1 CHECK(ac.entryMap["f"].wrongIndexType); } -TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_provisional") +TEST_CASE_FIXTURE(ACFixture, "do_wrong_compatible_self_calls") { check(R"( local t = {} @@ -2955,9 +2947,26 @@ t:@1 CHECK(!ac.entryMap["m"].wrongIndexType); } +TEST_CASE_FIXTURE(ACFixture, "no_wrong_compatible_self_calls_with_generics") +{ + ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; + + check(R"( +local t = {} +function t.m(a: T) end +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + // While this call is compatible with the type, this requires instantiation of a generic type which we don't perform + CHECK(ac.entryMap["m"].wrongIndexType); +} + TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; + ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; check(R"( local s = "hello" @@ -2976,7 +2985,7 @@ s:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") { - ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; + ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; check(R"( local s = "hello" @@ -2993,7 +3002,7 @@ s.@1 TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_non_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; + ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; check(R"( string.@1 @@ -3024,7 +3033,7 @@ table.@1 TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_self_calls_are_invalid") { - ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; + ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; check(R"( string:@1 @@ -3036,8 +3045,11 @@ string:@1 CHECK(ac.entryMap["byte"].wrongIndexType == true); REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == true); + + // We want the next test to evaluate to 'true', but we have to allow function defined with 'self' to be callable with ':' + // We may change the definition of the string metatable to not use 'self' types in the future (like byte/char/pack/unpack) REQUIRE(ac.entryMap.count("sub")); - CHECK(ac.entryMap["sub"].wrongIndexType == true); + CHECK(ac.entryMap["sub"].wrongIndexType == false); } TEST_CASE_FIXTURE(ACFixture, "source_module_preservation_and_invalidation") diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index fafafd71..0a1c5a8b 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -165,8 +165,8 @@ LOADN R1 1 FASTCALL2K 18 R1 K0 L0 LOADK R2 K0 GETIMPORT R0 3 -L0: CALL R0 2 -1 -RETURN R0 -1 +CALL R0 2 -1 +L0: RETURN R0 -1 )"); } @@ -2100,12 +2100,12 @@ FASTCALL2 18 R0 R1 L0 MOVE R5 R0 MOVE R6 R1 GETIMPORT R4 2 -L0: CALL R4 2 1 -FASTCALL2 19 R4 R2 L1 +CALL R4 2 1 +L0: FASTCALL2 19 R4 R2 L1 MOVE R5 R2 GETIMPORT R3 4 -L1: CALL R3 2 -1 -RETURN R3 -1 +CALL R3 2 -1 +L1: RETURN R3 -1 )"); } @@ -2382,11 +2382,13 @@ end TEST_CASE("DebugLineInfoRepeatUntil") { + ScopedFastFlag sff("LuauCompileXEQ", true); + CHECK_EQ("\n" + compileFunction0Coverage(R"( local f = 0 repeat f += 1 - if f == 1 then + if f == 1 then print(f) else f = 0 @@ -2397,13 +2399,13 @@ until f == 0 R"( 2: LOADN R0 0 4: L0: ADDK R0 R0 K0 -5: JUMPIFNOTEQK R0 K0 L1 +5: JUMPXEQKN R0 K0 L1 NOT 6: GETIMPORT R1 2 6: MOVE R2 R0 6: CALL R1 1 0 6: JUMP L2 8: L1: LOADN R0 0 -10: L2: JUMPIFEQK R0 K3 L3 +10: L2: JUMPXEQKN R0 K3 L3 10: JUMPBACK L0 11: L3: RETURN R0 0 )"); @@ -2511,8 +2513,8 @@ return 5: MOVE R3 R0 5: MOVE R4 R1 5: GETIMPORT R2 2 -5: L0: CALL R2 2 -1 -5: RETURN R2 -1 +5: CALL R2 2 -1 +5: L0: RETURN R2 -1 )"); } @@ -2828,8 +2830,8 @@ TEST_CASE("FastcallBytecode") LOADN R1 -5 FASTCALL1 2 R1 L0 GETIMPORT R0 2 -L0: CALL R0 1 -1 -RETURN R0 -1 +CALL R0 1 -1 +L0: RETURN R0 -1 )"); // call through a local variable @@ -2838,8 +2840,8 @@ GETIMPORT R0 2 LOADN R2 -5 FASTCALL1 2 R2 L0 MOVE R1 R0 -L0: CALL R1 1 -1 -RETURN R1 -1 +CALL R1 1 -1 +L0: RETURN R1 -1 )"); // call through an upvalue @@ -2847,8 +2849,8 @@ RETURN R1 -1 LOADN R1 -5 FASTCALL1 2 R1 L0 GETUPVAL R0 0 -L0: CALL R0 1 -1 -RETURN R0 -1 +CALL R0 1 -1 +L0: RETURN R0 -1 )"); // mutating the global in the script breaks the optimization @@ -2893,8 +2895,8 @@ LOADK R1 K0 FASTCALL1 57 R1 L0 GETIMPORT R0 2 GETVARARGS R2 -1 -L0: CALL R0 -1 1 -RETURN R0 1 +CALL R0 -1 1 +L0: RETURN R0 1 )"); // more complex example: select inside a for loop bound + select from a iterator @@ -2912,16 +2914,16 @@ LOADK R5 K0 FASTCALL1 57 R5 L0 GETIMPORT R4 2 GETVARARGS R6 -1 -L0: CALL R4 -1 1 -MOVE R1 R4 +CALL R4 -1 1 +L0: MOVE R1 R4 LOADN R2 1 FORNPREP R1 L3 L1: FASTCALL1 57 R3 L2 GETIMPORT R4 2 MOVE R5 R3 GETVARARGS R6 -1 -L2: CALL R4 -1 1 -ADD R0 R0 R4 +CALL R4 -1 1 +L2: ADD R0 R0 R4 FORNLOOP R1 L1 L3: RETURN R0 1 )"); @@ -3242,7 +3244,7 @@ LOADN R2 -1 FASTCALL1 2 R2 L0 GETGLOBAL R3 K1024 GETTABLEKS R1 R3 K1025 -L0: CALL R1 1 -1 +CALL R1 1 -1 )"); } @@ -3509,13 +3511,15 @@ RETURN R0 1 TEST_CASE("ConstantJumpCompare") { + ScopedFastFlag sff("LuauCompileXEQ", true); + CHECK_EQ("\n" + compileFunction0(R"( local obj = ... local b = obj == 1 )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 L0 +JUMPXEQKN R0 K0 L0 LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 @@ -3527,7 +3531,7 @@ local b = 1 == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 L0 +JUMPXEQKN R0 K0 L0 LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 @@ -3539,7 +3543,7 @@ local b = "Hello, Sailor!" == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 L0 +JUMPXEQKS R0 K0 L0 LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 @@ -3551,7 +3555,7 @@ local b = nil == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 L0 +JUMPXEQKNIL R0 L0 LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 @@ -3563,7 +3567,7 @@ local b = true == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 L0 +JUMPXEQKB R0 1 L0 LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 @@ -3575,7 +3579,7 @@ local b = nil ~= obj )"), R"( GETVARARGS R0 1 -JUMPIFNOTEQK R0 K0 L0 +JUMPXEQKNIL R0 L0 NOT LOADB R1 0 +1 L0: LOADB R1 1 L1: RETURN R0 0 @@ -3793,6 +3797,8 @@ RETURN R0 1 TEST_CASE("SharedClosure") { + ScopedFastFlag sff("LuauCompileFreeReassign", true); + // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( local val = ... @@ -3940,11 +3946,10 @@ LOADN R2 1 LOADN R0 10 LOADN R1 1 FORNPREP R0 L5 -L4: MOVE R3 R2 -GETIMPORT R4 1 -NEWCLOSURE R5 P2 -CAPTURE VAL R3 -CALL R4 1 0 +L4: GETIMPORT R3 1 +NEWCLOSURE R4 P2 +CAPTURE VAL R2 +CALL R3 1 0 FORNLOOP R0 L4 L5: RETURN R0 0 )"); @@ -4063,8 +4068,8 @@ LOADN R2 2 LOADN R3 3 FASTCALL 54 L0 GETIMPORT R0 2 -L0: CALL R0 3 -1 -RETURN R0 -1 +CALL R0 3 -1 +L0: RETURN R0 -1 )"); } @@ -4414,7 +4419,7 @@ L2: RETURN R0 0 // continue needs to properly close upvalues CHECK_EQ("\n" + compileFunction(R"( for i=1,1 do - local j = math.abs(i) + local j = global(i) print(function() return j end) if math.random() < 0.5 then continue @@ -4424,21 +4429,20 @@ end )", 1, 2), R"( +GETIMPORT R0 1 LOADN R1 1 -FASTCALL1 2 R1 L0 -GETIMPORT R0 2 -L0: CALL R0 1 1 -GETIMPORT R1 4 +CALL R0 1 1 +GETIMPORT R1 3 NEWCLOSURE R2 P0 CAPTURE REF R0 CALL R1 1 0 GETIMPORT R1 6 CALL R1 0 1 LOADK R2 K7 -JUMPIFNOTLT R1 R2 L1 +JUMPIFNOTLT R1 R2 L0 CLOSEUPVALS R0 RETURN R0 0 -L1: ADDK R0 R0 K8 +L0: ADDK R0 R0 K8 CLOSEUPVALS R0 RETURN R0 0 )"); @@ -4625,11 +4629,11 @@ FORNPREP R1 L3 L0: FASTCALL1 24 R3 L1 MOVE R6 R3 GETIMPORT R5 2 -L1: CALL R5 1 -1 -FASTCALL 2 L2 +CALL R5 1 -1 +L1: FASTCALL 2 L2 GETIMPORT R4 4 -L2: CALL R4 -1 1 -SETTABLE R4 R0 R3 +CALL R4 -1 1 +L2: SETTABLE R4 R0 R3 FORNLOOP R1 L0 L3: RETURN R0 1 )"); @@ -4660,6 +4664,131 @@ L1: RETURN R0 0 )"); } +TEST_CASE("LoopUnrollCostBuiltins") +{ + ScopedFastInt sfis[] = { + {"LuauCompileLoopUnrollThreshold", 25}, + {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, + }; + + // this loop uses builtins and is close to the cost budget so it's important that we model builtins as cheaper than regular calls + CHECK_EQ("\n" + compileFunction(R"( +function cipher(block, nonce) + for i = 0,3 do + block[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff) + end +end +)", + 0, 2), + R"( +FASTCALL2K 39 R1 K0 L0 +MOVE R4 R1 +LOADK R5 K0 +GETIMPORT R3 3 +CALL R3 2 1 +L0: FASTCALL2K 29 R3 K4 L1 +LOADK R4 K4 +GETIMPORT R2 6 +CALL R2 2 1 +L1: SETTABLEN R2 R0 1 +FASTCALL2K 39 R1 K7 L2 +MOVE R4 R1 +LOADK R5 K7 +GETIMPORT R3 3 +CALL R3 2 1 +L2: FASTCALL2K 29 R3 K4 L3 +LOADK R4 K4 +GETIMPORT R2 6 +CALL R2 2 1 +L3: SETTABLEN R2 R0 2 +FASTCALL2K 39 R1 K8 L4 +MOVE R4 R1 +LOADK R5 K8 +GETIMPORT R3 3 +CALL R3 2 1 +L4: FASTCALL2K 29 R3 K4 L5 +LOADK R4 K4 +GETIMPORT R2 6 +CALL R2 2 1 +L5: SETTABLEN R2 R0 3 +FASTCALL2K 39 R1 K9 L6 +MOVE R4 R1 +LOADK R5 K9 +GETIMPORT R3 3 +CALL R3 2 1 +L6: FASTCALL2K 29 R3 K4 L7 +LOADK R4 K4 +GETIMPORT R2 6 +CALL R2 2 1 +L7: SETTABLEN R2 R0 4 +RETURN R0 0 +)"); + + // note that if we break compiler's ability to reason about bit32 builtin the loop is no longer unrolled as it's too expensive + CHECK_EQ("\n" + compileFunction(R"( +bit32 = {} + +function cipher(block, nonce) + for i = 0,3 do + block[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff) + end +end +)", + 0, 2), + R"( +LOADN R4 0 +LOADN R2 3 +LOADN R3 1 +FORNPREP R2 L1 +L0: ADDK R5 R4 K0 +GETGLOBAL R7 K1 +GETTABLEKS R6 R7 K2 +GETGLOBAL R8 K1 +GETTABLEKS R7 R8 K3 +MOVE R8 R1 +MULK R9 R4 K4 +CALL R7 2 1 +LOADN R8 255 +CALL R6 2 1 +SETTABLE R6 R0 R5 +FORNLOOP R2 L0 +L1: RETURN R0 0 +)"); + + // additionally, if we pass too many constants the builtin stops being cheap because of argument setup + CHECK_EQ("\n" + compileFunction(R"( +function cipher(block, nonce) + for i = 0,3 do + block[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff, 0xff, 0xff, 0xff, 0xff) + end +end +)", + 0, 2), + R"( +LOADN R4 0 +LOADN R2 3 +LOADN R3 1 +FORNPREP R2 L3 +L0: ADDK R5 R4 K0 +MULK R9 R4 K1 +FASTCALL2 39 R1 R9 L1 +MOVE R8 R1 +GETIMPORT R7 4 +CALL R7 2 1 +L1: LOADN R8 255 +LOADN R9 255 +LOADN R10 255 +LOADN R11 255 +LOADN R12 255 +FASTCALL 29 L2 +GETIMPORT R6 6 +CALL R6 6 1 +L2: SETTABLE R6 R0 R5 +FORNLOOP R2 L0 +L3: RETURN R0 0 +)"); +} + TEST_CASE("InlineBasic") { // inline function that returns a constant @@ -5216,8 +5345,8 @@ DUPCLOSURE R0 K0 LOADK R3 K1 FASTCALL1 20 R3 L0 GETIMPORT R2 4 -L0: CALL R2 1 2 -ADD R1 R2 R3 +CALL R2 1 2 +L0: ADD R1 R2 R3 RETURN R1 1 )"); @@ -5483,14 +5612,14 @@ NEWTABLE R2 0 0 FASTCALL2K 49 R2 K1 L0 LOADK R3 K1 GETIMPORT R1 3 -L0: CALL R1 2 0 -NEWTABLE R1 0 0 +CALL R1 2 0 +L0: NEWTABLE R1 0 0 NEWTABLE R3 0 0 FASTCALL2 49 R3 R1 L1 MOVE R4 R1 GETIMPORT R2 3 -L1: CALL R2 2 0 -RETURN R0 0 +CALL R2 2 0 +L1: RETURN R0 0 )"); } @@ -5762,4 +5891,355 @@ RETURN R0 2 )"); } +TEST_CASE("OptimizationLevel") +{ + // at optimization level 1, no inlining is performed + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +return foo(42) +)", + 1, 1), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // you can override the level from 1 to 2 to force it + CHECK_EQ("\n" + compileFunction(R"( +--!optimize 2 +local function foo(a) + return a +end + +return foo(42) +)", + 1, 1), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // you can also override it externally + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // ... after which you can downgrade it back via hot comment + CHECK_EQ("\n" + compileFunction(R"( +--!optimize 1 +local function foo(a) + return a +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); +} + +TEST_CASE("BuiltinFolding") +{ + CHECK_EQ("\n" + compileFunction(R"( +return + math.abs(-42), + math.acos(1), + math.asin(0), + math.atan2(0, 1), + math.atan(0), + math.ceil(1.5), + math.cosh(0), + math.cos(0), + math.deg(3.14159265358979323846), + math.exp(0), + math.floor(-1.5), + math.fmod(7, 3), + math.ldexp(0.5, 3), + math.log10(100), + math.log(1), + math.log(4, 2), + math.log(27, 3), + math.max(1, 2, 3), + math.min(1, 2, 3), + math.pow(3, 3), + math.floor(math.rad(180)), + math.sinh(0), + math.sin(0), + math.sqrt(9), + math.tanh(0), + math.tan(0), + bit32.arshift(-10, 1), + bit32.arshift(10, 1), + bit32.band(1, 3), + bit32.bnot(-2), + bit32.bor(1, 2), + bit32.bxor(3, 7), + bit32.btest(1, 3), + bit32.extract(100, 1, 3), + bit32.lrotate(100, -1), + bit32.lshift(100, 1), + bit32.replace(100, 5, 1, 3), + bit32.rrotate(100, -1), + bit32.rshift(100, 1), + type(100), + string.byte("a"), + string.byte("abc", 2), + string.len("abc"), + typeof(true), + math.clamp(-1, 0, 1), + math.sign(77), + math.round(7.6), + (type("fin")) +)", + 0, 2), + R"( +LOADN R0 42 +LOADN R1 0 +LOADN R2 0 +LOADN R3 0 +LOADN R4 0 +LOADN R5 2 +LOADN R6 1 +LOADN R7 1 +LOADN R8 180 +LOADN R9 1 +LOADN R10 -2 +LOADN R11 1 +LOADN R12 4 +LOADN R13 2 +LOADN R14 0 +LOADN R15 2 +LOADN R16 3 +LOADN R17 3 +LOADN R18 1 +LOADN R19 27 +LOADN R20 3 +LOADN R21 0 +LOADN R22 0 +LOADN R23 3 +LOADN R24 0 +LOADN R25 0 +LOADK R26 K0 +LOADN R27 5 +LOADN R28 1 +LOADN R29 1 +LOADN R30 3 +LOADN R31 4 +LOADB R32 1 +LOADN R33 2 +LOADN R34 50 +LOADN R35 200 +LOADN R36 106 +LOADN R37 200 +LOADN R38 50 +LOADK R39 K1 +LOADN R40 97 +LOADN R41 98 +LOADN R42 3 +LOADK R43 K2 +LOADN R44 0 +LOADN R45 1 +LOADN R46 8 +LOADK R47 K3 +RETURN R0 48 +)"); +} + +TEST_CASE("BuiltinFoldingProhibited") +{ + CHECK_EQ("\n" + compileFunction(R"( +return + math.abs(), + math.max(1, true), + string.byte("abc", 42), + bit32.rshift(10, 42) +)", + 0, 2), + R"( +FASTCALL 2 L0 +GETIMPORT R0 2 +CALL R0 0 1 +L0: LOADN R2 1 +FASTCALL2K 18 R2 K3 L1 +LOADK R3 K3 +GETIMPORT R1 5 +CALL R1 2 1 +L1: LOADK R3 K6 +FASTCALL2K 41 R3 K7 L2 +LOADK R4 K7 +GETIMPORT R2 10 +CALL R2 2 1 +L2: LOADN R4 10 +FASTCALL2K 39 R4 K7 L3 +LOADK R5 K7 +GETIMPORT R3 13 +CALL R3 2 -1 +L3: RETURN R0 -1 +)"); +} + +TEST_CASE("BuiltinFoldingMultret") +{ + ScopedFastFlag sff("LuauCompileXEQ", true); + + CHECK_EQ("\n" + compileFunction(R"( +local NoLanes: Lanes = --[[ ]] 0b0000000000000000000000000000000 +local OffscreenLane: Lane = --[[ ]] 0b1000000000000000000000000000000 + +local function getLanesToRetrySynchronouslyOnError(root: FiberRoot): Lanes + local everythingButOffscreen = bit32.band(root.pendingLanes, bit32.bnot(OffscreenLane)) + if everythingButOffscreen ~= NoLanes then + return everythingButOffscreen + end + if bit32.band(everythingButOffscreen, OffscreenLane) ~= 0 then + return OffscreenLane + end + return NoLanes +end +)", + 0, 2), + R"( +GETTABLEKS R2 R0 K0 +FASTCALL2K 29 R2 K1 L0 +LOADK R3 K1 +GETIMPORT R1 4 +CALL R1 2 1 +L0: JUMPXEQKN R1 K5 L1 +RETURN R1 1 +L1: FASTCALL2K 29 R1 K6 L2 +MOVE R3 R1 +LOADK R4 K6 +GETIMPORT R2 4 +CALL R2 2 1 +L2: JUMPXEQKN R2 K5 L3 +LOADK R2 K6 +RETURN R2 1 +L3: LOADN R2 0 +RETURN R2 1 +)"); + + // Note: similarly, here we should have folded the return value but haven't because it's the last call in the sequence + CHECK_EQ("\n" + compileFunction(R"( +return math.abs(-42) +)", + 0, 2), + R"( +LOADN R0 42 +RETURN R0 1 +)"); +} + +TEST_CASE("LocalReassign") +{ + ScopedFastFlag sff("LuauCompileFreeReassign", true); + + // locals can be re-assigned and the register gets reused + CHECK_EQ("\n" + compileFunction0(R"( +local function test(a, b) + local c = a + return c + b +end +)"), + R"( +ADD R2 R0 R1 +RETURN R2 1 +)"); + + // this works if the expression is using type casts or grouping + CHECK_EQ("\n" + compileFunction0(R"( +local function test(a, b) + local c = (a :: number) + return c + b +end +)"), + R"( +ADD R2 R0 R1 +RETURN R2 1 +)"); + + // the optimization requires that neither local is mutated + CHECK_EQ("\n" + compileFunction0(R"( +local function test(a, b) + local c = a + c += 0 + local d = b + b += 0 + return c + d +end +)"), + R"( +MOVE R2 R0 +ADDK R2 R2 K0 +MOVE R3 R1 +ADDK R1 R1 K0 +ADD R4 R2 R3 +RETURN R4 1 +)"); + + // sanity check for two values + CHECK_EQ("\n" + compileFunction0(R"( +local function test(a, b) + local c = a + local d = b + return c + d +end +)"), + R"( +ADD R2 R0 R1 +RETURN R2 1 +)"); + + // note: we currently only support this for single assignments + CHECK_EQ("\n" + compileFunction0(R"( +local function test(a, b) + local c, d = a, b + return c + d +end +)"), + R"( +MOVE R2 R0 +MOVE R3 R1 +ADD R4 R2 R3 +RETURN R4 1 +)"); + + // of course, captures capture the original register as well (by value since it's immutable) + CHECK_EQ("\n" + compileFunction(R"( +local function test(a, b) + local c = a + local d = b + return function() return c + d end +end +)", + 1), + R"( +NEWCLOSURE R2 P0 +CAPTURE VAL R0 +CAPTURE VAL R1 +RETURN R2 1 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 3f415149..e07ba12a 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -17,6 +17,16 @@ #include extern bool verbose; +extern int optimizationLevel; + +static lua_CompileOptions defaultOptions() +{ + lua_CompileOptions copts = {}; + copts.optimizationLevel = optimizationLevel; + copts.debugLevel = 1; + + return copts; +} static int lua_collectgarbage(lua_State* L) { @@ -60,8 +70,8 @@ static int lua_loadstring(lua_State* L) return 1; lua_pushnil(L); - lua_insert(L, -2); /* put before error message */ - return 2; /* return nil plus error message */ + lua_insert(L, -2); // put before error message + return 2; // return nil plus error message } static int lua_vector(lua_State* L) @@ -127,7 +137,7 @@ int lua_silence(lua_State* L) using StateRef = std::unique_ptr; static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, - lua_State* initialLuaState = nullptr, lua_CompileOptions* copts = nullptr) + lua_State* initialLuaState = nullptr, lua_CompileOptions* options = nullptr) { std::string path = __FILE__; path.erase(path.find_last_of("\\/")); @@ -189,8 +199,11 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n std::string chunkname = "=" + std::string(name); + // note: luau_compile supports nullptr options, but we need to customize our defaults to improve test coverage + lua_CompileOptions opts = options ? *options : defaultOptions(); + size_t bytecodeSize = 0; - char* bytecode = luau_compile(source.data(), source.size(), copts, &bytecodeSize); + char* bytecode = luau_compile(source.data(), source.size(), &opts, &bytecodeSize); int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); free(bytecode); @@ -231,8 +244,6 @@ TEST_CASE("Assert") TEST_CASE("Basic") { - ScopedFastFlag sff("LuauLenTM", true); - runConformance("basic.lua"); } @@ -241,9 +252,9 @@ TEST_CASE("Math") runConformance("math.lua"); } -TEST_CASE("Table") +TEST_CASE("Tables") { - runConformance("nextvar.lua", [](lua_State* L) { + runConformance("tables.lua", [](lua_State* L) { lua_pushcfunction( L, [](lua_State* L) { @@ -278,6 +289,8 @@ TEST_CASE("Clear") TEST_CASE("Strings") { + ScopedFastFlag sff{"LuauTostringFormatSpecifier", true}; + runConformance("strings.lua"); } @@ -298,12 +311,14 @@ TEST_CASE("Literals") TEST_CASE("Errors") { + ScopedFastFlag sff("LuauNicerMethodErrors", true); + runConformance("errors.lua"); } TEST_CASE("Events") { - ScopedFastFlag sff("LuauLenTM", true); + ScopedFastFlag sff("LuauBetterNewindex", true); runConformance("events.lua"); } @@ -383,9 +398,7 @@ TEST_CASE("Pack") TEST_CASE("Vector") { - lua_CompileOptions copts = {}; - copts.optimizationLevel = 1; - copts.debugLevel = 1; + lua_CompileOptions copts = defaultOptions(); copts.vectorCtor = "vector"; runConformance( @@ -479,8 +492,6 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { - ScopedFastFlag sff("LuauCheckLenMT", true); - runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; @@ -519,8 +530,7 @@ TEST_CASE("Debugger") breakhits = 0; interruptedthread = nullptr; - lua_CompileOptions copts = {}; - copts.optimizationLevel = 1; + lua_CompileOptions copts = defaultOptions(); copts.debugLevel = 2; runConformance( @@ -850,6 +860,41 @@ TEST_CASE("ApiCalls") } } +TEST_CASE("ApiAtoms") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_callbacks(L)->useratom = [](const char* s, size_t l) -> int16_t { + if (strcmp(s, "string") == 0) + return 0; + if (strcmp(s, "important") == 0) + return 1; + + return -1; + }; + + lua_pushstring(L, "string"); + lua_pushstring(L, "import"); + lua_pushstring(L, "ant"); + lua_concat(L, 2); + lua_pushstring(L, "unimportant"); + + int a1, a2, a3; + const char* s1 = lua_tostringatom(L, -3, &a1); + const char* s2 = lua_tostringatom(L, -2, &a2); + const char* s3 = lua_tostringatom(L, -1, &a3); + + CHECK(strcmp(s1, "string") == 0); + CHECK(a1 == 0); + + CHECK(strcmp(s2, "important") == 0); + CHECK(a2 == 1); + + CHECK(strcmp(s3, "unimportant") == 0); + CHECK(a3 == -1); +} + static bool endsWith(const std::string& str, const std::string& suffix) { if (suffix.length() > str.length()) @@ -957,9 +1002,8 @@ TEST_CASE("TagMethodError") TEST_CASE("Coverage") { - lua_CompileOptions copts = {}; - copts.optimizationLevel = 1; - copts.debugLevel = 1; + lua_CompileOptions copts = defaultOptions(); + copts.optimizationLevel = 1; // disable inlining to get fixed expected hit results copts.coverageLevel = 2; runConformance( @@ -1059,6 +1103,9 @@ TEST_CASE("GCDump") TEST_CASE("Interrupt") { + lua_CompileOptions copts = defaultOptions(); + copts.optimizationLevel = 1; // disable loop unrolling to get fixed expected hit results + static const int expectedhits[] = { 2, 9, @@ -1109,7 +1156,8 @@ TEST_CASE("Interrupt") }, [](lua_State* L) { CHECK(index == 5); // a single yield point - }); + }, + nullptr, &copts); CHECK(index == int(std::size(expectedhits))); } @@ -1152,6 +1200,10 @@ TEST_CASE("UserdataApi") CHECK(lua_touserdatatagged(L, -1, 41) == nullptr); CHECK(lua_userdatatag(L, -1) == 42); + lua_setuserdatatag(L, -1, 43); + CHECK(lua_userdatatag(L, -1) == 43); + lua_setuserdatatag(L, -1, 42); + // user data with inline dtor void* ud3 = lua_newuserdatadtor(L, 4, [](void* data) { dtorhits += *(int*)data; diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index f521c667..e33f6570 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -9,7 +9,7 @@ using namespace Luau; -static TypeId requireBinding(NotNull scope, const char* name) +static TypeId requireBinding(NotNull scope, const char* name) { auto b = linearSearchForBinding(scope, name); LUAU_ASSERT(b.has_value()); @@ -26,7 +26,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") )"); cgb.visit(block); - NotNull rootScope = NotNull(cgb.rootScope); + NotNull rootScope = NotNull(cgb.rootScope); ConstraintSolver cs{&arena, rootScope}; @@ -46,7 +46,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") )"); cgb.visit(block); - NotNull rootScope = NotNull(cgb.rootScope); + NotNull rootScope = NotNull(cgb.rootScope); ConstraintSolver cs{&arena, rootScope}; @@ -73,7 +73,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") )"); cgb.visit(block); - NotNull rootScope = NotNull(cgb.rootScope); + NotNull rootScope = NotNull(cgb.rootScope); ToStringOptions opts; diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp index c709ba8e..eacc718b 100644 --- a/tests/CostModel.test.cpp +++ b/tests/CostModel.test.cpp @@ -10,7 +10,7 @@ namespace Luau namespace Compile { -uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount, const DenseHashMap& builtins); int computeCost(uint64_t model, const bool* varsConst, size_t varCount); } // namespace Compile @@ -29,7 +29,7 @@ static uint64_t modelFunction(const char* source) AstStatFunction* func = result.root->body.data[0]->as(); REQUIRE(func); - return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size); + return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size, {nullptr}); } TEST_CASE("Expression") diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index c92c4457..f51a9d1b 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -258,7 +258,7 @@ std::optional Fixture::getType(const std::string& name) REQUIRE(module); if (FFlag::DebugLuauDeferredConstraintResolution) - return linearSearchForBinding(module->getModuleScope2(), name.c_str()); + return linearSearchForBinding(module->getModuleScope().get(), name.c_str()); else return lookupName(module->getModuleScope(), name); } @@ -410,7 +410,7 @@ void Fixture::validateErrors(const std::vector& errors) LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) { unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, source, "@test"); + LoadDefinitionFileResult result = frontend.loadDefinitionFile(source, "@test"); freeze(typeChecker.globalTypes); REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); @@ -434,7 +434,7 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() : Fixture() - , cgb(mainModuleName, &arena, NotNull(&ice), frontend.getGlobalScope2()) + , cgb(mainModuleName, &arena, NotNull(&ice), frontend.getGlobalScope()) , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { BlockedTypeVar::nextIndex = 0; @@ -479,17 +479,17 @@ std::optional lookupName(ScopePtr scope, const std::string& name) return std::nullopt; } -std::optional linearSearchForBinding(Scope2* scope, const char* name) +std::optional linearSearchForBinding(Scope* scope, const char* name) { while (scope) { for (const auto& [n, ty] : scope->bindings) { if (n.astName() == name) - return ty; + return ty.typeId; } - scope = scope->parent; + scope = scope->parent.get(); } return std::nullopt; diff --git a/tests/Fixture.h b/tests/Fixture.h index 1bc573da..a716fe9b 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -128,6 +128,7 @@ struct Fixture std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; + ScopedFastFlag sff_UnknownNever{"LuauUnknownAndNeverType", true}; TestFileResolver fileResolver; TestConfigResolver configResolver; @@ -191,7 +192,7 @@ void dump(const std::vector& constraints); std::optional lookupName(ScopePtr scope, const std::string& name); // Warning: This function runs in O(n**2) -std::optional linearSearchForBinding(Scope2* scope, const char* name); +std::optional linearSearchForBinding(Scope* scope, const char* name); } // namespace Luau diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index b568ce01..8b1c04b1 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -756,26 +756,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "test_lint_uses_correct_config") CHECK_EQ(0, result4.warnings.size()); } -TEST_CASE_FIXTURE(FrontendFixture, "lintFragment") -{ - LintOptions lintOptions; - lintOptions.enableWarning(LintWarning::Code_ForRange); - - auto [_sourceModule, result] = frontend.lintFragment(R"( - local t = {} - - for i=#t,1 do - end - - for i=#t,1,-1 do - end - )", - lintOptions); - - CHECK_EQ(1, result.warnings.size()); - CHECK_EQ(0, result.errors.size()); -} - TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") { Frontend fe{&fileResolver, &configResolver, {false}}; @@ -791,6 +771,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") CHECK_EQ(0, module->internalTypes.typeVars.size()); CHECK_EQ(0, module->internalTypes.typePacks.size()); CHECK_EQ(0, module->astTypes.size()); + CHECK_EQ(0, module->astResolvedTypes.size()); + CHECK_EQ(0, module->astResolvedTypePacks.size()); } TEST_CASE_FIXTURE(FrontendFixture, "it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded") @@ -1066,4 +1048,71 @@ TEST_CASE("check_without_builtin_next") frontend.check("Module/B"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_cyclic_type") +{ + ScopedFastFlag sff[] = { + {"LuauForceExportSurfacesToBeNormal", true}, + {"LuauLowerBoundsCalculation", true}, + }; + + fileResolver.source["Module/A"] = R"( + type F = (set: G) -> () + + export type G = { + forEach: (a: F) -> (), + } + + function X(a: F): () + end + + return X + )"; + + fileResolver.source["Module/B"] = R"( + --!strict + local A = require(script.Parent.A) + + export type G = A.G + + return { + A = A, + } + )"; + + CheckResult result = frontend.check("Module/B"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_type_alias") +{ + ScopedFastFlag sff[] = { + {"LuauForceExportSurfacesToBeNormal", true}, + {"LuauLowerBoundsCalculation", true}, + }; + + fileResolver.source["Module/A"] = R"( + type KeyOfTestEvents = "test-file-start" | "test-file-success" | "test-file-failure" | "test-case-result" + type unknown = any + + export type TestFileEvent = ( + eventName: T, + args: any --[[ ROBLOX TODO: Unhandled node for type: TSIndexedAccessType ]] --[[ TestEvents[T] ]] + ) -> unknown + + return {} + )"; + + fileResolver.source["Module/B"] = R"( + --!strict + local A = require(script.Parent.A) + + export type TestFileEvent = A.TestFileEvent + )"; + + CheckResult result = frontend.check("Module/B"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/JsonEmitter.test.cpp b/tests/JsonEmitter.test.cpp new file mode 100644 index 00000000..ebe83209 --- /dev/null +++ b/tests/JsonEmitter.test.cpp @@ -0,0 +1,195 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/JsonEmitter.h" + +#include "doctest.h" + +using namespace Luau::Json; + +TEST_SUITE_BEGIN("JsonEmitter"); + +TEST_CASE("write_array") +{ + JsonEmitter emitter; + ArrayEmitter a = emitter.writeArray(); + a.writeValue(123); + a.writeValue("foo"); + a.finish(); + + std::string result = emitter.str(); + CHECK(result == "[123,\"foo\"]"); +} + +TEST_CASE("write_object") +{ + JsonEmitter emitter; + ObjectEmitter o = emitter.writeObject(); + o.writePair("foo", "bar"); + o.writePair("bar", "baz"); + o.finish(); + + std::string result = emitter.str(); + CHECK(result == "{\"foo\":\"bar\",\"bar\":\"baz\"}"); +} + +TEST_CASE("write_bool") +{ + JsonEmitter emitter; + write(emitter, false); + CHECK(emitter.str() == "false"); + + emitter = JsonEmitter{}; + write(emitter, true); + CHECK(emitter.str() == "true"); +} + +TEST_CASE("write_null") +{ + JsonEmitter emitter; + write(emitter, nullptr); + CHECK(emitter.str() == "null"); +} + +TEST_CASE("write_string") +{ + JsonEmitter emitter; + write(emitter, R"(foo,bar,baz, +"this should be escaped")"); + CHECK(emitter.str() == "\"foo,bar,baz,\\n\\\"this should be escaped\\\"\""); +} + +TEST_CASE("write_comma") +{ + JsonEmitter emitter; + emitter.writeComma(); + write(emitter, true); + emitter.writeComma(); + write(emitter, false); + CHECK(emitter.str() == "true,false"); +} + +TEST_CASE("push_and_pop_comma") +{ + JsonEmitter emitter; + emitter.writeComma(); + write(emitter, true); + emitter.writeComma(); + emitter.writeRaw('['); + bool comma = emitter.pushComma(); + emitter.writeComma(); + write(emitter, true); + emitter.writeComma(); + write(emitter, false); + emitter.writeRaw(']'); + emitter.popComma(comma); + emitter.writeComma(); + write(emitter, false); + + CHECK(emitter.str() == "true,[true,false],false"); +} + +TEST_CASE("write_optional") +{ + JsonEmitter emitter; + emitter.writeComma(); + write(emitter, std::optional{true}); + emitter.writeComma(); + write(emitter, std::nullopt); + + CHECK(emitter.str() == "true,null"); +} + +TEST_CASE("write_vector") +{ + std::vector values{1, 2, 3, 4}; + JsonEmitter emitter; + write(emitter, values); + CHECK(emitter.str() == "[1,2,3,4]"); +} + +TEST_CASE("prevent_multiple_object_finish") +{ + JsonEmitter emitter; + ObjectEmitter o = emitter.writeObject(); + o.writePair("a", "b"); + o.finish(); + o.finish(); + + CHECK(emitter.str() == "{\"a\":\"b\"}"); +} + +TEST_CASE("prevent_multiple_array_finish") +{ + JsonEmitter emitter; + ArrayEmitter a = emitter.writeArray(); + a.writeValue(1); + a.finish(); + a.finish(); + + CHECK(emitter.str() == "[1]"); +} + +TEST_CASE("cannot_write_pair_after_finished") +{ + JsonEmitter emitter; + ObjectEmitter o = emitter.writeObject(); + o.finish(); + o.writePair("a", "b"); + + CHECK(emitter.str() == "{}"); +} + +TEST_CASE("cannot_write_value_after_finished") +{ + JsonEmitter emitter; + ArrayEmitter a = emitter.writeArray(); + a.finish(); + a.writeValue(1); + + CHECK(emitter.str() == "[]"); +} + +TEST_CASE("finish_when_destructing_object") +{ + JsonEmitter emitter; + emitter.writeObject(); + + CHECK(emitter.str() == "{}"); +} + +TEST_CASE("finish_when_destructing_array") +{ + JsonEmitter emitter; + emitter.writeArray(); + + CHECK(emitter.str() == "[]"); +} + +namespace Luau::Json +{ + +struct Special +{ + int foo; + int bar; +}; + +void write(JsonEmitter& emitter, const Special& value) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("foo", value.foo); + o.writePair("bar", value.bar); +} + +} // namespace Luau::Json + +TEST_CASE("afford_extensibility") +{ + std::vector vec{Special{1, 2}, Special{3, 4}}; + JsonEmitter e; + write(e, vec); + + std::string result = e.str(); + CHECK(result == R"([{"foo":1,"bar":2},{"foo":3,"bar":4}])"); +} + +TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp new file mode 100644 index 00000000..20d8d0d5 --- /dev/null +++ b/tests/Lexer.test.cpp @@ -0,0 +1,141 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Lexer.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("LexerTests"); + +TEST_CASE("broken_string_works") +{ + const std::string testInput = "[["; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::Type::BrokenString); + CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, 2))); +} + +TEST_CASE("broken_comment") +{ + const std::string testInput = "--[[ "; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::Type::BrokenComment); + CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, 6))); +} + +TEST_CASE("broken_comment_kept") +{ + const std::string testInput = "--[[ "; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + lexer.setSkipComments(true); + CHECK_EQ(lexer.next().type, Lexeme::Type::BrokenComment); +} + +TEST_CASE("comment_skipped") +{ + const std::string testInput = "-- "; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + lexer.setSkipComments(true); + CHECK_EQ(lexer.next().type, Lexeme::Type::Eof); +} + +TEST_CASE("multilineCommentWithLexemeInAndAfter") +{ + const std::string testInput = "--[[ function \n" + "]] end"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme comment = lexer.next(); + Lexeme end = lexer.next(); + + CHECK_EQ(comment.type, Lexeme::Type::BlockComment); + CHECK_EQ(comment.location, Luau::Location(Luau::Position(0, 0), Luau::Position(1, 2))); + CHECK_EQ(end.type, Lexeme::Type::ReservedEnd); + CHECK_EQ(end.location, Luau::Location(Luau::Position(1, 3), Luau::Position(1, 6))); +} + +TEST_CASE("testBrokenEscapeTolerant") +{ + const std::string testInput = "'\\3729472897292378'"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme item = lexer.next(); + + CHECK_EQ(item.type, Lexeme::QuotedString); + CHECK_EQ(item.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, int(testInput.size())))); +} + +TEST_CASE("testBigDelimiters") +{ + const std::string testInput = "--[===[\n" + "\n" + "\n" + "\n" + "]===]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme item = lexer.next(); + + CHECK_EQ(item.type, Lexeme::Type::BlockComment); + CHECK_EQ(item.location, Luau::Location(Luau::Position(0, 0), Luau::Position(4, 5))); +} + +TEST_CASE("lookahead") +{ + const std::string testInput = "foo --[[ comment ]] bar : nil end"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + lexer.setSkipComments(true); + lexer.next(); // must call next() before reading data from lexer at least once + + CHECK_EQ(lexer.current().type, Lexeme::Name); + CHECK_EQ(lexer.current().name, std::string("foo")); + CHECK_EQ(lexer.lookahead().type, Lexeme::Name); + CHECK_EQ(lexer.lookahead().name, std::string("bar")); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::Name); + CHECK_EQ(lexer.current().name, std::string("bar")); + CHECK_EQ(lexer.lookahead().type, ':'); + + lexer.next(); + + CHECK_EQ(lexer.current().type, ':'); + CHECK_EQ(lexer.lookahead().type, Lexeme::ReservedNil); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::ReservedNil); + CHECK_EQ(lexer.lookahead().type, Lexeme::ReservedEnd); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::ReservedEnd); + CHECK_EQ(lexer.lookahead().type, Lexeme::Eof); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::Eof); + CHECK_EQ(lexer.lookahead().type, Lexeme::Eof); +} + +TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 202aeceb..35c40508 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1658,4 +1658,53 @@ end CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); } +TEST_CASE_FIXTURE(Fixture, "WrongCommentOptimize") +{ + LintResult result = lint(R"( +--!optimize +--!optimize +--!optimize me +--!optimize 100500 +--!optimize 2 +)"); + + REQUIRE_EQ(result.warnings.size(), 4); + CHECK_EQ(result.warnings[0].text, "optimize directive requires an optimization level"); + CHECK_EQ(result.warnings[1].text, "optimize directive requires an optimization level"); + CHECK_EQ(result.warnings[2].text, "optimize directive uses unknown optimization level 'me', 0..2 expected"); + CHECK_EQ(result.warnings[3].text, "optimize directive uses unknown optimization level '100500', 0..2 expected"); +} + +TEST_CASE_FIXTURE(Fixture, "LintIntegerParsing") +{ + ScopedFastFlag luauLintParseIntegerIssues{"LuauLintParseIntegerIssues", true}; + + LintResult result = lint(R"( +local _ = 0b10000000000000000000000000000000000000000000000000000000000000000 +local _ = 0x10000000000000000 +)"); + + REQUIRE_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].text, "Binary number literal exceeded available precision and has been truncated to 2^64"); + CHECK_EQ(result.warnings[1].text, "Hexadecimal number literal exceeded available precision and has been truncated to 2^64"); +} + +// TODO: remove with FFlagLuauErrorDoubleHexPrefix +TEST_CASE_FIXTURE(Fixture, "LintIntegerParsingDoublePrefix") +{ + ScopedFastFlag luauLintParseIntegerIssues{"LuauLintParseIntegerIssues", true}; + ScopedFastFlag luauErrorDoubleHexPrefix{"LuauErrorDoubleHexPrefix", false}; // Lint will be available until we start rejecting code + + LintResult result = lint(R"( +local _ = 0x0x123 +local _ = 0x0xffffffffffffffffffffffffffffffffff +)"); + + REQUIRE_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].text, + "Hexadecimal number literal has a double prefix, which will fail to parse in the future; remove the extra 0x to fix"); + CHECK_EQ(result.warnings[1].text, + "Hexadecimal number literal has a double prefix, which will fail to parse in the future; remove the extra 0x to fix"); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 7c2f4d1c..dd94e9d7 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -301,8 +301,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") { - ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; - fileResolver.source["Module/A"] = R"( export type A = B type B = A diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 50dcbad0..02e02e6b 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -13,77 +13,6 @@ using namespace Luau; TEST_SUITE_BEGIN("NonstrictModeTests"); -TEST_CASE_FIXTURE(Fixture, "globals") -{ - CheckResult result = check(R"( - --!nonstrict - foo = true - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("foo"))); -} - -TEST_CASE_FIXTURE(Fixture, "globals2") -{ - ScopedFastFlag sff[]{ - {"LuauReturnTypeInferenceInNonstrict", true}, - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - --!nonstrict - foo = function() return 1 end - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("() -> number", toString(tm->wantedType)); - CHECK_EQ("string", toString(tm->givenType)); - CHECK_EQ("() -> number", toString(requireType("foo"))); -} - -TEST_CASE_FIXTURE(Fixture, "globals_everywhere") -{ - CheckResult result = check(R"( - --!nonstrict - foo = 1 - - if true then - bar = 2 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); -} - -TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_number_or_string") -{ - ScopedFastFlag sff[]{{"LuauReturnTypeInferenceInNonstrict", true}, {"LuauLowerBoundsCalculation", true}}; - - CheckResult result = check(R"( - --!nonstrict - local function f() - if math.random() > 0.5 then - return 5 - else - return "hi" - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK("() -> number | string" == toString(requireType("f"))); -} - TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") { CheckResult result = check(R"( @@ -106,13 +35,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") REQUIRE_EQ(0, rets.size()); } -TEST_CASE_FIXTURE(Fixture, "first_return_type_dictates_number_of_return_types") +TEST_CASE_FIXTURE(Fixture, "infer_the_maximum_number_of_values_the_function_could_return") { - ScopedFastFlag sff[]{ - {"LuauReturnTypeInferenceInNonstrict", true}, - {"LuauLowerBoundsCalculation", true}, - }; - CheckResult result = check(R"( --!nonstrict function getMinCardCountForWidth(width) @@ -127,18 +51,22 @@ TEST_CASE_FIXTURE(Fixture, "first_return_type_dictates_number_of_return_types") TypeId t = requireType("getMinCardCountForWidth"); REQUIRE(t); - REQUIRE_EQ("(any) -> number", toString(t)); + REQUIRE_EQ("(any) -> (...any)", toString(t)); } +#if 0 +// Maybe we want this? TEST_CASE_FIXTURE(Fixture, "return_annotation_is_still_checked") { CheckResult result = check(R"( - --!nonstrict function foo(x): number return 'hello' end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE_NE(*typeChecker.anyType, *requireType("foo")); } +#endif TEST_CASE_FIXTURE(Fixture, "function_parameters_are_any") { @@ -324,11 +252,6 @@ TEST_CASE_FIXTURE(Fixture, "delay_function_does_not_require_its_argument_to_retu TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") { - ScopedFastFlag sff[]{ - {"LuauReturnTypeInferenceInNonstrict", true}, - {"LuauLowerBoundsCalculation", true}, - }; - CheckResult result = check(R"( --!nonstrict @@ -345,7 +268,7 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("((any) -> string) | {| foo: any |}", toString(getMainModule()->getModuleScope()->returnType)); + REQUIRE_EQ("any", toString(getMainModule()->getModuleScope()->returnType)); } TEST_CASE_FIXTURE(Fixture, "returning_insufficient_return_values") diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index a474b6e7..c64c41c5 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -621,7 +621,6 @@ TEST_CASE_FIXTURE(Fixture, "normalize_module_return_type") { ScopedFastFlag sff[] = { {"LuauLowerBoundsCalculation", true}, - {"LuauReturnTypeInferenceInNonstrict", true}, }; check(R"( @@ -642,7 +641,7 @@ TEST_CASE_FIXTURE(Fixture, "normalize_module_return_type") end )"); - CHECK_EQ("(any, any) -> (any, any) -> any", toString(getMainModule()->getModuleScope()->returnType)); + CHECK_EQ("(any, any) -> (...any)", toString(getMainModule()->getModuleScope()->returnType)); } TEST_CASE_FIXTURE(Fixture, "return_type_is_not_a_constrained_intersection") @@ -681,26 +680,12 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_with_annotation") CHECK_EQ("((a) -> b, a) -> b", toString(requireType("apply"))); } -TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_marked_normal") -{ - ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", true}, {"LuauNormalizeFlagIsConservative", false}}; - - check(R"( - type Fiber = { - return_: Fiber? - } - - local f: Fiber - )"); - - TypeId t = requireType("f"); - CHECK(t->normal); -} - // Unfortunately, getting this right in the general case is difficult. TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_not_marked_normal") { - ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", true}, {"LuauNormalizeFlagIsConservative", true}}; + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; check(R"( type Fiber = { @@ -756,6 +741,65 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_table_normalizes_sensibly") CHECK_EQ("t1 where t1 = { get: () -> t1 }", toString(ty, {true})); } +TEST_CASE_FIXTURE(Fixture, "cyclic_union") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauFixNormalizationOfCyclicUnions", true}, + }; + + CheckResult result = check(R"( + type T = {T?}? + + local a: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("t1? where t1 = {t1?}" == toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_intersection") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauFixNormalizationOfCyclicUnions", true}, + }; + + CheckResult result = check(R"( + type T = {T & {}} + + local a: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // FIXME: We are not properly normalizing this type, but we are at least not improperly discarding information + CHECK("t1 where t1 = {{t1 & {| |}}}" == toString(requireType("a"), {true})); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_indexers") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauFixNormalizationOfCyclicUnions", true}, + }; + + CheckResult result = check(R"( + type A = {number} + type B = {string} + + type C = A & B + + local a: C + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // FIXME: We are not properly normalizing this type, but we are at least not improperly discarding information + CHECK("{number & string}" == toString(requireType("a"), {true})); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "union_of_distinct_free_types") { ScopedFastFlag flags[] = { @@ -1023,7 +1067,6 @@ TEST_CASE_FIXTURE(Fixture, "bound_typevars_should_only_be_marked_normal_if_their { ScopedFastFlag sff[]{ {"LuauLowerBoundsCalculation", true}, - {"LuauNormalizeFlagIsConservative", true}, }; CheckResult result = check(R"( @@ -1055,8 +1098,6 @@ export type t1 = { a: typeof(string.byte) } TEST_CASE_FIXTURE(Fixture, "intersection_combine_on_bound_self") { - ScopedFastFlag luauNormalizeCombineEqFix{"LuauNormalizeCombineEqFix", true}; - CheckResult result = check(R"( export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,})) )"); @@ -1064,6 +1105,46 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "normalize_unions_containing_never") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type Foo = string | never + local foo: Foo + )"); + + CHECK_EQ("string", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "normalize_unions_containing_unknown") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type Foo = string | unknown + local foo: Foo + )"); + + CHECK_EQ("unknown", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "any_wins_the_battle_over_unknown_in_unions") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type Foo = unknown | any + local foo: Foo + + type Bar = any | unknown + local bar: Bar + )"); + + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_does_not_convert_ever") { ScopedFastFlag sff[]{ diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index c3c75998..5b86807c 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -97,138 +97,6 @@ TEST_CASE("initial_double_is_aligned") TEST_SUITE_END(); -TEST_SUITE_BEGIN("LexerTests"); - -TEST_CASE("broken_string_works") -{ - const std::string testInput = "[["; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - Lexeme lexeme = lexer.next(); - CHECK_EQ(lexeme.type, Lexeme::Type::BrokenString); - CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, 2))); -} - -TEST_CASE("broken_comment") -{ - const std::string testInput = "--[[ "; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - Lexeme lexeme = lexer.next(); - CHECK_EQ(lexeme.type, Lexeme::Type::BrokenComment); - CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, 6))); -} - -TEST_CASE("broken_comment_kept") -{ - const std::string testInput = "--[[ "; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - lexer.setSkipComments(true); - CHECK_EQ(lexer.next().type, Lexeme::Type::BrokenComment); -} - -TEST_CASE("comment_skipped") -{ - const std::string testInput = "-- "; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - lexer.setSkipComments(true); - CHECK_EQ(lexer.next().type, Lexeme::Type::Eof); -} - -TEST_CASE("multilineCommentWithLexemeInAndAfter") -{ - const std::string testInput = "--[[ function \n" - "]] end"; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - Lexeme comment = lexer.next(); - Lexeme end = lexer.next(); - - CHECK_EQ(comment.type, Lexeme::Type::BlockComment); - CHECK_EQ(comment.location, Luau::Location(Luau::Position(0, 0), Luau::Position(1, 2))); - CHECK_EQ(end.type, Lexeme::Type::ReservedEnd); - CHECK_EQ(end.location, Luau::Location(Luau::Position(1, 3), Luau::Position(1, 6))); -} - -TEST_CASE("testBrokenEscapeTolerant") -{ - const std::string testInput = "'\\3729472897292378'"; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - Lexeme item = lexer.next(); - - CHECK_EQ(item.type, Lexeme::QuotedString); - CHECK_EQ(item.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, int(testInput.size())))); -} - -TEST_CASE("testBigDelimiters") -{ - const std::string testInput = "--[===[\n" - "\n" - "\n" - "\n" - "]===]"; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - Lexeme item = lexer.next(); - - CHECK_EQ(item.type, Lexeme::Type::BlockComment); - CHECK_EQ(item.location, Luau::Location(Luau::Position(0, 0), Luau::Position(4, 5))); -} - -TEST_CASE("lookahead") -{ - const std::string testInput = "foo --[[ comment ]] bar : nil end"; - - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - lexer.setSkipComments(true); - lexer.next(); // must call next() before reading data from lexer at least once - - CHECK_EQ(lexer.current().type, Lexeme::Name); - CHECK_EQ(lexer.current().name, std::string("foo")); - CHECK_EQ(lexer.lookahead().type, Lexeme::Name); - CHECK_EQ(lexer.lookahead().name, std::string("bar")); - - lexer.next(); - - CHECK_EQ(lexer.current().type, Lexeme::Name); - CHECK_EQ(lexer.current().name, std::string("bar")); - CHECK_EQ(lexer.lookahead().type, ':'); - - lexer.next(); - - CHECK_EQ(lexer.current().type, ':'); - CHECK_EQ(lexer.lookahead().type, Lexeme::ReservedNil); - - lexer.next(); - - CHECK_EQ(lexer.current().type, Lexeme::ReservedNil); - CHECK_EQ(lexer.lookahead().type, Lexeme::ReservedEnd); - - lexer.next(); - - CHECK_EQ(lexer.current().type, Lexeme::ReservedEnd); - CHECK_EQ(lexer.lookahead().type, Lexeme::Eof); - - lexer.next(); - - CHECK_EQ(lexer.current().type, Lexeme::Eof); - CHECK_EQ(lexer.lookahead().type, Lexeme::Eof); -} - -TEST_SUITE_END(); - TEST_SUITE_BEGIN("ParserTests"); TEST_CASE_FIXTURE(Fixture, "basic_parse") @@ -814,20 +682,23 @@ TEST_CASE_FIXTURE(Fixture, "parse_numbers_binary") TEST_CASE_FIXTURE(Fixture, "parse_numbers_error") { - ScopedFastFlag luauErrorParseIntegerIssues{"LuauErrorParseIntegerIssues", true}; + ScopedFastFlag luauLintParseIntegerIssues{"LuauLintParseIntegerIssues", true}; + ScopedFastFlag luauErrorDoubleHexPrefix{"LuauErrorDoubleHexPrefix", true}; CHECK_EQ(getParseError("return 0b123"), "Malformed number"); CHECK_EQ(getParseError("return 123x"), "Malformed number"); CHECK_EQ(getParseError("return 0xg"), "Malformed number"); CHECK_EQ(getParseError("return 0x0x123"), "Malformed number"); + CHECK_EQ(getParseError("return 0xffffffffffffffffffffllllllg"), "Malformed number"); + CHECK_EQ(getParseError("return 0x0xffffffffffffffffffffffffffff"), "Malformed number"); } -TEST_CASE_FIXTURE(Fixture, "parse_numbers_range_error") +TEST_CASE_FIXTURE(Fixture, "parse_numbers_error_soft") { - ScopedFastFlag luauErrorParseIntegerIssues{"LuauErrorParseIntegerIssues", true}; + ScopedFastFlag luauLintParseIntegerIssues{"LuauLintParseIntegerIssues", true}; + ScopedFastFlag luauErrorDoubleHexPrefix{"LuauErrorDoubleHexPrefix", false}; - CHECK_EQ(getParseError("return 0x10000000000000000"), "Integer number value is out of range"); - CHECK_EQ(getParseError("return 0b10000000000000000000000000000000000000000000000000000000000000000"), "Integer number value is out of range"); + CHECK_EQ(getParseError("return 0x0x0x0x0x0x0x0"), "Malformed number"); } TEST_CASE_FIXTURE(Fixture, "break_return_not_last_error") @@ -2648,7 +2519,6 @@ type Z = { a: string | T..., b: number } TEST_CASE_FIXTURE(Fixture, "recover_function_return_type_annotations") { - ScopedFastFlag sff{"LuauReturnTypeTokenConfusion", true}; ParseResult result = tryParse(R"( type Custom = { x: A, y: B, z: C } type Packed = { x: (A...) -> () } diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index 87a1e1e2..18c243b0 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -82,7 +82,7 @@ private: capturedoutput = "" function arraytostring(arr) - local strings = {} + local strings = {} table.foreachi(arr, function(k,v) table.insert(strings, pptostring(v)) end ) return "{" .. table.concat(strings, ", ") .. "}" end diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 387e07cd..fe376d86 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -10,6 +10,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); +LUAU_FASTFLAG(LuauSpecialTypesAsterisked); TEST_SUITE_BEGIN("ToString"); @@ -62,7 +63,6 @@ TEST_CASE_FIXTURE(Fixture, "named_table") TEST_CASE_FIXTURE(Fixture, "empty_table") { - ScopedFastFlag LuauToStringTableBracesNewlines("LuauToStringTableBracesNewlines", true); CheckResult result = check(R"( local a: {} )"); @@ -77,7 +77,6 @@ TEST_CASE_FIXTURE(Fixture, "empty_table") TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") { - ScopedFastFlag LuauToStringTableBracesNewlines("LuauToStringTableBracesNewlines", true); CheckResult result = check(R"( local a: { prop: string, anotherProp: number, thirdProp: boolean } )"); @@ -96,6 +95,37 @@ TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") //clang-format on } +TEST_CASE_FIXTURE(Fixture, "metatable") +{ + TypeVar table{TypeVariant(TableTypeVar())}; + TypeVar metatable{TypeVariant(TableTypeVar())}; + TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable})}; + CHECK_EQ("{ @metatable { }, { } }", toString(&mtv)); +} + +TEST_CASE_FIXTURE(Fixture, "named_metatable") +{ + TypeVar table{TypeVariant(TableTypeVar())}; + TypeVar metatable{TypeVariant(TableTypeVar())}; + TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable, "NamedMetatable"})}; + CHECK_EQ("NamedMetatable", toString(&mtv)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "named_metatable_toStringNamedFunction") +{ + CheckResult result = check(R"( + local function createTbl(): NamedMetatable + return setmetatable({}, {}) + end + type NamedMetatable = typeof(createTbl()) + )"); + + TypeId ty = requireType("createTbl"); + const FunctionTypeVar* ftv = get(follow(ty)); + REQUIRE(ftv); + CHECK_EQ("createTbl(): NamedMetatable", toStringNamedFunction("createTbl", *ftv)); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") { CheckResult result = check(R"( @@ -238,8 +268,16 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") o.maxTypeLength = 40; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); + if (FFlag::LuauSpecialTypesAsterisked) + { + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + } + else + { + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); + } } TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") @@ -257,8 +295,16 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") o.maxTypeLength = 40; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); + if (FFlag::LuauSpecialTypesAsterisked) + { + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + } + else + { + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); + } } TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table_state_braces") @@ -468,7 +514,10 @@ local function target(callback: nil) return callback(4, "hello") end )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("(nil) -> (*unknown*)", toString(requireType("target"))); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("(nil) -> (*error-type*)", toString(requireType("target"))); + else + CHECK_EQ("(nil) -> ()", toString(requireType("target"))); } TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") @@ -676,10 +725,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_generics") { - ScopedFastFlag sff[] = { - {"LuauAlwaysQuantify", true}, - }; - CheckResult result = check(R"( function foo(x: a, y) end )"); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index b02a52b2..d2ed9aef 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -583,7 +583,7 @@ TEST_CASE_FIXTURE(Fixture, "transpile_error_expr") auto names = AstNameTable{allocator}; ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); - CHECK_EQ("local a = (error-expr: f.%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); + CHECK_EQ("local a = (error-expr: f:%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); } TEST_CASE_FIXTURE(Fixture, "transpile_error_stat") diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index bdd4d6fd..e487fd48 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -94,6 +94,65 @@ TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") } } +TEST_CASE_FIXTURE(Fixture, "mismatched_generic_type_param") +{ + CheckResult result = check(R"( + type T = (A...) -> () + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == + "Generic type 'A' is used as a variadic type parameter; consider changing 'A' to 'A...' in the generic argument list"); + CHECK(result.errors[0].location == Location{{1, 21}, {1, 25}}); +} + +TEST_CASE_FIXTURE(Fixture, "mismatched_generic_pack_type_param") +{ + CheckResult result = check(R"( + type T = (A) -> () + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == + "Variadic type parameter 'A...' is used as a regular generic type; consider changing 'A...' to 'A' in the generic argument list"); + CHECK(result.errors[0].location == Location{{1, 24}, {1, 25}}); +} + +TEST_CASE_FIXTURE(Fixture, "default_type_parameter") +{ + CheckResult result = check(R"( + type T = { a: A, b: B } + local x: T = { a = "foo", b = "bar" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("x")) == "T"); +} + +TEST_CASE_FIXTURE(Fixture, "default_pack_parameter") +{ + CheckResult result = check(R"( + type T = { fn: (A...) -> () } + local x: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("x")) == "T"); +} + +TEST_CASE_FIXTURE(Fixture, "saturate_to_first_type_pack") +{ + CheckResult result = check(R"( + type T = { fn: (A, B) -> C... } + local x: T + local f = x.fn + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("x")) == "T"); + CHECK(toString(requireType("f")) == "(string, number) -> (string, boolean)"); +} + TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") { CheckResult result = check(R"( @@ -126,6 +185,40 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_aliases") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "generic_aliases") +{ + ScopedFastFlag sff_DebugLuauDeferredConstraintResolution{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + type T = { v: a } + local x: T = { v = 123 } + local y: T = { v = "foo" } + local bad: T = { v = "foo" } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(result.errors[0].location == Location{{4, 31}, {4, 44}}); + CHECK(toString(result.errors[0]) == "Type '{ v: string }' could not be converted into 'T'"); +} + +TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") +{ + ScopedFastFlag sff_DebugLuauDeferredConstraintResolution{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + type T = { v: a } + type U = { t: T } + local x: U = { t = { v = 123 } } + local bad: U = { t = { v = "foo" } } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(result.errors[0].location == Location{{4, 31}, {4, 52}}); + CHECK(toString(result.errors[0]) == "Type '{ t: { v: string } }' could not be converted into 'U'"); +} + TEST_CASE_FIXTURE(Fixture, "mutually_recursive_generic_aliases") { CheckResult result = check(R"( @@ -360,6 +453,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") LUAU_REQUIRE_ERROR_COUNT(1, result); auto e = get(result.errors[0]); + REQUIRE(e != nullptr); CHECK_EQ("Node?", toString(e->givenType)); CHECK_EQ("Node", toString(e->wantedType)); } diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index bc55940e..f4766104 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauSpecialTypesAsterisked) + TEST_SUITE_BEGIN("TypeInferAnyError"); TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") @@ -94,7 +96,10 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("*unknown*", toString(requireType("a"))); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(requireType("a"))); + else + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") @@ -110,7 +115,10 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("*unknown*", toString(requireType("a"))); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(requireType("a"))); + else + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") @@ -225,7 +233,10 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - CHECK_EQ("*unknown*", toString(requireType("a"))); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(requireType("a"))); + else + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -234,7 +245,10 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - CHECK_EQ("*unknown*", toString(requireType("a"))); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(requireType("a"))); + else + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 2f0266ec..10da0efa 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -9,6 +9,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauSpecialTypesAsterisked); TEST_SUITE_BEGIN("BuiltinTests"); @@ -555,6 +556,29 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_correctly_ordered_types") CHECK_EQ(tm->givenType, typeChecker.numberType); } +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_tostring_specifier") +{ + CheckResult result = check(R"( + --!strict + string.format("%* %* %* %*", "string", 1, true, function() end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_tostring_specifier_type_constraint") +{ + CheckResult result = check(R"( + local function f(x): string + local _ = string.format("%*", x) + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(string) -> string", toString(requireType("f"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall") { CheckResult result = check(R"( @@ -925,7 +949,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_returns_false_and_string_iff_it_knows )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(nil) -> nil", toString(requireType("f"))); + CHECK_EQ("(nil) -> (never, ...never)", toString(requireType("f"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") @@ -952,7 +976,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("string", toString(requireType("b"))); CHECK_EQ("boolean", toString(requireType("c"))); - CHECK_EQ("*unknown*", toString(requireType("d"))); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(requireType("d"))); + else + CHECK_EQ("", toString(requireType("d"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") @@ -965,8 +992,8 @@ a:b() a:b({}) )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 0}, {2, 5}}, CountMismatch{2, 0}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{3, 0}, {3, 5}}, CountMismatch{2, 1}})); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function expects 2 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function expects 2 arguments, but only 1 is specified"); } TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") @@ -1008,4 +1035,249 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c = string.gmatch("This is a string", "(.()(%a+))")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types2") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c = ("This is a string"):gmatch("(.()(%a+))")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c, d = string.gmatch("T(his)() is a string", ".")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 4); + + CHECK_EQ(toString(requireType("a")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c, d = string.gmatch("T(his) is a string", "((.)%b()())")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 3); + CHECK_EQ(acm->actual, 4); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "string"); + CHECK_EQ(toString(requireType("c")), "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_ignored") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c = string.gmatch("T(his)() is a string", "(T[()])()")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 3); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b = string.gmatch("[[[", "()([[])")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "number"); + CHECK_EQ(toString(requireType("b")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_leading_end_bracket_is_part_of_set") +{ + CheckResult result = check(R"END( + -- An immediate right-bracket following a left-bracket is included within the set; + -- thus, '[]]'' is the set containing ']', and '[]' is an invalid set missing an enclosing + -- right-bracket. We detect an invalid set in this case and fall back to to default gmatch + -- typing. + local foo = string.gmatch("T[hi%]s]]]() is a string", "([]s)") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_invalid_pattern_fallback_to_builtin") +{ + CheckResult result = check(R"END( + local foo = string.gmatch("T(his)() is a string", ")") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_invalid_pattern_fallback_to_builtin2") +{ + CheckResult result = check(R"END( + local foo = string.gmatch("T(his)() is a string", "[") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types") +{ + ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c = string.match("This is a string", "(.()(%a+))") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") +{ + ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c = string.match("This is a string", "(.()(%a+))", "this should be a number") + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(toString(tm->wantedType), "number?"); + CHECK_EQ(toString(tm->givenType), "string"); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") +{ + ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; + CheckResult result = check(R"END( + local d, e, a, b, c = string.find("This is a string", "(.()(%a+))") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + CHECK_EQ(toString(requireType("d")), "number?"); + CHECK_EQ(toString(requireType("e")), "number?"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types2") +{ + ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; + CheckResult result = check(R"END( + local d, e, a, b, c = string.find("This is a string", "(.()(%a+))", "this should be a number") + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(toString(tm->wantedType), "number?"); + CHECK_EQ(toString(tm->givenType), "string"); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + CHECK_EQ(toString(requireType("d")), "number?"); + CHECK_EQ(toString(requireType("e")), "number?"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") +{ + ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; + CheckResult result = check(R"END( + local d, e, a, b, c = string.find("This is a string", "(.()(%a+))", 1, "this should be a bool") + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(toString(tm->wantedType), "boolean?"); + CHECK_EQ(toString(tm->givenType), "string"); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + CHECK_EQ(toString(requireType("d")), "number?"); + CHECK_EQ(toString(requireType("e")), "number?"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") +{ + ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true}; + CheckResult result = check(R"END( + local d, e, a, b = string.find("This is a string", "(.()(%a+))", 1, true) + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 4); + + CHECK_EQ(toString(requireType("d")), "number?"); + CHECK_EQ(toString(requireType("e")), "number?"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 401a6c64..074c86c3 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -14,6 +14,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauSpecialTypesAsterisked); TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -677,11 +678,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toposort_doesnt_break_mutual_recursion") TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") { - ScopedFastFlag sff[]{ - {"LuauReturnTypeInferenceInNonstrict", true}, - {"LuauLowerBoundsCalculation", true}, - }; - CheckResult result = check(R"( --!nonstrict @@ -690,7 +686,7 @@ TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") end return function() - return f() + return f():andThen() end )"); @@ -817,18 +813,14 @@ TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields TEST_CASE_FIXTURE(BuiltinsFixture, "calling_function_with_anytypepack_doesnt_leak_free_types") { - ScopedFastFlag sff[]{ - {"LuauReturnTypeInferenceInNonstrict", true}, - {"LuauLowerBoundsCalculation", true}, - }; - CheckResult result = check(R"( --!nonstrict - function Test(a): ...any + function Test(a) return 1, "" end + local tab = {} table.insert(tab, Test(1)); )"); @@ -916,13 +908,19 @@ TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") REQUIRE(tm1); CHECK_EQ("(string) -> number", toString(tm1->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm1->givenType)); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("(string, *error-type*) -> number", toString(tm1->givenType)); + else + CHECK_EQ("(string, ) -> number", toString(tm1->givenType)); auto tm2 = get(result.errors[1]); REQUIRE(tm2); CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm2->givenType)); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("(string, *error-type*) -> number", toString(tm2->givenType)); + else + CHECK_EQ("(string, ) -> number", toString(tm2->givenType)); } TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") @@ -1535,10 +1533,20 @@ function t:b() return 2 end -- not OK )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type '(*unknown*) -> number' could not be converted into '() -> number' + if (FFlag::LuauSpecialTypesAsterisked) + { + CHECK_EQ(R"(Type '(*error-type*) -> number' could not be converted into '() -> number' caused by: Argument count mismatch. Function expects 1 argument, but none are specified)", - toString(result.errors[0])); + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '() -> number' could not be converted into '() -> number' +caused by: + Argument count mismatch. Function expects 1 argument, but none are specified)", + toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") @@ -1625,21 +1633,6 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_failure_in_function_return_type") CHECK(nullptr != get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") -{ - ScopedFastFlag sff[]{ - {"LuauReturnTypeInferenceInNonstrict", true}, - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - local function f() return end - local g = function() return f() end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "quantify_constrained_types") { ScopedFastFlag sff[]{ @@ -1692,4 +1685,52 @@ TEST_CASE_FIXTURE(Fixture, "call_o_with_another_argument_after_foo_was_quantifie // TODO: check the normalized type of f } +TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_unknown") +{ + CheckResult result = check(R"( + local function foo(f: (unknown) -> (), x) + f(x) + end + )"); + + CHECK_EQ("((unknown) -> (), a) -> ()", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_infer_parameter_types_for_functions_from_their_call_site") +{ + CheckResult result = check(R"( + local t = {} + + function t.f(x) + return x + end + + t.__index = t + + function g(s) + local q = s.p and s.p.q or nil + return q and t.f(q) or nil + end + + local f = t.f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(a) -> a", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_mutate_the_underlying_head_of_typepack_when_calling_with_self") +{ + CheckResult result = check(R"( + local t = {} + function t:m(x) end + function f(): never return 5 :: never end + t:m(f()) + t:m(f()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index e9e94cfb..a8325727 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,6 +9,9 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauCheckGenericHOFTypes) +LUAU_FASTFLAG(LuauSpecialTypesAsterisked) + using namespace Luau; TEST_SUITE_BEGIN("GenericsTests"); @@ -1001,7 +1004,10 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK_EQ("*unknown*", toString(t0->type)); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(t0->type)); + else + CHECK_EQ("", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -1095,10 +1101,18 @@ local b = sumrec(sum) -- ok local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred )"); - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " - "parameters", - toString(result.errors[0])); + if (FFlag::LuauCheckGenericHOFTypes) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ( + "Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") @@ -1172,10 +1186,6 @@ end) TEST_CASE_FIXTURE(Fixture, "quantify_functions_even_if_they_have_an_explicit_generic") { - ScopedFastFlag sff[] = { - {"LuauAlwaysQuantify", true}, - }; - CheckResult result = check(R"( function foo(f, x: X) return f(x) @@ -1185,4 +1195,23 @@ TEST_CASE_FIXTURE(Fixture, "quantify_functions_even_if_they_have_an_explicit_gen CHECK("((X) -> (a...), X) -> (a...)" == toString(requireType("foo"))); } +TEST_CASE_FIXTURE(Fixture, "do_not_always_instantiate_generic_intersection_types") +{ + ScopedFastFlag sff[] = { + {"LuauMaybeGenericIntersectionTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type Array = { [number]: T } + + type Array_Statics = { + new: () -> Array, + } + + local _Arr : Array & Array_Statics = {} :: Array_Statics + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 1c6fe1d8..9b10092c 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauSpecialTypesAsterisked) + TEST_SUITE_BEGIN("TypeInferLoops"); TEST_CASE_FIXTURE(Fixture, "for_loop") @@ -142,7 +144,10 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") CHECK_EQ(2, result.errors.size()); TypeId p = requireType("p"); - CHECK_EQ("*unknown*", toString(p)); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(p)); + else + CHECK_EQ("", toString(p)); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") @@ -516,7 +521,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") CHECK_EQ(*typeChecker.nilType, *requireType("extra")); } -TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") +TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_strict") { CheckResult result = check(R"( local t = {} @@ -531,6 +536,17 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") CHECK_EQ("Cannot iterate over a table without indexer", ge->message); } +TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_nonstrict") +{ + CheckResult result = check(Mode::Nonstrict, R"( + local t = {} + for k, v in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_iter_metamethod") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index a0f670f1..a1d41339 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -12,6 +12,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauSpecialTypesAsterisked) + TEST_SUITE_BEGIN("TypeInferModules"); TEST_CASE_FIXTURE(BuiltinsFixture, "require") @@ -143,7 +145,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_module_that_does_not_export") auto hootyType = requireType(bModule, "Hooty"); - CHECK_EQ("*unknown*", toString(hootyType)); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(hootyType)); + else + CHECK_EQ("", toString(hootyType)); } TEST_CASE_FIXTURE(BuiltinsFixture, "warn_if_you_try_to_require_a_non_modulescript") @@ -244,7 +249,11 @@ local ModuleA = require(game.A) LUAU_REQUIRE_NO_ERRORS(result); std::optional oty = requireType("ModuleA"); - CHECK_EQ("*unknown*", toString(*oty)); + + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(*oty)); + else + CHECK_EQ("", toString(*oty)); } TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types") @@ -302,6 +311,30 @@ type Rename = typeof(x.x) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_4") +{ + fileResolver.source["game/A"] = R"( +export type Array = {T} +local arrayops = {} +function arrayops.foo(x: Array) end +return arrayops + )"; + + CheckResult result = check(R"( +local arrayops = require(game.A) + +local tbl = {} +tbl.a = 2 +function tbl:foo(b: number, c: number) + -- introduce BoundTypeVar to imported type + arrayops.foo(self._regions) +end +type Table = typeof(tbl) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict") { fileResolver.source["game/A"] = R"( @@ -363,4 +396,21 @@ caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "constrained_anyification_clone_immutable_types") +{ + ScopedFastFlag luauAnyificationMustClone{"LuauAnyificationMustClone", true}; + + fileResolver.source["game/A"] = R"( +return function(...) end + )"; + + fileResolver.source["game/B"] = R"( +local l0 = require(game.A) +return l0 + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index e6174df2..3d6c0193 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -490,8 +490,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") { - ScopedFastFlag sff("LuauCheckLenMT", true); - CheckResult result = check(R"( --!strict local foo = { @@ -871,4 +869,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "equality_operations_succeed_if_any_union_bra CHECK(toString(result2.errors[0]) == "Types Foo and Bar cannot be compared with == because they do not have the same metatable"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_and") +{ + ScopedFastFlag sff{"LuauBinaryNeedsExpectedTypesToo", true}; + + CheckResult result = check(R"( + local x: "a" | "b" | boolean = math.random() > 0.5 and "a" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_or") +{ + ScopedFastFlag sff{"LuauBinaryNeedsExpectedTypesToo", true}; + + CheckResult result = check(R"( + local x: "a" | "b" | boolean = math.random() > 0.5 or "b" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index e1684df7..a06fd749 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -11,6 +11,9 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauDeduceFindMatchReturnTypes) +LUAU_FASTFLAG(LuauSpecialTypesAsterisked) + using namespace Luau; TEST_SUITE_BEGIN("TypeInferPrimitives"); @@ -47,7 +50,10 @@ TEST_CASE_FIXTURE(Fixture, "string_index") REQUIRE(nat); CHECK_EQ("string", toString(nat->ty)); - CHECK_EQ("*unknown*", toString(requireType("t"))); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(requireType("t"))); + else + CHECK_EQ("", toString(requireType("t"))); } TEST_CASE_FIXTURE(Fixture, "string_method") @@ -80,7 +86,10 @@ TEST_CASE_FIXTURE(Fixture, "string_function_other") )"); CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(requireType("p")), "string?"); + if (FFlag::LuauDeduceFindMatchReturnTypes) + CHECK_EQ(toString(requireType("p")), "string"); + else + CHECK_EQ(toString(requireType("p")), "string?"); } TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 059aed2e..01923f38 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -225,7 +225,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") CHECK_EQ("{| x: nil, y: nil |} | {| x: string, y: number |}", toString(requireTypeAtPosition({7, 28}))); } -TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) +TEST_CASE_FIXTURE(BuiltinsFixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) { ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1}; @@ -343,6 +343,20 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } +TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", false}, + }; + + CheckResult result = check(R"( + local function f() return end + local g = function() return f() end + )"); + + LUAU_REQUIRE_ERRORS(result); // Should not have any errors. +} + TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") { ScopedFastFlag sff[] = { @@ -471,7 +485,6 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") { ScopedFastFlag sff[]{ {"LuauLowerBoundsCalculation", true}, - {"LuauNormalizeFlagIsConservative", true}, {"LuauQuantifyConstrained", true}, }; @@ -499,6 +512,17 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); } +TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_any") +{ + CheckResult result = check(R"( + local function foo(f: (any) -> (), x) + f(x) + end + )"); + + CHECK_EQ("((any) -> (), any) -> ()", toString(requireType("foo"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_function_with_no_returns") { ScopedFastFlag sff{"DebugLuauSharedSelf", true}; @@ -518,7 +542,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_f )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Not all codepaths in this function return '{ @metatable T, {| |} }, a...'.", toString(result.errors[0])); + CHECK_EQ("Not all codepaths in this function return 'self, a...'.", toString(result.errors[0])); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 3f5dad3d..8a1cadcf 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,6 +8,7 @@ #include "doctest.h" LUAU_FASTFLAG(LuauLowerBoundsCalculation) +LUAU_FASTFLAG(LuauSpecialTypesAsterisked) using namespace Luau; @@ -272,8 +273,8 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); + CHECK_EQ("never", toString(requireTypeAtPosition({8, 44}))); + CHECK_EQ("never", toString(requireTypeAtPosition({9, 38}))); } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -526,7 +527,10 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28}))); + else + CHECK_EQ("", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") @@ -651,7 +655,7 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } -TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_narrowed_into_nothingness") { CheckResult result = check(R"( local function f(t: {x: number}) @@ -666,7 +670,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_warns_on_no_overlapping_types_onl LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") @@ -1074,7 +1078,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" + CHECK_EQ("never", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" } @@ -1206,6 +1210,24 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") +{ + CheckResult result = check(R"( + local function f(x: unknown) + if type(x) == "string" then + local foo = x + else + local bar = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("unknown", toString(requireTypeAtPosition({5, 28}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") { ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; @@ -1227,4 +1249,19 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_ni CHECK_EQ("number", toString(requireTypeAtPosition({6, 28}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "what_nonsensical_condition") +{ + CheckResult result = check(R"( + local function f(x) + if type(x) == "string" and type(x) == "number" then + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 4a88abee..0a130d49 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -262,7 +262,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { CheckResult result = check(R"( --!strict - local x: { ["<>"] : number } + local x: { ["<>"] : number } x = { ["\n"] = 5 } )"); @@ -476,4 +476,21 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 23}))); } +TEST_CASE_FIXTURE(Fixture, "no_widening_from_callsites") +{ + ScopedFastFlag sff{"LuauReturnsFromCallsitesAreNotWidened", true}; + + CheckResult result = check(R"( + type Direction = "North" | "East" | "West" | "South" + + local function direction(): Direction + return "North" + end + + local d: Direction = direction() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index eead5b30..d9bfc89d 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2990,6 +2990,15 @@ TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra_2") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "expected_indexer_from_table_union") +{ + ScopedFastFlag luauExpectedTableUnionIndexerType{"LuauExpectedTableUnionIndexerType", true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"(local a: {[string]: {number | string}} = {a = {2, 's'}})")); + LUAU_REQUIRE_NO_ERRORS(check(R"(local a: {[string]: {number | string}}? = {a = {2, 's'}})")); + LUAU_REQUIRE_NO_ERRORS(check(R"(local a: {[string]: {[string]: {string?}}?} = {["a"] = {["b"] = {"a", "b"}}})")); +} + TEST_CASE_FIXTURE(Fixture, "prop_access_on_key_whose_types_mismatches") { ScopedFastFlag sff{"LuauReportErrorsOnIndexerKeyMismatch", true}; @@ -3070,4 +3079,100 @@ TEST_CASE_FIXTURE(Fixture, "quantify_even_that_table_was_never_exported_at_all") CHECK_EQ("{| m: ({+ x: a, y: b +}) -> a, n: ({+ x: a, y: b +}) -> b |}", toString(requireType("T"), opts)); } +TEST_CASE_FIXTURE(BuiltinsFixture, "leaking_bad_metatable_errors") +{ + ScopedFastFlag luauIndexSilenceErrors{"LuauIndexSilenceErrors", true}; + + CheckResult result = check(R"( +local a = setmetatable({}, 1) +local b = a.x + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Metatable was not a table", toString(result.errors[0])); + CHECK_EQ("Type 'a' does not have key 'x'", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "scalar_is_a_subtype_of_a_compatible_polymorphic_shape_type") +{ + ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + + CheckResult result = check(R"( + local function f(s) + return s:lower() + end + + f("foo" :: string) + f("bar" :: "bar") + f("baz" :: "bar" | "baz") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type") +{ + ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + + CheckResult result = check(R"( + local function f(s) + return s:absolutely_no_scalar_has_this_method() + end + + f("foo" :: string) + f("bar" :: "bar") + f("baz" :: "bar" | "baz") + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[0])); + CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[1])); + CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[2])); +} + +TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible") +{ + ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + + CheckResult result = check(R"( + local function f(s): string + local foo = s:lower() + return s + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(string) -> string", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible") +{ + ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + + CheckResult result = check(R"( + local function f(s): string + local foo = s:absolutely_no_scalar_has_this_method() + return s + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' +caused by: + The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[0])); + CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index efdfe0b1..80889363 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauSpecialTypesAsterisked); using namespace Luau; @@ -85,20 +86,19 @@ TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") { ScopedFastFlag sff[]{ {"DebugLuauDeferredConstraintResolution", false}, - {"LuauReturnTypeInferenceInNonstrict", true}, {"LuauLowerBoundsCalculation", true}, }; CheckResult result = check(R"( --!nocheck function f(x) - return 5 + return x end -- we get type information even if there's type errors f(1, 2) )"); - CHECK_EQ("(any) -> number", toString(requireType("f"))); + CHECK_EQ("(any) -> (...any)", toString(requireType("f"))); LUAU_REQUIRE_NO_ERRORS(result); } @@ -238,10 +238,20 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") // TODO: Should we assert anything about these tests when DCR is being used? if (!FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ("*unknown*", toString(requireType("c"))); - CHECK_EQ("*unknown*", toString(requireType("d"))); - CHECK_EQ("*unknown*", toString(requireType("e"))); - CHECK_EQ("*unknown*", toString(requireType("f"))); + if (FFlag::LuauSpecialTypesAsterisked) + { + CHECK_EQ("*error-type*", toString(requireType("c"))); + CHECK_EQ("*error-type*", toString(requireType("d"))); + CHECK_EQ("*error-type*", toString(requireType("e"))); + CHECK_EQ("*error-type*", toString(requireType("f"))); + } + else + { + CHECK_EQ("", toString(requireType("c"))); + CHECK_EQ("", toString(requireType("d"))); + CHECK_EQ("", toString(requireType("e"))); + CHECK_EQ("", toString(requireType("f"))); + } } } @@ -355,6 +365,35 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") CHECK(nullptr != get(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "globals") +{ + CheckResult result = check(R"( + --!nonstrict + foo = true + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals2") +{ + CheckResult result = check(R"( + --!nonstrict + foo = function() return 1 end + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("() -> (...any)", toString(tm->wantedType)); + CHECK_EQ("string", toString(tm->givenType)); + CHECK_EQ("() -> (...any)", toString(requireType("foo"))); +} + TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") { CheckResult result = check(R"( @@ -622,7 +661,11 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK_EQ("*unknown*", toString(t0->type)); + + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(t0->type)); + else + CHECK_EQ("", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -1003,4 +1046,27 @@ TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_ )"); } +TEST_CASE_FIXTURE(Fixture, "types stored in astResolvedTypes") +{ + CheckResult result = check(R"( +type alias = typeof("hello") +local function foo(param: alias) +end + )"); + + auto node = findNodeAtPosition(*getMainSourceModule(), {2, 16}); + auto ty = lookupType("alias"); + REQUIRE(node); + REQUIRE(node->is()); + REQUIRE(ty); + + auto func = node->as(); + REQUIRE(func->args.size == 1); + + auto arg = *func->args.begin(); + auto annotation = arg->annotation; + + CHECK_EQ(*getMainModule()->astResolvedTypes.find(annotation), *ty); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 49deae71..e0a0e5b5 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauSpecialTypesAsterisked) + struct TryUnifyFixture : Fixture { TypeArena arena; @@ -121,7 +123,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("a", toString(requireType("a"))); - CHECK_EQ("*unknown*", toString(requireType("b"))); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(requireType("b"))); + else + CHECK_EQ("", toString(requireType("b"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") @@ -136,7 +141,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_con LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("a", toString(requireType("a"))); - CHECK_EQ("*unknown*", toString(requireType("b"))); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(requireType("b"))); + else + CHECK_EQ("", toString(requireType("b"))); CHECK_EQ("number", toString(requireType("c"))); } diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index bcd30498..7aefa00d 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -203,7 +203,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_packs") ), "@test" ); - addGlobalBinding(typeChecker, "bar", + addGlobalBinding(typeChecker, "bar", arena.addType( FunctionTypeVar{ arena.addTypePack({{typeChecker.numberType}, listOfStrings}), diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 2b48133d..8eb485e9 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -7,6 +7,7 @@ #include "doctest.h" LUAU_FASTFLAG(LuauLowerBoundsCalculation) +LUAU_FASTFLAG(LuauSpecialTypesAsterisked) using namespace Luau; @@ -199,7 +200,10 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - CHECK_EQ("*unknown*", toString(requireType("r"))); + if (FFlag::LuauSpecialTypesAsterisked) + CHECK_EQ("*error-type*", toString(requireType("r"))); + else + CHECK_EQ("", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp new file mode 100644 index 00000000..2288db4e --- /dev/null +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -0,0 +1,316 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferUnknownNever"); + +TEST_CASE_FIXTURE(Fixture, "string_subtype_and_unknown_supertype") +{ + CheckResult result = check(R"( + local function f(x: string) + local foo: unknown = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_subtype_and_string_supertype") +{ + CheckResult result = check(R"( + local function f(x: unknown) + local foo: string = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_is_reflexive") +{ + CheckResult result = check(R"( + local function f(x: unknown) + local foo: unknown = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_subtype_and_never_supertype") +{ + CheckResult result = check(R"( + local function f(x: string) + local foo: never = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "never_subtype_and_string_supertype") +{ + CheckResult result = check(R"( + local function f(x: never) + local foo: string = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "never_is_reflexive") +{ + CheckResult result = check(R"( + local function f(x: never) + local foo: never = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_is_optional_because_it_too_encompasses_nil") +{ + CheckResult result = check(R"( + local t: {x: unknown} = {} + )"); +} + +TEST_CASE_FIXTURE(Fixture, "table_with_prop_of_type_never_is_uninhabitable") +{ + CheckResult result = check(R"( + local t: {x: never} = {} + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "table_with_prop_of_type_never_is_also_reflexive") +{ + CheckResult result = check(R"( + local t: {x: never} = {x = 5 :: never} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "array_like_table_of_never_is_inhabitable") +{ + CheckResult result = check(R"( + local t: {never} = {} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable") +{ + CheckResult result = check(R"( + local function f() return "foo", 5 :: never end + + local x, y, z = f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); + CHECK_EQ("never", toString(requireType("y"))); + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable2") +{ + CheckResult result = check(R"( + local function f(): (string, never) return "", 5 :: never end + local function g(): (never, string) return 5 :: never, "" end + + local x1, x2 = f() + local y1, y2 = g() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x1"))); + CHECK_EQ("never", toString(requireType("x2"))); + CHECK_EQ("never", toString(requireType("y1"))); + CHECK_EQ("never", toString(requireType("y2"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_never") +{ + CheckResult result = check(R"( + local x: never = 5 :: never + local z = x.y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "call_never") +{ + CheckResult result = check(R"( + local f: never = 5 :: never + local x, y, z = f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); + CHECK_EQ("never", toString(requireType("y"))); + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_local_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t = 3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_global_which_is_never") +{ + CheckResult result = check(R"( + --!nonstrict + t = 5 :: never + t = "" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_prop_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t.x = 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t[5] = 7 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never") +{ + CheckResult result = check(R"( + for i, v in (5 :: never) do + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "pick_never_from_variadic_type_pack") +{ + CheckResult result = check(R"( + local function f(...: never) + local x, y = (...) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_union_of_tables_for_properties_that_is_never") +{ + CheckResult result = check(R"( + type Disjoint = {foo: never, bar: unknown, tag: "ok"} | {foo: never, baz: unknown, tag: "err"} + local disjoint: Disjoint = {foo = 5 :: never, bar = true, tag = "ok"} + local foo = disjoint.foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_union_of_tables_for_properties_that_is_sorta_never") +{ + CheckResult result = check(R"( + type Disjoint = {foo: string, bar: unknown, tag: "ok"} | {foo: never, baz: unknown, tag: "err"} + local disjoint: Disjoint = {foo = 5 :: never, bar = true, tag = "ok"} + local foo = disjoint.foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "unary_minus_of_never") +{ + CheckResult result = check(R"( + local x = -(5 :: never) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "length_of_never") +{ + ScopedFastFlag sff{"LuauNeverTypesAndOperatorsInference", true}; + + CheckResult result = check(R"( + local x = #({} :: never) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators") +{ + ScopedFastFlag sff[]{ + {"LuauUnknownAndNeverType", true}, + {"LuauNeverTypesAndOperatorsInference", true}, + }; + + CheckResult result = check(R"( + local function ord(x: nil, y) + return x ~= nil and x > y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(nil, a) -> boolean", toString(requireType("ord"))); +} + +TEST_CASE_FIXTURE(Fixture, "math_operators_and_never") +{ + ScopedFastFlag sff[]{ + {"LuauUnknownAndNeverType", true}, + {"LuauNeverTypesAndOperatorsInference", true}, + }; + + CheckResult result = check(R"( + local function mul(x: nil, y) + return x ~= nil and x * y -- infers boolean | never, which is normalized into boolean + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(nil, a) -> boolean", toString(requireType("mul"))); +} + +TEST_SUITE_END(); diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 8a5a65fe..1087a24c 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -25,7 +25,7 @@ struct TypePackFixture TypePackId freshTypePack() { - typePacks.emplace_back(new TypePackVar{Unifiable::Free{{}}}); + typePacks.emplace_back(new TypePackVar{Unifiable::Free{TypeLevel{}}}); return typePacks.back().get(); } @@ -199,8 +199,6 @@ TEST_CASE_FIXTURE(TypePackFixture, "std_distance") TEST_CASE("content_reassignment") { - ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; - TypePackVar myError{Unifiable::Error{}, /*presistent*/ true}; TypeArena arena; diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 4f8fc502..f4670048 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -418,8 +418,6 @@ TEST_CASE("proof_that_isBoolean_uses_all_of") TEST_CASE("content_reassignment") { - ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; - TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; myAny.normal = true; myAny.documentationSymbol = "@global/any"; diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 385a0450..b2dcaf94 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -727,16 +727,20 @@ assert((function() local abs = math.abs function foo(...) return abs(...) end re -- NOTE: getfenv breaks fastcalls for the remainder of the source! hence why this is delayed until the end function testgetfenv() getfenv() + + -- declare constant so that at O2 this test doesn't interfere with constant folding which we can't deoptimize + local negfive negfive = -5 + -- getfenv breaks fastcalls (we assume we can't rely on knowing the semantics), but behavior shouldn't change - assert((function() return math.abs(-5) end)() == 5) - assert((function() local abs = math.abs return abs(-5) end)() == 5) - assert((function() local abs = math.abs function foo() return abs(-5) end return foo() end)() == 5) + assert((function() return math.abs(negfive) end)() == 5) + assert((function() local abs = math.abs return abs(negfive) end)() == 5) + assert((function() local abs = math.abs function foo() return abs(negfive) end return foo() end)() == 5) -- ... unless you actually reassign the function :D getfenv().math = { abs = function(n) return n*n end } - assert((function() return math.abs(-5) end)() == 25) - assert((function() local abs = math.abs return abs(-5) end)() == 25) - assert((function() local abs = math.abs function foo() return abs(-5) end return foo() end)() == 25) + assert((function() return math.abs(negfive) end)() == 25) + assert((function() local abs = math.abs return abs(negfive) end)() == 25) + assert((function() local abs = math.abs function foo() return abs(negfive) end return foo() end)() == 25) end -- you need to have enough arguments and arguments of the right type; if you don't, we'll fallback to the regular code. This checks coercions diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index 0b5aafed..b13e7a82 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -295,8 +295,9 @@ end -- testing syntax limits +local syntaxdepth = if limitedstack then 200 else 1000 local function testrep (init, rep) - local s = "local a; "..init .. string.rep(rep, 300) + local s = "local a; "..init .. string.rep(rep, syntaxdepth) local a,b = loadstring(s) assert(not a) -- and string.find(b, "syntax levels")) end @@ -380,11 +381,19 @@ assert(ecall(function() return "a" + "b" end) == "attempt to perform arithmetic assert(ecall(function() return 1 > nil end) == "attempt to compare nil < number") -- note reversed order (by design) assert(ecall(function() return "a" <= 5 end) == "attempt to compare string <= number") -assert(ecall(function() local t = {} setmetatable(t, { __newindex = function(t,i,v) end }) t[nil] = 2 end) == "table index is nil") +assert(ecall(function() local t = {} t[nil] = 2 end) == "table index is nil") +assert(ecall(function() local t = {} t[0/0] = 2 end) == "table index is NaN") -- for loop type errors assert(ecall(function() for i='a',2 do end end) == "invalid 'for' initial value (number expected, got string)") assert(ecall(function() for i=1,'a' do end end) == "invalid 'for' limit (number expected, got string)") assert(ecall(function() for i=1,2,'a' do end end) == "invalid 'for' step (number expected, got string)") +-- method call errors +assert(ecall(function() ({}):foo() end) == "attempt to call missing method 'foo' of table") +assert(ecall(function() (""):foo() end) == "attempt to call missing method 'foo' of string") +assert(ecall(function() (42):foo() end) == "attempt to index number with 'foo'") +assert(ecall(function() ({foo=42}):foo() end) == "attempt to call a number value") +assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = {} ud:foo() end) == "attempt to call missing method 'foo' of userdata") + return('OK') diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua index 42f1beda..6dcdbf0e 100644 --- a/tests/conformance/events.lua +++ b/tests/conformance/events.lua @@ -424,4 +424,57 @@ do assert(not ok and err:match("table or string expected")) end +-- verify that NaN/nil keys are passed to __newindex even though table assignment with them anywhere in the chain fails +do + assert(pcall(function() local t = {} t[nil] = 5 end) == false) + assert(pcall(function() local t = {} setmetatable(t, { __newindex = {} }) t[nil] = 5 end) == false) + assert(pcall(function() local t = {} setmetatable(t, { __newindex = function() end }) t[nil] = 5 end) == true) + + assert(pcall(function() local t = {} t[0/0] = 5 end) == false) + assert(pcall(function() local t = {} setmetatable(t, { __newindex = {} }) t[0/0] = 5 end) == false) + assert(pcall(function() local t = {} setmetatable(t, { __newindex = function() end }) t[0/0] = 5 end) == true) +end + +-- verify that __newindex gets called for frozen tables but only if the assignment is to a key absent from the table +do + local ni = {} + local t = table.create(2) + + t[1] = 42 + -- t[2] is semantically absent with storage allocated for it + + t.a = 1 + t.b = 2 + t.b = nil -- this sets 'b' value to nil but leaves key as is to exercise more internal paths -- no observable behavior change expected between b and other absent keys + + setmetatable(t, { __newindex = function(_, k, v) + assert(v == 42) + table.insert(ni, k) + end }) + table.freeze(t) + + -- "redundant" combinations are there to test all three of SETTABLEN/SETTABLEKS/SETTABLE + assert(pcall(function() t.a = 42 end) == false) + assert(pcall(function() t[1] = 42 end) == false) + assert(pcall(function() local key key = "a" t[key] = 42 end) == false) + assert(pcall(function() local key key = 1 t[key] = 42 end) == false) + + -- now repeat the same for keys absent from the table: b (semantically absent), c (physically absent), 2 (semantically absent), 3 (physically absent) + assert(pcall(function() t.b = 42 end) == true) + assert(pcall(function() t.c = 42 end) == true) + assert(pcall(function() local key key = "b" t[key] = 42 end) == true) + assert(pcall(function() local key key = "c" t[key] = 42 end) == true) + assert(pcall(function() t[2] = 42 end) == true) + assert(pcall(function() t[3] = 42 end) == true) + assert(pcall(function() local key key = 2 t[key] = 42 end) == true) + assert(pcall(function() local key key = 3 t[key] = 42 end) == true) + + -- validate the assignment sequence + local ei = { "b", "c", "b", "c", 2, 3, 2, 3 } + assert(#ni == #ei) + for k,v in ni do + assert(ei[k] == v) + end +end + return 'OK' diff --git a/tests/conformance/strings.lua b/tests/conformance/strings.lua index 98a5721a..c87cf15c 100644 --- a/tests/conformance/strings.lua +++ b/tests/conformance/strings.lua @@ -130,6 +130,26 @@ assert(string.format('"-%20s.20s"', string.rep("%", 2000)) == -- longest number that can be formated assert(string.len(string.format('%99.99f', -1e308)) >= 100) +local function return_one_thing() + return "hi" +end +local function return_two_nils() + return nil, nil +end + +assert(string.format("%*", return_one_thing()) == "hi") +assert(string.format("%* %*", return_two_nils()) == "nil nil") +assert(pcall(function() + string.format("%* %* %*", return_two_nils()) +end) == false) + +assert(string.format("%*", "a\0b\0c") == "a\0b\0c") +assert(string.format("%*", string.rep("doge", 3000)) == string.rep("doge", 3000)) + +assert(pcall(function() + string.format("%#*", "bad form") +end) == false) + assert(loadstring("return 1\n--comentário sem EOL no final")() == 1) diff --git a/tests/conformance/nextvar.lua b/tests/conformance/tables.lua similarity index 96% rename from tests/conformance/nextvar.lua rename to tests/conformance/tables.lua index 93c4ddf7..0eff8540 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/tables.lua @@ -592,4 +592,24 @@ do assert(countud() == 3) end +-- test __newindex-as-a-table indirection: this had memory safety bugs in Lua 5.1.0 +do + local hit = false + + local grandparent = {} + grandparent.__newindex = function(s,k,v) + assert(k == "foo" and v == 10) + hit = true + end + + local parent = {} + parent.__newindex = parent + setmetatable(parent, grandparent) + + local child = setmetatable({}, parent) + child.foo = 10 + + assert(hit and child.foo == nil and parent.foo == nil) +end + return"OK" diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 22d6adfc..6164e929 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -101,4 +101,20 @@ if vector_size == 4 then assert(vector(1, 2, 3, 4).W == 4) end +-- negative zero should hash the same as zero +-- note: our earlier test only really checks the low hash bit, so in absence of perfect avalanche it's insufficient +do + local larget = {} + for i = 1, 2^14 do + larget[vector(0, 0, i)] = true + end + + larget[vector(0, 0, 0)] = 42 + + assert(larget[vector(0, 0, 0)] == 42) + assert(larget[vector(0, 0, -0)] == 42) + assert(larget[vector(0, -0, 0)] == 42) + assert(larget[vector(-0, 0, 0)] == 42) +end + return 'OK' diff --git a/tests/main.cpp b/tests/main.cpp index 2af9f702..3e480c9f 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -23,10 +23,13 @@ #include -// Indicates if verbose output is enabled. -// Currently, this enables output from lua's 'print', but other verbose output could be enabled eventually. +// Indicates if verbose output is enabled; can be overridden via --verbose +// Currently, this enables output from 'print', but other verbose output could be enabled eventually. bool verbose = false; +// Default optimization level for conformance test; can be overridden via -On +int optimizationLevel = 1; + static bool skipFastFlag(const char* flagName) { if (strncmp(flagName, "Test", 4) == 0) @@ -60,7 +63,7 @@ static int testAssertionHandler(const char* expr, const char* file, int line, co if (debuggerPresent()) LUAU_DEBUGBREAK(); - ADD_FAIL_AT(file, line, "Assertion failed: ", expr); + ADD_FAIL_AT(file, line, "Assertion failed: ", std::string(expr)); return 1; } @@ -249,6 +252,15 @@ int main(int argc, char** argv) verbose = true; } + int level = -1; + if (doctest::parseIntOption(argc, argv, "-O", doctest::option_int, level)) + { + if (level < 0 || level > 2) + std::cerr << "Optimization level must be between 0 and 2 inclusive." << std::endl; + else + optimizationLevel = level; + } + if (std::vector flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags)) setFastFlags(flags); @@ -279,11 +291,10 @@ int main(int argc, char** argv) if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h")) { printf("Additional command line options:\n"); + printf(" -O[n] Changes default optimization level (1) for conformance runs\n"); printf(" --verbose Enables verbose output (e.g. lua 'print' statements)\n"); printf(" --fflags= Sets specified fast flags\n"); printf(" --list-fflags List all fast flags\n"); } return result; } - - diff --git a/tools/faillist.txt b/tools/faillist.txt new file mode 100644 index 00000000..6e93345b --- /dev/null +++ b/tools/faillist.txt @@ -0,0 +1,982 @@ +AnnotationTests.as_expr_does_not_propagate_type_info +AnnotationTests.as_expr_is_bidirectional +AnnotationTests.as_expr_warns_on_unrelated_cast +AnnotationTests.builtin_types_are_not_exported +AnnotationTests.cannot_use_nonexported_type +AnnotationTests.cloned_interface_maintains_pointers_between_definitions +AnnotationTests.define_generic_type_alias +AnnotationTests.duplicate_type_param_name +AnnotationTests.for_loop_counter_annotation_is_checked +AnnotationTests.function_return_annotations_are_checked +AnnotationTests.generic_aliases_are_cloned_properly +AnnotationTests.interface_types_belong_to_interface_arena +AnnotationTests.luau_ice_triggers_an_ice +AnnotationTests.luau_ice_triggers_an_ice_exception_with_flag +AnnotationTests.luau_ice_triggers_an_ice_exception_with_flag_handler +AnnotationTests.luau_ice_triggers_an_ice_handler +AnnotationTests.luau_print_is_magic_if_the_flag_is_set +AnnotationTests.luau_print_is_not_special_without_the_flag +AnnotationTests.occurs_check_on_cyclic_intersection_typevar +AnnotationTests.occurs_check_on_cyclic_union_typevar +AnnotationTests.self_referential_type_alias +AnnotationTests.too_many_type_params +AnnotationTests.two_type_params +AnnotationTests.type_annotations_inside_function_bodies +AnnotationTests.type_assertion_expr +AnnotationTests.unknown_type_reference_generates_error +AnnotationTests.use_type_required_from_another_file +AstQuery.last_argument_function_call_type +AstQuery::getDocumentationSymbolAtPosition.binding +AstQuery::getDocumentationSymbolAtPosition.event_callback_arg +AstQuery::getDocumentationSymbolAtPosition.overloaded_fn +AstQuery::getDocumentationSymbolAtPosition.prop +AutocompleteTest.argument_types +AutocompleteTest.arguments_to_global_lambda +AutocompleteTest.as_types +AutocompleteTest.autocomplete_boolean_singleton +AutocompleteTest.autocomplete_end_with_fn_exprs +AutocompleteTest.autocomplete_end_with_lambda +AutocompleteTest.autocomplete_first_function_arg_expected_type +AutocompleteTest.autocomplete_for_in_middle_keywords +AutocompleteTest.autocomplete_for_middle_keywords +AutocompleteTest.autocomplete_if_else_regression +AutocompleteTest.autocomplete_if_middle_keywords +AutocompleteTest.autocomplete_ifelse_expressions +AutocompleteTest.autocomplete_on_string_singletons +AutocompleteTest.autocomplete_oop_implicit_self +AutocompleteTest.autocomplete_repeat_middle_keyword +AutocompleteTest.autocomplete_string_singleton_equality +AutocompleteTest.autocomplete_string_singleton_escape +AutocompleteTest.autocomplete_string_singletons +AutocompleteTest.autocomplete_until_expression +AutocompleteTest.autocomplete_until_in_repeat +AutocompleteTest.autocomplete_while_middle_keywords +AutocompleteTest.autocompleteProp_index_function_metamethod_is_variadic +AutocompleteTest.bias_toward_inner_scope +AutocompleteTest.cyclic_table +AutocompleteTest.do_compatible_self_calls +AutocompleteTest.do_not_overwrite_context_sensitive_kws +AutocompleteTest.do_not_suggest_internal_module_type +AutocompleteTest.do_wrong_compatible_self_calls +AutocompleteTest.dont_offer_any_suggestions_from_within_a_broken_comment +AutocompleteTest.dont_offer_any_suggestions_from_within_a_broken_comment_at_the_very_end_of_the_file +AutocompleteTest.dont_offer_any_suggestions_from_within_a_comment +AutocompleteTest.dont_suggest_local_before_its_definition +AutocompleteTest.function_expr_params +AutocompleteTest.function_in_assignment_has_parentheses +AutocompleteTest.function_in_assignment_has_parentheses_2 +AutocompleteTest.function_parameters +AutocompleteTest.function_result_passed_to_function_has_parentheses +AutocompleteTest.generic_types +AutocompleteTest.get_suggestions_for_the_very_start_of_the_script +AutocompleteTest.global_function_params +AutocompleteTest.global_functions_are_not_scoped_lexically +AutocompleteTest.if_then_else_elseif_completions +AutocompleteTest.if_then_else_full_keywords +AutocompleteTest.keyword_methods +AutocompleteTest.keyword_types +AutocompleteTest.library_non_self_calls_are_fine +AutocompleteTest.library_self_calls_are_invalid +AutocompleteTest.local_function +AutocompleteTest.local_function_params +AutocompleteTest.local_functions_fall_out_of_scope +AutocompleteTest.method_call_inside_function_body +AutocompleteTest.module_type_members +AutocompleteTest.nested_member_completions +AutocompleteTest.nested_recursive_function +AutocompleteTest.no_function_name_suggestions +AutocompleteTest.no_incompatible_self_calls +AutocompleteTest.no_incompatible_self_calls_2 +AutocompleteTest.no_incompatible_self_calls_on_class +AutocompleteTest.no_wrong_compatible_self_calls_with_generics +AutocompleteTest.recursive_function +AutocompleteTest.recursive_function_global +AutocompleteTest.recursive_function_local +AutocompleteTest.return_types +AutocompleteTest.sometimes_the_metatable_is_an_error +AutocompleteTest.source_module_preservation_and_invalidation +AutocompleteTest.statement_between_two_statements +AutocompleteTest.stop_at_first_stat_when_recommending_keywords +AutocompleteTest.string_prim_non_self_calls_are_avoided +AutocompleteTest.string_prim_self_calls_are_fine +AutocompleteTest.suggest_external_module_type +AutocompleteTest.table_intersection +AutocompleteTest.table_union +AutocompleteTest.type_correct_argument_type_suggestion +AutocompleteTest.type_correct_expected_argument_type_pack_suggestion +AutocompleteTest.type_correct_expected_argument_type_suggestion +AutocompleteTest.type_correct_expected_argument_type_suggestion_optional +AutocompleteTest.type_correct_expected_argument_type_suggestion_self +AutocompleteTest.type_correct_expected_return_type_pack_suggestion +AutocompleteTest.type_correct_expected_return_type_suggestion +AutocompleteTest.type_correct_full_type_suggestion +AutocompleteTest.type_correct_function_no_parenthesis +AutocompleteTest.type_correct_function_return_types +AutocompleteTest.type_correct_function_type_suggestion +AutocompleteTest.type_correct_keywords +AutocompleteTest.type_correct_local_type_suggestion +AutocompleteTest.type_correct_sealed_table +AutocompleteTest.type_correct_suggestion_for_overloads +AutocompleteTest.type_correct_suggestion_in_argument +AutocompleteTest.unsealed_table +AutocompleteTest.unsealed_table_2 +AutocompleteTest.user_defined_local_functions_in_own_definition +BuiltinTests.aliased_string_format +BuiltinTests.assert_removes_falsy_types +BuiltinTests.assert_removes_falsy_types2 +BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type +BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy +BuiltinTests.bad_select_should_not_crash +BuiltinTests.builtin_tables_sealed +BuiltinTests.coroutine_resume_anything_goes +BuiltinTests.coroutine_wrap_anything_goes +BuiltinTests.debug_info_is_crazy +BuiltinTests.debug_traceback_is_crazy +BuiltinTests.dont_add_definitions_to_persistent_types +BuiltinTests.find_capture_types +BuiltinTests.find_capture_types2 +BuiltinTests.find_capture_types3 +BuiltinTests.gcinfo +BuiltinTests.getfenv +BuiltinTests.global_singleton_types_are_sealed +BuiltinTests.gmatch_capture_types +BuiltinTests.gmatch_capture_types2 +BuiltinTests.gmatch_capture_types_balanced_escaped_parens +BuiltinTests.gmatch_capture_types_default_capture +BuiltinTests.gmatch_capture_types_invalid_pattern_fallback_to_builtin +BuiltinTests.gmatch_capture_types_invalid_pattern_fallback_to_builtin2 +BuiltinTests.gmatch_capture_types_leading_end_bracket_is_part_of_set +BuiltinTests.gmatch_capture_types_parens_in_sets_are_ignored +BuiltinTests.gmatch_capture_types_set_containing_lbracket +BuiltinTests.gmatch_definition +BuiltinTests.ipairs_iterator_should_infer_types_and_type_check +BuiltinTests.lua_51_exported_globals_all_exist +BuiltinTests.match_capture_types +BuiltinTests.match_capture_types2 +BuiltinTests.math_max_checks_for_numbers +BuiltinTests.math_max_variatic +BuiltinTests.math_things_are_defined +BuiltinTests.next_iterator_should_infer_types_and_type_check +BuiltinTests.no_persistent_typelevel_change +BuiltinTests.os_time_takes_optional_date_table +BuiltinTests.pairs_iterator_should_infer_types_and_type_check +BuiltinTests.see_thru_select +BuiltinTests.see_thru_select_count +BuiltinTests.select_on_variadic +BuiltinTests.select_slightly_out_of_range +BuiltinTests.select_way_out_of_range +BuiltinTests.select_with_decimal_argument_is_rounded_down +BuiltinTests.select_with_variadic_typepack_tail +BuiltinTests.select_with_variadic_typepack_tail_and_string_head +BuiltinTests.set_metatable_needs_arguments +BuiltinTests.setmetatable_should_not_mutate_persisted_types +BuiltinTests.setmetatable_unpacks_arg_types_correctly +BuiltinTests.sort +BuiltinTests.sort_with_bad_predicate +BuiltinTests.sort_with_predicate +BuiltinTests.string_format_arg_count_mismatch +BuiltinTests.string_format_arg_types_inference +BuiltinTests.string_format_as_method +BuiltinTests.string_format_correctly_ordered_types +BuiltinTests.string_format_report_all_type_errors_at_correct_positions +BuiltinTests.string_format_use_correct_argument +BuiltinTests.string_format_use_correct_argument2 +BuiltinTests.string_lib_self_noself +BuiltinTests.table_concat_returns_string +BuiltinTests.table_dot_remove_optionally_returns_generic +BuiltinTests.table_freeze_is_generic +BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload +BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload +BuiltinTests.table_pack +BuiltinTests.table_pack_reduce +BuiltinTests.table_pack_variadic +BuiltinTests.thread_is_a_type +BuiltinTests.tonumber_returns_optional_number_type +BuiltinTests.tonumber_returns_optional_number_type2 +BuiltinTests.xpcall +DefinitionTests.class_definition_function_prop +DefinitionTests.declaring_generic_functions +DefinitionTests.definition_file_class_function_args +DefinitionTests.definition_file_classes +DefinitionTests.definition_file_loading +DefinitionTests.single_class_type_identity_in_global_types +FrontendTest.accumulate_cached_errors +FrontendTest.accumulate_cached_errors_in_consistent_order +FrontendTest.any_annotation_breaks_cycle +FrontendTest.ast_node_at_position +FrontendTest.automatically_check_cyclically_dependent_scripts +FrontendTest.automatically_check_dependent_scripts +FrontendTest.check_without_builtin_next +FrontendTest.clearStats +FrontendTest.cycle_detection_between_check_and_nocheck +FrontendTest.cycle_detection_disabled_in_nocheck +FrontendTest.cycle_error_paths +FrontendTest.cycle_errors_can_be_fixed +FrontendTest.cycle_incremental_type_surface +FrontendTest.cycle_incremental_type_surface_longer +FrontendTest.dont_recheck_script_that_hasnt_been_marked_dirty +FrontendTest.dont_reparse_clean_file_when_linting +FrontendTest.environments +FrontendTest.ignore_require_to_nonexistent_file +FrontendTest.imported_table_modification_2 +FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded +FrontendTest.no_use_after_free_with_type_fun_instantiation +FrontendTest.nocheck_cycle_used_by_checked +FrontendTest.nocheck_modules_are_typed +FrontendTest.produce_errors_for_unchanged_file_with_a_syntax_error +FrontendTest.re_report_type_error_in_required_file +FrontendTest.recheck_if_dependent_script_is_dirty +FrontendTest.reexport_cyclic_type +FrontendTest.reexport_type_alias +FrontendTest.report_require_to_nonexistent_file +FrontendTest.report_syntax_error_in_required_file +FrontendTest.reports_errors_from_multiple_sources +FrontendTest.stats_are_not_reset_between_checks +FrontendTest.trace_requires_in_nonstrict_mode +GenericsTests.apply_type_function_nested_generics1 +GenericsTests.apply_type_function_nested_generics2 +GenericsTests.better_mismatch_error_messages +GenericsTests.bound_tables_do_not_clone_original_fields +GenericsTests.check_generic_typepack_function +GenericsTests.check_mutual_generic_functions +GenericsTests.correctly_instantiate_polymorphic_member_functions +GenericsTests.do_not_always_instantiate_generic_intersection_types +GenericsTests.do_not_infer_generic_functions +GenericsTests.dont_unify_bound_types +GenericsTests.duplicate_generic_type_packs +GenericsTests.duplicate_generic_types +GenericsTests.error_detailed_function_mismatch_generic_pack +GenericsTests.error_detailed_function_mismatch_generic_types +GenericsTests.factories_of_generics +GenericsTests.function_arguments_can_be_polytypes +GenericsTests.function_results_can_be_polytypes +GenericsTests.generic_argument_count_too_few +GenericsTests.generic_argument_count_too_many +GenericsTests.generic_factories +GenericsTests.generic_functions_dont_cache_type_parameters +GenericsTests.generic_functions_in_types +GenericsTests.generic_functions_should_be_memory_safe +GenericsTests.generic_table_method +GenericsTests.generic_type_pack_parentheses +GenericsTests.generic_type_pack_syntax +GenericsTests.generic_type_pack_unification1 +GenericsTests.generic_type_pack_unification2 +GenericsTests.generic_type_pack_unification3 +GenericsTests.infer_generic_function_function_argument +GenericsTests.infer_generic_function_function_argument_overloaded +GenericsTests.infer_generic_lib_function_function_argument +GenericsTests.infer_generic_property +GenericsTests.inferred_local_vars_can_be_polytypes +GenericsTests.instantiate_cyclic_generic_function +GenericsTests.instantiate_generic_function_in_assignments +GenericsTests.instantiate_generic_function_in_assignments2 +GenericsTests.instantiated_function_argument_names +GenericsTests.instantiation_sharing_types +GenericsTests.local_vars_can_be_instantiated_polytypes +GenericsTests.mutable_state_polymorphism +GenericsTests.no_stack_overflow_from_quantifying +GenericsTests.properties_can_be_instantiated_polytypes +GenericsTests.properties_can_be_polytypes +GenericsTests.rank_N_types_via_typeof +GenericsTests.reject_clashing_generic_and_pack_names +GenericsTests.self_recursive_instantiated_param +GenericsTests.variadic_generics +IntersectionTypes.argument_is_intersection +IntersectionTypes.error_detailed_intersection_all +IntersectionTypes.error_detailed_intersection_part +IntersectionTypes.fx_intersection_as_argument +IntersectionTypes.fx_union_as_argument_fails +IntersectionTypes.index_on_an_intersection_type_with_mixed_types +IntersectionTypes.index_on_an_intersection_type_with_one_part_missing_the_property +IntersectionTypes.index_on_an_intersection_type_with_one_property_of_type_any +IntersectionTypes.index_on_an_intersection_type_with_property_guaranteed_to_exist +IntersectionTypes.index_on_an_intersection_type_works_at_arbitrary_depth +IntersectionTypes.no_stack_overflow_from_flattenintersection +IntersectionTypes.overload_is_not_a_function +IntersectionTypes.select_correct_union_fn +IntersectionTypes.should_still_pick_an_overload_whose_arguments_are_unions +IntersectionTypes.table_intersection_setmetatable +IntersectionTypes.table_intersection_write +IntersectionTypes.table_intersection_write_sealed +IntersectionTypes.table_intersection_write_sealed_indirect +IntersectionTypes.table_write_sealed_indirect +isSubtype.functions_and_any +isSubtype.intersection_of_functions_of_different_arities +isSubtype.intersection_of_tables +isSubtype.table_with_any_prop +isSubtype.table_with_table_prop +isSubtype.tables +Linter.DeprecatedApi +Linter.TableOperations +ModuleTests.builtin_types_point_into_globalTypes_arena +ModuleTests.clone_self_property +ModuleTests.deepClone_cyclic_table +NonstrictModeTests.delay_function_does_not_require_its_argument_to_return_anything +NonstrictModeTests.for_in_iterator_variables_are_any +NonstrictModeTests.function_parameters_are_any +NonstrictModeTests.inconsistent_module_return_types_are_ok +NonstrictModeTests.inconsistent_return_types_are_ok +NonstrictModeTests.infer_nullary_function +NonstrictModeTests.infer_the_maximum_number_of_values_the_function_could_return +NonstrictModeTests.inline_table_props_are_also_any +NonstrictModeTests.local_tables_are_not_any +NonstrictModeTests.locals_are_any_by_default +NonstrictModeTests.offer_a_hint_if_you_use_a_dot_instead_of_a_colon +NonstrictModeTests.parameters_having_type_any_are_optional +NonstrictModeTests.returning_insufficient_return_values +NonstrictModeTests.returning_too_many_values +NonstrictModeTests.table_dot_insert_and_recursive_calls +NonstrictModeTests.table_props_are_any +Normalize.any_wins_the_battle_over_unknown_in_unions +Normalize.constrained_intersection_of_intersections +Normalize.cyclic_intersection +Normalize.cyclic_table_normalizes_sensibly +Normalize.cyclic_union +Normalize.fuzz_failure_bound_type_is_normal_but_not_its_bounded_to +Normalize.fuzz_failure_instersection_combine_must_follow +Normalize.higher_order_function +Normalize.intersection_combine_on_bound_self +Normalize.intersection_inside_a_table_inside_another_intersection +Normalize.intersection_inside_a_table_inside_another_intersection_2 +Normalize.intersection_inside_a_table_inside_another_intersection_3 +Normalize.intersection_inside_a_table_inside_another_intersection_4 +Normalize.intersection_of_confluent_overlapping_tables +Normalize.intersection_of_disjoint_tables +Normalize.intersection_of_functions +Normalize.intersection_of_overlapping_tables +Normalize.intersection_of_tables_with_indexers +Normalize.nested_table_normalization_with_non_table__no_ice +Normalize.normalization_does_not_convert_ever +Normalize.normalize_module_return_type +Normalize.normalize_unions_containing_never +Normalize.normalize_unions_containing_unknown +Normalize.return_type_is_not_a_constrained_intersection +Normalize.skip_force_normal_on_external_types +Normalize.union_of_distinct_free_types +Normalize.variadic_tail_is_marked_normal +Normalize.visiting_a_type_twice_is_not_considered_normal +ParseErrorRecovery.generic_type_list_recovery +ParseErrorRecovery.recovery_of_parenthesized_expressions +ParserTests.parse_nesting_based_end_detection_failsafe_earlier +ParserTests.parse_nesting_based_end_detection_local_function +ProvisionalTests.bail_early_if_unification_is_too_complicated +ProvisionalTests.choose_the_right_overload_for_pcall +ProvisionalTests.constrained_is_level_dependent +ProvisionalTests.discriminate_from_x_not_equal_to_nil +ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack +ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean +ProvisionalTests.free_is_not_bound_to_any +ProvisionalTests.function_returns_many_things_but_first_of_it_is_forgotten +ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns +ProvisionalTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound +ProvisionalTests.it_should_be_agnostic_of_actual_size +ProvisionalTests.lower_bounds_calculation_is_too_permissive_with_overloaded_higher_order_functions +ProvisionalTests.lvalue_equals_another_lvalue_with_no_overlap +ProvisionalTests.normalization_fails_on_certain_kinds_of_cyclic_tables +ProvisionalTests.operator_eq_completely_incompatible +ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing +ProvisionalTests.setmetatable_constrains_free_type_into_free_table +ProvisionalTests.typeguard_inference_incomplete +ProvisionalTests.weird_fail_to_unify_type_pack +ProvisionalTests.weirditer_should_not_loop_forever +ProvisionalTests.while_body_are_also_refined +ProvisionalTests.xpcall_returns_what_f_returns +RefinementTest.and_constraint +RefinementTest.and_or_peephole_refinement +RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string +RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number +RefinementTest.assert_non_binary_expressions_actually_resolve_constraints +RefinementTest.assign_table_with_refined_property_with_a_similar_type_is_illegal +RefinementTest.call_a_more_specific_function_using_typeguard +RefinementTest.correctly_lookup_a_shadowed_local_that_which_was_previously_refined +RefinementTest.correctly_lookup_property_whose_base_was_previously_refined +RefinementTest.correctly_lookup_property_whose_base_was_previously_refined2 +RefinementTest.discriminate_from_isa_of_x +RefinementTest.discriminate_from_truthiness_of_x +RefinementTest.discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false +RefinementTest.discriminate_tag +RefinementTest.either_number_or_string +RefinementTest.eliminate_subclasses_of_instance +RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil +RefinementTest.free_type_is_equal_to_an_lvalue +RefinementTest.impossible_type_narrow_is_not_an_error +RefinementTest.index_on_a_refined_property +RefinementTest.invert_is_truthy_constraint +RefinementTest.invert_is_truthy_constraint_ifelse_expression +RefinementTest.is_truthy_constraint +RefinementTest.is_truthy_constraint_ifelse_expression +RefinementTest.lvalue_is_equal_to_a_term +RefinementTest.lvalue_is_equal_to_another_lvalue +RefinementTest.lvalue_is_not_nil +RefinementTest.merge_should_be_fully_agnostic_of_hashmap_ordering +RefinementTest.narrow_property_of_a_bounded_variable +RefinementTest.narrow_this_large_union +RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true +RefinementTest.not_a_and_not_b +RefinementTest.not_a_and_not_b2 +RefinementTest.not_a_or_not_b +RefinementTest.not_a_or_not_b2 +RefinementTest.not_and_constraint +RefinementTest.not_t_or_some_prop_of_t +RefinementTest.or_predicate_with_truthy_predicates +RefinementTest.parenthesized_expressions_are_followed_through +RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table +RefinementTest.refine_the_correct_types_opposite_of_when_a_is_not_number_or_string +RefinementTest.refine_unknowns +RefinementTest.string_not_equal_to_string_or_nil +RefinementTest.term_is_equal_to_an_lvalue +RefinementTest.truthy_constraint_on_properties +RefinementTest.type_assertion_expr_carry_its_constraints +RefinementTest.type_comparison_ifelse_expression +RefinementTest.type_guard_can_filter_for_intersection_of_tables +RefinementTest.type_guard_can_filter_for_overloaded_function +RefinementTest.type_guard_narrowed_into_nothingness +RefinementTest.type_narrow_for_all_the_userdata +RefinementTest.type_narrow_to_vector +RefinementTest.typeguard_cast_free_table_to_vector +RefinementTest.typeguard_cast_instance_or_vector3_to_vector +RefinementTest.typeguard_doesnt_leak_to_elseif +RefinementTest.typeguard_in_assert_position +RefinementTest.typeguard_in_if_condition_position +RefinementTest.typeguard_narrows_for_functions +RefinementTest.typeguard_narrows_for_table +RefinementTest.typeguard_not_to_be_string +RefinementTest.typeguard_only_look_up_types_from_global_scope +RefinementTest.unknown_lvalue_is_not_synonymous_with_other_on_not_equal +RefinementTest.what_nonsensical_condition +RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table +RefinementTest.x_is_not_instance_or_else_not_part +RuntimeLimits.typescript_port_of_Result_type +TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible +TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible +TableTests.access_index_metamethod_that_returns_variadic +TableTests.accidentally_checked_prop_in_opposite_branch +TableTests.assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer +TableTests.augment_nested_table +TableTests.augment_table +TableTests.builtin_table_names +TableTests.call_method +TableTests.call_method_with_explicit_self_argument +TableTests.cannot_augment_sealed_table +TableTests.cannot_call_tables +TableTests.cannot_change_type_of_unsealed_table_prop +TableTests.casting_sealed_tables_with_props_into_table_with_indexer +TableTests.casting_tables_with_props_into_table_with_indexer3 +TableTests.casting_tables_with_props_into_table_with_indexer4 +TableTests.checked_prop_too_early +TableTests.common_table_element_union_in_call +TableTests.common_table_element_union_in_call_tail +TableTests.confusing_indexing +TableTests.defining_a_method_for_a_builtin_sealed_table_must_fail +TableTests.defining_a_method_for_a_local_sealed_table_must_fail +TableTests.defining_a_method_for_a_local_unsealed_table_is_ok +TableTests.defining_a_self_method_for_a_builtin_sealed_table_must_fail +TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail +TableTests.defining_a_self_method_for_a_local_unsealed_table_is_ok +TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar +TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index +TableTests.dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back +TableTests.dont_leak_free_table_props +TableTests.dont_quantify_table_that_belongs_to_outer_scope +TableTests.dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table +TableTests.dont_suggest_exact_match_keys +TableTests.error_detailed_indexer_key +TableTests.error_detailed_indexer_value +TableTests.error_detailed_metatable_prop +TableTests.error_detailed_prop +TableTests.error_detailed_prop_nested +TableTests.expected_indexer_from_table_union +TableTests.expected_indexer_value_type_extra +TableTests.expected_indexer_value_type_extra_2 +TableTests.explicitly_typed_table +TableTests.explicitly_typed_table_error +TableTests.explicitly_typed_table_with_indexer +TableTests.found_like_key_in_table_function_call +TableTests.found_like_key_in_table_property_access +TableTests.found_multiple_like_keys +TableTests.function_calls_produces_sealed_table_given_unsealed_table +TableTests.generalize_table_argument +TableTests.getmetatable_returns_pointer_to_metatable +TableTests.give_up_after_one_metatable_index_look_up +TableTests.hide_table_error_properties +TableTests.indexer_fn +TableTests.indexer_on_sealed_table_must_unify_with_free_table +TableTests.indexer_table +TableTests.indexing_from_a_table_should_prefer_properties_when_possible +TableTests.inequality_operators_imply_exactly_matching_types +TableTests.infer_array_2 +TableTests.infer_indexer_from_value_property_in_literal +TableTests.inferred_return_type_of_free_table +TableTests.inferring_crazy_table_should_also_be_quick +TableTests.instantiate_table_cloning +TableTests.instantiate_table_cloning_2 +TableTests.instantiate_table_cloning_3 +TableTests.instantiate_tables_at_scope_level +TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound +TableTests.leaking_bad_metatable_errors +TableTests.length_operator_intersection +TableTests.length_operator_non_table_union +TableTests.length_operator_union +TableTests.length_operator_union_errors +TableTests.less_exponential_blowup_please +TableTests.meta_add +TableTests.meta_add_both_ways +TableTests.meta_add_inferred +TableTests.metatable_mismatch_should_fail +TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred +TableTests.mixed_tables_with_implicit_numbered_keys +TableTests.MixedPropertiesAndIndexers +TableTests.nil_assign_doesnt_hit_indexer +TableTests.okay_to_add_property_to_unsealed_tables_by_function_call +TableTests.only_ascribe_synthetic_names_at_module_scope +TableTests.oop_indexer_works +TableTests.oop_polymorphic +TableTests.open_table_unification_2 +TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table +TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table_2 +TableTests.pass_incompatible_union_to_a_generic_table_without_crashing +TableTests.passing_compatible_unions_to_a_generic_table_without_crashing +TableTests.persistent_sealed_table_is_immutable +TableTests.property_lookup_through_tabletypevar_metatable +TableTests.quantify_even_that_table_was_never_exported_at_all +TableTests.quantify_metatables_of_metatables_of_table +TableTests.quantifying_a_bound_var_works +TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table +TableTests.recursive_metatable_type_call +TableTests.result_is_always_any_if_lhs_is_any +TableTests.result_is_bool_for_equality_operators_if_lhs_is_any +TableTests.right_table_missing_key +TableTests.right_table_missing_key2 +TableTests.scalar_is_a_subtype_of_a_compatible_polymorphic_shape_type +TableTests.scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type +TableTests.shared_selfs +TableTests.shared_selfs_from_free_param +TableTests.shared_selfs_through_metatables +TableTests.table_function_check_use_after_free +TableTests.table_indexing_error_location +TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict +TableTests.table_insert_should_cope_with_optional_properties_in_strict +TableTests.table_length +TableTests.table_param_row_polymorphism_2 +TableTests.table_param_row_polymorphism_3 +TableTests.table_simple_call +TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors +TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors +TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors2 +TableTests.table_unifies_into_map +TableTests.tables_get_names_from_their_locals +TableTests.tc_member_function +TableTests.tc_member_function_2 +TableTests.top_table_type +TableTests.type_mismatch_on_massive_table_is_cut_short +TableTests.unification_of_unions_in_a_self_referential_type +TableTests.unifying_tables_shouldnt_uaf1 +TableTests.unifying_tables_shouldnt_uaf2 +TableTests.used_colon_correctly +TableTests.used_colon_instead_of_dot +TableTests.used_dot_instead_of_colon +TableTests.used_dot_instead_of_colon_but_correctly +TableTests.width_subtyping +ToDot.bound_table +ToDot.class +ToDot.function +ToDot.metatable +ToDot.primitive +ToDot.table +ToString.exhaustive_toString_of_cyclic_table +ToString.function_type_with_argument_names_and_self +ToString.function_type_with_argument_names_generic +ToString.named_metatable_toStringNamedFunction +ToString.no_parentheses_around_cyclic_function_type_in_union +ToString.toStringDetailed2 +ToString.toStringErrorPack +ToString.toStringNamedFunction_generic_pack +ToString.toStringNamedFunction_hide_type_params +ToString.toStringNamedFunction_id +ToString.toStringNamedFunction_map +ToString.toStringNamedFunction_overrides_param_names +ToString.toStringNamedFunction_variadics +TranspilerTests.type_lists_should_be_emitted_correctly +TranspilerTests.types_should_not_be_considered_cyclic_if_they_are_not_recursive +TryUnifyTests.cli_41095_concat_log_in_sealed_table_unification +TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType +TryUnifyTests.result_of_failed_typepack_unification_is_constrained +TryUnifyTests.typepack_unification_should_trim_free_tails +TryUnifyTests.variadics_should_use_reversed_properly +TypeAliases.cli_38393_recursive_intersection_oom +TypeAliases.corecursive_types_generic +TypeAliases.do_not_quantify_unresolved_aliases +TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any +TypeAliases.general_require_multi_assign +TypeAliases.generic_param_remap +TypeAliases.mismatched_generic_pack_type_param +TypeAliases.mismatched_generic_type_param +TypeAliases.mutually_recursive_generic_aliases +TypeAliases.mutually_recursive_types_restriction_not_ok_1 +TypeAliases.mutually_recursive_types_restriction_not_ok_2 +TypeAliases.mutually_recursive_types_swapsies_not_ok +TypeAliases.recursive_types_restriction_not_ok +TypeAliases.stringify_optional_parameterized_alias +TypeAliases.stringify_type_alias_of_recursive_template_table_type +TypeAliases.stringify_type_alias_of_recursive_template_table_type2 +TypeAliases.type_alias_import_mutation +TypeAliases.type_alias_local_mutation +TypeAliases.type_alias_local_rename +TypeAliases.type_alias_of_an_imported_recursive_generic_type +TypeAliases.type_alias_of_an_imported_recursive_type +TypeInfer.check_expr_recursion_limit +TypeInfer.checking_should_not_ice +TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error +TypeInfer.cyclic_follow +TypeInfer.do_not_bind_a_free_table_to_a_union_containing_that_table +TypeInfer.dont_report_type_errors_within_an_AstStatError +TypeInfer.follow_on_new_types_in_substitution +TypeInfer.free_typevars_introduced_within_control_flow_constructs_do_not_get_an_elevated_TypeLevel +TypeInfer.globals +TypeInfer.globals2 +TypeInfer.index_expr_should_be_checked +TypeInfer.infer_assignment_value_types +TypeInfer.infer_assignment_value_types_mutable_lval +TypeInfer.infer_through_group_expr +TypeInfer.infer_type_assertion_value_type +TypeInfer.no_heap_use_after_free_error +TypeInfer.no_stack_overflow_from_isoptional +TypeInfer.recursive_metatable_crash +TypeInfer.tc_after_error_recovery_no_replacement_name_in_error +TypeInfer.tc_if_else_expressions1 +TypeInfer.tc_if_else_expressions2 +TypeInfer.tc_if_else_expressions_expected_type_1 +TypeInfer.tc_if_else_expressions_expected_type_2 +TypeInfer.tc_if_else_expressions_expected_type_3 +TypeInfer.tc_if_else_expressions_type_union +TypeInfer.type_infer_recursion_limit_no_ice +TypeInfer.types stored in astResolvedTypes +TypeInfer.warn_on_lowercase_parent_property +TypeInfer.weird_case +TypeInferAnyError.any_type_propagates +TypeInferAnyError.assign_prop_to_table_by_calling_any_yields_any +TypeInferAnyError.call_to_any_yields_any +TypeInferAnyError.calling_error_type_yields_error +TypeInferAnyError.can_get_length_of_any +TypeInferAnyError.can_subscript_any +TypeInferAnyError.CheckMethodsOfAny +TypeInferAnyError.for_in_loop_iterator_is_any +TypeInferAnyError.for_in_loop_iterator_is_any2 +TypeInferAnyError.for_in_loop_iterator_is_error +TypeInferAnyError.for_in_loop_iterator_is_error2 +TypeInferAnyError.for_in_loop_iterator_returns_any +TypeInferAnyError.for_in_loop_iterator_returns_any2 +TypeInferAnyError.indexing_error_type_does_not_produce_an_error +TypeInferAnyError.length_of_error_type_does_not_produce_an_error +TypeInferAnyError.metatable_of_any_can_be_a_table +TypeInferAnyError.prop_access_on_any_with_other_options +TypeInferAnyError.quantify_any_does_not_bind_to_itself +TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any +TypeInferAnyError.type_error_addition +TypeInferClasses.assign_to_prop_of_class +TypeInferClasses.call_base_method +TypeInferClasses.call_instance_method +TypeInferClasses.call_method_of_a_child_class +TypeInferClasses.call_method_of_a_class +TypeInferClasses.can_assign_to_prop_of_base_class +TypeInferClasses.can_assign_to_prop_of_base_class_using_string +TypeInferClasses.can_read_prop_of_base_class +TypeInferClasses.can_read_prop_of_base_class_using_string +TypeInferClasses.cannot_call_method_of_child_on_base_instance +TypeInferClasses.cannot_call_unknown_method_of_a_class +TypeInferClasses.cannot_unify_class_instance_with_primitive +TypeInferClasses.class_type_mismatch_with_name_conflict +TypeInferClasses.class_unification_type_mismatch_is_correct_order +TypeInferClasses.classes_can_have_overloaded_operators +TypeInferClasses.classes_without_overloaded_operators_cannot_be_added +TypeInferClasses.detailed_class_unification_error +TypeInferClasses.function_arguments_are_covariant +TypeInferClasses.higher_order_function_arguments_are_contravariant +TypeInferClasses.higher_order_function_return_type_is_not_contravariant +TypeInferClasses.higher_order_function_return_values_are_covariant +TypeInferClasses.optional_class_field_access_error +TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties +TypeInferClasses.table_indexers_are_invariant +TypeInferClasses.table_properties_are_invariant +TypeInferClasses.warn_when_prop_almost_matches +TypeInferClasses.we_can_infer_that_a_parameter_must_be_a_particular_class +TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class +TypeInferFunctions.another_indirect_function_case_where_it_is_ok_to_provide_too_many_arguments +TypeInferFunctions.another_recursive_local_function +TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types +TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument +TypeInferFunctions.cannot_hoist_interior_defns_into_signature +TypeInferFunctions.check_function_before_lambda_that_uses_it +TypeInferFunctions.complicated_return_types_require_an_explicit_annotation +TypeInferFunctions.cyclic_function_type_in_args +TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists +TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site +TypeInferFunctions.dont_mutate_the_underlying_head_of_typepack_when_calling_with_self +TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict +TypeInferFunctions.error_detailed_function_mismatch_arg +TypeInferFunctions.error_detailed_function_mismatch_arg_count +TypeInferFunctions.error_detailed_function_mismatch_ret +TypeInferFunctions.error_detailed_function_mismatch_ret_count +TypeInferFunctions.error_detailed_function_mismatch_ret_mult +TypeInferFunctions.first_argument_can_be_optional +TypeInferFunctions.free_is_not_bound_to_unknown +TypeInferFunctions.func_expr_doesnt_leak_free +TypeInferFunctions.function_cast_error_uses_correct_language +TypeInferFunctions.function_decl_non_self_sealed_overwrite +TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 +TypeInferFunctions.function_decl_non_self_unsealed_overwrite +TypeInferFunctions.function_decl_quantify_right_type +TypeInferFunctions.function_does_not_return_enough_values +TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer +TypeInferFunctions.higher_order_function_2 +TypeInferFunctions.higher_order_function_4 +TypeInferFunctions.ignored_return_values +TypeInferFunctions.inconsistent_higher_order_function +TypeInferFunctions.inconsistent_return_types +TypeInferFunctions.infer_anonymous_function_arguments +TypeInferFunctions.infer_anonymous_function_arguments_outside_call +TypeInferFunctions.infer_return_type_from_selected_overload +TypeInferFunctions.infer_that_function_does_not_return_a_table +TypeInferFunctions.inferred_higher_order_functions_are_quantified_at_the_right_time +TypeInferFunctions.inferred_higher_order_functions_are_quantified_at_the_right_time2 +TypeInferFunctions.it_is_ok_not_to_supply_enough_retvals +TypeInferFunctions.it_is_ok_to_oversaturate_a_higher_order_function_argument +TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count +TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count +TypeInferFunctions.mutual_recursion +TypeInferFunctions.no_lossy_function_type +TypeInferFunctions.occurs_check_failure_in_function_return_type +TypeInferFunctions.quantify_constrained_types +TypeInferFunctions.record_matching_overload +TypeInferFunctions.recursive_function +TypeInferFunctions.recursive_local_function +TypeInferFunctions.report_exiting_without_return_nonstrict +TypeInferFunctions.report_exiting_without_return_strict +TypeInferFunctions.return_type_by_overload +TypeInferFunctions.strict_mode_ok_with_missing_arguments +TypeInferFunctions.too_few_arguments_variadic +TypeInferFunctions.too_few_arguments_variadic_generic +TypeInferFunctions.too_few_arguments_variadic_generic2 +TypeInferFunctions.too_many_arguments +TypeInferFunctions.too_many_return_values +TypeInferFunctions.toposort_doesnt_break_mutual_recursion +TypeInferFunctions.vararg_function_is_quantified +TypeInferFunctions.vararg_functions_should_allow_calls_of_any_types_and_size +TypeInferLoops.correctly_scope_locals_while +TypeInferLoops.for_in_loop +TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values +TypeInferLoops.for_in_loop_error_on_iterator_requiring_args_but_none_given +TypeInferLoops.for_in_loop_on_error +TypeInferLoops.for_in_loop_on_non_function +TypeInferLoops.for_in_loop_should_fail_with_non_function_iterator +TypeInferLoops.for_in_loop_where_iteratee_is_free +TypeInferLoops.for_in_loop_with_custom_iterator +TypeInferLoops.for_in_loop_with_next +TypeInferLoops.for_in_with_a_custom_iterator_should_type_check +TypeInferLoops.for_in_with_an_iterator_of_type_any +TypeInferLoops.for_in_with_just_one_iterator_is_ok +TypeInferLoops.fuzz_fail_missing_instantitation_follow +TypeInferLoops.ipairs_produces_integral_indices +TypeInferLoops.loop_iter_basic +TypeInferLoops.loop_iter_iter_metamethod +TypeInferLoops.loop_iter_no_indexer_nonstrict +TypeInferLoops.loop_iter_no_indexer_strict +TypeInferLoops.loop_iter_trailing_nil +TypeInferLoops.loop_typecheck_crash_on_empty_optional +TypeInferLoops.properly_infer_iteratee_is_a_free_table +TypeInferLoops.repeat_loop +TypeInferLoops.repeat_loop_condition_binds_to_its_block +TypeInferLoops.symbols_in_repeat_block_should_not_be_visible_beyond_until_condition +TypeInferLoops.unreachable_code_after_infinite_loop +TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free +TypeInferLoops.while_loop +TypeInferModules.bound_free_table_export_is_ok +TypeInferModules.constrained_anyification_clone_immutable_types +TypeInferModules.custom_require_global +TypeInferModules.do_not_modify_imported_types +TypeInferModules.do_not_modify_imported_types_2 +TypeInferModules.do_not_modify_imported_types_3 +TypeInferModules.do_not_modify_imported_types_4 +TypeInferModules.general_require_call_expression +TypeInferModules.general_require_type_mismatch +TypeInferModules.module_type_conflict +TypeInferModules.module_type_conflict_instantiated +TypeInferModules.require +TypeInferModules.require_a_variadic_function +TypeInferModules.require_failed_module +TypeInferModules.require_module_that_does_not_export +TypeInferModules.require_types +TypeInferModules.type_error_of_unknown_qualified_type +TypeInferModules.warn_if_you_try_to_require_a_non_modulescript +TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works +TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 +TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon +TypeInferOOP.inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table +TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory +TypeInferOOP.method_depends_on_table +TypeInferOOP.methods_are_topologically_sorted +TypeInferOOP.nonstrict_self_mismatch_tail +TypeInferOOP.object_constructor_can_refer_to_method_of_self +TypeInferOOP.table_oop +TypeInferOperators.and_adds_boolean +TypeInferOperators.and_adds_boolean_no_superfluous_union +TypeInferOperators.and_binexps_dont_unify +TypeInferOperators.and_or_ternary +TypeInferOperators.CallAndOrOfFunctions +TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable +TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable +TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators +TypeInferOperators.cli_38355_recursive_union +TypeInferOperators.compare_numbers +TypeInferOperators.compare_strings +TypeInferOperators.compound_assign_basic +TypeInferOperators.compound_assign_metatable +TypeInferOperators.compound_assign_mismatch_metatable +TypeInferOperators.compound_assign_mismatch_op +TypeInferOperators.compound_assign_mismatch_result +TypeInferOperators.concat_op_on_free_lhs_and_string_rhs +TypeInferOperators.concat_op_on_string_lhs_and_free_rhs +TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops +TypeInferOperators.dont_strip_nil_from_rhs_or_operator +TypeInferOperators.equality_operations_succeed_if_any_union_branch_succeeds +TypeInferOperators.error_on_invalid_operand_types_to_relational_operators +TypeInferOperators.error_on_invalid_operand_types_to_relational_operators2 +TypeInferOperators.expected_types_through_binary_and +TypeInferOperators.expected_types_through_binary_or +TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators +TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown +TypeInferOperators.operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap +TypeInferOperators.operator_eq_verifies_types_do_intersect +TypeInferOperators.or_joins_types +TypeInferOperators.or_joins_types_with_no_extras +TypeInferOperators.primitive_arith_no_metatable +TypeInferOperators.primitive_arith_no_metatable_with_follows +TypeInferOperators.primitive_arith_possible_metatable +TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not +TypeInferOperators.refine_and_or +TypeInferOperators.some_primitive_binary_ops +TypeInferOperators.strict_binary_op_where_lhs_unknown +TypeInferOperators.strip_nil_from_lhs_or_operator +TypeInferOperators.strip_nil_from_lhs_or_operator2 +TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection +TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs +TypeInferOperators.typecheck_unary_len_error +TypeInferOperators.typecheck_unary_minus +TypeInferOperators.typecheck_unary_minus_error +TypeInferOperators.unary_not_is_boolean +TypeInferOperators.unknown_type_in_comparison +TypeInferOperators.UnknownGlobalCompoundAssign +TypeInferPrimitives.cannot_call_primitives +TypeInferPrimitives.CheckMethodsOfNumber +TypeInferPrimitives.string_function_other +TypeInferPrimitives.string_index +TypeInferPrimitives.string_length +TypeInferPrimitives.string_method +TypeInferUnknownNever.array_like_table_of_never_is_inhabitable +TypeInferUnknownNever.assign_to_global_which_is_never +TypeInferUnknownNever.assign_to_local_which_is_never +TypeInferUnknownNever.assign_to_prop_which_is_never +TypeInferUnknownNever.assign_to_subscript_which_is_never +TypeInferUnknownNever.call_never +TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators +TypeInferUnknownNever.index_on_never +TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_never +TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_sorta_never +TypeInferUnknownNever.length_of_never +TypeInferUnknownNever.math_operators_and_never +TypeInferUnknownNever.never_is_reflexive +TypeInferUnknownNever.never_subtype_and_string_supertype +TypeInferUnknownNever.pick_never_from_variadic_type_pack +TypeInferUnknownNever.string_subtype_and_never_supertype +TypeInferUnknownNever.string_subtype_and_unknown_supertype +TypeInferUnknownNever.table_with_prop_of_type_never_is_also_reflexive +TypeInferUnknownNever.table_with_prop_of_type_never_is_uninhabitable +TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable +TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 +TypeInferUnknownNever.unary_minus_of_never +TypeInferUnknownNever.unknown_is_reflexive +TypeInferUnknownNever.unknown_subtype_and_string_supertype +TypePackTests.cyclic_type_packs +TypePackTests.higher_order_function +TypePackTests.multiple_varargs_inference_are_not_confused +TypePackTests.no_return_size_should_be_zero +TypePackTests.pack_tail_unification_check +TypePackTests.parenthesized_varargs_returns_any +TypePackTests.self_and_varargs_should_work +TypePackTests.type_alias_backwards_compatible +TypePackTests.type_alias_default_export +TypePackTests.type_alias_default_mixed_self +TypePackTests.type_alias_default_type_chained +TypePackTests.type_alias_default_type_errors +TypePackTests.type_alias_default_type_pack_self_chained_tp +TypePackTests.type_alias_default_type_pack_self_tp +TypePackTests.type_alias_default_type_self +TypePackTests.type_alias_defaults_confusing_types +TypePackTests.type_alias_defaults_recursive_type +TypePackTests.type_alias_type_pack_explicit +TypePackTests.type_alias_type_pack_explicit_multi +TypePackTests.type_alias_type_pack_multi +TypePackTests.type_alias_type_pack_variadic +TypePackTests.type_alias_type_packs +TypePackTests.type_alias_type_packs_errors +TypePackTests.type_alias_type_packs_import +TypePackTests.type_alias_type_packs_nested +TypePackTests.type_pack_hidden_free_tail_infinite_growth +TypePackTests.type_pack_type_parameters +TypePackTests.varargs_inference_through_multiple_scopes +TypePackTests.variadic_argument_tail +TypePackTests.variadic_pack_syntax +TypePackTests.variadic_packs +TypeSingletons.bool_singleton_subtype +TypeSingletons.bool_singletons +TypeSingletons.bool_singletons_mismatch +TypeSingletons.enums_using_singletons +TypeSingletons.enums_using_singletons_mismatch +TypeSingletons.enums_using_singletons_subtyping +TypeSingletons.error_detailed_tagged_union_mismatch_bool +TypeSingletons.error_detailed_tagged_union_mismatch_string +TypeSingletons.function_call_with_singletons_mismatch +TypeSingletons.if_then_else_expression_singleton_options +TypeSingletons.indexing_on_string_singletons +TypeSingletons.indexing_on_union_of_string_singletons +TypeSingletons.no_widening_from_callsites +TypeSingletons.overloaded_function_call_with_singletons +TypeSingletons.overloaded_function_call_with_singletons_mismatch +TypeSingletons.return_type_of_f_is_not_widened +TypeSingletons.string_singleton_subtype +TypeSingletons.string_singletons +TypeSingletons.string_singletons_escape_chars +TypeSingletons.string_singletons_mismatch +TypeSingletons.table_insert_with_a_singleton_argument +TypeSingletons.table_properties_type_error_escapes +TypeSingletons.tagged_unions_using_singletons +TypeSingletons.taking_the_length_of_string_singleton +TypeSingletons.taking_the_length_of_union_of_string_singleton +TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton +TypeSingletons.widening_happens_almost_everywhere +TypeSingletons.widening_happens_almost_everywhere_except_for_tables +UnionTypes.error_detailed_optional +UnionTypes.error_detailed_union_all +UnionTypes.error_detailed_union_part +UnionTypes.error_takes_optional_arguments +UnionTypes.index_on_a_union_type_with_missing_property +UnionTypes.index_on_a_union_type_with_mixed_types +UnionTypes.index_on_a_union_type_with_one_optional_property +UnionTypes.index_on_a_union_type_with_one_property_of_type_any +UnionTypes.index_on_a_union_type_with_property_guaranteed_to_exist +UnionTypes.index_on_a_union_type_works_at_arbitrary_depth +UnionTypes.optional_arguments +UnionTypes.optional_assignment_errors +UnionTypes.optional_call_error +UnionTypes.optional_field_access_error +UnionTypes.optional_index_error +UnionTypes.optional_length_error +UnionTypes.optional_missing_key_error_details +UnionTypes.optional_union_follow +UnionTypes.optional_union_functions +UnionTypes.optional_union_members +UnionTypes.optional_union_methods +UnionTypes.return_types_can_be_disjoint +UnionTypes.table_union_write_indirect +UnionTypes.unify_unsealed_table_union_check +UnionTypes.union_equality_comparisons diff --git a/tools/natvis/Analysis.natvis b/tools/natvis/Analysis.natvis index b9ea3141..7d03dd3f 100644 --- a/tools/natvis/Analysis.natvis +++ b/tools/natvis/Analysis.natvis @@ -38,6 +38,38 @@ {{ typeId=29, value={*($T30*)storage} }} {{ typeId=30, value={*($T31*)storage} }} {{ typeId=31, value={*($T32*)storage} }} + {{ typeId=32, value={*($T33*)storage} }} + {{ typeId=33, value={*($T34*)storage} }} + {{ typeId=34, value={*($T35*)storage} }} + {{ typeId=35, value={*($T36*)storage} }} + {{ typeId=36, value={*($T37*)storage} }} + {{ typeId=37, value={*($T38*)storage} }} + {{ typeId=38, value={*($T39*)storage} }} + {{ typeId=39, value={*($T40*)storage} }} + {{ typeId=40, value={*($T41*)storage} }} + {{ typeId=41, value={*($T42*)storage} }} + {{ typeId=42, value={*($T43*)storage} }} + {{ typeId=43, value={*($T44*)storage} }} + {{ typeId=44, value={*($T45*)storage} }} + {{ typeId=45, value={*($T46*)storage} }} + {{ typeId=46, value={*($T47*)storage} }} + {{ typeId=47, value={*($T48*)storage} }} + {{ typeId=48, value={*($T49*)storage} }} + {{ typeId=49, value={*($T50*)storage} }} + {{ typeId=50, value={*($T51*)storage} }} + {{ typeId=51, value={*($T52*)storage} }} + {{ typeId=52, value={*($T53*)storage} }} + {{ typeId=53, value={*($T54*)storage} }} + {{ typeId=54, value={*($T55*)storage} }} + {{ typeId=55, value={*($T56*)storage} }} + {{ typeId=56, value={*($T57*)storage} }} + {{ typeId=57, value={*($T58*)storage} }} + {{ typeId=58, value={*($T59*)storage} }} + {{ typeId=59, value={*($T60*)storage} }} + {{ typeId=60, value={*($T61*)storage} }} + {{ typeId=61, value={*($T62*)storage} }} + {{ typeId=62, value={*($T63*)storage} }} + {{ typeId=63, value={*($T64*)storage} }} typeId *($T1*)storage @@ -72,6 +104,38 @@ *($T30*)storage *($T31*)storage *($T32*)storage + *($T33*)storage + *($T34*)storage + *($T35*)storage + *($T36*)storage + *($T37*)storage + *($T38*)storage + *($T39*)storage + *($T40*)storage + *($T41*)storage + *($T42*)storage + *($T43*)storage + *($T44*)storage + *($T45*)storage + *($T46*)storage + *($T47*)storage + *($T48*)storage + *($T49*)storage + *($T50*)storage + *($T51*)storage + *($T52*)storage + *($T53*)storage + *($T54*)storage + *($T55*)storage + *($T56*)storage + *($T57*)storage + *($T58*)storage + *($T59*)storage + *($T60*)storage + *($T61*)storage + *($T62*)storage + *($T63*)storage + *($T64*)storage diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis index ccc7e390..cb2f355f 100644 --- a/tools/natvis/VM.natvis +++ b/tools/natvis/VM.natvis @@ -137,16 +137,28 @@ {data,s} + + + + + + + + + empty + none + + + {proto()->source->data,sb}:{line()} function {proto()->debugname->data,sb}() + {proto()->source->data,sb}:{line()} function() + + + =[C] function {cl().c.debugname,sb}() {cl().c.f,na} + =[C] {cl().c.f,na} + + - - {ci->func->value.gc->cl.c.f,na} - - - {ci->func->value.gc->cl.l.p->source->data,sb}:{ci->func->value.gc->cl.l.p->linedefined,d} {ci->func->value.gc->cl.l.p->debugname->data,sb} - - - {ci->func->value.gc->cl.l.p->source->data,sb}:{ci->func->value.gc->cl.l.p->linedefined,d} - + {ci,na} thread @@ -156,7 +168,7 @@ ci-base_ci - base_ci[ci-base_ci - $i].func->value.gc->cl,view(short) + base_ci[ci-base_ci - $i] diff --git a/tools/patchtests.py b/tools/patchtests.py index dcaf6083..56970c9f 100644 --- a/tools/patchtests.py +++ b/tools/patchtests.py @@ -16,7 +16,11 @@ state = 0 # parse input into errors[] with the state machine; this is using doctest output and expects multi-line match failures for line in input: if state == 0: - match = re.match("tests/[^:]+:(\d+): ERROR: CHECK_EQ", line) + if sys.platform == "win32": + match = re.match("[^(]+\((\d+)\): ERROR: CHECK_EQ", line) + else: + match = re.match("tests/[^:]+:(\d+): ERROR: CHECK_EQ", line) + if match: error_line = int(match[1]) state = 1 @@ -52,12 +56,16 @@ result = [] current = 0 index = 0 +target = 0 while index < len(source): line = source[index] error = errors[current] if current < len(errors) else None - if not error or index < error[0] or line != error[1][0]: + if error: + target = error[0] if sys.platform != "win32" else error[0] - len(error[1]) - 1 + + if not error or index < target or line != error[1][0]: result.append(line) index += 1 else: diff --git a/tools/perfgraph.py b/tools/perfgraph.py index 95baef9c..ae74b7d6 100644 --- a/tools/perfgraph.py +++ b/tools/perfgraph.py @@ -6,6 +6,12 @@ import sys import svg +import argparse +import json + +argumentParser = argparse.ArgumentParser(description='Generate flamegraph SVG from Luau sampling profiler dumps') +argumentParser.add_argument('source_file', type=open) +argumentParser.add_argument('--json', dest='useJson',action='store_const',const=1,default=0,help='Parse source_file as JSON') class Node(svg.Node): def __init__(self): @@ -27,26 +33,64 @@ class Node(svg.Node): def details(self, root): return "Function: {} [{}:{}] ({:,} usec, {:.1%}); self: {:,} usec".format(self.function, self.source, self.line, self.width, self.width / root.width, self.ticks) -with open(sys.argv[1]) as f: - dump = f.readlines() -root = Node() +def nodeFromCallstackListFile(source_file): + dump = source_file.readlines() + root = Node() -for l in dump: - ticks, stack = l.strip().split(" ", 1) - node = root + for l in dump: + ticks, stack = l.strip().split(" ", 1) + node = root - for f in reversed(stack.split(";")): - source, function, line = f.split(",") + for f in reversed(stack.split(";")): + source, function, line = f.split(",") - child = node.child(f) - child.function = function - child.source = source - child.line = int(line) if len(line) > 0 else 0 + child = node.child(f) + child.function = function + child.source = source + child.line = int(line) if len(line) > 0 else 0 + + node = child + + node.ticks += int(ticks) + + return root + + +def nodeFromJSONbject(node, key, obj): + source, function, line = key.split(",") + + node.function = function + node.source = source + node.line = int(line) if len(line) > 0 else 0 + + node.ticks = obj['Duration'] + + for key, obj in obj['Children'].items(): + nodeFromJSONbject(node.child(key), key, obj) + + return node + + +def nodeFromJSONFile(source_file): + dump = json.load(source_file) + + root = Node() + + for key, obj in dump['Children'].items(): + nodeFromJSONbject(root.child(key), key, obj) + + return root + + +arguments = argumentParser.parse_args() + +if arguments.useJson: + root = nodeFromJSONFile(arguments.source_file) +else: + root = nodeFromCallstackListFile(arguments.source_file) - node = child - node.ticks += int(ticks) svg.layout(root, lambda n: n.ticks) svg.display(root, "Flame Graph", "hot", flip = True) diff --git a/tools/test_dcr.py b/tools/test_dcr.py new file mode 100644 index 00000000..0efea3c4 --- /dev/null +++ b/tools/test_dcr.py @@ -0,0 +1,139 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +import argparse +import os.path +import subprocess as sp +import sys +import xml.sax as x + +SCRIPT_PATH = os.path.split(sys.argv[0])[0] +FAIL_LIST_PATH = os.path.join(SCRIPT_PATH, "faillist.txt") + + +def loadFailList(): + with open(FAIL_LIST_PATH) as f: + return set(map(str.strip, f.readlines())) + +def safeParseInt(i, default=0): + try: + return int(i) + except ValueError: + return default + +class Handler(x.ContentHandler): + def __init__(self, failList): + self.currentTest = [] + self.failList = failList # Set of dotted test names that are expected to fail + + self.results = {} # {DottedName: TrueIfTheTestPassed} + + self.numSkippedTests = 0 + + def startElement(self, name, attrs): + if name == "TestSuite": + self.currentTest.append(attrs["name"]) + elif name == "TestCase": + self.currentTest.append(attrs["name"]) + + elif name == "OverallResultsAsserts": + if self.currentTest: + passed = 0 == safeParseInt(attrs["failures"]) + + dottedName = ".".join(self.currentTest) + + # Sometimes we get multiple XML trees for the same test. All of + # them must report a pass in order for us to consider the test + # to have passed. + r = self.results.get(dottedName, True) + self.results[dottedName] = r and passed + + elif name == 'OverallResultsTestCases': + self.numSkippedTests = safeParseInt(attrs.get("skipped", 0)) + + def endElement(self, name): + if name == "TestCase": + self.currentTest.pop() + + elif name == "TestSuite": + self.currentTest.pop() + + +def main(): + parser = argparse.ArgumentParser( + description="Run Luau.UnitTest with deferred constraint resolution enabled" + ) + parser.add_argument( + "path", action="store", help="Path to the Luau.UnitTest executable" + ) + parser.add_argument( + "--dump", + dest="dump", + action="store_true", + help="Instead of doing any processing, dump the raw output of the test run. Useful for debugging this tool.", + ) + parser.add_argument( + "--write", + dest="write", + action="store_true", + help="Write a new faillist.txt after running tests.", + ) + + args = parser.parse_args() + + failList = loadFailList() + + p = sp.Popen( + [ + args.path, + "--reporters=xml", + "--fflags=true,DebugLuauDeferredConstraintResolution=true", + ], + stdout=sp.PIPE, + ) + + handler = Handler(failList) + + if args.dump: + for line in p.stdout: + sys.stdout.buffer.write(line) + return + else: + x.parse(p.stdout, handler) + + p.wait() + + for testName, passed in handler.results.items(): + if passed and testName in failList: + print('UNEXPECTED: {} should have failed'.format(testName)) + elif not passed and testName not in failList: + print('UNEXPECTED: {} should have passed'.format(testName)) + + if args.write: + newFailList = sorted( + ( + dottedName + for dottedName, passed in handler.results.items() + if not passed + ), + key=str.lower, + ) + with open(FAIL_LIST_PATH, "w", newline="\n") as f: + for name in newFailList: + print(name, file=f) + print("Updated faillist.txt") + + if handler.numSkippedTests > 0: + print('{} test(s) were skipped! That probably means that a test segfaulted!'.format(handler.numSkippedTests), file=sys.stderr) + sys.exit(1) + + sys.exit( + 0 + if all( + not passed == (dottedName in failList) + for dottedName, passed in handler.results.items() + ) + else 1 + ) + +if __name__ == "__main__": + main()