diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 00000000..4df7b2f3 --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,270 @@ +name: benchmark + +on: + push: + branches: + - master + paths-ignore: + - "docs/**" + - "papers/**" + - "rfcs/**" + - "*.md" + - "prototyping/**" + +jobs: + windows: + name: windows-${{matrix.arch}} + strategy: + fail-fast: false + matrix: + os: [windows-latest] + arch: [Win32, x64] + bench: + - { + script: "run-benchmarks", + timeout: 12, + title: "Luau Benchmarks", + cachegrindTitle: "Performance", + cachegrindIterCount: 20, + } + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Luau repository + uses: actions/checkout@v3 + + - name: Build Luau + shell: bash # necessary for fail-fast + run: | + mkdir build && cd build + cmake .. -DCMAKE_BUILD_TYPE=Release + cmake --build . --target Luau.Repl.CLI --config Release + cmake --build . --target Luau.Analyze.CLI --config Release + + - name: Move build files to root + run: | + move build/Release/* . + + - uses: actions/setup-python@v3 + with: + python-version: "3.9" + architecture: "x64" + + - name: Install python dependencies + run: | + python -m pip install requests + python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose + + - name: Run benchmark + run: | + python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt + + - name: Checkout Benchmark Results repository + uses: actions/checkout@v3 + with: + repository: ${{ matrix.benchResultsRepo.name }} + ref: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + + - name: Store ${{ matrix.bench.title }} result + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} (Windows ${{matrix.arch}}) + tool: "benchmarkluau" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Push benchmark results + if: github.event_name == 'push' + run: | + echo "Pushing benchmark results..." + cd gh-pages + git config user.name github-actions + git config user.email github@users.noreply.github.com + git add ./dev/bench/data.json + git commit -m "Add benchmarks results for ${{ github.sha }}" + git push + cd .. + + unix: + name: ${{matrix.os}} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + bench: + - { + script: "run-benchmarks", + timeout: 12, + title: "Luau Benchmarks", + cachegrindTitle: "Performance", + cachegrindIterCount: 20, + } + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Luau repository + uses: actions/checkout@v3 + + - name: Build Luau + run: make config=release luau luau-analyze + + - uses: actions/setup-python@v3 + with: + python-version: "3.9" + architecture: "x64" + + - name: Install python dependencies + run: | + python -m pip install requests + python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose + + - name: Run benchmark + run: | + python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt + + - name: Install valgrind + if: matrix.os == 'ubuntu-latest' + run: | + sudo apt-get install valgrind + + - name: Run ${{ matrix.bench.title }} (Cold Cachegrind) + if: matrix.os == 'ubuntu-latest' + run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Run ${{ matrix.bench.title }} (Warm Cachegrind) + if: matrix.os == 'ubuntu-latest' + run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle }}" ${{ matrix.bench.cachegrindIterCount }} | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Checkout Benchmark Results repository + uses: actions/checkout@v3 + with: + repository: ${{ matrix.benchResultsRepo.name }} + ref: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + + - name: Store ${{ matrix.bench.title }} result + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} + tool: "benchmarkluau" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} + + - name: Store ${{ matrix.bench.title }} result (CacheGrind) + if: matrix.os == 'ubuntu-latest' + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} (CacheGrind) + tool: "roblox" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} + + - name: Push benchmark results + if: github.event_name == 'push' + run: | + echo "Pushing benchmark results..." + cd gh-pages + git config user.name github-actions + git config user.email github@users.noreply.github.com + git add ./dev/bench/data.json + git commit -m "Add benchmarks results for ${{ github.sha }}" + git push + cd .. + + static-analysis: + name: luau-analyze + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + bench: + - { + script: "run-analyze", + timeout: 12, + title: "Luau Analyze", + cachegrindTitle: "Performance", + cachegrindIterCount: 20, + } + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + with: + token: "${{ secrets.BENCH_GITHUB_TOKEN }}" + + - name: Build Luau + run: make config=release luau luau-analyze + + - uses: actions/setup-python@v4 + with: + python-version: "3.9" + architecture: "x64" + + - name: Install python dependencies + run: | + sudo pip install requests numpy scipy matplotlib ipython jupyter pandas sympy nose + + - name: Install valgrind + run: | + sudo apt-get install valgrind + + - name: Run Luau Analyze on static file + run: sudo python ./bench/measure_time.py ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee ${{ matrix.bench.script }}-output.txt + + - name: Run ${{ matrix.bench.title }} (Cold Cachegrind) + run: sudo ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Run ${{ matrix.bench.title }} (Warm Cachegrind) + run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Checkout Benchmark Results repository + uses: actions/checkout@v3 + with: + repository: ${{ matrix.benchResultsRepo.name }} + ref: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + + - name: Store ${{ matrix.bench.title }} result + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} + tool: "benchmarkluau" + + gh-pages-branch: "main" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} + + - name: Store ${{ matrix.bench.title }} result (CacheGrind) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} + tool: "roblox" + gh-pages-branch: "main" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} + + - name: Push benchmark results + if: github.event_name == 'push' + run: | + echo "Pushing benchmark results..." + cd gh-pages + git config user.name github-actions + git config user.email github@users.noreply.github.com + git add ./dev/bench/data.json + git commit -m "Add benchmarks results for ${{ github.sha }}" + git push + cd .. diff --git a/.github/workflows/prototyping.yml b/.github/workflows/prototyping.yml index 6bc8a81b..ff66881d 100644 --- a/.github/workflows/prototyping.yml +++ b/.github/workflows/prototyping.yml @@ -10,7 +10,9 @@ jobs: linux: strategy: matrix: - agda: [2.6.2.1] + agda: [2.6.2.2] + hackageDate: ["2022-04-07"] + hackageTime: ["23:06:28"] name: prototyping runs-on: ubuntu-latest steps: @@ -18,7 +20,7 @@ jobs: - uses: actions/cache@v2 with: path: ~/.cabal/store - key: prototyping-${{ runner.os }}-${{ matrix.agda }} + key: "prototyping-${{ runner.os }}-${{ matrix.agda }}-${{ matrix.hackageDate }}-${{ matrix.hackageTime }}" - uses: actions/cache@v2 id: luau-ast-cache with: @@ -28,12 +30,12 @@ jobs: run: sudo apt-get install -y cabal-install - name: cabal update working-directory: prototyping - run: cabal update + run: cabal v2-update "hackage.haskell.org,${{ matrix.hackageDate }}T${{ matrix.hackageTime }}Z" - name: cabal install working-directory: prototyping run: | - cabal install Agda-${{ matrix.agda }} cabal install --lib scientific vector aeson --package-env . + cabal install --allow-newer Agda-${{ matrix.agda }} - name: check targets working-directory: prototyping run: | diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h new file mode 100644 index 00000000..548a58f5 --- /dev/null +++ b/Analysis/include/Luau/Clone.h @@ -0,0 +1,30 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeArena.h" +#include "Luau/TypeVar.h" + +#include + +namespace Luau +{ + +// Only exposed so they can be unit tested. +using SeenTypes = std::unordered_map; +using SeenTypePacks = std::unordered_map; + +struct CloneState +{ + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + int recursionCount = 0; +}; + +TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); +TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); + +TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log); + +} // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h new file mode 100644 index 00000000..8a41c9e8 --- /dev/null +++ b/Analysis/include/Luau/Constraint.h @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/NotNull.h" +#include "Luau/Variant.h" + +#include +#include +#include + +namespace Luau +{ + +struct Scope2; +struct TypeVar; +using TypeId = const TypeVar*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +// subType <: superType +struct SubtypeConstraint +{ + TypeId subType; + TypeId superType; +}; + +// subPack <: superPack +struct PackSubtypeConstraint +{ + TypePackId subPack; + TypePackId superPack; +}; + +// subType ~ gen superType +struct GeneralizationConstraint +{ + TypeId generalizedType; + TypeId sourceType; + Scope2* scope; +}; + +// subType ~ inst superType +struct InstantiationConstraint +{ + TypeId subType; + TypeId superType; +}; + +// name(namedType) = name +struct NameConstraint +{ + TypeId namedType; + std::string name; +}; + +using ConstraintV = Variant; +using ConstraintPtr = std::unique_ptr; + +struct Constraint +{ + explicit Constraint(ConstraintV&& c); + + Constraint(const Constraint&) = delete; + Constraint& operator=(const Constraint&) = delete; + + ConstraintV c; + std::vector> dependencies; +}; + +inline Constraint& asMutable(const Constraint& c) +{ + return const_cast(c); +} + +template +T* getMutable(Constraint& c) +{ + return ::Luau::get_if(&c.c); +} + +template +const T* get(const Constraint& c) +{ + return getMutable(asMutable(c)); +} + +} // namespace Luau diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h new file mode 100644 index 00000000..9b118691 --- /dev/null +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -0,0 +1,150 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include +#include +#include + +#include "Luau/Ast.h" +#include "Luau/Constraint.h" +#include "Luau/Module.h" +#include "Luau/NotNull.h" +#include "Luau/Symbol.h" +#include "Luau/TypeVar.h" +#include "Luau/Variant.h" + +namespace Luau +{ + +struct Scope2; + +struct ConstraintGraphBuilder +{ + // A list of all the scopes in the module. This vector holds ownership of the + // scope pointers; the scopes themselves borrow pointers to other scopes to + // define the scope hierarchy. + std::vector>> scopes; + SingletonTypes& singletonTypes; + TypeArena* const arena; + // The root scope of the module we're generating constraints for. + Scope2* rootScope; + // A mapping of AST node to TypeId. + DenseHashMap astTypes{nullptr}; + // A mapping of AST node to TypePackId. + DenseHashMap astTypePacks{nullptr}; + DenseHashMap astOriginalCallTypes{nullptr}; + // Types resolved from type annotations. Analogous to astTypes. + DenseHashMap astResolvedTypes{nullptr}; + // Type packs resolved from type annotations. Analogous to astTypePacks. + DenseHashMap astResolvedTypePacks{nullptr}; + + explicit ConstraintGraphBuilder(TypeArena* arena); + + /** + * Fabricates a new free type belonging to a given scope. + * @param scope the scope the free type belongs to. Must not be null. + */ + TypeId freshType(Scope2* scope); + + /** + * Fabricates a new free type pack belonging to a given scope. + * @param scope the scope the free type pack belongs to. Must not be null. + */ + TypePackId freshTypePack(Scope2* scope); + + /** + * Fabricates a scope that is a child of another scope. + * @param location the lexical extent of the scope in the source code. + * @param parent the parent scope of the new scope. Must not be null. + */ + Scope2* childScope(Location location, Scope2* parent); + + /** + * Adds a new constraint with no dependencies to a given scope. + * @param scope the scope to add the constraint to. Must not be null. + * @param cv the constraint variant to add. + */ + void addConstraint(Scope2* scope, ConstraintV cv); + + /** + * Adds a constraint to a given scope. + * @param scope the scope to add the constraint to. Must not be null. + * @param c the constraint to add. + */ + void addConstraint(Scope2* scope, std::unique_ptr c); + + /** + * The entry point to the ConstraintGraphBuilder. This will construct a set + * of scopes, constraints, and free types that can be solved later. + * @param block the root block to generate constraints for. + */ + void visit(AstStatBlock* block); + + void visit(Scope2* scope, AstStat* stat); + void visit(Scope2* scope, AstStatBlock* block); + void visit(Scope2* scope, AstStatLocal* local); + void visit(Scope2* scope, AstStatLocalFunction* function); + void visit(Scope2* scope, AstStatFunction* function); + void visit(Scope2* scope, AstStatReturn* ret); + void visit(Scope2* scope, AstStatAssign* assign); + void visit(Scope2* scope, AstStatIf* ifStatement); + void visit(Scope2* scope, AstStatTypeAlias* alias); + + TypePackId checkExprList(Scope2* scope, const AstArray& exprs); + + TypePackId checkPack(Scope2* scope, AstArray exprs); + TypePackId checkPack(Scope2* scope, AstExpr* expr); + + /** + * Checks an expression that is expected to evaluate to one type. + * @param scope the scope the expression is contained within. + * @param expr the expression to check. + * @return the type of the expression. + */ + TypeId check(Scope2* scope, AstExpr* expr); + + TypeId checkExprTable(Scope2* scope, AstExprTable* expr); + TypeId check(Scope2* scope, AstExprIndexName* indexName); + + std::pair checkFunctionSignature(Scope2* parent, AstExprFunction* fn); + + /** + * Checks the body of a function expression. + * @param scope the interior scope of the body of the function. + * @param fn the function expression to check. + */ + void checkFunctionBody(Scope2* scope, AstExprFunction* fn); + + /** + * Resolves a type from its AST annotation. + * @param scope the scope that the type annotation appears within. + * @param ty the AST annotation to resolve. + * @return the type of the AST annotation. + **/ + TypeId resolveType(Scope2* scope, AstType* ty); + + /** + * Resolves a type pack from its AST annotation. + * @param scope the scope that the type annotation appears within. + * @param tp the AST annotation to resolve. + * @return the type pack of the AST annotation. + **/ + TypePackId resolveTypePack(Scope2* scope, AstTypePack* tp); + + TypePackId resolveTypePack(Scope2* scope, const AstTypeList& list); +}; + +/** + * Collects a vector of borrowed constraints from the scope and all its child + * scopes. It is important to only call this function when you're done adding + * constraints to the scope or its descendants, lest the borrowed pointers + * become invalid due to a container reallocation. + * @param rootScope the root scope of the scope graph to collect constraints + * from. + * @return a list of pointers to constraints contained within the scope graph. + * None of these pointers should be null. + */ +std::vector> collectConstraints(Scope2* rootScope); + +} // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h new file mode 100644 index 00000000..4870157f --- /dev/null +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -0,0 +1,120 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Error.h" +#include "Luau/Variant.h" +#include "Luau/Constraint.h" +#include "Luau/ConstraintSolverLogger.h" +#include "Luau/TypeVar.h" + +#include + +namespace Luau +{ + +// TypeId, TypePackId, or Constraint*. It is impossible to know which, but we +// never dereference this pointer. +using BlockedConstraintId = const void*; + +struct ConstraintSolver +{ + TypeArena* arena; + InternalErrorReporter iceReporter; + // The entire set of constraints that the solver is trying to resolve. It + // is important to not add elements to this vector, lest the underlying + // storage that we retain pointers to be mutated underneath us. + const std::vector> constraints; + Scope2* rootScope; + + // This includes every constraint that has not been fully solved. + // A constraint can be both blocked and unsolved, for instance. + std::vector> unsolvedConstraints; + + // A mapping of constraint pointer to how many things the constraint is + // blocked on. Can be empty or 0 for constraints that are not blocked on + // anything. + std::unordered_map, size_t> blockedConstraints; + // A mapping of type/pack pointers to the constraints they block. + std::unordered_map>> blocked; + + ConstraintSolverLogger logger; + + explicit ConstraintSolver(TypeArena* arena, Scope2* rootScope); + + /** + * Attempts to dispatch all pending constraints and reach a type solution + * that satisfies all of the constraints. + **/ + void run(); + + bool done(); + + bool tryDispatch(NotNull c, bool force); + bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const NameConstraint& c, NotNull constraint); + + void block(NotNull target, NotNull constraint); + /** + * Block a constraint on the resolution of a TypeVar. + * @returns false always. This is just to allow tryDispatch to return the result of block() + */ + bool block(TypeId target, NotNull constraint); + bool block(TypePackId target, NotNull constraint); + + void unblock(NotNull progressed); + void unblock(TypeId progressed); + void unblock(TypePackId progressed); + + /** + * @returns true if the TypeId is in a blocked state. + */ + bool isBlocked(TypeId ty); + + /** + * Returns whether the constraint is blocked on anything. + * @param constraint the constraint to check. + */ + bool isBlocked(NotNull constraint); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result. + * @param subType the sub-type to unify. + * @param superType the super-type to unify. + */ + void unify(TypeId subType, TypeId superType); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result. + * @param subPack the sub-type pack to unify. + * @param superPack the super-type pack to unify. + */ + void unify(TypePackId subPack, TypePackId superPack); + +private: + /** + * Marks a constraint as being blocked on a type or type pack. The constraint + * solver will not attempt to dispatch blocked constraints until their + * dependencies have made progress. + * @param target the type or type pack pointer that the constraint is blocked on. + * @param constraint the constraint to block. + **/ + void block_(BlockedConstraintId target, NotNull constraint); + + /** + * Informs the solver that progress has been made on a type or type pack. The + * solver will wake up all constraints that are blocked on the type or type pack, + * and will resume attempting to dispatch them. + * @param progressed the type or type pack pointer that has progressed. + **/ + void unblock_(BlockedConstraintId progressed); +}; + +void dump(Scope2* rootScope, struct ToStringOptions& opts); + +} // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolverLogger.h b/Analysis/include/Luau/ConstraintSolverLogger.h new file mode 100644 index 00000000..55336a23 --- /dev/null +++ b/Analysis/include/Luau/ConstraintSolverLogger.h @@ -0,0 +1,28 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Constraint.h" +#include "Luau/NotNull.h" +#include "Luau/Scope.h" +#include "Luau/ToString.h" + +#include +#include +#include + +namespace Luau +{ + +struct ConstraintSolverLogger +{ + std::string compileOutput(); + void captureBoundarySnapshot(const Scope2* rootScope, std::vector>& unsolvedConstraints); + void prepareStepSnapshot(const Scope2* rootScope, NotNull current, std::vector>& unsolvedConstraints); + void commitPreparedStepSnapshot(); + +private: + std::vector snapshots; + std::optional preparedSnapshot; + ToStringOptions opts; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 53b946a0..a1323960 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -5,6 +5,7 @@ #include "Luau/Location.h" #include "Luau/TypeVar.h" #include "Luau/Variant.h" +#include "Luau/TypeArena.h" namespace Luau { @@ -108,9 +109,6 @@ struct FunctionDoesNotTakeSelf struct FunctionRequiresSelf { - // TODO: Delete with LuauAnyInIsOptionalIsOptional - int requiredExtraNils = 0; - bool operator==(const FunctionRequiresSelf& rhs) const; }; @@ -171,6 +169,13 @@ struct GenericError bool operator==(const GenericError& rhs) const; }; +struct InternalError +{ + std::string message; + + bool operator==(const InternalError& rhs) const; +}; + struct CannotCallNonFunction { TypeId ty; @@ -287,12 +292,20 @@ struct TypesAreUnrelated bool operator==(const TypesAreUnrelated& rhs) const; }; -using TypeErrorData = - Variant; +struct NormalizationTooComplex +{ + bool operator==(const NormalizationTooComplex&) const + { + return true; + } +}; + +using TypeErrorData = Variant; struct TypeError { @@ -333,7 +346,13 @@ T* get(TypeError& e) using ErrorVec = std::vector; +struct TypeErrorToStringOptions +{ + FileResolver* fileResolver = nullptr; +}; + std::string toString(const TypeError& error); +std::string toString(const TypeError& error, TypeErrorToStringOptions options); bool containsParseErrorName(const TypeError& error); @@ -350,4 +369,24 @@ struct InternalErrorReporter [[noreturn]] void ice(const std::string& message); }; +class InternalCompilerError : public std::exception { +public: + explicit InternalCompilerError(const std::string& message, const std::string& moduleName) + : message(message) + , moduleName(moduleName) + { + } + explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location) + : message(message) + , moduleName(moduleName) + , location(location) + { + } + virtual const char* what() const throw(); + + const std::string message; + const std::string moduleName; + const std::optional location; +}; + } // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 0bf8f362..f4226cc1 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -55,10 +55,23 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleNa struct SourceNode { + bool hasDirtySourceModule() const + { + return dirtySourceModule; + } + + bool hasDirtyModule(bool forAutocomplete) const + { + return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; + } + ModuleName name; - std::unordered_set requires; + std::unordered_set requireSet; std::vector> requireLocations; - bool dirty = true; + bool dirtySourceModule = true; + bool dirtyModule = true; + bool dirtyModuleForAutocomplete = true; + double autocompleteLimitsMult = 1.0; }; struct FrontendOptions @@ -69,14 +82,14 @@ struct FrontendOptions // is complete. bool retainFullTypeGraphs = false; - // When true, we run typechecking twice, once in the regular mode, and once in strict mode - // in order to get more precise type information (e.g. for autocomplete). - bool typecheckTwice = false; + // Run typechecking only in mode required for autocomplete (strict mode in order to get more precise type information) + bool forAutocomplete = false; }; struct CheckResult { std::vector errors; + std::vector timeoutHits; }; struct FrontendModuleResolver : ModuleResolver @@ -120,10 +133,9 @@ struct Frontend */ std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); - CheckResult check(const SourceModule& module); // OLD. TODO KILL LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); - bool isDirty(const ModuleName& name) const; + bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); /** Borrow a pointer into the SourceModule cache. @@ -147,10 +159,12 @@ struct Frontend void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); private: + ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope); + std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); - bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root); + bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete); static LintResult classifyLints(const std::vector& warnings, const Config& config); @@ -172,7 +186,7 @@ public: std::unordered_map sourceNodes; std::unordered_map sourceModules; - std::unordered_map requires; + std::unordered_map requireTrace; Stats stats = {}; }; diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h new file mode 100644 index 00000000..e05ceebe --- /dev/null +++ b/Analysis/include/Luau/Instantiation.h @@ -0,0 +1,53 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Substitution.h" +#include "Luau/TypeVar.h" +#include "Luau/Unifiable.h" + +namespace Luau +{ + +struct TypeArena; +struct TxnLog; + +// A substitution which replaces generic types in a given set by free types. +struct ReplaceGenerics : Substitution +{ + ReplaceGenerics( + const TxnLog* log, TypeArena* arena, TypeLevel level, const std::vector& generics, const std::vector& genericPacks) + : Substitution(log, arena) + , level(level) + , generics(generics) + , genericPacks(genericPacks) + { + } + + TypeLevel level; + std::vector generics; + std::vector genericPacks; + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +// A substitution which replaces generic functions by monomorphic functions +struct Instantiation : Substitution +{ + Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level) + : Substitution(log, arena) + , level(level) + { + } + + TypeLevel level; + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/IostreamHelpers.h b/Analysis/include/Luau/IostreamHelpers.h index ee994296..05b94516 100644 --- a/Analysis/include/Luau/IostreamHelpers.h +++ b/Analysis/include/Luau/IostreamHelpers.h @@ -30,6 +30,7 @@ std::ostream& operator<<(std::ostream& lhs, const OccursCheckFailed& error); std::ostream& operator<<(std::ostream& lhs, const UnknownRequire& error); std::ostream& operator<<(std::ostream& lhs, const UnknownPropButFoundLikeProp& e); std::ostream& operator<<(std::ostream& lhs, const GenericError& error); +std::ostream& operator<<(std::ostream& lhs, const InternalError& error); std::ostream& operator<<(std::ostream& lhs, const FunctionExitsWithoutReturning& error); std::ostream& operator<<(std::ostream& lhs, const MissingProperties& error); std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error); diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index 3d510d5f..1a92d52d 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -34,8 +34,8 @@ const LValue* baseof(const LValue& lvalue); std::optional tryGetLValue(const class AstExpr& expr); -// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. -std::pair> getFullName(const LValue& lvalue); +// Utility function: breaks down an LValue to get at the Symbol +Symbol getBaseSymbol(const LValue& lvalue); template const T* get(const LValue& lvalue) diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 6c689b7c..39f8dfb7 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -1,12 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/FileResolver.h" -#include "Luau/TypePack.h" -#include "Luau/TypedAllocator.h" -#include "Luau/ParseOptions.h" #include "Luau/Error.h" +#include "Luau/FileResolver.h" +#include "Luau/ParseOptions.h" #include "Luau/ParseResult.h" +#include "Luau/Scope.h" +#include "Luau/TypeArena.h" #include #include @@ -21,6 +21,9 @@ struct Module; using ScopePtr = std::shared_ptr; using ModulePtr = std::shared_ptr; +class AstType; +class AstTypePack; + /// Root of the AST of a parsed source file struct SourceModule { @@ -29,8 +32,8 @@ struct SourceModule std::optional environmentName; bool cyclic = false; - std::unique_ptr allocator; - std::unique_ptr names; + std::shared_ptr allocator; + std::shared_ptr names; std::vector parseErrors; AstStatBlock* root = nullptr; @@ -48,49 +51,12 @@ struct SourceModule bool isWithinComment(const SourceModule& sourceModule, Position pos); -struct TypeArena +struct RequireCycle { - TypedAllocator typeVars; - TypedAllocator typePacks; - - void clear(); - - template - TypeId addType(T tv) - { - if constexpr (std::is_same_v) - LUAU_ASSERT(tv.options.size() >= 2); - - return addTV(TypeVar(std::move(tv))); - } - - TypeId addTV(TypeVar&& tv); - - TypeId freshType(TypeLevel level); - - TypePackId addTypePack(std::initializer_list types); - TypePackId addTypePack(std::vector types); - TypePackId addTypePack(TypePack pack); - TypePackId addTypePack(TypePackVar pack); + Location location; + std::vector path; // one of the paths for a require() to go all the way back to the originating module }; -void freeze(TypeArena& arena); -void unfreeze(TypeArena& arena); - -// Only exposed so they can be unit tested. -using SeenTypes = std::unordered_map; -using SeenTypePacks = std::unordered_map; - -struct CloneState -{ - int recursionCount = 0; - bool encounteredFreeType = false; -}; - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); - struct Module { ~Module(); @@ -98,25 +64,33 @@ struct Module TypeArena interfaceTypes; TypeArena internalTypes; + // Scopes and AST types refer to parse data, so we need to keep that alive + std::shared_ptr allocator; + std::shared_ptr names; + std::vector> scopes; // never empty + std::vector>> scope2s; // never empty DenseHashMap astTypes{nullptr}; + DenseHashMap astTypePacks{nullptr}; DenseHashMap astExpectedTypes{nullptr}; DenseHashMap astOriginalCallTypes{nullptr}; DenseHashMap astOverloadResolvedTypes{nullptr}; + DenseHashMap astResolvedTypes{nullptr}; + DenseHashMap astResolvedTypePacks{nullptr}; std::unordered_map declaredGlobals; ErrorVec errors; Mode mode; SourceCode::Type type; + bool timeout = false; ScopePtr getModuleScope() const; + Scope2* getModuleScope2() const; // Once a module has been typechecked, we clone its public interface into a separate arena. // This helps us to force TypeVar ownership into a DAG rather than a DCG. - // Returns true if there were any free types encountered in the public interface. This - // indicates a bug in the type checker that we want to surface. - bool clonePublicInterface(); + void clonePublicInterface(InternalErrorReporter& ice); }; } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h new file mode 100644 index 00000000..f5fd9886 --- /dev/null +++ b/Analysis/include/Luau/Normalize.h @@ -0,0 +1,20 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Substitution.h" +#include "Luau/TypeVar.h" +#include "Luau/Module.h" + +namespace Luau +{ + +struct InternalErrorReporter; + +bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice); +bool isSubtype(TypePackId subTy, TypePackId superTy, InternalErrorReporter& ice); + +std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice); +std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, TypeArena& arena, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, const ModulePtr& module, InternalErrorReporter& ice); + +} // namespace Luau diff --git a/Analysis/include/Luau/NotNull.h b/Analysis/include/Luau/NotNull.h new file mode 100644 index 00000000..f6043e9c --- /dev/null +++ b/Analysis/include/Luau/NotNull.h @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include + +namespace Luau +{ + +/** A non-owning, non-null pointer to a T. + * + * A NotNull is notionally identical to a T* with the added restriction that + * it can never store nullptr. + * + * The sole conversion rule from T* to NotNull is the single-argument + * constructor, which is intentionally marked explicit. This constructor + * performs a runtime test to verify that the passed pointer is never nullptr. + * + * Pointer arithmetic, increment, decrement, and array indexing are all + * forbidden. + * + * An implicit coersion from NotNull to T* is afforded, as are the pointer + * indirection and member access operators. (*p and p->prop) + * + * The explicit delete statement is permitted (but not recommended) on a + * NotNull through this implicit conversion. + */ +template +struct NotNull +{ + explicit NotNull(T* t) + : ptr(t) + { + LUAU_ASSERT(t); + } + + explicit NotNull(std::nullptr_t) = delete; + void operator=(std::nullptr_t) = delete; + + template + NotNull(NotNull other) + : ptr(other.get()) + {} + + operator T*() const noexcept + { + return ptr; + } + + T& operator*() const noexcept + { + return *ptr; + } + + T* operator->() const noexcept + { + return ptr; + } + + T& operator[](int) = delete; + + T& operator+(int) = delete; + T& operator-(int) = delete; + + T* get() const noexcept + { + return ptr; + } + +private: + T* ptr; +}; + +} + +namespace std +{ + +template struct hash> +{ + size_t operator()(const Luau::NotNull& p) const + { + return std::hash()(p.get()); + } +}; + +} diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index e48cad40..f46f0cb5 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -6,6 +6,10 @@ namespace Luau { +struct TypeArena; +struct Scope2; + void quantify(TypeId ty, TypeLevel level); +TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope); } // namespace Luau diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index 89632cea..f964dbfe 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -4,10 +4,19 @@ #include "Luau/Common.h" #include +#include namespace Luau { +struct RecursionLimitException : public std::exception +{ + const char* what() const noexcept + { + return "Internal recursion counter limit exceeded"; + } +}; + struct RecursionCounter { RecursionCounter(int* count) @@ -32,7 +41,9 @@ struct RecursionLimiter : RecursionCounter : RecursionCounter(count) { if (limit > 0 && *count > limit) - throw std::runtime_error("Internal recursion counter limit exceeded"); + { + throw RecursionLimitException(); + } } }; diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h index c25545f5..f69d133e 100644 --- a/Analysis/include/Luau/RequireTracer.h +++ b/Analysis/include/Luau/RequireTracer.h @@ -19,7 +19,7 @@ struct RequireTraceResult { DenseHashMap exprs{nullptr}; - std::vector> requires; + std::vector> requireList; }; RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName); diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 45338409..cef4b94f 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Constraint.h" #include "Luau/Location.h" #include "Luau/TypeVar.h" @@ -64,4 +65,21 @@ struct Scope std::unordered_map typeAliasTypePackParameters; }; +struct Scope2 +{ + // The parent scope of this scope. Null if there is no parent (i.e. this + // is the module-level scope). + Scope2* parent = nullptr; + // All the children of this scope. + std::vector children; + std::unordered_map bindings; // TODO: I think this can be a DenseHashMap + std::unordered_map typeBindings; + TypePackId returnType; + // All constraints belonging to this scope. + std::vector constraints; + + std::optional lookup(Symbol sym); + std::optional lookupTypeBinding(const Name& name); +}; + } // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 9662d5b3..f3c3ae9a 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -1,8 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Module.h" -#include "Luau/ModuleResolver.h" +#include "Luau/TypeArena.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/DenseHash.h" @@ -90,6 +89,7 @@ struct Tarjan std::vector lowlink; int childCount = 0; + int childLimit = 0; // This should never be null; ensure you initialize it before calling // substitution methods. diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h index b5dd9c89..1fe037e5 100644 --- a/Analysis/include/Luau/Symbol.h +++ b/Analysis/include/Luau/Symbol.h @@ -30,6 +30,9 @@ struct Symbol { } + template + Symbol(const T&) = delete; + AstLocal* local; AstName global; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 49ee82fe..a50fef78 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -3,6 +3,7 @@ #include "Luau/Common.h" #include "Luau/TypeVar.h" +#include "Luau/ConstraintGraphBuilder.h" #include #include @@ -28,6 +29,8 @@ struct ToStringOptions bool functionTypeArguments = false; // If true, output function type argument names when they are available bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. + bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self + bool indent = false; size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); std::optional nameMap; @@ -51,6 +54,7 @@ ToStringResult toStringDetailed(TypePackId ty, const ToStringOptions& opts = {}) std::string toString(TypeId ty, const ToStringOptions& opts); std::string toString(TypePackId ty, const ToStringOptions& opts); +std::string toString(const Constraint& c, ToStringOptions& opts); // These are offered as overloads rather than a default parameter so that they can be easily invoked from within the MSVC debugger. // You can use them in watch expressions! @@ -62,6 +66,11 @@ inline std::string toString(TypePackId ty) { return toString(ty, ToStringOptions{}); } +inline std::string toString(const Constraint& c) +{ + ToStringOptions opts; + return toString(c, opts); +} std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); @@ -72,6 +81,9 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression std::string dump(TypeId ty); std::string dump(TypePackId ty); +std::string dump(const Constraint& c); + +std::string dump(const std::shared_ptr& scope, const char* name); std::string generateName(size_t n); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index c8ebaaeb..cd115e3b 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -7,8 +7,6 @@ #include "Luau/TypeVar.h" #include "Luau/TypePack.h" -LUAU_FASTFLAG(LuauShareTxnSeen); - namespace Luau { @@ -64,13 +62,17 @@ T* getMutable(PendingTypePack* pending) struct TxnLog { TxnLog() - : ownedSeen() + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , ownedSeen() , sharedSeen(&ownedSeen) { } explicit TxnLog(TxnLog* parent) - : parent(parent) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , parent(parent) { if (parent) { @@ -83,12 +85,8 @@ struct TxnLog } explicit TxnLog(std::vector>* sharedSeen) - : sharedSeen(sharedSeen) - { - } - - TxnLog(TxnLog* parent, std::vector>* sharedSeen) - : parent(parent) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) , sharedSeen(sharedSeen) { } @@ -243,6 +241,12 @@ struct TxnLog return Luau::getMutable(ty); } + template + const T* get(TID ty) const + { + return this->getMutable(ty); + } + // Returns whether a given type or type pack is a given state, respecting the // log's pending state. // @@ -263,11 +267,8 @@ private: // unique_ptr is used to give us stable pointers across insertions into the // map. Otherwise, it would be really easy to accidentally invalidate the // pointers returned from queue/pending. - // - // We can't use a DenseHashMap here because we need a non-const iterator - // over the map when we concatenate. - std::unordered_map, DenseHashPointer> typeVarChanges; - std::unordered_map, DenseHashPointer> typePackChanges; + DenseHashMap> typeVarChanges; + DenseHashMap> typePackChanges; TxnLog* parent = nullptr; diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h new file mode 100644 index 00000000..559c55c8 --- /dev/null +++ b/Analysis/include/Luau/TypeArena.h @@ -0,0 +1,42 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypedAllocator.h" +#include "Luau/TypeVar.h" +#include "Luau/TypePack.h" + +#include + +namespace Luau +{ + +struct TypeArena +{ + TypedAllocator typeVars; + TypedAllocator typePacks; + + void clear(); + + template + TypeId addType(T tv) + { + if constexpr (std::is_same_v) + LUAU_ASSERT(tv.options.size() >= 2); + + return addTV(TypeVar(std::move(tv))); + } + + TypeId addTV(TypeVar&& tv); + + TypeId freshType(TypeLevel level); + + TypePackId addTypePack(std::initializer_list types); + TypePackId addTypePack(std::vector types); + TypePackId addTypePack(TypePack pack); + TypePackId addTypePack(TypePackVar pack); +}; + +void freeze(TypeArena& arena); +void unfreeze(TypeArena& arena); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h new file mode 100644 index 00000000..a6c7a3e3 --- /dev/null +++ b/Analysis/include/Luau/TypeChecker2.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Module.h" + +namespace Luau +{ + +void check(const SourceModule& sourceModule, Module* module); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 839043cc..28adc9d9 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -34,61 +34,35 @@ const AstStat* getFallthrough(const AstStat* node); struct UnifierOptions; struct Unifier; -// A substitution which replaces generic types in a given set by free types. -struct ReplaceGenerics : Substitution -{ - ReplaceGenerics( - const TxnLog* log, TypeArena* arena, TypeLevel level, const std::vector& generics, const std::vector& genericPacks) - : Substitution(log, arena) - , level(level) - , generics(generics) - , genericPacks(genericPacks) - { - } - - TypeLevel level; - std::vector generics; - std::vector genericPacks; - bool ignoreChildren(TypeId ty) override; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; -}; - -// A substitution which replaces generic functions by monomorphic functions -struct Instantiation : Substitution -{ - Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level) - : Substitution(log, arena) - , level(level) - { - } - - TypeLevel level; - bool ignoreChildren(TypeId ty) override; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; -}; - // A substitution which replaces free types by any struct Anyification : Substitution { - Anyification(TypeArena* arena, TypeId anyType, TypePackId anyTypePack) + Anyification(TypeArena* arena, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) : Substitution(TxnLog::empty(), arena) + , iceHandler(iceHandler) , anyType(anyType) , anyTypePack(anyTypePack) { } + InternalErrorReporter* iceHandler; + TypeId anyType; TypePackId anyTypePack; + bool normalizationTooComplex = false; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; TypeId clean(TypeId ty) override; TypePackId clean(TypePackId tp) override; + + bool ignoreChildren(TypeId ty) override + { + return ty->persistent; + } + bool ignoreChildren(TypePackId ty) override + { + return ty->persistent; + } }; // A substitution which replaces the type parameters of a type function by arguments @@ -124,6 +98,12 @@ struct HashBoolNamePair size_t operator()(const std::pair& pair) const; }; +class TimeLimitError : public std::exception +{ +public: + virtual const char* what() const throw(); +}; + // All TypeVars are retained via Environment::typeVars. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -133,6 +113,7 @@ struct TypeChecker TypeChecker& operator=(const TypeChecker&) = delete; ModulePtr check(const SourceModule& module, Mode mode, std::optional environmentScope = std::nullopt); + ModulePtr checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope = std::nullopt); std::vector> getScopes() const; @@ -154,27 +135,28 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); + void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); - ExprResult checkExpr( + WithPredicate checkExpr( const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt, bool forceSingleton = false); - ExprResult checkExpr(const ScopePtr& scope, const AstExprLocal& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprCall& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr(const ScopePtr& scope, const AstExprUnary& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprLocal& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); TypeId checkBinaryOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); - ExprResult checkExpr(const ScopePtr& scope, const AstExprBinary& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprError& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType); @@ -197,11 +179,11 @@ struct TypeChecker void checkArgumentList( const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector& argLocations); - ExprResult checkExprPack(const ScopePtr& scope, const AstExpr& expr); - ExprResult checkExprPack(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); + WithPredicate checkExprPack(const ScopePtr& scope, const AstExprCall& expr); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); - std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, + std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); @@ -209,7 +191,7 @@ struct TypeChecker const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, const std::vector& errors); - ExprResult checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, + WithPredicate checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil = false, const std::vector& lhsAnnotations = {}, const std::vector>& expectedTypes = {}); @@ -252,6 +234,8 @@ struct TypeChecker ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); + void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location); + std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); @@ -371,7 +355,7 @@ private: const AstArray& genericNames, const AstArray& genericPackNames, bool useCache = false); public: - ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); + void resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); private: void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate); @@ -379,16 +363,17 @@ private: std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); - void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); - void resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); - void resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); - void resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const PredicateVec& predicates, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); + void resolve(const Predicate& predicate, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); + void resolve(const TruthyPredicate& truthyP, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); + void resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const OrPredicate& orP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const TypeGuardPredicate& typeguardP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; + bool useConstrainedIntersections() const; public: /** Extract the types in a type pack, given the assumption that the pack must have some exact length. @@ -413,6 +398,13 @@ public: UnifierSharedState unifierState; + std::vector requireCycles; + + // Type inference limits + std::optional finishTime; + std::optional instantiationChildLimit; + std::optional unifierIterationLimit; + public: const TypeId nilType; const TypeId numberType; @@ -420,7 +412,6 @@ public: const TypeId booleanType; const TypeId threadType; const TypeId anyType; - const TypeId optionalNumberType; const TypePackId anyTypePack; diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 85fa467f..c1de242f 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -40,6 +40,7 @@ struct TypePack struct VariadicTypePack { TypeId ty; + bool hidden = false; // if true, we don't display this when toString()ing a pack with this variadic as its tail. }; struct TypePackVar @@ -47,13 +48,24 @@ struct TypePackVar explicit TypePackVar(const TypePackVariant& ty); explicit TypePackVar(TypePackVariant&& ty); TypePackVar(TypePackVariant&& ty, bool persistent); + bool operator==(const TypePackVar& rhs) const; + TypePackVar& operator=(TypePackVariant&& tp); + TypePackVar& operator=(const TypePackVar& rhs); + + // Re-assignes the content of the pack, but doesn't change the owning arena and can't make pack persistent. + void reassign(const TypePackVar& rhs) + { + ty = rhs.ty; + } + TypePackVariant ty; + bool persistent = false; - // Pointer to the type arena that allocated this type. + // Pointer to the type arena that allocated this pack. TypeArena* owningArena = nullptr; }; @@ -109,10 +121,10 @@ private: }; TypePackIterator begin(TypePackId tp); -TypePackIterator begin(TypePackId tp, TxnLog* log); +TypePackIterator begin(TypePackId tp, const TxnLog* log); TypePackIterator end(TypePackId tp); -using SeenSet = std::set>; +using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); @@ -122,7 +134,7 @@ TypePackId follow(TypePackId tp, std::function mapper); size_t size(TypePackId tp, TxnLog* log = nullptr); bool finite(TypePackId tp, TxnLog* log = nullptr); size_t size(const TypePack& tp, TxnLog* log = nullptr); -std::optional first(TypePackId tp); +std::optional first(TypePackId tp, bool ignoreHiddenVariadics = true); TypePackVar* asMutable(TypePackId tp); TypePack* asMutable(const TypePack* tp); @@ -154,5 +166,12 @@ bool isEmpty(TypePackId tp); /// Flattens out a type pack. Also returns a valid TypePackId tail if the type pack's full size is not known std::pair, std::optional> flatten(TypePackId tp); +std::pair, std::optional> flatten(TypePackId tp, const TxnLog& log); + +/// Returs true if the type pack arose from a function that is declared to be variadic. +/// Returns *false* for function argument packs that are inferred to be safe to oversaturate! +bool isVariadic(TypePackId tp); +bool isVariadic(TypePackId tp, const TxnLog& log); + } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index b8c4b362..20f4107c 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -24,6 +24,7 @@ namespace Luau { struct TypeArena; +struct Scope2; /** * There are three kinds of type variables: @@ -83,6 +84,24 @@ using Tags = std::vector; using ModuleName = std::string; +/** A TypeVar that cannot be computed. + * + * BlockedTypeVars essentially serve as a way to encode partial ordering on the + * constraint graph. Until a BlockedTypeVar is unblocked by its owning + * constraint, nothing at all can be said about it. Constraints that need to + * process a BlockedTypeVar cannot be dispatched. + * + * Whenever a BlockedTypeVar is added to the graph, we also record a constraint + * that will eventually unblock it. + */ +struct BlockedTypeVar +{ + BlockedTypeVar(); + int index; + + static int nextIndex; +}; + struct PrimitiveTypeVar { enum Type @@ -109,6 +128,24 @@ struct PrimitiveTypeVar } }; +struct ConstrainedTypeVar +{ + explicit ConstrainedTypeVar(TypeLevel level) + : level(level) + { + } + + explicit ConstrainedTypeVar(TypeLevel level, const std::vector& parts) + : parts(parts) + , level(level) + { + } + + std::vector parts; + TypeLevel level; + Scope2* scope = nullptr; +}; + // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Types for true and false struct BooleanSingleton @@ -212,42 +249,44 @@ struct FunctionDefinition // TODO: Do we actually need this? We'll find out later if we can delete this. // Does not exactly belong in TypeVar.h, but this is the only way to appease the compiler. template -struct ExprResult +struct WithPredicate { T type; PredicateVec predicates; }; -using MagicFunction = std::function>( - struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, ExprResult)>; +using MagicFunction = std::function>( + struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; struct FunctionTypeVar { // Global monomorphic function - FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Global polymorphic function - FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local monomorphic function - FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local polymorphic function - FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); TypeLevel level; + Scope2* scope = nullptr; /// These should all be generic std::vector generics; std::vector genericPacks; TypePackId argTypes; std::vector> argNames; - TypePackId retType; + TypePackId retTypes; std::optional definition; MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. bool hasSelf; Tags tags; + bool hasNoGenerics = false; }; enum class TableState @@ -305,13 +344,13 @@ struct TableTypeVar TableState state = TableState::Unsealed; TypeLevel level; + Scope2* scope = nullptr; std::optional name; // Sometimes we throw a type on a name to make for nicer error messages, but without creating any entry in the type namespace // We need to know which is which when we stringify types. std::optional syntheticName; - std::map methodDefinitionLocations; std::vector instantiatedTypeParams; std::vector instantiatedTypePackParams; ModuleName definitionModuleName; @@ -355,15 +394,17 @@ struct ClassTypeVar std::optional metatable; // metaclass? Tags tags; std::shared_ptr userData; + ModuleName definitionModuleName; - ClassTypeVar( - Name name, Props props, std::optional parent, std::optional metatable, Tags tags, std::shared_ptr userData) + ClassTypeVar(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, + std::shared_ptr userData, ModuleName definitionModuleName) : name(name) , props(props) , parent(parent) , metatable(metatable) , tags(tags) , userData(userData) + , definitionModuleName(definitionModuleName) { } }; @@ -418,8 +459,8 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct TypeVar final { @@ -436,9 +477,18 @@ struct TypeVar final TypeVar(const TypeVariant& ty, bool persistent) : ty(ty) , persistent(persistent) + , normal(persistent) // We assume that all persistent types are irreducable. { } + // Re-assignes the content of the type, but doesn't change the owning arena and can't make type persistent. + void reassign(const TypeVar& rhs) + { + ty = rhs.ty; + normal = rhs.normal; + documentationSymbol = rhs.documentationSymbol; + } + TypeVariant ty; // Kludge: A persistent TypeVar is one that belongs to the global scope. @@ -446,6 +496,10 @@ struct TypeVar final // Persistent TypeVars do not get cloned. bool persistent = false; + // Normalization sets this for types that are fully normalized. + // This implies that they are transitively immutable. + bool normal = false; + std::optional documentationSymbol; // Pointer to the type arena that allocated this type. @@ -456,9 +510,11 @@ struct TypeVar final TypeVar& operator=(const TypeVariant& rhs); TypeVar& operator=(TypeVariant&& rhs); + + TypeVar& operator=(const TypeVar& rhs); }; -using SeenSet = std::set>; +using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs); // Follow BoundTypeVars until we get to something real @@ -513,8 +569,9 @@ struct SingletonTypes const TypeId stringType; const TypeId booleanType; const TypeId threadType; + const TypeId trueType; + const TypeId falseType; const TypeId anyType; - const TypeId optionalNumberType; const TypePackId anyTypePack; @@ -543,6 +600,8 @@ void persist(TypePackId tp); const TypeLevel* getLevel(TypeId ty); TypeLevel* getMutableLevel(TypeId ty); +std::optional getLevel(TypePackId tp); + const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); diff --git a/Analysis/include/Luau/TypedAllocator.h b/Analysis/include/Luau/TypedAllocator.h index c1c04d10..f67e3d8e 100644 --- a/Analysis/include/Luau/TypedAllocator.h +++ b/Analysis/include/Luau/TypedAllocator.h @@ -23,6 +23,12 @@ public: currentBlockSize = kBlockSize; } + TypedAllocator(const TypedAllocator&) = delete; + TypedAllocator& operator=(const TypedAllocator&) = delete; + + TypedAllocator(TypedAllocator&&) = default; + TypedAllocator& operator=(TypedAllocator&&) = default; + ~TypedAllocator() { if (frozen) diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index e8eafe68..4ff91714 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -8,6 +8,8 @@ namespace Luau { +struct Scope2; + /** * The 'level' of a TypeVar is an indirect way to talk about the scope that it 'belongs' too. * To start, read http://okmij.org/ftp/ML/generalization.html @@ -56,6 +58,14 @@ struct TypeLevel } }; +inline TypeLevel max(const TypeLevel& a, const TypeLevel& b) +{ + if (a.subsumes(b)) + return b; + else + return a; +} + inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) { if (a.subsumes(b)) @@ -64,7 +74,9 @@ inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) return b; } -namespace Unifiable +} // namespace Luau + +namespace Luau::Unifiable { using Name = std::string; @@ -72,9 +84,11 @@ using Name = std::string; struct Free { explicit Free(TypeLevel level); + explicit Free(Scope2* scope); int index; TypeLevel level; + Scope2* scope = nullptr; // True if this free type variable is part of a mutually // recursive type alias whose definitions haven't been // resolved yet. @@ -101,12 +115,15 @@ struct Generic Generic(); explicit Generic(TypeLevel level); explicit Generic(const Name& name); + explicit Generic(Scope2* scope); Generic(TypeLevel level, const Name& name); + Generic(Scope2* scope, const Name& name); int index; TypeLevel level; + Scope2* scope = nullptr; Name name; - bool explicitName; + bool explicitName = false; private: static int nextIndex; @@ -125,7 +142,6 @@ private: }; template -using Variant = Variant, Generic, Error, Value...>; +using Variant = Luau::Variant, Generic, Error, Value...>; -} // namespace Unifiable -} // namespace Luau +} // namespace Luau::Unifiable diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 474af50c..4af324cb 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -5,7 +5,7 @@ #include "Luau/Location.h" #include "Luau/TxnLog.h" #include "Luau/TypeInfer.h" -#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header. +#include "Luau/TypeArena.h" #include "Luau/UnifierSharedState.h" #include @@ -32,6 +32,9 @@ struct Widen : Substitution TypeId clean(TypeId ty) override; TypePackId clean(TypePackId ty) override; bool ignoreChildren(TypeId ty) override; + + TypeId operator()(TypeId ty); + TypePackId operator()(TypePackId ty); }; // TODO: Use this more widely. @@ -49,14 +52,12 @@ struct Unifier ErrorVec errors; Location location; Variance variance = Covariant; + bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. CountMismatch::Context ctx = CountMismatch::Arg; UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, - TxnLog* parentLog = nullptr); - Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); + Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); @@ -78,12 +79,8 @@ private: void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); - void DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); - void tryUnifyFreeTable(TypeId subTy, TypeId superTy); - void tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); - void tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer); TypeId widen(TypeId ty); TypePackId widen(TypePackId tp); @@ -92,7 +89,6 @@ private: bool canCacheResult(TypeId subTy, TypeId superTy); void cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount); - void cacheResult_DEPRECATED(TypeId subTy, TypeId superTy); public: void tryUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); @@ -106,7 +102,12 @@ private: std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); + void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy); + void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy); + public: + void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel); + // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); void occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); @@ -115,12 +116,7 @@ public: Unifier makeChildUnifier(); - // A utility function that appends the given error to the unifier's error log. - // This allows setting a breakpoint wherever the unifier reports an error. - void reportError(TypeError error) - { - errors.push_back(error); - } + void reportError(TypeError err); private: bool isNonstrictMode() const; @@ -135,4 +131,6 @@ private: std::optional firstPackErrorPos; }; +void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, TypePackId tp); + } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index 9a3ba56d..d4315d47 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -28,7 +28,9 @@ struct TypeIdPairHash struct UnifierCounters { int recursionCount = 0; + int recursionLimit = 0; int iterationCount = 0; + int iterationLimit = 0; }; struct UnifierSharedState @@ -40,7 +42,6 @@ struct UnifierSharedState InternalErrorReporter* iceHandler; - DenseHashSet seenAny{nullptr}; DenseHashMap skipCacheForType{nullptr}; DenseHashSet, TypeIdPairHash> cachedUnify{{nullptr, nullptr}}; DenseHashMap, TypeErrorData, TypeIdPairHash> cachedUnifyError{{nullptr, nullptr}}; diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index 63d5a65c..f637222e 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -2,45 +2,15 @@ #pragma once #include "Luau/Common.h" - -#ifndef LUAU_USE_STD_VARIANT -#define LUAU_USE_STD_VARIANT 0 -#endif - -#if LUAU_USE_STD_VARIANT -#include -#else #include #include #include #include -#endif +#include namespace Luau { -#if LUAU_USE_STD_VARIANT -template -using Variant = std::variant; - -template -auto visit(Visitor&& vis, Variant&& var) -{ - // This change resolves the ABI issues with std::variant on libc++; std::visit normally throws bad_variant_access - // but it requires an update to libc++.dylib which ships with macOS 10.14. To work around this, we assert on valueless - // variants since we will never generate them and call into a libc++ function that doesn't throw. - LUAU_ASSERT(!var.valueless_by_exception()); - -#ifdef __APPLE__ - // See https://stackoverflow.com/a/53868971/503215 - return std::__variant_detail::__visitation::__variant::__visit_value(vis, var); -#else - return std::visit(vis, var); -#endif -} - -using std::get_if; -#else template class Variant { @@ -126,6 +96,20 @@ public: return *this; } + template + T& emplace(Args&&... args) + { + using TT = std::decay_t; + constexpr int tid = getTypeId(); + static_assert(tid >= 0, "unsupported T"); + + tableDtor[typeId](&storage); + typeId = tid; + new (&storage) TT(std::forward(args)...); + + return *reinterpret_cast(&storage); + } + template const T* get_if() const { @@ -248,6 +232,8 @@ static void fnVisitV(Visitor& vis, std::conditional_t, const template auto visit(Visitor&& vis, const Variant& var) { + static_assert(std::conjunction_v...>, "visitor must accept every alternative as an argument"); + using Result = std::invoke_result_t::first_alternative>; static_assert(std::conjunction_v>...>, "visitor result type must be consistent between alternatives"); @@ -273,6 +259,8 @@ auto visit(Visitor&& vis, const Variant& var) template auto visit(Visitor&& vis, Variant& var) { + static_assert(std::conjunction_v...>, "visitor must accept every alternative as an argument"); + using Result = std::invoke_result_t::first_alternative&>; static_assert(std::conjunction_v>...>, "visitor result type must be consistent between alternatives"); @@ -294,7 +282,6 @@ auto visit(Visitor&& vis, Variant& var) return res; } } -#endif template inline constexpr bool always_false_v = false; diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 740854b3..5fd43f0b 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -1,9 +1,15 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include + #include "Luau/DenseHash.h" -#include "Luau/TypeVar.h" +#include "Luau/RecursionCounter.h" #include "Luau/TypePack.h" +#include "Luau/TypeVar.h" + +LUAU_FASTINT(LuauVisitRecursionLimit) +LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) namespace Luau { @@ -52,182 +58,296 @@ inline void unsee(std::unordered_set& seen, const void* tv) inline void unsee(DenseHashSet& seen, const void* tv) { - // When DenseHashSet is used for 'visitOnce', where don't forget visited elements -} - -template -void visit(TypePackId tp, F& f, Set& seen); - -template -void visit(TypeId ty, F& f, Set& seen) -{ - if (visit_detail::hasSeen(seen, ty)) - { - f.cycle(ty); - return; - } - - if (auto btv = get(ty)) - { - if (apply(ty, *btv, seen, f)) - visit(btv->boundTo, f, seen); - } - - else if (auto ftv = get(ty)) - apply(ty, *ftv, seen, f); - - else if (auto gtv = get(ty)) - apply(ty, *gtv, seen, f); - - else if (auto etv = get(ty)) - apply(ty, *etv, seen, f); - - else if (auto ptv = get(ty)) - apply(ty, *ptv, seen, f); - - else if (auto ftv = get(ty)) - { - if (apply(ty, *ftv, seen, f)) - { - visit(ftv->argTypes, f, seen); - visit(ftv->retType, f, seen); - } - } - - else if (auto ttv = get(ty)) - { - // Some visitors want to see bound tables, that's why we visit the original type - if (apply(ty, *ttv, seen, f)) - { - if (ttv->boundTo) - { - visit(*ttv->boundTo, f, seen); - } - else - { - for (auto& [_name, prop] : ttv->props) - visit(prop.type, f, seen); - - if (ttv->indexer) - { - visit(ttv->indexer->indexType, f, seen); - visit(ttv->indexer->indexResultType, f, seen); - } - } - } - } - - else if (auto mtv = get(ty)) - { - if (apply(ty, *mtv, seen, f)) - { - visit(mtv->table, f, seen); - visit(mtv->metatable, f, seen); - } - } - - else if (auto ctv = get(ty)) - { - if (apply(ty, *ctv, seen, f)) - { - for (const auto& [name, prop] : ctv->props) - visit(prop.type, f, seen); - - if (ctv->parent) - visit(*ctv->parent, f, seen); - - if (ctv->metatable) - visit(*ctv->metatable, f, seen); - } - } - - else if (auto atv = get(ty)) - apply(ty, *atv, seen, f); - - else if (auto utv = get(ty)) - { - if (apply(ty, *utv, seen, f)) - { - for (TypeId optTy : utv->options) - visit(optTy, f, seen); - } - } - - else if (auto itv = get(ty)) - { - if (apply(ty, *itv, seen, f)) - { - for (TypeId partTy : itv->parts) - visit(partTy, f, seen); - } - } - - visit_detail::unsee(seen, ty); -} - -template -void visit(TypePackId tp, F& f, Set& seen) -{ - if (visit_detail::hasSeen(seen, tp)) - { - f.cycle(tp); - return; - } - - if (auto btv = get(tp)) - { - if (apply(tp, *btv, seen, f)) - visit(btv->boundTo, f, seen); - } - - else if (auto ftv = get(tp)) - apply(tp, *ftv, seen, f); - - else if (auto gtv = get(tp)) - apply(tp, *gtv, seen, f); - - else if (auto etv = get(tp)) - apply(tp, *etv, seen, f); - - else if (auto pack = get(tp)) - { - apply(tp, *pack, seen, f); - - for (TypeId ty : pack->head) - visit(ty, f, seen); - - if (pack->tail) - visit(*pack->tail, f, seen); - } - else if (auto pack = get(tp)) - { - apply(tp, *pack, seen, f); - visit(pack->ty, f, seen); - } - - visit_detail::unsee(seen, tp); + // When DenseHashSet is used for 'visitTypeVarOnce', where don't forget visited elements } } // namespace visit_detail -template -void visitTypeVar(TID ty, F& f, std::unordered_set& seen) +template +struct GenericTypeVarVisitor { - visit_detail::visit(ty, f, seen); -} + using Set = S; -template -void visitTypeVar(TID ty, F& f) -{ - std::unordered_set seen; - visit_detail::visit(ty, f, seen); -} + Set seen; + int recursionCounter = 0; -template -void visitTypeVarOnce(TID ty, F& f, DenseHashSet& seen) + GenericTypeVarVisitor() = default; + + explicit GenericTypeVarVisitor(Set seen) + : seen(std::move(seen)) + { + } + + virtual void cycle(TypeId) {} + virtual void cycle(TypePackId) {} + + virtual bool visit(TypeId ty) + { + return true; + } + virtual bool visit(TypeId ty, const BoundTypeVar& btv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const FreeTypeVar& ftv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const GenericTypeVar& gtv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const ErrorTypeVar& etv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const FunctionTypeVar& ftv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const TableTypeVar& ttv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const MetatableTypeVar& mtv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const ClassTypeVar& ctv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const AnyTypeVar& atv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const UnionTypeVar& utv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const IntersectionTypeVar& itv) + { + return visit(ty); + } + + virtual bool visit(TypePackId tp) + { + return true; + } + virtual bool visit(TypePackId tp, const BoundTypePack& btp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const FreeTypePack& ftp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const GenericTypePack& gtp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const Unifiable::Error& etp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const TypePack& pack) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const VariadicTypePack& vtp) + { + return visit(tp); + } + + void traverse(TypeId ty) + { + RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit}; + + if (visit_detail::hasSeen(seen, ty)) + { + cycle(ty); + return; + } + + if (auto btv = get(ty)) + { + if (visit(ty, *btv)) + traverse(btv->boundTo); + } + + else if (auto ftv = get(ty)) + visit(ty, *ftv); + + else if (auto gtv = get(ty)) + visit(ty, *gtv); + + else if (auto etv = get(ty)) + visit(ty, *etv); + + else if (auto ctv = get(ty)) + { + if (visit(ty, *ctv)) + { + for (TypeId part : ctv->parts) + traverse(part); + } + } + + else if (auto ptv = get(ty)) + visit(ty, *ptv); + + else if (auto ftv = get(ty)) + { + if (visit(ty, *ftv)) + { + traverse(ftv->argTypes); + traverse(ftv->retTypes); + } + } + + else if (auto ttv = get(ty)) + { + // Some visitors want to see bound tables, that's why we traverse the original type + if (visit(ty, *ttv)) + { + if (ttv->boundTo) + { + traverse(*ttv->boundTo); + } + else + { + for (auto& [_name, prop] : ttv->props) + traverse(prop.type); + + if (ttv->indexer) + { + traverse(ttv->indexer->indexType); + traverse(ttv->indexer->indexResultType); + } + } + } + } + + else if (auto mtv = get(ty)) + { + if (visit(ty, *mtv)) + { + traverse(mtv->table); + traverse(mtv->metatable); + } + } + + else if (auto ctv = get(ty)) + { + if (visit(ty, *ctv)) + { + for (const auto& [name, prop] : ctv->props) + traverse(prop.type); + + if (ctv->parent) + traverse(*ctv->parent); + + if (ctv->metatable) + traverse(*ctv->metatable); + } + } + + else if (auto atv = get(ty)) + visit(ty, *atv); + + else if (auto utv = get(ty)) + { + if (visit(ty, *utv)) + { + for (TypeId optTy : utv->options) + traverse(optTy); + } + } + + else if (auto itv = get(ty)) + { + if (visit(ty, *itv)) + { + for (TypeId partTy : itv->parts) + traverse(partTy); + } + } + + visit_detail::unsee(seen, ty); + } + + void traverse(TypePackId tp) + { + if (visit_detail::hasSeen(seen, tp)) + { + cycle(tp); + return; + } + + if (auto btv = get(tp)) + { + if (visit(tp, *btv)) + traverse(btv->boundTo); + } + + else if (auto ftv = get(tp)) + visit(tp, *ftv); + + else if (auto gtv = get(tp)) + visit(tp, *gtv); + + else if (auto etv = get(tp)) + visit(tp, *etv); + + else if (auto pack = get(tp)) + { + bool res = visit(tp, *pack); + if (!FFlag::LuauNormalizeFlagIsConservative || res) + { + for (TypeId ty : pack->head) + traverse(ty); + + if (pack->tail) + traverse(*pack->tail); + } + } + else if (auto pack = get(tp)) + { + bool res = visit(tp, *pack); + if (!FFlag::LuauNormalizeFlagIsConservative || res) + traverse(pack->ty); + } + else + LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!"); + + visit_detail::unsee(seen, tp); + } +}; + +/** Visit each type under a given type. Skips over cycles and keeps recursion depth under control. + * + * The same type may be visited multiple times if there are multiple distinct paths to it. If this is undesirable, use + * TypeVarOnceVisitor. + */ +struct TypeVarVisitor : GenericTypeVarVisitor> { - seen.clear(); - visit_detail::visit(ty, f, seen); -} +}; + +/// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it. +struct TypeVarOnceVisitor : GenericTypeVarVisitor> +{ + TypeVarOnceVisitor() + : GenericTypeVarVisitor{DenseHashSet{nullptr}} + { + } +}; } // namespace Luau diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 0aed34c0..0522b1fa 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -71,9 +71,11 @@ struct FindFullAncestry final : public AstVisitor { std::vector nodes; Position pos; + Position documentEnd; - explicit FindFullAncestry(Position pos) + explicit FindFullAncestry(Position pos, Position documentEnd) : pos(pos) + , documentEnd(documentEnd) { } @@ -84,6 +86,16 @@ struct FindFullAncestry final : public AstVisitor nodes.push_back(node); return true; } + + // Edge case: If we ask for the node at the position that is the very end of the document + // return the innermost AST element that ends at that position. + + if (node->location.end == documentEnd && pos >= documentEnd) + { + nodes.push_back(node); + return true; + } + return false; } }; @@ -92,7 +104,11 @@ struct FindFullAncestry final : public AstVisitor std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos) { - FindFullAncestry finder(pos); + const Position end = source.root->location.end; + if (pos > end) + pos = end; + + FindFullAncestry finder(pos, end); source.root->visit(&finder); return std::move(finder.nodes); } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 492edf25..8a63901f 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,8 +13,7 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); -LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) +LUAU_FASTFLAG(LuauSelfCallAutocompleteFix2) static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -150,8 +149,12 @@ static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionTyp auto idxExpr = nodes.back()->as(); bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; - auto args = Luau::flatten(func->argTypes); - bool noArgFunction = (args.first.empty() || (hasImplicitSelf && args.first.size() == 1)) && !args.second.has_value(); + auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes); + + if (argVariadicPack.has_value() && isVariadic(*argVariadicPack)) + return ParenthesesRecommendation::CursorInside; + + bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1); return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; } @@ -243,7 +246,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ ty = follow(ty); auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2); InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); @@ -262,16 +265,16 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*typeAtPosition); auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) { - if (std::optional firstRetTy = first(ftv->retType)) + if (std::optional firstRetTy = first(ftv->retTypes)) return checkTypeMatch(typeArena, *firstRetTy, expectedType); return false; } else { - auto [retHead, retTail] = flatten(ftv->retType); + auto [retHead, retTail] = flatten(ftv->retTypes); if (!retHead.empty() && canUnify(retHead.front(), expectedType)) return true; @@ -303,7 +306,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } } - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) return checkTypeMatch(typeArena, ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; else return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; @@ -320,7 +323,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) { - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) rootTy = follow(rootTy); ty = follow(ty); @@ -330,7 +333,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId seen.insert(ty); auto isWrongIndexer_DEPRECATED = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2); if (indexType == PropIndexType::Key) return false; @@ -363,7 +366,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } }; auto isWrongIndexer = [typeArena, rootTy, indexType](Luau::TypeId type) { - LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix2); if (indexType == PropIndexType::Key) return false; @@ -377,10 +380,15 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId return calledWithSelf == ftv->hasSelf; } - if (std::optional firstArgTy = first(ftv->argTypes)) + // If a call is made with ':', it is invalid if a function has incompatible first argument or no arguments at all + // If a call is made with '.', but it was declared with 'self', it is considered invalid if first argument is compatible + if (calledWithSelf || ftv->hasSelf) { - if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) - return calledWithSelf; + if (std::optional firstArgTy = first(ftv->argTypes)) + { + if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) + return calledWithSelf; + } } return !calledWithSelf; @@ -422,7 +430,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryKind::Property, type, prop.deprecated, - FFlag::LuauSelfCallAutocompleteFix ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), + FFlag::LuauSelfCallAutocompleteFix2 ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), typeCorrect, containingClass, &prop, @@ -445,7 +453,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } else if (auto indexFunction = get(followed)) { - std::optional indexFunctionResult = first(indexFunction->retType); + std::optional indexFunctionResult = first(indexFunction->retTypes); if (indexFunctionResult) autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); } @@ -457,7 +465,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId containingClass = containingClass.value_or(cls); fillProps(cls->props); if (cls->parent) - autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, cls); + autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); } else if (auto tbl = get(ty)) fillProps(tbl->props); @@ -465,7 +473,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId { autocompleteProps(module, typeArena, rootTy, mt->table, indexType, nodes, result, seen); - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) { if (auto mtable = get(mt->metatable)) fillMetatableProps(mtable); @@ -484,7 +492,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId autocompleteProps(module, typeArena, rootTy, followed, indexType, nodes, result, seen); else if (auto indexFunction = get(followed)) { - std::optional indexFunctionResult = first(indexFunction->retType); + std::optional indexFunctionResult = first(indexFunction->retTypes); if (indexFunctionResult) autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); } @@ -531,7 +539,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen; - if (!FFlag::LuauSelfCallAutocompleteFix) + if (!FFlag::LuauSelfCallAutocompleteFix2) innerSeen = seen; if (isNil(*iter)) @@ -557,7 +565,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ++iter; } } - else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix) + else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix2) { if (pt->metatable) { @@ -565,7 +573,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId fillMetatableProps(mtable); } } - else if (FFlag::LuauSelfCallAutocompleteFix && get(get(ty))) + else if (FFlag::LuauSelfCallAutocompleteFix2 && get(get(ty))) { autocompleteProps(module, typeArena, rootTy, getSingletonTypes().stringType, indexType, nodes, result, seen); } @@ -625,6 +633,31 @@ AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position posi return result; } +static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteEntryMap& result) +{ + auto formatKey = [addQuotes](const std::string& key) { + if (addQuotes) + return "\"" + escape(key) + "\""; + + return escape(key); + }; + + ty = follow(ty); + + if (auto ss = get(get(ty))) + { + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + else if (auto uty = get(ty)) + { + for (auto el : uty) + { + if (auto ss = get(get(el))) + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + } +}; + static bool canSuggestInferredType(ScopePtr scope, TypeId ty) { ty = follow(ty); @@ -708,7 +741,7 @@ static std::optional findTypeElementAt(AstType* astType, TypeId ty, Posi if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) return element; - if (auto element = findTypeElementAt(type->returnTypes, ftv->retType, position)) + if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position)) return element; } @@ -924,7 +957,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = get(follow(*it))) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) inferredType = *ty; } } @@ -1016,7 +1049,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, i)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); } @@ -1033,7 +1066,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, ~0u)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); } } @@ -1232,7 +1265,7 @@ static bool autocompleteIfElseExpression( if (!parent) return false; - if (FFlag::LuauIfElseExprFixCompletionIssue && node->is()) + if (node->is()) { // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else // expression. @@ -1310,16 +1343,20 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul } TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); - TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType); + TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().trueType); + TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().falseType); TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; - result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; + result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForTrue}; + result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForFalse}; result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; result["not"] = {AutocompleteEntryKind::Keyword}; result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + + if (auto ty = findExpectedTypeAt(module, node, position)) + autocompleteStringSingleton(*ty, true, result); } } @@ -1466,7 +1503,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - if (!FFlag::LuauSelfCallAutocompleteFix && isString(ty)) + if (!FFlag::LuauSelfCallAutocompleteFix2 && isString(ty)) return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry), finder.ancestry}; else @@ -1625,17 +1662,29 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } else if (node->is()) { + AutocompleteEntryMap result; + + if (auto it = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*it, false, result); + if (finder.ancestry.size() >= 2) { if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) { if (auto it = module->astTypes.find(idxExpr->expr)) + autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result); + } + else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) + { + if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { - return {autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry), finder.ancestry}; + if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) + autocompleteStringSingleton(*it, false, result); } } } - return {}; + + return {result, finder.ancestry}; } if (node->is()) @@ -1655,16 +1704,16 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName { // FIXME: We can improve performance here by parsing without checking. // The old type graph is probably fine. (famous last words!) - // FIXME: We don't need to typecheck for script analysis here, just for autocomplete. - frontend.check(moduleName); + FrontendOptions opts; + opts.forAutocomplete = true; + frontend.check(moduleName, opts); const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) return {}; - TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete; + ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName); if (!module) return {}; @@ -1692,8 +1741,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->mode = Mode::Strict; sourceModule->commentLocations = std::move(result.commentLocations); - TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - + TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete; ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); OwningAutocompleteResult autocompleteResult = { diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index bf9ef303..2f57e23c 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -8,8 +8,6 @@ #include -LUAU_FASTFLAG(LuauAssertStripsFalsyTypes) -LUAU_FASTFLAGVARIABLE(LuauTableCloneType, false) LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) /** FIXME: Many of these type definitions are not quite completely accurate. @@ -21,16 +19,16 @@ LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) namespace Luau { -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -181,44 +179,13 @@ void registerBuiltinTypes(TypeChecker& typeChecker) LUAU_ASSERT(!typeChecker.globalTypes.typeVars.isFrozen()); LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen()); - TypeId numberType = typeChecker.numberType; - TypeId booleanType = typeChecker.booleanType; TypeId nilType = typeChecker.nilType; TypeArena& arena = typeChecker.globalTypes; - TypePackId oneNumberPack = arena.addTypePack({numberType}); - TypePackId oneBooleanPack = arena.addTypePack({booleanType}); - - TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); - TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); - - TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{ - listOfAtLeastOneNumber, - oneNumberPack, - }); - - TypeId listOfAtLeastZeroNumbersToNumberType = arena.addType(FunctionTypeVar{numberVariadicList, oneNumberPack}); - LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); LUAU_ASSERT(loadResult.success); - TypeId mathLibType = getGlobalBinding(typeChecker, "math"); - if (TableTypeVar* ttv = getMutable(mathLibType)) - { - ttv->props["min"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.min"); - ttv->props["max"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.max"); - } - - TypeId bit32LibType = getGlobalBinding(typeChecker, "bit32"); - if (TableTypeVar* ttv = getMutable(bit32LibType)) - { - ttv->props["band"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.band"); - ttv->props["bor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bor"); - ttv->props["bxor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bxor"); - ttv->props["btest"] = makeProperty(arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneBooleanPack}), "@luau/global/bit32.btest"); - } - TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); @@ -233,7 +200,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - // next(t: Table, i: K | nil) -> (K, V) + // next(t: Table, i: K?) -> (K, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); addGlobalBinding(typeChecker, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); @@ -243,8 +210,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - // NOTE we are missing 'i: K | nil' argument in the first return types' argument. - // pairs(t: Table) -> ((Table) -> (K, V), Table, nil) + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); @@ -289,9 +255,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) { // tabTy is a generic table type which we can't express via declaration syntax yet ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); - - if (FFlag::LuauTableCloneType) - ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); } @@ -299,10 +263,10 @@ void registerBuiltinTypes(TypeChecker& typeChecker) attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); } -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; (void)scope; @@ -323,10 +287,10 @@ static std::optional> magicFunctionSelect( if (size_t(offset) < v.size()) { std::vector result(v.begin() + offset, v.end()); - return ExprResult{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})}; + return WithPredicate{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})}; } else if (tail) - return ExprResult{*tail}; + return WithPredicate{*tail}; } typechecker.reportError(TypeError{arg1->location, GenericError{"bad argument #1 to select (index out of range)"}}); @@ -334,16 +298,16 @@ static std::optional> magicFunctionSelect( else if (AstExprConstantString* str = arg1->as()) { if (str->value.size == 1 && str->value.data[0] == '#') - return ExprResult{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})}; + return WithPredicate{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})}; } return std::nullopt; } -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -379,7 +343,7 @@ static std::optional> magicFunctionSetMetaTable( if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1) { - return ExprResult{}; + return WithPredicate{}; } if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self) @@ -392,7 +356,7 @@ static std::optional> magicFunctionSetMetaTable( } } - return ExprResult{arena.addTypePack({mtTy})}; + return WithPredicate{arena.addTypePack({mtTy})}; } } else if (get(target) || get(target) || isTableIntersection(target)) @@ -403,55 +367,43 @@ static std::optional> magicFunctionSetMetaTable( typechecker.reportError(TypeError{expr.location, GenericError{"setmetatable should take a table"}}); } - return ExprResult{arena.addTypePack({target})}; + return WithPredicate{arena.addTypePack({target})}; } -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, predicates] = exprResult; + auto [paramPack, predicates] = withPredicate; - if (FFlag::LuauAssertStripsFalsyTypes) + TypeArena& arena = typechecker.currentModule->internalTypes; + + auto [head, tail] = flatten(paramPack); + if (head.empty() && tail) { - TypeArena& arena = typechecker.currentModule->internalTypes; - - auto [head, tail] = flatten(paramPack); - if (head.empty() && tail) - { - std::optional fst = first(*tail); - if (!fst) - return ExprResult{paramPack}; - head.push_back(*fst); - } - - typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); - - if (head.size() > 0) - { - std::optional newhead = typechecker.pickTypesFromSense(head[0], true); - if (!newhead) - head = {typechecker.nilType}; - else - head[0] = *newhead; - } - - return ExprResult{arena.addTypePack(TypePack{std::move(head), tail})}; + std::optional fst = first(*tail); + if (!fst) + return WithPredicate{paramPack}; + head.push_back(*fst); } - else + + typechecker.resolve(predicates, scope, true); + + if (head.size() > 0) { - if (expr.args.size < 1) - return ExprResult{paramPack}; - - typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); - - return ExprResult{paramPack}; + std::optional newhead = typechecker.pickTypesFromSense(head[0], true); + if (!newhead) + head = {typechecker.nilType}; + else + head[0] = *newhead; } + + return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; } -static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -484,7 +436,7 @@ static std::optional> magicFunctionPack( TypeId packedTable = arena.addType( TableTypeVar{{{"n", {typechecker.numberType}}}, TableIndexer(typechecker.numberType, result), scope->level, TableState::Sealed}); - return ExprResult{arena.addTypePack({packedTable})}; + return WithPredicate{arena.addTypePack({packedTable})}; } static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) @@ -509,8 +461,8 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) return good; } -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { TypeArena& arena = typechecker.currentModule->internalTypes; @@ -524,7 +476,7 @@ static std::optional> magicFunctionRequire( return std::nullopt; if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr)) - return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; + return WithPredicate{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; return std::nullopt; } diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp new file mode 100644 index 00000000..df4e0a6b --- /dev/null +++ b/Analysis/src/Clone.cpp @@ -0,0 +1,450 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Clone.h" +#include "Luau/RecursionCounter.h" +#include "Luau/TxnLog.h" +#include "Luau/TypePack.h" +#include "Luau/Unifiable.h" + +LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) + +LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) + +namespace Luau +{ + +namespace +{ + +struct TypePackCloner; + +/* + * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. + * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. + */ + +struct TypeCloner +{ + TypeCloner(TypeArena& dest, TypeId typeId, CloneState& cloneState) + : dest(dest) + , typeId(typeId) + , seenTypes(cloneState.seenTypes) + , seenTypePacks(cloneState.seenTypePacks) + , cloneState(cloneState) + { + } + + TypeArena& dest; + TypeId typeId; + SeenTypes& seenTypes; + SeenTypePacks& seenTypePacks; + CloneState& cloneState; + + template + void defaultClone(const T& t); + + void operator()(const Unifiable::Free& t); + void operator()(const Unifiable::Generic& t); + void operator()(const Unifiable::Bound& t); + void operator()(const Unifiable::Error& t); + void operator()(const BlockedTypeVar& t); + void operator()(const PrimitiveTypeVar& t); + void operator()(const ConstrainedTypeVar& t); + void operator()(const SingletonTypeVar& t); + void operator()(const FunctionTypeVar& t); + void operator()(const TableTypeVar& t); + void operator()(const MetatableTypeVar& t); + void operator()(const ClassTypeVar& t); + void operator()(const AnyTypeVar& t); + void operator()(const UnionTypeVar& t); + void operator()(const IntersectionTypeVar& t); + void operator()(const LazyTypeVar& t); +}; + +struct TypePackCloner +{ + TypeArena& dest; + TypePackId typePackId; + SeenTypes& seenTypes; + SeenTypePacks& seenTypePacks; + CloneState& cloneState; + + TypePackCloner(TypeArena& dest, TypePackId typePackId, CloneState& cloneState) + : dest(dest) + , typePackId(typePackId) + , seenTypes(cloneState.seenTypes) + , seenTypePacks(cloneState.seenTypePacks) + , cloneState(cloneState) + { + } + + template + void defaultClone(const T& t) + { + TypePackId cloned = dest.addTypePack(TypePackVar{t}); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const Unifiable::Free& t) + { + defaultClone(t); + } + void operator()(const Unifiable::Generic& t) + { + defaultClone(t); + } + void operator()(const Unifiable::Error& t) + { + defaultClone(t); + } + + // While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. + // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. + void operator()(const Unifiable::Bound& t) + { + TypePackId cloned = clone(t.boundTo, dest, cloneState); + if (FFlag::DebugLuauCopyBeforeNormalizing) + cloned = dest.addTypePack(TypePackVar{BoundTypePack{cloned}}); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const VariadicTypePack& t) + { + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, cloneState), /*hidden*/ t.hidden}}); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const TypePack& t) + { + TypePackId cloned = dest.addTypePack(TypePack{}); + TypePack* destTp = getMutable(cloned); + LUAU_ASSERT(destTp != nullptr); + seenTypePacks[typePackId] = cloned; + + for (TypeId ty : t.head) + destTp->head.push_back(clone(ty, dest, cloneState)); + + if (t.tail) + destTp->tail = clone(*t.tail, dest, cloneState); + } +}; + +template +void TypeCloner::defaultClone(const T& t) +{ + TypeId cloned = dest.addType(t); + seenTypes[typeId] = cloned; +} + +void TypeCloner::operator()(const Unifiable::Free& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const Unifiable::Generic& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const Unifiable::Bound& t) +{ + TypeId boundTo = clone(t.boundTo, dest, cloneState); + if (FFlag::DebugLuauCopyBeforeNormalizing) + boundTo = dest.addType(BoundTypeVar{boundTo}); + seenTypes[typeId] = boundTo; +} + +void TypeCloner::operator()(const Unifiable::Error& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const BlockedTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const PrimitiveTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const ConstrainedTypeVar& t) +{ + TypeId res = dest.addType(ConstrainedTypeVar{t.level}); + ConstrainedTypeVar* ctv = getMutable(res); + LUAU_ASSERT(ctv); + + seenTypes[typeId] = res; + + std::vector parts; + for (TypeId part : t.parts) + parts.push_back(clone(part, dest, cloneState)); + + ctv->parts = std::move(parts); +} + +void TypeCloner::operator()(const SingletonTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const FunctionTypeVar& t) +{ + TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); + FunctionTypeVar* ftv = getMutable(result); + LUAU_ASSERT(ftv != nullptr); + + seenTypes[typeId] = result; + + for (TypeId generic : t.generics) + ftv->generics.push_back(clone(generic, dest, cloneState)); + + for (TypePackId genericPack : t.genericPacks) + ftv->genericPacks.push_back(clone(genericPack, dest, cloneState)); + + ftv->tags = t.tags; + ftv->argTypes = clone(t.argTypes, dest, cloneState); + ftv->argNames = t.argNames; + ftv->retTypes = clone(t.retTypes, dest, cloneState); + ftv->hasNoGenerics = t.hasNoGenerics; +} + +void TypeCloner::operator()(const TableTypeVar& t) +{ + // If table is now bound to another one, we ignore the content of the original + if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) + { + TypeId boundTo = clone(*t.boundTo, dest, cloneState); + seenTypes[typeId] = boundTo; + return; + } + + TypeId result = dest.addType(TableTypeVar{}); + TableTypeVar* ttv = getMutable(result); + LUAU_ASSERT(ttv != nullptr); + + *ttv = t; + + seenTypes[typeId] = result; + + ttv->level = TypeLevel{0, 0}; + + if (FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) + ttv->boundTo = clone(*t.boundTo, dest, cloneState); + + for (const auto& [name, prop] : t.props) + ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + + if (t.indexer) + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; + + for (TypeId& arg : ttv->instantiatedTypeParams) + arg = clone(arg, dest, cloneState); + + for (TypePackId& arg : ttv->instantiatedTypePackParams) + arg = clone(arg, dest, cloneState); + + ttv->definitionModuleName = t.definitionModuleName; + ttv->tags = t.tags; +} + +void TypeCloner::operator()(const MetatableTypeVar& t) +{ + TypeId result = dest.addType(MetatableTypeVar{}); + MetatableTypeVar* mtv = getMutable(result); + seenTypes[typeId] = result; + + mtv->table = clone(t.table, dest, cloneState); + mtv->metatable = clone(t.metatable, dest, cloneState); +} + +void TypeCloner::operator()(const ClassTypeVar& t) +{ + TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData, t.definitionModuleName}); + ClassTypeVar* ctv = getMutable(result); + + seenTypes[typeId] = result; + + for (const auto& [name, prop] : t.props) + ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + + if (t.parent) + ctv->parent = clone(*t.parent, dest, cloneState); + + if (t.metatable) + ctv->metatable = clone(*t.metatable, dest, cloneState); +} + +void TypeCloner::operator()(const AnyTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const UnionTypeVar& t) +{ + std::vector options; + options.reserve(t.options.size()); + + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, cloneState)); + + TypeId result = dest.addType(UnionTypeVar{std::move(options)}); + seenTypes[typeId] = result; +} + +void TypeCloner::operator()(const IntersectionTypeVar& t) +{ + TypeId result = dest.addType(IntersectionTypeVar{}); + seenTypes[typeId] = result; + + IntersectionTypeVar* option = getMutable(result); + LUAU_ASSERT(option != nullptr); + + for (TypeId ty : t.parts) + option->parts.push_back(clone(ty, dest, cloneState)); +} + +void TypeCloner::operator()(const LazyTypeVar& t) +{ + defaultClone(t); +} + +} // anonymous namespace + +TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) +{ + if (tp->persistent) + return tp; + + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + + TypePackId& res = cloneState.seenTypePacks[tp]; + + if (res == nullptr) + { + TypePackCloner cloner{dest, tp, cloneState}; + Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. + } + + return res; +} + +TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) +{ + if (typeId->persistent) + return typeId; + + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + + TypeId& res = cloneState.seenTypes[typeId]; + + if (res == nullptr) + { + TypeCloner cloner{dest, typeId, cloneState}; + Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. + + // Persistent types are not being cloned and we get the original type back which might be read-only + if (!res->persistent) + { + asMutable(res)->documentationSymbol = typeId->documentationSymbol; + asMutable(res)->normal = typeId->normal; + } + } + + return res; +} + +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) +{ + TypeFun result; + + for (auto param : typeFun.typeParams) + { + TypeId ty = clone(param.ty, dest, cloneState); + std::optional defaultValue; + + if (param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, cloneState); + + result.typeParams.push_back({ty, defaultValue}); + } + + for (auto param : typeFun.typePackParams) + { + TypePackId tp = clone(param.tp, dest, cloneState); + std::optional defaultValue; + + if (param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, cloneState); + + result.typePackParams.push_back({tp, defaultValue}); + } + + result.type = clone(typeFun.type, dest, cloneState); + + return result; +} + +TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log) +{ + ty = log->follow(ty); + + TypeId result = ty; + + if (auto pty = log->pending(ty)) + ty = &pty->pending; + + if (const FunctionTypeVar* ftv = get(ty)) + { + FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; + clone.generics = ftv->generics; + clone.genericPacks = ftv->genericPacks; + clone.magicFunction = ftv->magicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + result = dest.addType(std::move(clone)); + } + else if (const TableTypeVar* ttv = get(ty)) + { + LUAU_ASSERT(!ttv->boundTo); + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; + clone.definitionModuleName = ttv->definitionModuleName; + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.instantiatedTypeParams = ttv->instantiatedTypeParams; + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; + clone.tags = ttv->tags; + result = dest.addType(std::move(clone)); + } + else if (const MetatableTypeVar* mtv = get(ty)) + { + MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; + clone.syntheticName = mtv->syntheticName; + result = dest.addType(std::move(clone)); + } + else if (const UnionTypeVar* utv = get(ty)) + { + UnionTypeVar clone; + clone.options = utv->options; + result = dest.addType(std::move(clone)); + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + IntersectionTypeVar clone; + clone.parts = itv->parts; + result = dest.addType(std::move(clone)); + } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + ConstrainedTypeVar clone{ctv->level, ctv->parts}; + result = dest.addType(std::move(clone)); + } + else + return result; + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; +} + +} // namespace Luau diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp new file mode 100644 index 00000000..64e3a666 --- /dev/null +++ b/Analysis/src/Constraint.cpp @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Constraint.h" + +namespace Luau +{ + +Constraint::Constraint(ConstraintV&& c) + : c(std::move(c)) +{ +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp new file mode 100644 index 00000000..d9e8d238 --- /dev/null +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -0,0 +1,773 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ConstraintGraphBuilder.h" + +#include "Luau/Scope.h" + +namespace Luau +{ + +const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp + +ConstraintGraphBuilder::ConstraintGraphBuilder(TypeArena* arena) + : singletonTypes(getSingletonTypes()) + , arena(arena) + , rootScope(nullptr) +{ + LUAU_ASSERT(arena); +} + +TypeId ConstraintGraphBuilder::freshType(Scope2* scope) +{ + LUAU_ASSERT(scope); + return arena->addType(FreeTypeVar{scope}); +} + +TypePackId ConstraintGraphBuilder::freshTypePack(Scope2* scope) +{ + LUAU_ASSERT(scope); + FreeTypePack f{scope}; + return arena->addTypePack(TypePackVar{std::move(f)}); +} + +Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) +{ + LUAU_ASSERT(parent); + auto scope = std::make_unique(); + Scope2* borrow = scope.get(); + scopes.emplace_back(location, std::move(scope)); + + borrow->parent = parent; + borrow->returnType = parent->returnType; + parent->children.push_back(borrow); + + return borrow; +} + +void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv) +{ + LUAU_ASSERT(scope); + scope->constraints.emplace_back(new Constraint{std::move(cv)}); +} + +void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr c) +{ + LUAU_ASSERT(scope); + scope->constraints.emplace_back(std::move(c)); +} + +void ConstraintGraphBuilder::visit(AstStatBlock* block) +{ + LUAU_ASSERT(scopes.empty()); + LUAU_ASSERT(rootScope == nullptr); + scopes.emplace_back(block->location, std::make_unique()); + rootScope = scopes.back().second.get(); + rootScope->returnType = freshTypePack(rootScope); + + // TODO: We should share the global scope. + rootScope->typeBindings["nil"] = singletonTypes.nilType; + rootScope->typeBindings["number"] = singletonTypes.numberType; + rootScope->typeBindings["string"] = singletonTypes.stringType; + rootScope->typeBindings["boolean"] = singletonTypes.booleanType; + rootScope->typeBindings["thread"] = singletonTypes.threadType; + + visit(rootScope, block); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) +{ + LUAU_ASSERT(scope); + + if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto f = stat->as()) + visit(scope, f); + else if (auto f = stat->as()) + visit(scope, f); + else if (auto r = stat->as()) + visit(scope, r); + else if (auto a = stat->as()) + visit(scope, a); + else if (auto e = stat->as()) + checkPack(scope, e->expr); + else if (auto i = stat->as()) + visit(scope, i); + else if (auto a = stat->as()) + visit(scope, a); + else + LUAU_ASSERT(0); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) +{ + LUAU_ASSERT(scope); + + std::vector varTypes; + + for (AstLocal* local : local->vars) + { + TypeId ty = freshType(scope); + + if (local->annotation) + { + TypeId annotation = resolveType(scope, local->annotation); + addConstraint(scope, SubtypeConstraint{ty, annotation}); + } + + varTypes.push_back(ty); + scope->bindings[local] = ty; + } + + for (size_t i = 0; i < local->values.size; ++i) + { + if (local->values.data[i]->is()) + { + // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. + // See the test TypeInfer/infer_locals_with_nil_value. + // Better flow awareness should make this obsolete. + } + else if (i == local->values.size - 1) + { + TypePackId exprPack = checkPack(scope, local->values.data[i]); + + if (i < local->vars.size) + { + std::vector tailValues{varTypes.begin() + i, varTypes.end()}; + TypePackId tailPack = arena->addTypePack(std::move(tailValues)); + addConstraint(scope, PackSubtypeConstraint{exprPack, tailPack}); + } + } + else + { + TypeId exprType = check(scope, local->values.data[i]); + if (i < varTypes.size()) + addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}); + } + } +} + +void addConstraints(Constraint* constraint, Scope2* scope) +{ + LUAU_ASSERT(scope); + + scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); + + for (const auto& c : scope->constraints) + constraint->dependencies.push_back(NotNull{c.get()}); + + for (Scope2* childScope : scope->children) + addConstraints(constraint, childScope); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function) +{ + LUAU_ASSERT(scope); + + // Local + // Global + // Dotted path + // Self? + + TypeId functionType = nullptr; + auto ty = scope->lookup(function->name); + if (ty.has_value()) + { + // TODO: This is duplicate definition of a local function. Is this allowed? + functionType = *ty; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[function->name] = functionType; + } + + auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); + innerScope->bindings[function->name] = actualFunctionType; + + checkFunctionBody(innerScope, function->func); + + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; + addConstraints(c.get(), innerScope); + + addConstraint(scope, std::move(c)); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) +{ + // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. + // With or without self + + TypeId functionType = nullptr; + + auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); + + if (AstExprLocal* localName = function->name->as()) + { + std::optional existingFunctionTy = scope->lookup(localName->local); + if (existingFunctionTy) + { + // Duplicate definition + functionType = *existingFunctionTy; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[localName->local] = functionType; + } + innerScope->bindings[localName->local] = actualFunctionType; + } + else if (AstExprGlobal* globalName = function->name->as()) + { + std::optional existingFunctionTy = scope->lookup(globalName->name); + if (existingFunctionTy) + { + // Duplicate definition + functionType = *existingFunctionTy; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + rootScope->bindings[globalName->name] = functionType; + } + innerScope->bindings[globalName->name] = actualFunctionType; + } + else if (AstExprIndexName* indexName = function->name->as()) + { + LUAU_ASSERT(0); // not yet implemented + } + + checkFunctionBody(innerScope, function->func); + + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; + addConstraints(c.get(), innerScope); + + addConstraint(scope, std::move(c)); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret) +{ + LUAU_ASSERT(scope); + + TypePackId exprTypes = checkPack(scope, ret->list); + addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) +{ + LUAU_ASSERT(scope); + + // In order to enable mutually-recursive type aliases, we need to + // populate the type bindings before we actually check any of the + // alias statements. Since we're not ready to actually resolve + // any of the annotations, we just use a fresh type for now. + for (AstStat* stat : block->body) + { + if (auto alias = stat->as()) + { + TypeId initialType = freshType(scope); + scope->typeBindings[alias->name.value] = initialType; + } + } + + for (AstStat* stat : block->body) + visit(scope, stat); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) +{ + TypePackId varPackId = checkExprList(scope, assign->vars); + TypePackId valuePack = checkPack(scope, assign->values); + + addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement) +{ + check(scope, ifStatement->condition); + + Scope2* thenScope = childScope(ifStatement->thenbody->location, scope); + visit(thenScope, ifStatement->thenbody); + + if (ifStatement->elsebody) + { + Scope2* elseScope = childScope(ifStatement->elsebody->location, scope); + visit(elseScope, ifStatement->elsebody); + } +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) +{ + // TODO: Exported type aliases + // TODO: Generic type aliases + + auto it = scope->typeBindings.find(alias->name.value); + // This should always be here since we do a separate pass over the + // AST to set up typeBindings. If it's not, we've somehow skipped + // this alias in that first pass. + LUAU_ASSERT(it != scope->typeBindings.end()); + + TypeId ty = resolveType(scope, alias->type); + + // Rather than using a subtype constraint, we instead directly bind + // the free type we generated in the first pass to the resolved type. + // This prevents a case where you could cause another constraint to + // bind the free alias type to an unrelated type, causing havoc. + asMutable(it->second)->ty.emplace(ty); + + addConstraint(scope, NameConstraint{ty, alias->name.value}); +} + +TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray exprs) +{ + LUAU_ASSERT(scope); + + if (exprs.size == 0) + return arena->addTypePack({}); + + std::vector types; + TypePackId last = nullptr; + + for (size_t i = 0; i < exprs.size; ++i) + { + if (i < exprs.size - 1) + types.push_back(check(scope, exprs.data[i])); + else + last = checkPack(scope, exprs.data[i]); + } + + LUAU_ASSERT(last != nullptr); + + return arena->addTypePack(TypePack{std::move(types), last}); +} + +TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray& exprs) +{ + TypePackId result = arena->addTypePack({}); + TypePack* resultPack = getMutable(result); + LUAU_ASSERT(resultPack); + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* expr = exprs.data[i]; + if (i < exprs.size - 1) + resultPack->head.push_back(check(scope, expr)); + else + resultPack->tail = checkPack(scope, expr); + } + + if (resultPack->head.empty() && resultPack->tail) + return *resultPack->tail; + else + return result; +} + +TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) +{ + LUAU_ASSERT(scope); + + TypePackId result = nullptr; + + if (AstExprCall* call = expr->as()) + { + std::vector args; + + for (AstExpr* arg : call->args) + { + args.push_back(check(scope, arg)); + } + + // TODO self + + TypeId fnType = check(scope, call->func); + + astOriginalCallTypes[call->func] = fnType; + + TypeId instantiatedType = freshType(scope); + addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); + + TypePackId rets = freshTypePack(scope); + FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); + TypeId inferredFnType = arena->addType(ftv); + + addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); + result = rets; + } + else + { + TypeId t = check(scope, expr); + result = arena->addTypePack({t}); + } + + LUAU_ASSERT(result); + astTypePacks[expr] = result; + return result; +} + +TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) +{ + LUAU_ASSERT(scope); + + TypeId result = nullptr; + + if (auto group = expr->as()) + result = check(scope, group->expr); + else if (expr->is()) + result = singletonTypes.stringType; + else if (expr->is()) + result = singletonTypes.numberType; + else if (expr->is()) + result = singletonTypes.booleanType; + else if (expr->is()) + result = singletonTypes.nilType; + else if (auto a = expr->as()) + { + std::optional ty = scope->lookup(a->local); + if (ty) + result = *ty; + else + result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? + } + else if (auto g = expr->as()) + { + std::optional ty = scope->lookup(g->name); + if (ty) + result = *ty; + else + result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? + } + else if (auto a = expr->as()) + { + TypePackId packResult = checkPack(scope, expr); + if (auto f = first(packResult)) + return *f; + else if (get(packResult)) + { + TypeId typeResult = freshType(scope); + TypePack onePack{{typeResult}, freshTypePack(scope)}; + TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); + + addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack}); + + return typeResult; + } + } + else if (auto a = expr->as()) + { + auto [fnType, functionScope] = checkFunctionSignature(scope, a); + checkFunctionBody(functionScope, a); + return fnType; + } + else if (auto indexName = expr->as()) + { + result = check(scope, indexName); + } + else if (auto table = expr->as()) + { + result = checkExprTable(scope, table); + } + else + { + LUAU_ASSERT(0); + result = freshType(scope); + } + + LUAU_ASSERT(result); + astTypes[expr] = result; + return result; +} + +TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) +{ + TypeId obj = check(scope, indexName->expr); + TypeId result = freshType(scope); + + TableTypeVar::Props props{{indexName->index.value, Property{result}}}; + const std::optional indexer; + TableTypeVar ttv{std::move(props), indexer, TypeLevel{}, TableState::Free}; + + TypeId expectedTableType = arena->addType(std::move(ttv)); + + addConstraint(scope, SubtypeConstraint{obj, expectedTableType}); + + return result; +} + +TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) +{ + TypeId ty = arena->addType(TableTypeVar{}); + TableTypeVar* ttv = getMutable(ty); + LUAU_ASSERT(ttv); + + auto createIndexer = [this, scope, ttv](TypeId currentIndexType, TypeId currentResultType) { + if (!ttv->indexer) + { + TypeId indexType = this->freshType(scope); + TypeId resultType = this->freshType(scope); + ttv->indexer = TableIndexer{indexType, resultType}; + } + + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}); + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}); + }; + + for (const AstExprTable::Item& item : expr->items) + { + TypeId itemTy = check(scope, item.value); + + if (item.key) + { + // Even though we don't need to use the type of the item's key if + // it's a string constant, we still want to check it to populate + // astTypes. + TypeId keyTy = check(scope, item.key); + + if (AstExprConstantString* key = item.key->as()) + { + ttv->props[key->value.begin()] = {itemTy}; + } + else + { + createIndexer(keyTy, itemTy); + } + } + else + { + TypeId numberType = singletonTypes.numberType; + createIndexer(numberType, itemTy); + } + } + + return ty; +} + +std::pair ConstraintGraphBuilder::checkFunctionSignature(Scope2* parent, AstExprFunction* fn) +{ + Scope2* innerScope = childScope(fn->body->location, parent); + TypePackId returnType = freshTypePack(innerScope); + innerScope->returnType = returnType; + + if (fn->returnAnnotation) + { + TypePackId annotatedRetType = resolveTypePack(innerScope, *fn->returnAnnotation); + addConstraint(innerScope, PackSubtypeConstraint{returnType, annotatedRetType}); + } + + std::vector argTypes; + + for (AstLocal* local : fn->args) + { + TypeId t = freshType(innerScope); + argTypes.push_back(t); + innerScope->bindings[local] = t; + + if (local->annotation) + { + TypeId argAnnotation = resolveType(innerScope, local->annotation); + addConstraint(innerScope, SubtypeConstraint{t, argAnnotation}); + } + } + + // TODO: Vararg annotation. + + FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); + LUAU_ASSERT(actualFunctionType); + astTypes[fn] = actualFunctionType; + + return {actualFunctionType, innerScope}; +} + +void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* fn) +{ + for (AstStat* stat : fn->body->body) + visit(scope, stat); + + // If it is possible for execution to reach the end of the function, the return type must be compatible with () + + if (nullptr != getFallthrough(fn->body)) + { + TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever + addConstraint(scope, PackSubtypeConstraint{scope->returnType, empty}); + } +} + +TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) +{ + TypeId result = nullptr; + + if (auto ref = ty->as()) + { + // TODO: Support imported types w/ require tracing. + // TODO: Support generic type references. + LUAU_ASSERT(!ref->prefix); + LUAU_ASSERT(!ref->hasParameterList); + + // TODO: If it doesn't exist, should we introduce a free binding? + // This is probably important for handling type aliases. + result = scope->lookupTypeBinding(ref->name.value).value_or(singletonTypes.errorRecoveryType()); + } + else if (auto tab = ty->as()) + { + TableTypeVar::Props props; + std::optional indexer; + + for (const AstTableProp& prop : tab->props) + { + std::string name = prop.name.value; + // TODO: Recursion limit. + TypeId propTy = resolveType(scope, prop.type); + // TODO: Fill in location. + props[name] = {propTy}; + } + + if (tab->indexer) + { + // TODO: Recursion limit. + indexer = TableIndexer{ + resolveType(scope, tab->indexer->indexType), + resolveType(scope, tab->indexer->resultType), + }; + } + + // TODO: Remove TypeLevel{} here, we don't need it. + result = arena->addType(TableTypeVar{props, indexer, TypeLevel{}, TableState::Sealed}); + } + else if (auto fn = ty->as()) + { + // TODO: Generic functions. + // TODO: Scope (though it may not be needed). + // TODO: Recursion limit. + TypePackId argTypes = resolveTypePack(scope, fn->argTypes); + TypePackId returnTypes = resolveTypePack(scope, fn->returnTypes); + + // TODO: Is this the right constructor to use? + result = arena->addType(FunctionTypeVar{argTypes, returnTypes}); + + FunctionTypeVar* ftv = getMutable(result); + ftv->argNames.reserve(fn->argNames.size); + for (const auto& el : fn->argNames) + { + if (el) + { + const auto& [name, location] = *el; + ftv->argNames.push_back(FunctionArgument{name.value, location}); + } + else + { + ftv->argNames.push_back(std::nullopt); + } + } + } + else if (auto tof = ty->as()) + { + // TODO: Recursion limit. + TypeId exprType = check(scope, tof->expr); + result = exprType; + } + else if (auto unionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : unionAnnotation->types) + { + // TODO: Recursion limit. + parts.push_back(resolveType(scope, part)); + } + + result = arena->addType(UnionTypeVar{parts}); + } + else if (auto intersectionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : intersectionAnnotation->types) + { + // TODO: Recursion limit. + parts.push_back(resolveType(scope, part)); + } + + result = arena->addType(IntersectionTypeVar{parts}); + } + else if (auto boolAnnotation = ty->as()) + { + result = arena->addType(SingletonTypeVar(BooleanSingleton{boolAnnotation->value})); + } + else if (auto stringAnnotation = ty->as()) + { + result = arena->addType(SingletonTypeVar(StringSingleton{std::string(stringAnnotation->value.data, stringAnnotation->value.size)})); + } + else if (ty->is()) + { + result = singletonTypes.errorRecoveryType(); + } + else + { + LUAU_ASSERT(0); + result = singletonTypes.errorRecoveryType(); + } + + astResolvedTypes[ty] = result; + return result; +} + +TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* tp) +{ + TypePackId result; + if (auto expl = tp->as()) + { + result = resolveTypePack(scope, expl->typeList); + } + else if (auto var = tp->as()) + { + TypeId ty = resolveType(scope, var->variadicType); + result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); + } + else if (auto gen = tp->as()) + { + result = arena->addTypePack(TypePackVar{GenericTypePack{scope, gen->genericName.value}}); + } + else + { + LUAU_ASSERT(0); + result = singletonTypes.errorRecoveryTypePack(); + } + + astResolvedTypePacks[tp] = result; + return result; +} + +TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeList& list) +{ + std::vector head; + + for (AstType* headTy : list.types) + { + head.push_back(resolveType(scope, headTy)); + } + + std::optional tail = std::nullopt; + if (list.tailType) + { + tail = resolveTypePack(scope, list.tailType); + } + + return arena->addTypePack(TypePack{head, tail}); +} + +void collectConstraints(std::vector>& result, Scope2* scope) +{ + for (const auto& c : scope->constraints) + result.push_back(NotNull{c.get()}); + + for (Scope2* child : scope->children) + collectConstraints(result, child); +} + +std::vector> collectConstraints(Scope2* rootScope) +{ + std::vector> result; + collectConstraints(result, rootScope); + return result; +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp new file mode 100644 index 00000000..9e355236 --- /dev/null +++ b/Analysis/src/ConstraintSolver.cpp @@ -0,0 +1,361 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ConstraintSolver.h" +#include "Luau/Instantiation.h" +#include "Luau/Location.h" +#include "Luau/Quantify.h" +#include "Luau/ToString.h" +#include "Luau/Unifier.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); + +namespace Luau +{ + +[[maybe_unused]] static void dumpBindings(Scope2* scope, ToStringOptions& opts) +{ + for (const auto& [k, v] : scope->bindings) + { + auto d = toStringDetailed(v, opts); + opts.nameMap = d.nameMap; + printf("\t%s : %s\n", k.c_str(), d.name.c_str()); + } + + for (Scope2* child : scope->children) + dumpBindings(child, opts); +} + +static void dumpConstraints(Scope2* scope, ToStringOptions& opts) +{ + for (const ConstraintPtr& c : scope->constraints) + { + printf("\t%s\n", toString(*c, opts).c_str()); + } + + for (Scope2* child : scope->children) + dumpConstraints(child, opts); +} + +void dump(Scope2* rootScope, ToStringOptions& opts) +{ + printf("constraints:\n"); + dumpConstraints(rootScope, opts); +} + +void dump(ConstraintSolver* cs, ToStringOptions& opts) +{ + printf("constraints:\n"); + for (const Constraint* c : cs->unsolvedConstraints) + { + printf("\t%s\n", toString(*c, opts).c_str()); + + for (const Constraint* dep : c->dependencies) + printf("\t\t%s\n", toString(*dep, opts).c_str()); + } +} + +ConstraintSolver::ConstraintSolver(TypeArena* arena, Scope2* rootScope) + : arena(arena) + , constraints(collectConstraints(rootScope)) + , rootScope(rootScope) +{ + for (NotNull c : constraints) + { + unsolvedConstraints.push_back(c); + + for (NotNull dep : c->dependencies) + { + block(dep, c); + } + } +} + +void ConstraintSolver::run() +{ + if (done()) + return; + + ToStringOptions opts; + + if (FFlag::DebugLuauLogSolver) + { + printf("Starting solver\n"); + dump(this, opts); + } + + if (FFlag::DebugLuauLogSolverToJson) + { + logger.captureBoundarySnapshot(rootScope, unsolvedConstraints); + } + + auto runSolverPass = [&](bool force) { + bool progress = false; + + size_t i = 0; + while (i < unsolvedConstraints.size()) + { + NotNull c = unsolvedConstraints[i]; + if (!force && isBlocked(c)) + { + ++i; + continue; + } + + std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{}; + + if (FFlag::DebugLuauLogSolverToJson) + { + logger.prepareStepSnapshot(rootScope, c, unsolvedConstraints); + } + + bool success = tryDispatch(c, force); + + progress |= success; + + if (success) + { + unsolvedConstraints.erase(unsolvedConstraints.begin() + i); + + if (FFlag::DebugLuauLogSolverToJson) + { + logger.commitPreparedStepSnapshot(); + } + + if (FFlag::DebugLuauLogSolver) + { + if (force) + printf("Force "); + printf("Dispatched\n\t%s\n", saveMe.c_str()); + dump(this, opts); + } + } + else + ++i; + + if (force && success) + return true; + } + + return progress; + }; + + bool progress = false; + do + { + progress = runSolverPass(false); + if (!progress) + progress |= runSolverPass(true); + } while (progress); + + if (FFlag::DebugLuauLogSolver) + { + dumpBindings(rootScope, opts); + } + + if (FFlag::DebugLuauLogSolverToJson) + { + logger.captureBoundarySnapshot(rootScope, unsolvedConstraints); + printf("Logger output:\n%s\n", logger.compileOutput().c_str()); + } +} + +bool ConstraintSolver::done() +{ + return unsolvedConstraints.empty(); +} + +bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) +{ + if (!force && isBlocked(constraint)) + return false; + + bool success = false; + + if (auto sc = get(*constraint)) + success = tryDispatch(*sc, constraint, force); + else if (auto psc = get(*constraint)) + success = tryDispatch(*psc, constraint, force); + else if (auto gc = get(*constraint)) + success = tryDispatch(*gc, constraint, force); + else if (auto ic = get(*constraint)) + success = tryDispatch(*ic, constraint, force); + else if (auto nc = get(*constraint)) + success = tryDispatch(*nc, constraint); + else + LUAU_ASSERT(0); + + if (success) + { + unblock(constraint); + } + + return success; +} + +bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) +{ + if (isBlocked(c.subType)) + return block(c.subType, constraint); + else if (isBlocked(c.superType)) + return block(c.superType, constraint); + + unify(c.subType, c.superType); + + unblock(c.subType); + unblock(c.superType); + + return true; +} + +bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) +{ + unify(c.subPack, c.superPack); + unblock(c.subPack); + unblock(c.superPack); + + return true; +} + +bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force) +{ + if (isBlocked(c.sourceType)) + return block(c.sourceType, constraint); + + if (isBlocked(c.generalizedType)) + asMutable(c.generalizedType)->ty.emplace(c.sourceType); + else + unify(c.generalizedType, c.sourceType); + + TypeId generalized = quantify(arena, c.sourceType, c.scope); + *asMutable(c.sourceType) = *generalized; + + unblock(c.generalizedType); + unblock(c.sourceType); + + return true; +} + +bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force) +{ + if (isBlocked(c.superType)) + return block(c.superType, constraint); + + Instantiation inst(TxnLog::empty(), arena, TypeLevel{}); + + std::optional instantiated = inst.substitute(c.superType); + LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS + + unify(c.subType, *instantiated); + unblock(c.subType); + + return true; +} + +bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull constraint) +{ + if (isBlocked(c.namedType)) + return block(c.namedType, constraint); + + TypeId target = follow(c.namedType); + if (TableTypeVar* ttv = getMutable(target)) + ttv->name = c.name; + else if (MetatableTypeVar* mtv = getMutable(target)) + mtv->syntheticName = c.name; + else + return block(c.namedType, constraint); + + return true; +} + +void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) +{ + blocked[target].push_back(constraint); + + auto& count = blockedConstraints[constraint]; + count += 1; +} + +void ConstraintSolver::block(NotNull target, NotNull constraint) +{ + block_(target, constraint); +} + +bool ConstraintSolver::block(TypeId target, NotNull constraint) +{ + block_(target, constraint); + return false; +} + +bool ConstraintSolver::block(TypePackId target, NotNull constraint) +{ + block_(target, constraint); + return false; +} + +void ConstraintSolver::unblock_(BlockedConstraintId progressed) +{ + auto it = blocked.find(progressed); + if (it == blocked.end()) + return; + + // unblocked should contain a value always, because of the above check + for (NotNull unblockedConstraint : it->second) + { + auto& count = blockedConstraints[unblockedConstraint]; + // This assertion being hit indicates that `blocked` and + // `blockedConstraints` desynchronized at some point. This is problematic + // because we rely on this count being correct to skip over blocked + // constraints. + LUAU_ASSERT(count > 0); + count -= 1; + } + + blocked.erase(it); +} + +void ConstraintSolver::unblock(NotNull progressed) +{ + return unblock_(progressed); +} + +void ConstraintSolver::unblock(TypeId progressed) +{ + return unblock_(progressed); +} + +void ConstraintSolver::unblock(TypePackId progressed) +{ + return unblock_(progressed); +} + +bool ConstraintSolver::isBlocked(TypeId ty) +{ + return nullptr != get(follow(ty)); +} + +bool ConstraintSolver::isBlocked(NotNull constraint) +{ + auto blockedIt = blockedConstraints.find(constraint); + return blockedIt != blockedConstraints.end() && blockedIt->second > 0; +} + +void ConstraintSolver::unify(TypeId subType, TypeId superType) +{ + UnifierSharedState sharedState{&iceReporter}; + Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + + u.tryUnify(subType, superType); + u.log.commit(); +} + +void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack) +{ + UnifierSharedState sharedState{&iceReporter}; + Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + + u.tryUnify(subPack, superPack); + u.log.commit(); +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintSolverLogger.cpp b/Analysis/src/ConstraintSolverLogger.cpp new file mode 100644 index 00000000..2f93c280 --- /dev/null +++ b/Analysis/src/ConstraintSolverLogger.cpp @@ -0,0 +1,139 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ConstraintSolverLogger.h" + +namespace Luau +{ + +static std::string dumpScopeAndChildren(const Scope2* scope, ToStringOptions& opts) +{ + std::string output = "{\"bindings\":{"; + + bool comma = false; + for (const auto& [name, type] : scope->bindings) + { + if (comma) + output += ","; + + output += "\""; + output += name.c_str(); + output += "\": \""; + + ToStringResult result = toStringDetailed(type, opts); + opts.nameMap = std::move(result.nameMap); + output += result.name; + output += "\""; + + comma = true; + } + + output += "},\"children\":["; + comma = false; + + for (const Scope2* child : scope->children) + { + if (comma) + output += ","; + + output += dumpScopeAndChildren(child, opts); + comma = true; + } + + output += "]}"; + return output; +} + +static std::string dumpConstraintsToDot(std::vector>& constraints, ToStringOptions& opts) +{ + std::string result = "digraph Constraints {\\n"; + + std::unordered_set> contained; + for (NotNull c : constraints) + { + contained.insert(c); + } + + for (NotNull c : constraints) + { + std::string id = std::to_string(reinterpret_cast(c.get())); + result += id; + result += " [label=\\\""; + result += toString(*c, opts).c_str(); + result += "\\\"];\\n"; + + for (NotNull dep : c->dependencies) + { + if (contained.count(dep) == 0) + continue; + + result += std::to_string(reinterpret_cast(dep.get())); + result += " -> "; + result += id; + result += ";\\n"; + } + } + + result += "}"; + + return result; +} + +std::string ConstraintSolverLogger::compileOutput() +{ + std::string output = "["; + bool comma = false; + + for (const std::string& snapshot : snapshots) + { + if (comma) + output += ","; + output += snapshot; + + comma = true; + } + + output += "]"; + return output; +} + +void ConstraintSolverLogger::captureBoundarySnapshot(const Scope2* rootScope, std::vector>& unsolvedConstraints) +{ + std::string snapshot = "{\"type\":\"boundary\",\"rootScope\":"; + + snapshot += dumpScopeAndChildren(rootScope, opts); + snapshot += ",\"constraintGraph\":\""; + snapshot += dumpConstraintsToDot(unsolvedConstraints, opts); + snapshot += "\"}"; + + snapshots.push_back(std::move(snapshot)); +} + +void ConstraintSolverLogger::prepareStepSnapshot( + const Scope2* rootScope, NotNull current, std::vector>& unsolvedConstraints) +{ + // LUAU_ASSERT(!preparedSnapshot); + + std::string snapshot = "{\"type\":\"step\",\"rootScope\":"; + + snapshot += dumpScopeAndChildren(rootScope, opts); + snapshot += ",\"constraintGraph\":\""; + snapshot += dumpConstraintsToDot(unsolvedConstraints, opts); + snapshot += "\",\"currentId\":\""; + snapshot += std::to_string(reinterpret_cast(current.get())); + snapshot += "\",\"current\":\""; + snapshot += toString(*current, opts); + snapshot += "\"}"; + + preparedSnapshot = std::move(snapshot); +} + +void ConstraintSolverLogger::commitPreparedStepSnapshot() +{ + if (preparedSnapshot) + { + snapshots.push_back(std::move(*preparedSnapshot)); + preparedSnapshot = std::nullopt; + } +} + +} // namespace Luau diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index be3fcd7d..2407e3ef 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -7,7 +7,10 @@ namespace Luau static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC( declare bit32: { - -- band, bor, bxor, and btest are declared in C++ + band: (...number) -> number, + bor: (...number) -> number, + bxor: (...number) -> number, + btest: (number, ...number) -> boolean, rrotate: (number, number) -> number, lrotate: (number, number) -> number, lshift: (number, number) -> number, @@ -50,7 +53,8 @@ declare math: { asin: (number) -> number, atan2: (number, number) -> number, - -- min and max are declared in C++. + min: (number, ...number) -> number, + max: (number, ...number) -> number, pi: number, huge: number, @@ -143,7 +147,7 @@ declare coroutine: { create: ((A...) -> R...) -> thread, resume: (thread, A...) -> (boolean, R...), running: () -> thread, - status: (thread) -> string, + status: (thread) -> "dead" | "running" | "normal" | "suspended", -- FIXME: This technically returns a function, but we can't represent this yet. wrap: ((A...) -> R...) -> any, yield: (A...) -> R..., @@ -179,7 +183,7 @@ declare debug: { } declare utf8: { - char: (number, ...number) -> string, + char: (...number) -> string, charpattern: string, codes: (string) -> ((string, number) -> (number, number), string, number), -- FIXME diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 210c0191..93cb65b9 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -1,14 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Error.h" -#include "Luau/Module.h" +#include "Luau/Clone.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" #include -LUAU_FASTFLAGVARIABLE(BetterDiagnosticCodesInStudio, false); -LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleName, false); +LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false) +LUAU_FASTFLAGVARIABLE(LuauUseInternalCompilerErrorException, false) static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { @@ -52,6 +52,8 @@ namespace Luau struct ErrorConverter { + FileResolver* fileResolver = nullptr; + std::string operator()(const Luau::TypeMismatch& tm) const { std::string givenTypeName = Luau::toString(tm.givenType); @@ -59,27 +61,30 @@ struct ErrorConverter std::string result; - if (FFlag::LuauTypeMismatchModuleName) + if (givenTypeName == wantedTypeName) { - if (givenTypeName == wantedTypeName) + if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) { - if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) + if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) { - if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) + if (FFlag::LuauTypeMismatchModuleNameResolution && fileResolver != nullptr) + { + std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); + std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); + result = "Type '" + givenTypeName + "' from '" + givenModuleName + "' could not be converted into '" + wantedTypeName + + "' from '" + wantedModuleName + "'"; + } + else { result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + "' from '" + *wantedDefinitionModule + "'"; } } } + } - if (result.empty()) - result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; - } - else - { + if (result.empty()) result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; - } if (tm.error) { @@ -88,7 +93,14 @@ struct ErrorConverter if (!tm.reason.empty()) result += tm.reason + " "; - result += Luau::toString(*tm.error); + if (FFlag::LuauTypeMismatchModuleNameResolution) + { + result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); + } + else + { + result += Luau::toString(*tm.error); + } } else if (!tm.reason.empty()) { @@ -187,15 +199,7 @@ struct ErrorConverter std::string operator()(const Luau::FunctionRequiresSelf& e) const { - if (e.requiredExtraNils) - { - const char* plural = e.requiredExtraNils == 1 ? "" : "s"; - return format("This function was declared to accept self, but you did not pass enough arguments. Use a colon instead of a dot or " - "pass %i extra nil%s to suppress this warning", - e.requiredExtraNils, plural); - } - else - return "This function must be called with self. Did you mean to use a colon instead of a dot?"; + return "This function must be called with self. Did you mean to use a colon instead of a dot?"; } std::string operator()(const Luau::OccursCheckFailed&) const @@ -251,14 +255,7 @@ struct ErrorConverter std::string operator()(const Luau::SyntaxError& e) const { - if (FFlag::BetterDiagnosticCodesInStudio) - { - return e.message; - } - else - { - return "Syntax error: " + e.message; - } + return e.message; } std::string operator()(const Luau::CodeTooComplex&) const @@ -305,6 +302,11 @@ struct ErrorConverter return e.message; } + std::string operator()(const Luau::InternalError& e) const + { + return e.message; + } + std::string operator()(const Luau::CannotCallNonFunction& e) const { return "Cannot call non-function " + toString(e.ty); @@ -450,6 +452,11 @@ struct ErrorConverter { return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated"; } + + std::string operator()(const NormalizationTooComplex&) const + { + return "Code is too complex to typecheck! Consider simplifying the code around this area"; + } }; struct InvalidNameChecker @@ -550,7 +557,7 @@ bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const bool FunctionRequiresSelf::operator==(const FunctionRequiresSelf& e) const { - return requiredExtraNils == e.requiredExtraNils; + return true; } bool OccursCheckFailed::operator==(const OccursCheckFailed&) const @@ -618,6 +625,11 @@ bool GenericError::operator==(const GenericError& rhs) const return message == rhs.message; } +bool InternalError::operator==(const InternalError& rhs) const +{ + return message == rhs.message; +} + bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const { return ty == rhs.ty; @@ -705,7 +717,12 @@ bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const std::string toString(const TypeError& error) { - ErrorConverter converter; + return toString(error, TypeErrorToStringOptions{}); +} + +std::string toString(const TypeError& error, TypeErrorToStringOptions options) +{ + ErrorConverter converter{options.fileResolver}; return Luau::visit(converter, error.data); } @@ -715,14 +732,14 @@ bool containsParseErrorName(const TypeError& error) } template -void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState cloneState) +void copyError(T& e, TypeArena& destArena, CloneState cloneState) { auto clone = [&](auto&& ty) { - return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks, cloneState); + return ::Luau::clone(ty, destArena, cloneState); }; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks, cloneState); + copyError(e, destArena, cloneState); }; if constexpr (false) @@ -793,6 +810,9 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& else if constexpr (std::is_same_v) { } + else if constexpr (std::is_same_v) + { + } else if constexpr (std::is_same_v) { e.ty = clone(e.ty); @@ -843,18 +863,19 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& e.left = clone(e.left); e.right = clone(e.right); } + else if constexpr (std::is_same_v) + { + } else static_assert(always_false_v, "Non-exhaustive type switch"); } void copyErrors(ErrorVec& errors, TypeArena& destArena) { - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks, cloneState); + copyError(e, destArena, cloneState); }; LUAU_ASSERT(!destArena.typeVars.isFrozen()); @@ -866,22 +887,51 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena) void InternalErrorReporter::ice(const std::string& message, const Location& location) { - std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message); + if (FFlag::LuauUseInternalCompilerErrorException) + { + InternalCompilerError error(message, moduleName, location); - if (onInternalError) - onInternalError(error.what()); + if (onInternalError) + onInternalError(error.what()); - throw error; + throw error; + } + else + { + std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message); + + if (onInternalError) + onInternalError(error.what()); + + throw error; + } } void InternalErrorReporter::ice(const std::string& message) { - std::runtime_error error("Internal error in " + moduleName + ": " + message); + if (FFlag::LuauUseInternalCompilerErrorException) + { + InternalCompilerError error(message, moduleName); - if (onInternalError) - onInternalError(error.what()); + if (onInternalError) + onInternalError(error.what()); - throw error; + throw error; + } + else + { + std::runtime_error error("Internal error in " + moduleName + ": " + message); + + if (onInternalError) + onInternalError(error.what()); + + throw error; + } +} + +const char* InternalCompilerError::what() const throw() +{ + return this->message.data(); } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index d8906f6e..85c5dbc8 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,23 +1,31 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" +#include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Config.h" +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintSolver.h" #include "Luau/FileResolver.h" #include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" -#include "Luau/Common.h" #include #include #include +LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) +LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) +LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) namespace Luau { @@ -93,13 +101,11 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, checkedModule}; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; @@ -109,7 +115,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; @@ -211,7 +217,7 @@ ErrorVec accumulateErrors( continue; const SourceNode& sourceNode = it->second; - queue.insert(queue.end(), sourceNode.requires.begin(), sourceNode.requires.end()); + queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end()); // FIXME: If a module has a syntax error, we won't be able to re-report it here. // The solution is probably to move errors from Module to SourceNode @@ -234,12 +240,6 @@ ErrorVec accumulateErrors( return result; } -struct RequireCycle -{ - Location location; - std::vector path; // one of the paths for a require() to go all the way back to the originating module -}; - // Given a source node (start), find all requires that start a transitive dependency path that ends back at start // For each such path, record the full path and the location of the require in the starting module. // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) @@ -356,33 +356,44 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond.dirty) + if (it != sourceNodes.end() && !it->second.hasDirtyModule(frontendOptions.forAutocomplete)) { // No recheck required. - auto it2 = moduleResolver.modules.find(name); - if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); + if (frontendOptions.forAutocomplete) + { + auto it2 = moduleResolverForAutocomplete.modules.find(name); + if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) + throw std::runtime_error("Frontend::modules does not have data for " + name); + } + else + { + auto it2 = moduleResolver.modules.find(name); + if (it2 == moduleResolver.modules.end() || it2->second == nullptr) + throw std::runtime_error("Frontend::modules does not have data for " + name); + } - return CheckResult{accumulateErrors(sourceNodes, moduleResolver.modules, name)}; + return CheckResult{ + accumulateErrors(sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; } std::vector buildQueue; - bool cycleDetected = parseGraph(buildQueue, checkResult, name); - - FrontendOptions frontendOptions = optionOverride.value_or(options); + bool cycleDetected = parseGraph(buildQueue, checkResult, name, frontendOptions.forAutocomplete); // Keep track of which AST nodes we've reported cycles in std::unordered_set reportedCycles; + double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; + for (const ModuleName& moduleName : buildQueue) { LUAU_ASSERT(sourceNodes.count(moduleName)); SourceNode& sourceNode = sourceNodes[moduleName]; - if (!sourceNode.dirty) + if (!sourceNode.hasDirtyModule(frontendOptions.forAutocomplete)) continue; LUAU_ASSERT(sourceModules.count(moduleName)); @@ -408,17 +419,64 @@ CheckResult Frontend::check(const ModuleName& name, std::optional 0) + typeCheckerForAutocomplete.instantiationChildLimit = + std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckerForAutocomplete.unifierIterationLimit = + std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; + } + ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; + + double duration = getTimestamp() - timestamp; + + if (moduleForAutocomplete->timeout) + { + checkResult.timeoutHits.push_back(moduleName); + + if (FFlag::LuauAutocompleteDynamicLimits) + sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + } + else if (FFlag::LuauAutocompleteDynamicLimits && duration < autocompleteTimeLimit / 2.0) + { + sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); + } + + stats.timeCheck += duration; + stats.filesStrict += 1; + + sourceNode.dirtyModuleForAutocomplete = false; + continue; } + typeChecker.requireCycles = requireCycles; + + ModulePtr module = FFlag::DebugLuauDeferredConstraintResolution ? check(sourceModule, mode, environmentScope) + : typeChecker.check(sourceModule, mode, environmentScope); + stats.timeCheck += getTimestamp() - timestamp; stats.filesStrict += mode == Mode::Strict; stats.filesNonstrict += mode == Mode::Nonstrict; @@ -461,13 +519,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalerrors.begin(), module->errors.end()); moduleResolver.modules[moduleName] = std::move(module); - sourceNode.dirty = false; + sourceNode.dirtyModule = false; } return checkResult; } -bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root) +bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete) { LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); @@ -529,7 +587,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec path.push_back(top); // push children - for (const ModuleName& dep : top->requires) + for (const ModuleName& dep : top->requireSet) { auto it = sourceNodes.find(dep); if (it != sourceNodes.end()) @@ -538,7 +596,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec // this relies on the fact that markDirty marks reverse-dependencies dirty as well // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // to be built, *and* can't form a cycle with any nodes we did process. - if (!it->second.dirty) + if (!it->second.hasDirtyModule(forAutocomplete)) continue; // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization @@ -625,30 +683,6 @@ std::pair Frontend::lintFragment(std::string_view sour return {std::move(sourceModule), classifyLints(warnings, config)}; } -CheckResult Frontend::check(const SourceModule& module) -{ - LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); - - const Config& config = configResolver->getConfig(module.name); - - Mode mode = module.mode.value_or(config.mode); - - double timestamp = getTimestamp(); - - ModulePtr checkedModule = typeChecker.check(module, mode); - - stats.timeCheck += getTimestamp() - timestamp; - stats.filesStrict += mode == Mode::Strict; - stats.filesNonstrict += mode == Mode::Nonstrict; - - if (checkedModule == nullptr) - throw std::runtime_error("Frontend::check produced a nullptr module for module " + module.name); - moduleResolver.modules[module.name] = checkedModule; - - return CheckResult{checkedModule->errors}; -} - LintResult Frontend::lint(const SourceModule& module, std::optional enabledLintWarnings) { LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); @@ -685,10 +719,10 @@ LintResult Frontend::lint(const SourceModule& module, std::optionalsecond.dirty; + return it == sourceNodes.end() || it->second.hasDirtyModule(forAutocomplete); } /* @@ -699,13 +733,13 @@ bool Frontend::isDirty(const ModuleName& name) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { - if (!moduleResolver.modules.count(name)) + if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name)) return; std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) { - for (const auto& dep : module.second.requires) + for (const auto& dep : module.second.requireSet) reverseDeps[dep].push_back(module.first); } @@ -722,10 +756,12 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked if (markedDirty) markedDirty->push_back(next); - if (sourceNode.dirty) + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) continue; - sourceNode.dirty = true; + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; if (0 == reverseDeps.count(name)) continue; @@ -751,6 +787,30 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons return const_cast(this)->getSourceModule(moduleName); } +ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope) +{ + ModulePtr result = std::make_shared(); + + ConstraintGraphBuilder cgb{&result->internalTypes}; + cgb.visit(sourceModule.root); + + ConstraintSolver cs{&result->internalTypes, cgb.rootScope}; + cs.run(); + + result->scope2s = std::move(cgb.scopes); + result->astTypes = std::move(cgb.astTypes); + result->astTypePacks = std::move(cgb.astTypePacks); + result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes); + result->astResolvedTypes = std::move(cgb.astResolvedTypes); + result->astResolvedTypePacks = std::move(cgb.astResolvedTypePacks); + + result->clonePublicInterface(iceHandler); + + Luau::check(sourceModule, result.get()); + + return result; +} + // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) { @@ -758,7 +818,7 @@ std::pair Frontend::getSourceNode(CheckResult& check LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.dirty) + if (it != sourceNodes.end() && !it->second.hasDirtySourceModule()) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) @@ -789,8 +849,8 @@ std::pair Frontend::getSourceNode(CheckResult& check SourceModule result = parse(name, source->source, opts); result.type = source->type; - RequireTraceResult& requireTrace = requires[name]; - requireTrace = traceRequires(fileResolver, result.root, name); + RequireTraceResult& require = requireTrace[name]; + require = traceRequires(fileResolver, result.root, name); SourceNode& sourceNode = sourceNodes[name]; SourceModule& sourceModule = sourceModules[name]; @@ -799,14 +859,20 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceModule.environmentName = environmentName; sourceNode.name = name; - sourceNode.requires.clear(); + sourceNode.requireSet.clear(); sourceNode.requireLocations.clear(); - sourceNode.dirty = true; + sourceNode.dirtySourceModule = false; - for (const auto& [moduleName, location] : requireTrace.requires) - sourceNode.requires.insert(moduleName); + if (it == sourceNodes.end()) + { + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; + } - sourceNode.requireLocations = requireTrace.requires; + for (const auto& [moduleName, location] : require.requireList) + sourceNode.requireSet.insert(moduleName); + + sourceNode.requireLocations = require.requireList; return {&sourceNode, &sourceModule}; } @@ -867,8 +933,8 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const std::optional FrontendModuleResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) { // FIXME I think this can be pushed into the FileResolver. - auto it = frontend->requires.find(currentModuleName); - if (it == frontend->requires.end()) + auto it = frontend->requireTrace.find(currentModuleName); + if (it == frontend->requireTrace.end()) { // CLI-43699 // If we can't find the current module name, that's because we bypassed the frontend's initializer @@ -967,7 +1033,7 @@ void Frontend::clear() sourceModules.clear(); moduleResolver.modules.clear(); moduleResolverForAutocomplete.modules.clear(); - requires.clear(); + requireTrace.clear(); } } // namespace Luau diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp new file mode 100644 index 00000000..77c62422 --- /dev/null +++ b/Analysis/src/Instantiation.cpp @@ -0,0 +1,124 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" +#include "Luau/Instantiation.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeArena.h" + +namespace Luau +{ + +bool Instantiation::isDirty(TypeId ty) +{ + if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (ftv->hasNoGenerics) + return false; + + return true; + } + else + { + return false; + } +} + +bool Instantiation::isDirty(TypePackId tp) +{ + return false; +} + +bool Instantiation::ignoreChildren(TypeId ty) +{ + if (log->getMutable(ty)) + return true; + else + return false; +} + +TypeId Instantiation::clean(TypeId ty) +{ + const FunctionTypeVar* ftv = log->getMutable(ty); + LUAU_ASSERT(ftv); + + FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; + clone.magicFunction = ftv->magicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + TypeId result = addType(std::move(clone)); + + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + ReplaceGenerics replaceGenerics{log, arena, level, ftv->generics, ftv->genericPacks}; + + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = replaceGenerics.substitute(result).value_or(result); + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; +} + +TypePackId Instantiation::clean(TypePackId tp) +{ + LUAU_ASSERT(false); + return tp; +} + +bool ReplaceGenerics::ignoreChildren(TypeId ty) +{ + if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (ftv->hasNoGenerics) + return true; + + // We aren't recursing in the case of a generic function which + // binds the same generics. This can happen if, for example, there's recursive types. + // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. + // It's OK to use vector equality here, since we always generate fresh generics + // whenever we quantify, so the vectors overlap if and only if they are equal. + return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); + } + else + { + return false; + } +} + +bool ReplaceGenerics::isDirty(TypeId ty) +{ + if (const TableTypeVar* ttv = log->getMutable(ty)) + return ttv->state == TableState::Generic; + else if (log->getMutable(ty)) + return std::find(generics.begin(), generics.end(), ty) != generics.end(); + else + return false; +} + +bool ReplaceGenerics::isDirty(TypePackId tp) +{ + if (log->getMutable(tp)) + return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); + else + return false; +} + +TypeId ReplaceGenerics::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + if (const TableTypeVar* ttv = log->getMutable(ty)) + { + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; + clone.definitionModuleName = ttv->definitionModuleName; + return addType(std::move(clone)); + } + else + return addType(FreeTypeVar{level}); +} + +TypePackId ReplaceGenerics::clean(TypePackId tp) +{ + LUAU_ASSERT(isDirty(tp)); + return addTypePack(TypePackVar(FreeTypePack{level})); +} + +} // namespace Luau diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 19c2ddab..e4fac455 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -23,9 +23,182 @@ std::ostream& operator<<(std::ostream& stream, const AstName& name) return stream << ""; } -std::ostream& operator<<(std::ostream& stream, const TypeMismatch& tm) +template +static void errorToString(std::ostream& stream, const T& err) { - return stream << "TypeMismatch { " << toString(tm.wantedType) << ", " << toString(tm.givenType) << " }"; + if constexpr (false) + { + } + else if constexpr (std::is_same_v) + stream << "TypeMismatch { " << toString(err.wantedType) << ", " << toString(err.givenType) << " }"; + else if constexpr (std::is_same_v) + stream << "UnknownSymbol { " << err.name << " , context " << err.context << " }"; + else if constexpr (std::is_same_v) + stream << "UnknownProperty { " << toString(err.table) << ", key = " << err.key << " }"; + else if constexpr (std::is_same_v) + stream << "NotATable { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "CannotExtendTable { " << toString(err.tableType) << ", context " << err.context << ", prop \"" << err.prop << "\" }"; + else if constexpr (std::is_same_v) + stream << "OnlyTablesCanHaveMethods { " << toString(err.tableType) << " }"; + else if constexpr (std::is_same_v) + stream << "DuplicateTypeDefinition { " << err.name << " }"; + else if constexpr (std::is_same_v) + stream << "CountMismatch { expected " << err.expected << ", got " << err.actual << ", context " << err.context << " }"; + else if constexpr (std::is_same_v) + stream << "FunctionDoesNotTakeSelf { }"; + else if constexpr (std::is_same_v) + stream << "FunctionRequiresSelf { }"; + else if constexpr (std::is_same_v) + stream << "OccursCheckFailed { }"; + else if constexpr (std::is_same_v) + stream << "UnknownRequire { " << err.modulePath << " }"; + else if constexpr (std::is_same_v) + { + stream << "IncorrectGenericParameterCount { name = " << err.name; + + if (!err.typeFun.typeParams.empty() || !err.typeFun.typePackParams.empty()) + { + stream << "<"; + bool first = true; + for (auto param : err.typeFun.typeParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(param.ty); + } + + for (auto param : err.typeFun.typePackParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(param.tp); + } + + stream << ">"; + } + + stream << ", typeFun = " << toString(err.typeFun.type) << ", actualCount = " << err.actualParameters << " }"; + } + else if constexpr (std::is_same_v) + stream << "SyntaxError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "CodeTooComplex {}"; + else if constexpr (std::is_same_v) + stream << "UnificationTooComplex {}"; + else if constexpr (std::is_same_v) + { + stream << "UnknownPropButFoundLikeProp { key = '" << err.key << "', suggested = { "; + + bool first = true; + for (Name name : err.candidates) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << name << "'"; + } + + stream << " }, table = " << toString(err.table) << " } "; + } + else if constexpr (std::is_same_v) + stream << "GenericError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "InternalError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "CannotCallNonFunction { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "ExtraInformation { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "DeprecatedApiUsed { " << err.symbol << ", useInstead = " << err.useInstead << " }"; + else if constexpr (std::is_same_v) + { + stream << "ModuleHasCyclicDependency {"; + + bool first = true; + for (const ModuleName& name : err.cycle) + { + if (first) + first = false; + else + stream << ", "; + + stream << name; + } + + stream << "}"; + } + else if constexpr (std::is_same_v) + stream << "IllegalRequire { " << err.moduleName << ", reason = " << err.reason << " }"; + else if constexpr (std::is_same_v) + stream << "FunctionExitsWithoutReturning {" << toString(err.expectedReturnType) << "}"; + else if constexpr (std::is_same_v) + stream << "DuplicateGenericParameter { " + err.parameterName + " }"; + else if constexpr (std::is_same_v) + stream << "CannotInferBinaryOperation { op = " + toString(err.op) + ", suggested = '" + + (err.suggestedToAnnotate ? *err.suggestedToAnnotate : "") + "', kind " + << err.kind << "}"; + else if constexpr (std::is_same_v) + { + stream << "MissingProperties { superType = '" << toString(err.superType) << "', subType = '" << toString(err.subType) << "', properties = { "; + + bool first = true; + for (Name name : err.properties) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << name << "'"; + } + + stream << " }, context " << err.context << " } "; + } + else if constexpr (std::is_same_v) + stream << "SwappedGenericTypeParameter { name = '" + err.name + "', kind = " + std::to_string(err.kind) + " }"; + else if constexpr (std::is_same_v) + stream << "OptionalValueAccess { optional = '" + toString(err.optional) + "' }"; + else if constexpr (std::is_same_v) + { + stream << "MissingUnionProperty { type = '" + toString(err.type) + "', missing = { "; + + bool first = true; + for (auto ty : err.missing) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << toString(ty) << "'"; + } + + stream << " }, key = '" + err.key + "' }"; + } + else if constexpr (std::is_same_v) + stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; + else if constexpr (std::is_same_v) + stream << "NormalizationTooComplex { }"; + else + static_assert(always_false_v, "Non-exhaustive type switch"); +} + +std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data) +{ + auto cb = [&](const auto& e) { + return errorToString(stream, e); + }; + visit(cb, data); + return stream; } std::ostream& operator<<(std::ostream& stream, const TypeError& error) @@ -33,241 +206,6 @@ std::ostream& operator<<(std::ostream& stream, const TypeError& error) return stream << "TypeError { \"" << error.moduleName << "\", " << error.location << ", " << error.data << " }"; } -std::ostream& operator<<(std::ostream& stream, const UnknownSymbol& error) -{ - return stream << "UnknownSymbol { " << error.name << " , context " << error.context << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownProperty& error) -{ - return stream << "UnknownProperty { " << toString(error.table) << ", key = " << error.key << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const NotATable& ge) -{ - return stream << "NotATable { " << toString(ge.ty) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotExtendTable& error) -{ - return stream << "CannotExtendTable { " << toString(error.tableType) << ", context " << error.context << ", prop \"" << error.prop << "\" }"; -} - -std::ostream& operator<<(std::ostream& stream, const OnlyTablesCanHaveMethods& error) -{ - return stream << "OnlyTablesCanHaveMethods { " << toString(error.tableType) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const DuplicateTypeDefinition& error) -{ - return stream << "DuplicateTypeDefinition { " << error.name << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CountMismatch& error) -{ - return stream << "CountMismatch { expected " << error.expected << ", got " << error.actual << ", context " << error.context << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionDoesNotTakeSelf&) -{ - return stream << "FunctionDoesNotTakeSelf { }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionRequiresSelf& error) -{ - return stream << "FunctionRequiresSelf { extraNils " << error.requiredExtraNils << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const OccursCheckFailed&) -{ - return stream << "OccursCheckFailed { }"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownRequire& error) -{ - return stream << "UnknownRequire { " << error.modulePath << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCount& error) -{ - stream << "IncorrectGenericParameterCount { name = " << error.name; - - if (!error.typeFun.typeParams.empty() || !error.typeFun.typePackParams.empty()) - { - stream << "<"; - bool first = true; - for (auto param : error.typeFun.typeParams) - { - if (first) - first = false; - else - stream << ", "; - - stream << toString(param.ty); - } - - for (auto param : error.typeFun.typePackParams) - { - if (first) - first = false; - else - stream << ", "; - - stream << toString(param.tp); - } - - stream << ">"; - } - - stream << ", typeFun = " << toString(error.typeFun.type) << ", actualCount = " << error.actualParameters << " }"; - return stream; -} - -std::ostream& operator<<(std::ostream& stream, const SyntaxError& ge) -{ - return stream << "SyntaxError { " << ge.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CodeTooComplex&) -{ - return stream << "CodeTooComplex {}"; -} - -std::ostream& operator<<(std::ostream& stream, const UnificationTooComplex&) -{ - return stream << "UnificationTooComplex {}"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownPropButFoundLikeProp& e) -{ - stream << "UnknownPropButFoundLikeProp { key = '" << e.key << "', suggested = { "; - - bool first = true; - for (Name name : e.candidates) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << name << "'"; - } - - return stream << " }, table = " << toString(e.table) << " } "; -} - -std::ostream& operator<<(std::ostream& stream, const GenericError& ge) -{ - return stream << "GenericError { " << ge.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotCallNonFunction& e) -{ - return stream << "CannotCallNonFunction { " << toString(e.ty) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionExitsWithoutReturning& error) -{ - return stream << "FunctionExitsWithoutReturning {" << toString(error.expectedReturnType) << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const ExtraInformation& e) -{ - return stream << "ExtraInformation { " << e.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const DeprecatedApiUsed& e) -{ - return stream << "DeprecatedApiUsed { " << e.symbol << ", useInstead = " << e.useInstead << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const ModuleHasCyclicDependency& e) -{ - stream << "ModuleHasCyclicDependency {"; - - bool first = true; - for (const ModuleName& name : e.cycle) - { - if (first) - first = false; - else - stream << ", "; - - stream << name; - } - - return stream << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const IllegalRequire& e) -{ - return stream << "IllegalRequire { " << e.moduleName << ", reason = " << e.reason << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const MissingProperties& e) -{ - stream << "MissingProperties { superType = '" << toString(e.superType) << "', subType = '" << toString(e.subType) << "', properties = { "; - - bool first = true; - for (Name name : e.properties) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << name << "'"; - } - - return stream << " }, context " << e.context << " } "; -} - -std::ostream& operator<<(std::ostream& stream, const DuplicateGenericParameter& error) -{ - return stream << "DuplicateGenericParameter { " + error.parameterName + " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotInferBinaryOperation& error) -{ - return stream << "CannotInferBinaryOperation { op = " + toString(error.op) + ", suggested = '" + - (error.suggestedToAnnotate ? *error.suggestedToAnnotate : "") + "', kind " - << error.kind << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const SwappedGenericTypeParameter& error) -{ - return stream << "SwappedGenericTypeParameter { name = '" + error.name + "', kind = " + std::to_string(error.kind) + " }"; -} - -std::ostream& operator<<(std::ostream& stream, const OptionalValueAccess& error) -{ - return stream << "OptionalValueAccess { optional = '" + toString(error.optional) + "' }"; -} - -std::ostream& operator<<(std::ostream& stream, const MissingUnionProperty& error) -{ - stream << "MissingUnionProperty { type = '" + toString(error.type) + "', missing = { "; - - bool first = true; - for (auto ty : error.missing) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << toString(ty) << "'"; - } - - return stream << " }, key = '" + error.key + "' }"; -} - -std::ostream& operator<<(std::ostream& stream, const TypesAreUnrelated& error) -{ - stream << "TypesAreUnrelated { left = '" + toString(error.left) + "', right = '" + toString(error.right) + "' }"; - return stream; -} - std::ostream& operator<<(std::ostream& stream, const TableState& tv) { return stream << static_cast::type>(tv); @@ -283,15 +221,4 @@ std::ostream& operator<<(std::ostream& stream, const TypePackVar& tv) return stream << toString(tv); } -std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted) -{ - Luau::visit( - [&](const auto& a) { - lhs << a; - }, - ted); - - return lhs; -} - } // namespace Luau diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 811e7c24..829ffa02 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -403,35 +403,26 @@ struct AstJsonEncoder : public AstVisitor void write(const AstExprTable::Item& item) { writeRaw("{"); - bool comma = pushComma(); + bool c = pushComma(); write("kind", item.kind); switch (item.kind) { case AstExprTable::Item::List: - write(item.value); + write("value", item.value); break; default: - write(item.key); - writeRaw(","); - write(item.value); + write("key", item.key); + write("value", item.value); break; } - popComma(comma); + popComma(c); writeRaw("}"); } void write(class AstExprTable* node) { writeNode(node, "AstExprTable", [&]() { - bool comma = false; - for (const auto& prop : node->items) - { - if (comma) - writeRaw(","); - else - comma = true; - write(prop); - } + PROP(items); }); } diff --git a/Analysis/src/LValue.cpp b/Analysis/src/LValue.cpp index c9466a40..38dfe1ae 100644 --- a/Analysis/src/LValue.cpp +++ b/Analysis/src/LValue.cpp @@ -77,19 +77,15 @@ std::optional tryGetLValue(const AstExpr& node) return std::nullopt; } -std::pair> getFullName(const LValue& lvalue) +Symbol getBaseSymbol(const LValue& lvalue) { const LValue* current = &lvalue; - std::vector keys; while (auto field = get(*current)) - { - keys.push_back(field->key); current = baseof(*current); - } const Symbol* symbol = get(*current); LUAU_ASSERT(symbol); - return {*symbol, std::vector(keys.rbegin(), keys.rend())}; + return *symbol; } void merge(RefinementMap& l, const RefinementMap& r, std::function f) diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index b7480e34..50868e56 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,7 +14,6 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) -LUAU_FASTFLAGVARIABLE(LuauLintNoRobloxBits, false) namespace Luau { @@ -1140,25 +1139,8 @@ private: Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata. Kind_Vector, // 'vector' but only used when type is used Kind_Userdata, // custom userdata type - - // TODO: remove these with LuauLintNoRobloxBits - Kind_Class, // custom userdata type that reflects Roblox Instance-derived hierarchy - Part/etc. - Kind_Enum, // custom userdata type referring to an enum item of enum classes, e.g. Enum.NormalId.Back/Enum.Axis.X/etc. }; - bool containsPropName(TypeId ty, const std::string& propName) - { - LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); - - if (auto ctv = get(ty)) - return lookupClassProp(ctv, propName) != nullptr; - - if (auto ttv = get(ty)) - return ttv->props.find(propName) != ttv->props.end(); - - return false; - } - TypeKind getTypeKind(const std::string& name) { if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" || @@ -1168,23 +1150,10 @@ private: if (name == "vector") return Kind_Vector; - if (FFlag::LuauLintNoRobloxBits) - { - if (std::optional maybeTy = context->scope->lookupType(name)) - return Kind_Userdata; + if (std::optional maybeTy = context->scope->lookupType(name)) + return Kind_Userdata; - return Kind_Unknown; - } - else - { - if (std::optional maybeTy = context->scope->lookupType(name)) - // Kind_Userdata is probably not 100% precise but is close enough - return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata; - else if (std::optional maybeTy = context->scope->lookupImportedType("Enum", name)) - return Kind_Enum; - - return Kind_Unknown; - } + return Kind_Unknown; } void validateType(AstExprConstantString* expr, std::initializer_list expected, const char* expectedString) @@ -1202,67 +1171,11 @@ private: { if (kind == ek) return; - - // as a special case, Instance and EnumItem are both a userdata type (as returned by typeof) and a class type - if (!FFlag::LuauLintNoRobloxBits && ek == Kind_Userdata && (name == "Instance" || name == "EnumItem")) - return; } emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s' (expected %s)", name.c_str(), expectedString); } - bool acceptsClassName(AstName method) - { - LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); - - return method.value[0] == 'F' && (method == "FindFirstChildOfClass" || method == "FindFirstChildWhichIsA" || - method == "FindFirstAncestorOfClass" || method == "FindFirstAncestorWhichIsA"); - } - - bool visit(AstExprCall* node) override - { - // TODO: Simply remove the override - if (FFlag::LuauLintNoRobloxBits) - return true; - - if (AstExprIndexName* index = node->func->as()) - { - AstExprConstantString* arg0 = node->args.size > 0 ? node->args.data[0]->as() : NULL; - - if (arg0) - { - if (node->self && index->index == "IsA" && node->args.size == 1) - { - validateType(arg0, {Kind_Class, Kind_Enum}, "class or enum type"); - } - else if (node->self && (index->index == "GetService" || index->index == "FindService") && node->args.size == 1) - { - AstExprGlobal* g = index->expr->as(); - - if (g && (g->name == "game" || g->name == "Game")) - { - validateType(arg0, {Kind_Class}, "class type"); - } - } - else if (node->self && acceptsClassName(index->index) && node->args.size == 1) - { - validateType(arg0, {Kind_Class}, "class type"); - } - else if (!node->self && index->index == "new" && node->args.size <= 2) - { - AstExprGlobal* g = index->expr->as(); - - if (g && g->name == "Instance") - { - validateType(arg0, {Kind_Class}, "class type"); - } - } - } - } - - return true; - } - bool visit(AstExprBinary* node) override { if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq) @@ -2369,7 +2282,7 @@ private: size_t getReturnCount(TypeId ty) { if (auto ftv = get(ty)) - return size(ftv->retType); + return size(ftv->retTypes); if (auto itv = get(ty)) { @@ -2378,7 +2291,7 @@ private: for (TypeId part : itv->parts) if (auto ftv = get(follow(part))) - result = std::max(result, size(ftv->retType)); + result = std::max(result, size(ftv->retTypes)); return result; } @@ -2740,12 +2653,12 @@ static void lintComments(LintContext& context, const std::vector& ho } else { - std::string::size_type space = hc.content.find_first_of(" \t"); + size_t space = hc.content.find_first_of(" \t"); std::string_view first = std::string_view(hc.content).substr(0, space); if (first == "nolint") { - std::string::size_type notspace = hc.content.find_first_not_of(" \t", space); + size_t notspace = hc.content.find_first_not_of(" \t", space); if (space == std::string::npos || notspace == std::string::npos) { @@ -2914,7 +2827,7 @@ uint64_t LintWarning::parseMask(const std::vector& hotcomments) if (hc.content.compare(0, 6, "nolint") != 0) continue; - std::string::size_type name = hc.content.find_first_not_of(" \t", 6); + size_t name = hc.content.find_first_not_of(" \t", 6); // --!nolint disables everything if (name == std::string::npos) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 0787d3a4..95eb125e 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,7 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/Normalize.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" @@ -11,9 +14,9 @@ #include -LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) -LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) +LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauNormalizeFlagIsConservative); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); namespace Luau { @@ -53,421 +56,121 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) return contains(pos, *iter); } -void TypeArena::clear() +struct ForceNormal : TypeVarOnceVisitor { - typeVars.clear(); - typePacks.clear(); -} + const TypeArena* typeArena = nullptr; -TypeId TypeArena::addTV(TypeVar&& tv) -{ - TypeId allocated = typeVars.allocate(std::move(tv)); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypeId TypeArena::freshType(TypeLevel level) -{ - TypeId allocated = typeVars.allocate(FreeTypeVar{level}); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(std::initializer_list types) -{ - TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(std::vector types) -{ - TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(TypePack tp) -{ - TypePackId allocated = typePacks.allocate(std::move(tp)); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(TypePackVar tp) -{ - TypePackId allocated = typePacks.allocate(std::move(tp)); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -namespace -{ - -struct TypePackCloner; - -/* - * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. - * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. - */ - -struct TypeCloner -{ - TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) - : dest(dest) - , typeId(typeId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) - , cloneState(cloneState) + ForceNormal(const TypeArena* typeArena) + : typeArena(typeArena) { } - TypeArena& dest; - TypeId typeId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; - - template - void defaultClone(const T& t); - - void operator()(const Unifiable::Free& t); - void operator()(const Unifiable::Generic& t); - void operator()(const Unifiable::Bound& t); - void operator()(const Unifiable::Error& t); - void operator()(const PrimitiveTypeVar& t); - void operator()(const SingletonTypeVar& t); - void operator()(const FunctionTypeVar& t); - void operator()(const TableTypeVar& t); - void operator()(const MetatableTypeVar& t); - void operator()(const ClassTypeVar& t); - void operator()(const AnyTypeVar& t); - void operator()(const UnionTypeVar& t); - void operator()(const IntersectionTypeVar& t); - void operator()(const LazyTypeVar& t); -}; - -struct TypePackCloner -{ - TypeArena& dest; - TypePackId typePackId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; - - TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) - : dest(dest) - , typePackId(typePackId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) - , cloneState(cloneState) + bool visit(TypeId ty) override { + if (ty->owningArena != typeArena) + return false; + + asMutable(ty)->normal = true; + return true; } - template - void defaultClone(const T& t) + bool visit(TypeId ty, const FreeTypeVar& ftv) override { - TypePackId cloned = dest.addTypePack(TypePackVar{t}); - seenTypePacks[typePackId] = cloned; + visit(ty); + return true; } - void operator()(const Unifiable::Free& t) + bool visit(TypePackId tp, const FreeTypePack& ftp) override { - cloneState.encounteredFreeType = true; - - TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); - TypePackId cloned = dest.addTypePack(*err); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const Unifiable::Generic& t) - { - defaultClone(t); - } - void operator()(const Unifiable::Error& t) - { - defaultClone(t); - } - - // While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. - // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. - void operator()(const Unifiable::Bound& t) - { - TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const VariadicTypePack& t) - { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}}); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const TypePack& t) - { - TypePackId cloned = dest.addTypePack(TypePack{}); - TypePack* destTp = getMutable(cloned); - LUAU_ASSERT(destTp != nullptr); - seenTypePacks[typePackId] = cloned; - - for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - - if (t.tail) - destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState); + return true; } }; -template -void TypeCloner::defaultClone(const T& t) +Module::~Module() { - TypeId cloned = dest.addType(t); - seenTypes[typeId] = cloned; + unfreeze(interfaceTypes); + unfreeze(internalTypes); } -void TypeCloner::operator()(const Unifiable::Free& t) +void Module::clonePublicInterface(InternalErrorReporter& ice) { - cloneState.encounteredFreeType = true; - TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); - TypeId cloned = dest.addType(*err); - seenTypes[typeId] = cloned; -} + LUAU_ASSERT(interfaceTypes.typeVars.empty()); + LUAU_ASSERT(interfaceTypes.typePacks.empty()); -void TypeCloner::operator()(const Unifiable::Generic& t) -{ - defaultClone(t); -} + CloneState cloneState; -void TypeCloner::operator()(const Unifiable::Bound& t) -{ - TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); - seenTypes[typeId] = boundTo; -} + ScopePtr moduleScope = FFlag::DebugLuauDeferredConstraintResolution ? nullptr : getModuleScope(); + Scope2* moduleScope2 = FFlag::DebugLuauDeferredConstraintResolution ? getModuleScope2() : nullptr; -void TypeCloner::operator()(const Unifiable::Error& t) -{ - defaultClone(t); -} + TypePackId returnType = FFlag::DebugLuauDeferredConstraintResolution ? moduleScope2->returnType : moduleScope->returnType; + std::optional varargPack = FFlag::DebugLuauDeferredConstraintResolution ? std::nullopt : moduleScope->varargPack; + std::unordered_map* exportedTypeBindings = + FFlag::DebugLuauDeferredConstraintResolution ? nullptr : &moduleScope->exportedTypeBindings; -void TypeCloner::operator()(const PrimitiveTypeVar& t) -{ - defaultClone(t); -} + returnType = clone(returnType, interfaceTypes, cloneState); -void TypeCloner::operator()(const SingletonTypeVar& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const FunctionTypeVar& t) -{ - TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); - FunctionTypeVar* ftv = getMutable(result); - LUAU_ASSERT(ftv != nullptr); - - seenTypes[typeId] = result; - - for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState)); - - for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState)); - - ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState); - ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState); -} - -void TypeCloner::operator()(const TableTypeVar& t) -{ - // If table is now bound to another one, we ignore the content of the original - if (t.boundTo) + if (moduleScope) { - TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState); - seenTypes[typeId] = boundTo; - return; + moduleScope->returnType = returnType; + if (varargPack) + { + varargPack = clone(*varargPack, interfaceTypes, cloneState); + moduleScope->varargPack = varargPack; + } + } + else + { + LUAU_ASSERT(moduleScope2); + moduleScope2->returnType = returnType; // TODO varargPack } - TypeId result = dest.addType(TableTypeVar{}); - TableTypeVar* ttv = getMutable(result); - LUAU_ASSERT(ttv != nullptr); - - *ttv = t; - - seenTypes[typeId] = result; - - ttv->level = TypeLevel{0, 0}; - - for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; - - if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState), - clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)}; - - for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); - - for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); - - if (ttv->state == TableState::Free) + if (FFlag::LuauLowerBoundsCalculation) { - cloneState.encounteredFreeType = true; - - ttv->state = TableState::Sealed; + normalize(returnType, interfaceTypes, ice); + if (varargPack) + normalize(*varargPack, interfaceTypes, ice); } - ttv->definitionModuleName = t.definitionModuleName; - ttv->methodDefinitionLocations = t.methodDefinitionLocations; - ttv->tags = t.tags; -} + ForceNormal forceNormal{&interfaceTypes}; -void TypeCloner::operator()(const MetatableTypeVar& t) -{ - TypeId result = dest.addType(MetatableTypeVar{}); - MetatableTypeVar* mtv = getMutable(result); - seenTypes[typeId] = result; - - mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState); - mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState); -} - -void TypeCloner::operator()(const ClassTypeVar& t) -{ - TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); - ClassTypeVar* ctv = getMutable(result); - - seenTypes[typeId] = result; - - for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; - - if (t.parent) - ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState); - - if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState); -} - -void TypeCloner::operator()(const AnyTypeVar& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const UnionTypeVar& t) -{ - std::vector options; - options.reserve(t.options.size()); - - for (TypeId ty : t.options) - options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - - TypeId result = dest.addType(UnionTypeVar{std::move(options)}); - seenTypes[typeId] = result; -} - -void TypeCloner::operator()(const IntersectionTypeVar& t) -{ - TypeId result = dest.addType(IntersectionTypeVar{}); - seenTypes[typeId] = result; - - IntersectionTypeVar* option = getMutable(result); - LUAU_ASSERT(option != nullptr); - - for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); -} - -void TypeCloner::operator()(const LazyTypeVar& t) -{ - defaultClone(t); -} - -} // anonymous namespace - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) -{ - if (tp->persistent) - return tp; - - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypePackId& res = seenTypePacks[tp]; - - if (res == nullptr) + if (exportedTypeBindings) { - TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState}; - Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. + for (auto& [name, tf] : *exportedTypeBindings) + { + tf = clone(tf, interfaceTypes, cloneState); + if (FFlag::LuauLowerBoundsCalculation) + { + normalize(tf.type, interfaceTypes, ice); + + if (FFlag::LuauNormalizeFlagIsConservative) + { + // We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables + // won't be marked normal. If the types aren't normal by now, they never will be. + forceNormal.traverse(tf.type); + } + } + } } - return res; -} - -TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) -{ - if (typeId->persistent) - return typeId; - - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypeId& res = seenTypes[typeId]; - - if (res == nullptr) + for (TypeId ty : returnType) { - TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; - Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - - // Persistent types are not being cloned and we get the original type back which might be read-only - if (!res->persistent) - asMutable(res)->documentationSymbol = typeId->documentationSymbol; + if (get(follow(ty))) + { + auto t = asMutable(ty); + t->ty = AnyTypeVar{}; + t->normal = true; + } } - return res; -} - -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) -{ - TypeFun result; - - for (auto param : typeFun.typeParams) + for (auto& [name, ty] : declaredGlobals) { - TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); - std::optional defaultValue; - - if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); - - result.typeParams.push_back({ty, defaultValue}); + ty = clone(ty, interfaceTypes, cloneState); + if (FFlag::LuauLowerBoundsCalculation) + normalize(ty, interfaceTypes, ice); } - for (auto param : typeFun.typePackParams) - { - TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); - std::optional defaultValue; - - if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); - - result.typePackParams.push_back({tp, defaultValue}); - } - - result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); - - return result; + freeze(internalTypes); + freeze(interfaceTypes); } ScopePtr Module::getModuleScope() const @@ -476,62 +179,10 @@ ScopePtr Module::getModuleScope() const return scopes.front().second; } -void freeze(TypeArena& arena) +Scope2* Module::getModuleScope2() const { - if (!FFlag::DebugLuauFreezeArena) - return; - - arena.typeVars.freeze(); - arena.typePacks.freeze(); -} - -void unfreeze(TypeArena& arena) -{ - if (!FFlag::DebugLuauFreezeArena) - return; - - arena.typeVars.unfreeze(); - arena.typePacks.unfreeze(); -} - -Module::~Module() -{ - unfreeze(interfaceTypes); - unfreeze(internalTypes); -} - -bool Module::clonePublicInterface() -{ - LUAU_ASSERT(interfaceTypes.typeVars.empty()); - LUAU_ASSERT(interfaceTypes.typePacks.empty()); - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - - ScopePtr moduleScope = getModuleScope(); - - moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, cloneState); - if (moduleScope->varargPack) - moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState); - - for (auto& [name, tf] : moduleScope->exportedTypeBindings) - tf = clone(tf, interfaceTypes, seenTypes, seenTypePacks, cloneState); - - for (TypeId ty : moduleScope->returnType) - if (get(follow(ty))) - *asMutable(ty) = AnyTypeVar{}; - - if (FFlag::LuauCloneDeclaredGlobals) - { - for (auto& [name, ty] : declaredGlobals) - ty = clone(ty, interfaceTypes, seenTypes, seenTypePacks, cloneState); - } - - freeze(internalTypes); - freeze(interfaceTypes); - - return cloneState.encounteredFreeType; + LUAU_ASSERT(!scope2s.empty()); + return scope2s.front().second.get(); } } // namespace Luau diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp new file mode 100644 index 00000000..d36665e2 --- /dev/null +++ b/Analysis/src/Normalize.cpp @@ -0,0 +1,859 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Normalize.h" + +#include + +#include "Luau/Clone.h" +#include "Luau/Substitution.h" +#include "Luau/Unifier.h" +#include "Luau/VisitTypeVar.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) + +// This could theoretically be 2000 on amd64, but x86 requires this. +LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); +LUAU_FASTFLAGVARIABLE(LuauReplaceReplacer, false); +LUAU_FASTFLAG(LuauQuantifyConstrained) + +namespace Luau +{ + +namespace +{ + +struct Replacer : Substitution +{ + TypeId sourceType; + TypeId replacedType; + DenseHashMap replacedTypes{nullptr}; + DenseHashMap replacedPacks{nullptr}; + + Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType) + : Substitution(TxnLog::empty(), arena) + , sourceType(sourceType) + , replacedType(replacedType) + { + } + + bool isDirty(TypeId ty) override + { + if (!sourceType) + return false; + + auto vecHasSourceType = [sourceType = sourceType](const auto& vec) { + return end(vec) != std::find(begin(vec), end(vec), sourceType); + }; + + // Walk every kind of TypeVar and find pointers to sourceType + if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return vecHasSourceType(t->parts); + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + { + if (vecHasSourceType(t->generics)) + return true; + + return false; + } + else if (auto t = get(ty)) + { + if (t->boundTo) + return *t->boundTo == sourceType; + + for (const auto& [_name, prop] : t->props) + { + if (prop.type == sourceType) + return true; + } + + if (auto indexer = t->indexer) + { + if (indexer->indexType == sourceType || indexer->indexResultType == sourceType) + return true; + } + + if (vecHasSourceType(t->instantiatedTypeParams)) + return true; + + return false; + } + else if (auto t = get(ty)) + return t->table == sourceType || t->metatable == sourceType; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return vecHasSourceType(t->options); + else if (auto t = get(ty)) + return vecHasSourceType(t->parts); + else if (auto t = get(ty)) + return false; + + LUAU_ASSERT(!"Luau::Replacer::isDirty internal error: Unknown TypeVar type"); + LUAU_UNREACHABLE(); + } + + bool isDirty(TypePackId tp) override + { + if (auto it = replacedPacks.find(tp)) + return false; + + if (auto pack = get(tp)) + { + for (TypeId ty : pack->head) + { + if (ty == sourceType) + return true; + } + return false; + } + else if (auto vtp = get(tp)) + return vtp->ty == sourceType; + else + return false; + } + + TypeId clean(TypeId ty) override + { + LUAU_ASSERT(sourceType && replacedType); + + // Walk every kind of TypeVar and create a copy with sourceType replaced by replacedType + // Before returning, memoize the result for later use. + + // Helpfully, Substitution::clone() only shallow-clones the kinds of types that we care to work with. This + // function returns the identity for things like primitives. + TypeId res = clone(ty); + + if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + for (TypeId& part : t->parts) + { + if (part == sourceType) + part = replacedType; + } + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + // The constituent typepacks are cleaned separately. We just need to walk the generics array. + for (TypeId& g : t->generics) + { + if (g == sourceType) + g = replacedType; + } + } + else if (auto t = getMutable(res)) + { + for (auto& [_key, prop] : t->props) + { + if (prop.type == sourceType) + prop.type = replacedType; + } + } + else if (auto t = getMutable(res)) + { + if (t->table == sourceType) + t->table = replacedType; + if (t->metatable == sourceType) + t->table = replacedType; + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + for (TypeId& option : t->options) + { + if (option == sourceType) + option = replacedType; + } + } + else if (auto t = getMutable(res)) + { + for (TypeId& part : t->parts) + { + if (part == sourceType) + part = replacedType; + } + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else + LUAU_ASSERT(!"Luau::Replacer::clean internal error: Unknown TypeVar type"); + + replacedTypes[ty] = res; + return res; + } + + TypePackId clean(TypePackId tp) override + { + TypePackId res = clone(tp); + + if (auto pack = getMutable(res)) + { + for (TypeId& type : pack->head) + { + if (type == sourceType) + type = replacedType; + } + } + else if (auto vtp = getMutable(res)) + { + if (vtp->ty == sourceType) + vtp->ty = replacedType; + } + + replacedPacks[tp] = res; + return res; + } + + TypeId smartClone(TypeId t) + { + if (FFlag::LuauReplaceReplacer) + { + // The new smartClone is just a memoized clone() + // TODO: Remove the Substitution base class and all other methods from this struct. + // Add DenseHashMap newTypes; + t = log->follow(t); + TypeId* res = newTypes.find(t); + if (res) + return *res; + + TypeId result = shallowClone(t, *arena, TxnLog::empty()); + newTypes[t] = result; + newTypes[result] = result; + + return result; + } + else + { + std::optional res = replace(t); + LUAU_ASSERT(res.has_value()); // TODO think about this + if (*res == t) + return clone(t); + return *res; + } + } +}; + +} // anonymous namespace + +bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) +{ + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(subTy, superTy); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; +} + +bool isSubtype(TypePackId subPack, TypePackId superPack, InternalErrorReporter& ice) +{ + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(subPack, superPack); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; +} + +template +static bool areNormal_(const T& t, const std::unordered_set& seen, InternalErrorReporter& ice) +{ + int count = 0; + auto isNormal = [&](TypeId ty) { + ++count; + if (count >= FInt::LuauNormalizeIterationLimit) + ice.ice("Luau::areNormal hit iteration limit"); + + if (FFlag::LuauNormalizeFlagIsConservative) + return ty->normal; + else + { + // The follow is here because a bound type may not be normal, but the bound type is normal. + return ty->normal || follow(ty)->normal || seen.find(asMutable(ty)) != seen.end(); + } + }; + + return std::all_of(begin(t), end(t), isNormal); +} + +static bool areNormal(const std::vector& types, const std::unordered_set& seen, InternalErrorReporter& ice) +{ + return areNormal_(types, seen, ice); +} + +static bool areNormal(TypePackId tp, const std::unordered_set& seen, InternalErrorReporter& ice) +{ + tp = follow(tp); + if (get(tp)) + return false; + + auto [head, tail] = flatten(tp); + + if (!areNormal_(head, seen, ice)) + return false; + + if (!tail) + return true; + + if (auto vtp = get(*tail)) + return vtp->ty->normal || follow(vtp->ty)->normal || seen.find(asMutable(vtp->ty)) != seen.end(); + + return true; +} + +#define CHECK_ITERATION_LIMIT(...) \ + do \ + { \ + if (iterationLimit > FInt::LuauNormalizeIterationLimit) \ + { \ + limitExceeded = true; \ + return __VA_ARGS__; \ + } \ + ++iterationLimit; \ + } while (false) + +struct Normalize final : TypeVarVisitor +{ + using TypeVarVisitor::Set; + + Normalize(TypeArena& arena, InternalErrorReporter& ice) + : arena(arena) + , ice(ice) + { + } + + TypeArena& arena; + InternalErrorReporter& ice; + + int iterationLimit = 0; + bool limitExceeded = false; + + bool visit(TypeId ty, const FreeTypeVar&) override + { + LUAU_ASSERT(!ty->normal); + return false; + } + + bool visit(TypeId ty, const BoundTypeVar& btv) override + { + // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. + // So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack. + if (seen.find(asMutable(btv.boundTo)) != seen.end()) + return false; + + // It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases. + LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); + + asMutable(ty)->normal = btv.boundTo->normal; + return !ty->normal; + } + + bool visit(TypeId ty, const PrimitiveTypeVar&) override + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool visit(TypeId ty, const GenericTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + + return false; + } + + bool visit(TypeId ty, const ErrorTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override + { + CHECK_ITERATION_LIMIT(false); + LUAU_ASSERT(!ty->normal); + + ConstrainedTypeVar* ctv = const_cast(&ctvRef); + + std::vector parts = std::move(ctv->parts); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId part : parts) + traverse(part); + + std::vector newParts = normalizeUnion(parts); + + if (FFlag::LuauQuantifyConstrained) + { + ctv->parts = std::move(newParts); + } + else + { + const bool normal = areNormal(newParts, seen, ice); + + if (newParts.size() == 1) + *asMutable(ty) = BoundTypeVar{newParts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + + asMutable(ty)->normal = normal; + } + + return false; + } + + bool visit(TypeId ty, const FunctionTypeVar& ftv) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + traverse(ftv.argTypes); + traverse(ftv.retTypes); + + asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retTypes, seen, ice); + + return false; + } + + bool visit(TypeId ty, const TableTypeVar& ttv) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + bool normal = true; + + auto checkNormal = [&](TypeId t) { + // if t is on the stack, it is possible that this type is normal. + // If t is not normal and it is not on the stack, this type is definitely not normal. + if (!t->normal && seen.find(asMutable(t)) == seen.end()) + normal = false; + }; + + if (ttv.boundTo) + { + traverse(*ttv.boundTo); + asMutable(ty)->normal = (*ttv.boundTo)->normal; + return false; + } + + for (const auto& [_name, prop] : ttv.props) + { + traverse(prop.type); + checkNormal(prop.type); + } + + if (ttv.indexer) + { + traverse(ttv.indexer->indexType); + checkNormal(ttv.indexer->indexType); + traverse(ttv.indexer->indexResultType); + checkNormal(ttv.indexer->indexResultType); + } + + // An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal. + if (FFlag::LuauQuantifyConstrained) + { + if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal)) + asMutable(ty)->normal = normal; + } + else + asMutable(ty)->normal = normal; + + return false; + } + + bool visit(TypeId ty, const MetatableTypeVar& mtv) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + traverse(mtv.table); + traverse(mtv.metatable); + + asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal; + + return false; + } + + bool visit(TypeId ty, const ClassTypeVar& ctv) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool visit(TypeId ty, const AnyTypeVar&) override + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool visit(TypeId ty, const UnionTypeVar& utvRef) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + UnionTypeVar* utv = &const_cast(utvRef); + std::vector options = std::move(utv->options); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId option : options) + traverse(option); + + std::vector newOptions = normalizeUnion(options); + + const bool normal = areNormal(newOptions, seen, ice); + + LUAU_ASSERT(!newOptions.empty()); + + if (newOptions.size() == 1) + *asMutable(ty) = BoundTypeVar{newOptions[0]}; + else + utv->options = std::move(newOptions); + + asMutable(ty)->normal = normal; + + return false; + } + + bool visit(TypeId ty, const IntersectionTypeVar& itvRef) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + IntersectionTypeVar* itv = &const_cast(itvRef); + + std::vector oldParts = std::move(itv->parts); + + for (TypeId part : oldParts) + traverse(part); + + std::vector tables; + for (TypeId part : oldParts) + { + part = follow(part); + if (get(part)) + tables.push_back(part); + else + { + Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD + combineIntoIntersection(replacer, itv, part); + } + } + + // Don't allocate a new table if there's just one in the intersection. + if (tables.size() == 1) + itv->parts.push_back(tables[0]); + else if (!tables.empty()) + { + const TableTypeVar* first = get(tables[0]); + LUAU_ASSERT(first); + + TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); + TableTypeVar* ttv = getMutable(newTable); + for (TypeId part : tables) + { + // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need + // to be rewritten to point at 'newTable' in the clone. + Replacer replacer{&arena, part, newTable}; + combineIntoTable(replacer, ttv, part); + } + + itv->parts.push_back(newTable); + } + + asMutable(ty)->normal = areNormal(itv->parts, seen, ice); + + if (itv->parts.size() == 1) + { + TypeId part = itv->parts[0]; + *asMutable(ty) = BoundTypeVar{part}; + } + + return false; + } + + std::vector normalizeUnion(const std::vector& options) + { + if (options.size() == 1) + return options; + + std::vector result; + + for (TypeId part : options) + combineIntoUnion(result, part); + + return result; + } + + void combineIntoUnion(std::vector& result, TypeId ty) + { + ty = follow(ty); + if (auto utv = get(ty)) + { + for (TypeId t : utv) + combineIntoUnion(result, t); + return; + } + + for (TypeId& part : result) + { + if (isSubtype(ty, part, ice)) + return; // no need to do anything + else if (isSubtype(part, ty, ice)) + { + part = ty; // replace the less general type by the more general one + return; + } + } + + result.push_back(ty); + } + + /** + * @param replacer knows how to clone a type such that any recursive references point at the new containing type. + * @param result is an intersection that is safe for us to mutate in-place. + */ + void combineIntoIntersection(Replacer& replacer, IntersectionTypeVar* result, TypeId ty) + { + // Note: this check guards against running out of stack space + // so if you increase the size of a stack frame, you'll need to decrease the limit. + CHECK_ITERATION_LIMIT(); + + ty = follow(ty); + if (auto itv = get(ty)) + { + for (TypeId part : itv->parts) + combineIntoIntersection(replacer, result, part); + return; + } + + // Let's say that the last part of our result intersection is always a table, if any table is part of this intersection + if (get(ty)) + { + if (result->parts.empty()) + result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); + + TypeId theTable = result->parts.back(); + + if (!get(follow(theTable))) + { + result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); + theTable = result->parts.back(); + } + + TypeId newTable = replacer.smartClone(theTable); + result->parts.back() = newTable; + + combineIntoTable(replacer, getMutable(newTable), ty); + } + else if (auto ftv = get(ty)) + { + bool merged = false; + for (TypeId& part : result->parts) + { + if (isSubtype(part, ty, ice)) + { + merged = true; + break; // no need to do anything + } + else if (isSubtype(ty, part, ice)) + { + merged = true; + part = ty; // replace the less general type by the more general one + break; + } + } + + if (!merged) + result->parts.push_back(ty); + } + else + result->parts.push_back(ty); + } + + TableState combineTableStates(TableState lhs, TableState rhs) + { + if (lhs == rhs) + return lhs; + + if (lhs == TableState::Free || rhs == TableState::Free) + return TableState::Free; + + if (lhs == TableState::Unsealed || rhs == TableState::Unsealed) + return TableState::Unsealed; + + return lhs; + } + + /** + * @param replacer gives us a way to clone a type such that recursive references are rewritten to the new + * "containing" type. + * @param table always points into a table that is safe for us to mutate. + */ + void combineIntoTable(Replacer& replacer, TableTypeVar* table, TypeId ty) + { + // Note: this check guards against running out of stack space + // so if you increase the size of a stack frame, you'll need to decrease the limit. + CHECK_ITERATION_LIMIT(); + + LUAU_ASSERT(table); + + ty = follow(ty); + + TableTypeVar* tyTable = getMutable(ty); + LUAU_ASSERT(tyTable); + + for (const auto& [propName, prop] : tyTable->props) + { + if (auto it = table->props.find(propName); it != table->props.end()) + { + /** + * If we are going to recursively merge intersections of tables, we need to ensure that we never mutate + * a table that comes from somewhere else in the type graph. + * + * smarClone() does some nice things for us: It will perform a clone that is as shallow as possible + * while still rewriting any cyclic references back to the new 'root' table. + * + * replacer also keeps a mapping of types that have previously been copied, so we have the added + * advantage here of knowing that, whether or not a new copy was actually made, the resulting TypeVar is + * safe for us to mutate in-place. + */ + TypeId clone = replacer.smartClone(it->second.type); + it->second.type = combine(replacer, clone, prop.type); + } + else + table->props.insert({propName, prop}); + } + + table->state = combineTableStates(table->state, tyTable->state); + table->level = max(table->level, tyTable->level); + } + + /** + * @param a is always cloned by the caller. It is safe to mutate in-place. + * @param b will never be mutated. + */ + TypeId combine(Replacer& replacer, TypeId a, TypeId b) + { + if (FFlag::LuauNormalizeCombineEqFix) + b = follow(b); + + if (FFlag::LuauNormalizeCombineTableFix && a == b) + return a; + + if (!get(a) && !get(a)) + { + if (!FFlag::LuauNormalizeCombineTableFix && a == b) + return a; + else + return arena.addType(IntersectionTypeVar{{a, b}}); + } + + if (auto itv = getMutable(a)) + { + combineIntoIntersection(replacer, itv, b); + return a; + } + else if (auto ttv = getMutable(a)) + { + if (FFlag::LuauNormalizeCombineTableFix && !get(FFlag::LuauNormalizeCombineEqFix ? b : follow(b))) + return arena.addType(IntersectionTypeVar{{a, b}}); + combineIntoTable(replacer, ttv, b); + return a; + } + + LUAU_ASSERT(!"Impossible"); + LUAU_UNREACHABLE(); + } +}; + +#undef CHECK_ITERATION_LIMIT + +/** + * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) + */ +std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice) +{ + CloneState state; + if (FFlag::DebugLuauCopyBeforeNormalizing) + (void)clone(ty, arena, state); + + Normalize n{arena, ice}; + n.traverse(ty); + + return {ty, !n.limitExceeded}; +} + +// TODO: Think about using a temporary arena and cloning types out of it so that we +// reclaim memory used by wantonly allocated intermediate types here. +// The main wrinkle here is that we don't want clone() to copy a type if the source and dest +// arena are the same. +std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice) +{ + return normalize(ty, module->internalTypes, ice); +} + +/** + * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) + */ +std::pair normalize(TypePackId tp, TypeArena& arena, InternalErrorReporter& ice) +{ + CloneState state; + if (FFlag::DebugLuauCopyBeforeNormalizing) + (void)clone(tp, arena, state); + + Normalize n{arena, ice}; + n.traverse(tp); + + return {tp, !n.limitExceeded}; +} + +std::pair normalize(TypePackId tp, const ModulePtr& module, InternalErrorReporter& ice) +{ + return normalize(tp, module->internalTypes, ice); +} + +} // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 94e169f1..40e14c68 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -2,59 +2,139 @@ #include "Luau/Quantify.h" +#include "Luau/Scope.h" +#include "Luau/Substitution.h" +#include "Luau/TxnLog.h" #include "Luau/VisitTypeVar.h" +LUAU_FASTFLAG(LuauAlwaysQuantify); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false) + namespace Luau { -struct Quantifier +/// @return true if outer encloses inner +static bool subsumes(Scope2* outer, Scope2* inner) +{ + while (inner) + { + if (inner == outer) + return true; + inner = inner->parent; + } + + return false; +} + +struct Quantifier final : TypeVarOnceVisitor { TypeLevel level; std::vector generics; std::vector genericPacks; + Scope2* scope = nullptr; + bool seenGenericType = false; + bool seenMutableType = false; - Quantifier(TypeLevel level) + explicit Quantifier(TypeLevel level) : level(level) { + LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); } - void cycle(TypeId) {} - void cycle(TypePackId) {} - - bool operator()(TypeId ty, const FreeTypeVar& ftv) + explicit Quantifier(Scope2* scope) + : scope(scope) { - if (!level.subsumes(ftv.level)) + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + } + + /// @return true if outer encloses inner + bool subsumes(Scope2* outer, Scope2* inner) + { + while (inner) + { + if (inner == outer) + return true; + inner = inner->parent; + } + + return false; + } + + bool visit(TypeId ty, const FreeTypeVar& ftv) override + { + seenMutableType = true; + + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ftv.scope) : !level.subsumes(ftv.level)) return false; - *asMutable(ty) = GenericTypeVar{level}; + if (FFlag::DebugLuauDeferredConstraintResolution) + *asMutable(ty) = GenericTypeVar{scope}; + else + *asMutable(ty) = GenericTypeVar{level}; + generics.push_back(ty); return false; } - template - bool operator()(TypeId ty, const T& t) + bool visit(TypeId ty, const ConstrainedTypeVar&) override { - return true; + if (FFlag::LuauQuantifyConstrained) + { + ConstrainedTypeVar* ctv = getMutable(ty); + + seenMutableType = true; + + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ctv->scope) : !level.subsumes(ctv->level)) + return false; + + std::vector opts = std::move(ctv->parts); + + // We might transmute, so it's not safe to rely on the builtin traversal logic + for (TypeId opt : opts) + traverse(opt); + + if (opts.size() == 1) + *asMutable(ty) = BoundTypeVar{opts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(opts)}; + + return false; + } + else + return true; } - template - bool operator()(TypePackId, const T&) - { - return true; - } - - bool operator()(TypeId ty, const TableTypeVar&) + bool visit(TypeId ty, const TableTypeVar&) override { + LUAU_ASSERT(getMutable(ty)); TableTypeVar& ttv = *getMutable(ty); - if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) - return false; - if (!level.subsumes(ttv.level)) - return false; + if (ttv.state == TableState::Generic) + seenGenericType = true; if (ttv.state == TableState::Free) + seenMutableType = true; + + if (!FFlag::LuauQuantifyConstrained) + { + if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) + return false; + } + + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ttv.scope) : !level.subsumes(ttv.level)) + { + if (ttv.state == TableState::Unsealed) + seenMutableType = true; + return false; + } + + if (ttv.state == TableState::Free) + { ttv.state = TableState::Generic; + seenGenericType = true; + } else if (ttv.state == TableState::Unsealed) ttv.state = TableState::Sealed; @@ -63,9 +143,11 @@ struct Quantifier return true; } - bool operator()(TypePackId tp, const FreeTypePack& ftp) + bool visit(TypePackId tp, const FreeTypePack& ftp) override { - if (!level.subsumes(ftp.level)) + seenMutableType = true; + + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ftp.scope) : !level.subsumes(ftp.level)) return false; *asMutable(tp) = GenericTypePack{level}; @@ -77,13 +159,145 @@ struct Quantifier void quantify(TypeId ty, TypeLevel level) { Quantifier q{level}; - DenseHashSet seen{nullptr}; - visitTypeVarOnce(ty, q, seen); + q.traverse(ty); FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); - ftv->generics = q.generics; - ftv->genericPacks = q.genericPacks; + if (FFlag::LuauAlwaysQuantify) + { + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + } + else + { + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; + } + + if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + ftv->hasNoGenerics = true; +} + +void quantify(TypeId ty, Scope2* scope) +{ + Quantifier q{scope}; + q.traverse(ty); + + FunctionTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + if (FFlag::LuauAlwaysQuantify) + { + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + } + else + { + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; + } + + if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + ftv->hasNoGenerics = true; +} + +struct PureQuantifier : Substitution +{ + Scope2* scope; + std::vector insertedGenerics; + std::vector insertedGenericPacks; + + PureQuantifier(const TxnLog* log, TypeArena* arena, Scope2* scope) + : Substitution(log, arena) + , scope(scope) + { + } + + bool isDirty(TypeId ty) override + { + LUAU_ASSERT(ty == follow(ty)); + + if (auto ftv = get(ty)) + { + return subsumes(scope, ftv->scope); + } + else if (auto ttv = get(ty)) + { + return ttv->state == TableState::Free && subsumes(scope, ttv->scope); + } + + return false; + } + + bool isDirty(TypePackId tp) override + { + if (auto ftp = get(tp)) + { + return subsumes(scope, ftp->scope); + } + + return false; + } + + TypeId clean(TypeId ty) override + { + if (auto ftv = get(ty)) + { + TypeId result = arena->addType(GenericTypeVar{}); + insertedGenerics.push_back(result); + return result; + } + else if (auto ttv = get(ty)) + { + TypeId result = arena->addType(TableTypeVar{}); + TableTypeVar* resultTable = getMutable(result); + LUAU_ASSERT(resultTable); + + *resultTable = *ttv; + resultTable->scope = nullptr; + resultTable->state = TableState::Generic; + + return result; + } + + return ty; + } + + TypePackId clean(TypePackId tp) override + { + if (auto ftp = get(tp)) + { + TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{}}); + insertedGenericPacks.push_back(result); + return result; + } + + return tp; + } + + bool ignoreChildren(TypeId ty) override + { + return ty->persistent; + } + bool ignoreChildren(TypePackId ty) override + { + return ty->persistent; + } +}; + +TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope) +{ + PureQuantifier quantifier{TxnLog::empty(), arena, scope}; + std::optional result = quantifier.substitute(ty); + LUAU_ASSERT(result); + + FunctionTypeVar* ftv = getMutable(*result); + LUAU_ASSERT(ftv); + ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); + + // TODO: Set hasNoGenerics. + + return *result; } } // namespace Luau diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 8ed245fb..c036a7a5 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -28,7 +28,7 @@ struct RequireTracer : AstVisitor AstExprGlobal* global = expr->func->as(); if (global && global->name == "require" && expr->args.size >= 1) - requires.push_back(expr); + requireCalls.push_back(expr); return true; } @@ -84,9 +84,9 @@ struct RequireTracer : AstVisitor ModuleInfo moduleContext{currentModuleName}; // seed worklist with require arguments - work.reserve(requires.size()); + work.reserve(requireCalls.size()); - for (AstExprCall* require : requires) + for (AstExprCall* require : requireCalls) work.push_back(require->args.data[0]); // push all dependent expressions to the work stack; note that the vector is modified during traversal @@ -125,15 +125,15 @@ struct RequireTracer : AstVisitor } // resolve all requires according to their argument - result.requires.reserve(requires.size()); + result.requireList.reserve(requireCalls.size()); - for (AstExprCall* require : requires) + for (AstExprCall* require : requireCalls) { AstExpr* arg = require->args.data[0]; if (const ModuleInfo* info = result.exprs.find(arg)) { - result.requires.push_back({info->name, require->location}); + result.requireList.push_back({info->name, require->location}); ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info! result.exprs[require] = std::move(infoCopy); @@ -151,7 +151,7 @@ struct RequireTracer : AstVisitor DenseHashMap locals; std::vector work; - std::vector requires; + std::vector requireCalls; }; RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 0a362a5e..66aaee1f 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -2,8 +2,6 @@ #include "Luau/Scope.h" -LUAU_FASTFLAG(LuauTwoPassAliasDefinitionFix); - namespace Luau { @@ -19,8 +17,7 @@ Scope::Scope(const ScopePtr& parent, int subLevel) , returnType(parent->returnType) , level(parent->level.incr()) { - if (FFlag::LuauTwoPassAliasDefinitionFix) - level = level.incr(); + level = level.incr(); level.subLevel = subLevel; } @@ -124,4 +121,36 @@ std::optional Scope::linearSearchForBinding(const std::string& name, bo return std::nullopt; } +std::optional Scope2::lookup(Symbol sym) +{ + Scope2* s = this; + + while (true) + { + auto it = s->bindings.find(sym); + if (it != s->bindings.end()) + return it->second; + + if (s->parent) + s = s->parent; + else + return std::nullopt; + } +} + +std::optional Scope2::lookupTypeBinding(const Name& name) +{ + Scope2* s = this; + while (s) + { + auto it = s->typeBindings.find(name); + if (it != s->typeBindings.end()) + return it->second; + + s = s->parent; + } + + return std::nullopt; +} + } // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 770c7a47..9c4ce829 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -2,29 +2,34 @@ #include "Luau/Substitution.h" #include "Luau/Common.h" +#include "Luau/Clone.h" #include "Luau/TxnLog.h" #include #include -LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) +LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) namespace Luau { void Tarjan::visitChildren(TypeId ty, int index) { - ty = log->follow(ty); + LUAU_ASSERT(ty == log->follow(ty)); if (ignoreChildren(ty)) return; - if (const FunctionTypeVar* ftv = log->getMutable(ty)) + if (auto pty = log->pending(ty)) + ty = &pty->pending; + + if (const FunctionTypeVar* ftv = get(ty)) { visitChild(ftv->argTypes); - visitChild(ftv->retType); + visitChild(ftv->retTypes); } - else if (const TableTypeVar* ttv = log->getMutable(ty)) + else if (const TableTypeVar* ttv = get(ty)) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) @@ -41,38 +46,46 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId itp : ttv->instantiatedTypePackParams) visitChild(itp); } - else if (const MetatableTypeVar* mtv = log->getMutable(ty)) + else if (const MetatableTypeVar* mtv = get(ty)) { visitChild(mtv->table); visitChild(mtv->metatable); } - else if (const UnionTypeVar* utv = log->getMutable(ty)) + else if (const UnionTypeVar* utv = get(ty)) { for (TypeId opt : utv->options) visitChild(opt); } - else if (const IntersectionTypeVar* itv = log->getMutable(ty)) + else if (const IntersectionTypeVar* itv = get(ty)) { for (TypeId part : itv->parts) visitChild(part); } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + for (TypeId part : ctv->parts) + visitChild(part); + } } void Tarjan::visitChildren(TypePackId tp, int index) { - tp = log->follow(tp); + LUAU_ASSERT(tp == log->follow(tp)); if (ignoreChildren(tp)) return; - if (const TypePack* tpp = log->getMutable(tp)) + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; + + if (const TypePack* tpp = get(tp)) { for (TypeId tv : tpp->head) visitChild(tv); if (tpp->tail) visitChild(*tpp->tail); } - else if (const VariadicTypePack* vtp = log->getMutable(tp)) + else if (const VariadicTypePack* vtp = get(tp)) { visitChild(vtp->ty); } @@ -141,7 +154,7 @@ TarjanResult Tarjan::loop() if (currEdge == -1) { ++childCount; - if (FInt::LuauTarjanChildLimit > 0 && FInt::LuauTarjanChildLimit < childCount) + if (childLimit > 0 && childLimit < childCount) return TarjanResult::TooManyChildren; stack.push_back(index); @@ -229,6 +242,9 @@ TarjanResult Tarjan::loop() TarjanResult Tarjan::visitRoot(TypeId ty) { childCount = 0; + if (childLimit == 0) + childLimit = FInt::LuauTarjanChildLimit; + ty = log->follow(ty); auto [index, fresh] = indexify(ty); @@ -239,6 +255,9 @@ TarjanResult Tarjan::visitRoot(TypeId ty) TarjanResult Tarjan::visitRoot(TypePackId tp) { childCount = 0; + if (childLimit == 0) + childLimit = FInt::LuauTarjanChildLimit; + tp = log->follow(tp); auto [index, fresh] = indexify(tp); @@ -343,67 +362,24 @@ std::optional Substitution::substitute(TypePackId tp) TypeId Substitution::clone(TypeId ty) { - ty = log->follow(ty); - - TypeId result = ty; - - if (const FunctionTypeVar* ftv = log->getMutable(ty)) - { - FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; - clone.generics = ftv->generics; - clone.genericPacks = ftv->genericPacks; - clone.magicFunction = ftv->magicFunction; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - result = addType(std::move(clone)); - } - else if (const TableTypeVar* ttv = log->getMutable(ty)) - { - LUAU_ASSERT(!ttv->boundTo); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.instantiatedTypeParams = ttv->instantiatedTypeParams; - clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - clone.tags = ttv->tags; - result = addType(std::move(clone)); - } - else if (const MetatableTypeVar* mtv = log->getMutable(ty)) - { - MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; - clone.syntheticName = mtv->syntheticName; - result = addType(std::move(clone)); - } - else if (const UnionTypeVar* utv = log->getMutable(ty)) - { - UnionTypeVar clone; - clone.options = utv->options; - result = addType(std::move(clone)); - } - else if (const IntersectionTypeVar* itv = log->getMutable(ty)) - { - IntersectionTypeVar clone; - clone.parts = itv->parts; - result = addType(std::move(clone)); - } - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; + return shallowClone(ty, *arena, log); } TypePackId Substitution::clone(TypePackId tp) { tp = log->follow(tp); - if (const TypePack* tpp = log->getMutable(tp)) + + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; + + if (const TypePack* tpp = get(tp)) { TypePack clone; clone.head = tpp->head; clone.tail = tpp->tail; return addTypePack(std::move(clone)); } - else if (const VariadicTypePack* vtp = log->getMutable(tp)) + else if (const VariadicTypePack* vtp = get(tp)) { VariadicTypePack clone; clone.ty = vtp->ty; @@ -416,24 +392,27 @@ TypePackId Substitution::clone(TypePackId tp) void Substitution::foundDirty(TypeId ty) { ty = log->follow(ty); + if (isDirty(ty)) - newTypes[ty] = clean(ty); + newTypes[ty] = follow(clean(ty)); else - newTypes[ty] = clone(ty); + newTypes[ty] = follow(clone(ty)); } void Substitution::foundDirty(TypePackId tp) { tp = log->follow(tp); + if (isDirty(tp)) - newPacks[tp] = clean(tp); + newPacks[tp] = follow(clean(tp)); else - newPacks[tp] = clone(tp); + newPacks[tp] = follow(clone(tp)); } TypeId Substitution::replace(TypeId ty) { ty = log->follow(ty); + if (TypeId* prevTy = newTypes.find(ty)) return *prevTy; else @@ -443,6 +422,7 @@ TypeId Substitution::replace(TypeId ty) TypePackId Substitution::replace(TypePackId tp) { tp = log->follow(tp); + if (TypePackId* prevTp = newPacks.find(tp)) return *prevTp; else @@ -451,7 +431,10 @@ TypePackId Substitution::replace(TypePackId tp) void Substitution::replaceChildren(TypeId ty) { - ty = log->follow(ty); + if (BoundTypeVar* btv = log->getMutable(ty); FFlag::LuauLowerBoundsCalculation && btv) + btv->boundTo = replace(btv->boundTo); + + LUAU_ASSERT(ty == log->follow(ty)); if (ignoreChildren(ty)) return; @@ -459,7 +442,7 @@ void Substitution::replaceChildren(TypeId ty) if (FunctionTypeVar* ftv = getMutable(ty)) { ftv->argTypes = replace(ftv->argTypes); - ftv->retType = replace(ftv->retType); + ftv->retTypes = replace(ftv->retTypes); } else if (TableTypeVar* ttv = getMutable(ty)) { @@ -493,11 +476,16 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& part : itv->parts) part = replace(part); } + else if (ConstrainedTypeVar* ctv = getMutable(ty)) + { + for (TypeId& part : ctv->parts) + part = replace(part); + } } void Substitution::replaceChildren(TypePackId tp) { - tp = log->follow(tp); + LUAU_ASSERT(tp == log->follow(tp)); if (ignoreChildren(tp)) return; diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index df9d4188..6b677bb8 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -154,7 +154,7 @@ void StateDot::visitChildren(TypeId ty, int index) finishNode(); visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retType, index, "ret"); + visitChild(ftv->retTypes, index, "ret"); } else if (const TableTypeVar* ttv = get(ty)) { @@ -237,6 +237,15 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + formatAppend(result, "ConstrainedTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId part : ctv->parts) + visitChild(part, index); + } else if (get(ty)) { formatAppend(result, "ErrorTypeVar %d", index); @@ -258,6 +267,28 @@ void StateDot::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable, index, "[metatable]"); } + else if (const SingletonTypeVar* stv = get(ty)) + { + std::string res; + + if (const StringSingleton* ss = get(stv)) + { + // Don't put in quotes anywhere. If it's outside of the call to escape, + // then it's invalid syntax. If it's inside, then escaping is super noisy. + res = "string: " + escape(ss->value); + } + else if (const BooleanSingleton* bs = get(stv)) + { + res = "boolean: "; + res += bs->value ? "true" : "false"; + } + else + LUAU_ASSERT(!"unknown singleton type"); + + formatAppend(result, "SingletonTypeVar %s", res.c_str()); + finishNodeLabel(ty); + finishNode(); + } else { LUAU_ASSERT(!"unknown type kind"); @@ -296,7 +327,7 @@ void StateDot::visitChildren(TypePackId tp, int index) } else if (const VariadicTypePack* vtp = get(tp)) { - formatAppend(result, "VariadicTypePack %d", index); + formatAppend(result, "VariadicTypePack %s%d", vtp->hidden ? "hidden " : "", index); finishNodeLabel(tp); finishNode(); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 59ee6de2..7a458964 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,13 +10,15 @@ #include #include +LUAU_FASTFLAG(LuauLowerBoundsCalculation) + /* * Prefix generic typenames with gen- * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 * Fair warning: Setting this will break a lot of Luau unit tests. */ LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) -LUAU_FASTFLAGVARIABLE(LuauDocFuncParameters, false) +LUAU_FASTFLAGVARIABLE(LuauToStringTableBracesNewlines, false) namespace Luau { @@ -24,7 +26,7 @@ namespace Luau namespace { -struct FindCyclicTypes +struct FindCyclicTypes final : TypeVarVisitor { FindCyclicTypes() = default; FindCyclicTypes(const FindCyclicTypes&) = delete; @@ -33,28 +35,30 @@ struct FindCyclicTypes bool exhaustive = false; std::unordered_set visited; std::unordered_set visitedPacks; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; - void cycle(TypeId ty) + void cycle(TypeId ty) override { cycles.insert(ty); } - void cycle(TypePackId tp) + void cycle(TypePackId tp) override { cycleTPs.insert(tp); } - template - bool operator()(TypeId ty, const T&) + bool visit(TypeId ty) override { return visited.insert(ty).second; } - bool operator()(TypeId ty, const TableTypeVar& ttv) = delete; + bool visit(TypePackId tp) override + { + return visitedPacks.insert(tp).second; + } - bool operator()(TypeId ty, const TableTypeVar& ttv, std::unordered_set& seen) + bool visit(TypeId ty, const TableTypeVar& ttv) override { if (!visited.insert(ty).second) return false; @@ -62,10 +66,10 @@ struct FindCyclicTypes if (ttv.name || ttv.syntheticName) { for (TypeId itp : ttv.instantiatedTypeParams) - visitTypeVar(itp, *this, seen); + traverse(itp); for (TypePackId itp : ttv.instantiatedTypePackParams) - visitTypeVar(itp, *this, seen); + traverse(itp); return exhaustive; } @@ -73,24 +77,18 @@ struct FindCyclicTypes return true; } - bool operator()(TypeId, const ClassTypeVar&) + bool visit(TypeId ty, const ClassTypeVar&) override { return false; } - - template - bool operator()(TypePackId tp, const T&) - { - return visitedPacks.insert(tp).second; - } }; template -void findCyclicTypes(std::unordered_set& cycles, std::unordered_set& cycleTPs, TID ty, bool exhaustive) +void findCyclicTypes(std::set& cycles, std::set& cycleTPs, TID ty, bool exhaustive) { FindCyclicTypes fct; fct.exhaustive = exhaustive; - visitTypeVar(ty, fct); + fct.traverse(ty); cycles = std::move(fct.cycles); cycleTPs = std::move(fct.cycleTPs); @@ -124,6 +122,7 @@ struct StringifierState std::unordered_map cycleTpNames; std::unordered_set seen; std::unordered_set usedNames; + size_t indentation = 0; bool exhaustive; @@ -180,6 +179,8 @@ struct StringifierState return generateName(s); } + int previousNameIndex = 0; + std::string getName(TypePackId ty) { const size_t s = result.nameMap.typePacks.size(); @@ -189,9 +190,10 @@ struct StringifierState for (int count = 0; count < 256; ++count) { - std::string candidate = generateName(usedNames.size() + count); + std::string candidate = generateName(previousNameIndex + count); if (!usedNames.count(candidate)) { + previousNameIndex += count; usedNames.insert(candidate); n = candidate; return candidate; @@ -209,6 +211,13 @@ struct StringifierState result.name += s; } + void emit(TypeLevel level) + { + emit(std::to_string(level.level)); + emit("-"); + emit(std::to_string(level.subLevel)); + } + void emit(const char* s) { if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) @@ -216,6 +225,39 @@ struct StringifierState result.name += s; } + + void emit(int i) + { + emit(std::to_string(i).c_str()); + } + + void indent() + { + indentation += 4; + } + + void dedent() + { + indentation -= 4; + } + + void newline() + { + if (!opts.useLineBreaks) + return emit(" "); + + emit("\n"); + emitIndentation(); + } + +private: + void emitIndentation() + { + if (!opts.indent) + return; + + emit(std::string(indentation, ' ')); + } }; struct TypeVarStringifier @@ -247,7 +289,8 @@ struct TypeVarStringifier } Luau::visit( - [this, tv](auto&& t) { + [this, tv](auto&& t) + { return (*this)(tv, t); }, tv->ty); @@ -312,7 +355,7 @@ struct TypeVarStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(std::to_string(ftv.level.level)); + state.emit(ftv.level); } } @@ -321,10 +364,11 @@ struct TypeVarStringifier stringify(btv.boundTo); } - void operator()(TypeId ty, const Unifiable::Generic& gtv) + void operator()(TypeId ty, const GenericTypeVar& gtv) { if (gtv.explicitName) { + state.usedNames.insert(gtv.name); state.result.nameMap.typeVars[ty] = gtv.name; state.emit(gtv.name); } @@ -332,6 +376,36 @@ struct TypeVarStringifier state.emit(state.getName(ty)); } + void operator()(TypeId, const ConstrainedTypeVar& ctv) + { + state.result.invalid = true; + + state.emit("["); + if (FFlag::DebugLuauVerboseTypeNames) + state.emit(ctv.level); + state.emit("["); + + bool first = true; + for (TypeId ty : ctv.parts) + { + if (first) + first = false; + else + state.emit("|"); + + stringify(ty); + } + + state.emit("]]"); + } + + void operator()(TypeId, const BlockedTypeVar& btv) + { + state.emit("*blocked-"); + state.emit(btv.index); + state.emit("*"); + } + void operator()(TypeId, const PrimitiveTypeVar& ptv) { switch (ptv.type) @@ -415,16 +489,31 @@ struct TypeVarStringifier state.emit(") -> "); bool plural = true; - if (auto retPack = get(follow(ftv.retType))) + + if (FFlag::LuauLowerBoundsCalculation) { - if (retPack->head.size() == 1 && !retPack->tail) - plural = false; + auto retBegin = begin(ftv.retTypes); + auto retEnd = end(ftv.retTypes); + if (retBegin != retEnd) + { + ++retBegin; + if (retBegin == retEnd && !retBegin.tail()) + plural = false; + } + } + else + { + if (auto retPack = get(follow(ftv.retTypes))) + { + if (retPack->head.size() == 1 && !retPack->tail) + plural = false; + } } if (plural) state.emit("("); - stringify(ftv.retType); + stringify(ftv.retTypes); if (plural) state.emit(")"); @@ -482,22 +571,54 @@ struct TypeVarStringifier { case TableState::Sealed: state.result.invalid = true; - openbrace = "{| "; - closedbrace = " |}"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{|"; + closedbrace = "|}"; + } + else + { + openbrace = "{| "; + closedbrace = " |}"; + } break; case TableState::Unsealed: - openbrace = "{ "; - closedbrace = " }"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{"; + closedbrace = "}"; + } + else + { + openbrace = "{ "; + closedbrace = " }"; + } break; case TableState::Free: state.result.invalid = true; - openbrace = "{- "; - closedbrace = " -}"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{-"; + closedbrace = "-}"; + } + else + { + openbrace = "{- "; + closedbrace = " -}"; + } break; case TableState::Generic: state.result.invalid = true; - openbrace = "{+ "; - closedbrace = " +}"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{+"; + closedbrace = "+}"; + } + else + { + openbrace = "{+ "; + closedbrace = " +}"; + } break; } @@ -511,10 +632,13 @@ struct TypeVarStringifier } state.emit(openbrace); + state.indent(); bool comma = false; if (ttv.indexer) { + if (FFlag::LuauToStringTableBracesNewlines) + state.newline(); state.emit("["); stringify(ttv.indexer->indexType); state.emit("]: "); @@ -527,7 +651,14 @@ struct TypeVarStringifier for (const auto& [name, prop] : ttv.props) { if (comma) - state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + { + state.emit(","); + state.newline(); + } + else if (FFlag::LuauToStringTableBracesNewlines) + { + state.newline(); + } size_t length = state.result.name.length() - oldLength; @@ -553,6 +684,14 @@ struct TypeVarStringifier ++index; } + state.dedent(); + if (FFlag::LuauToStringTableBracesNewlines) + { + if (comma) + state.newline(); + else + state.emit(" "); + } state.emit(closedbrace); state.unsee(&ttv); @@ -563,7 +702,8 @@ struct TypeVarStringifier state.result.invalid = true; state.emit("{ @metatable "); stringify(mtv.metatable); - state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + state.emit(","); + state.newline(); stringify(mtv.table); state.emit(" }"); } @@ -627,7 +767,10 @@ struct TypeVarStringifier for (std::string& ss : results) { if (!first) - state.emit(" | "); + { + state.newline(); + state.emit("| "); + } state.emit(ss); first = false; } @@ -680,7 +823,10 @@ struct TypeVarStringifier for (std::string& ss : results) { if (!first) - state.emit(" & "); + { + state.newline(); + state.emit("& "); + } state.emit(ss); first = false; } @@ -746,7 +892,8 @@ struct TypePackStringifier } Luau::visit( - [this, tp](auto&& t) { + [this, tp](auto&& t) + { return (*this)(tp, t); }, tp->ty); @@ -784,13 +931,16 @@ struct TypePackStringifier if (tp.tail && !isEmpty(*tp.tail)) { - const auto& tail = *tp.tail; - if (first) - first = false; - else - state.emit(", "); + TypePackId tail = follow(*tp.tail); + if (auto vtp = get(tail); !vtp || (!FFlag::DebugLuauVerboseTypeNames && !vtp->hidden)) + { + if (first) + first = false; + else + state.emit(", "); - stringify(tail); + stringify(tail); + } } state.unsee(&tp); @@ -805,6 +955,8 @@ struct TypePackStringifier void operator()(TypePackId, const VariadicTypePack& pack) { state.emit("..."); + if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) + state.emit(""); stringify(pack.ty); } @@ -814,6 +966,7 @@ struct TypePackStringifier state.emit("gen-"); if (pack.explicitName) { + state.usedNames.insert(pack.name); state.result.nameMap.typePacks[tp] = pack.name; state.emit(pack.name); } @@ -834,7 +987,7 @@ struct TypePackStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(std::to_string(pack.level.level)); + state.emit(pack.level); } state.emit("..."); @@ -858,15 +1011,12 @@ void TypeVarStringifier::stringify(TypePackId tpid, const std::vector& cycles, const std::unordered_set& cycleTPs, +static void assignCycleNames(const std::set& cycles, const std::set& cycleTPs, std::unordered_map& cycleNames, std::unordered_map& cycleTpNames, bool exhaustive) { int nextIndex = 1; - std::vector sortedCycles{cycles.begin(), cycles.end()}; - std::sort(sortedCycles.begin(), sortedCycles.end(), std::less{}); - - for (TypeId cycleTy : sortedCycles) + for (TypeId cycleTy : cycles) { std::string name; @@ -874,9 +1024,11 @@ static void assignCycleNames(const std::unordered_set& cycles, const std if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) { // If we have a cycle type in type parameters, assign a cycle name for this named table - if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { - return cycles.count(follow(el)); - }) != ttv->instantiatedTypeParams.end()) + if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), + [&](auto&& el) + { + return cycles.count(follow(el)); + }) != ttv->instantiatedTypeParams.end()) cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; continue; @@ -888,10 +1040,7 @@ static void assignCycleNames(const std::unordered_set& cycles, const std cycleNames[cycleTy] = std::move(name); } - std::vector sortedCycleTps{cycleTPs.begin(), cycleTPs.end()}; - std::sort(sortedCycleTps.begin(), sortedCycleTps.end(), std::less()); - - for (TypePackId tp : sortedCycleTps) + for (TypePackId tp : cycleTPs) { std::string name = "tp" + std::to_string(nextIndex); ++nextIndex; @@ -913,8 +1062,8 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) StringifierState state{opts, result, opts.nameMap}; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; findCyclicTypes(cycles, cycleTPs, ty, opts.exhaustive); @@ -975,9 +1124,11 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + }); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -988,7 +1139,8 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto&& t) { + [&tvs, cycleTy = cycleTy](auto&& t) + { return tvs(cycleTy, t); }, cycleTy->ty); @@ -1016,8 +1168,8 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) ToStringResult result; StringifierState state{opts, result, opts.nameMap}; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; findCyclicTypes(cycles, cycleTPs, tp, opts.exhaustive); @@ -1045,9 +1197,11 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + }); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -1058,7 +1212,8 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto&& t) { + [&tvs, cycleTy = cycleTy](auto t) + { return tvs(cycleTy, t); }, cycleTy->ty); @@ -1108,81 +1263,66 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp auto argPackIter = begin(ftv.argTypes); bool first = true; - if (FFlag::LuauDocFuncParameters) + size_t idx = 0; + while (argPackIter != end(ftv.argTypes)) { - size_t idx = 0; - while (argPackIter != end(ftv.argTypes)) + // ftv takes a self parameter as the first argument, skip it if specified in option + if (idx == 0 && ftv.hasSelf && opts.hideFunctionSelfArgument) { - if (!first) - state.emit(", "); - first = false; - - // We don't respect opts.functionTypeArguments - if (idx < opts.namedFunctionOverrideArgNames.size()) - { - state.emit(opts.namedFunctionOverrideArgNames[idx] + ": "); - } - else if (idx < ftv.argNames.size() && ftv.argNames[idx]) - { - state.emit(ftv.argNames[idx]->name + ": "); - } - else - { - state.emit("_: "); - } - tvs.stringify(*argPackIter); - ++argPackIter; ++idx; + continue; } - } - else - { - auto argNameIter = ftv.argNames.begin(); - while (argPackIter != end(ftv.argTypes)) + + if (!first) + state.emit(", "); + first = false; + + // We don't respect opts.functionTypeArguments + if (idx < opts.namedFunctionOverrideArgNames.size()) { - if (!first) - state.emit(", "); - first = false; - - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (argNameIter != ftv.argNames.end()) - { - state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": "); - ++argNameIter; - } - else - { - state.emit("_: "); - } - - tvs.stringify(*argPackIter); - ++argPackIter; + state.emit(opts.namedFunctionOverrideArgNames[idx] + ": "); } + else if (idx < ftv.argNames.size() && ftv.argNames[idx]) + { + state.emit(ftv.argNames[idx]->name + ": "); + } + else + { + state.emit("_: "); + } + tvs.stringify(*argPackIter); + + ++argPackIter; + ++idx; } if (argPackIter.tail()) { - if (!first) - state.emit(", "); + if (auto vtp = get(*argPackIter.tail()); !vtp || !vtp->hidden) + { + if (!first) + state.emit(", "); - state.emit("...: "); - if (auto vtp = get(*argPackIter.tail())) - tvs.stringify(vtp->ty); - else - tvs.stringify(*argPackIter.tail()); + state.emit("...: "); + + if (vtp) + tvs.stringify(vtp->ty); + else + tvs.stringify(*argPackIter.tail()); + } } state.emit("): "); - size_t retSize = size(ftv.retType); - bool hasTail = !finite(ftv.retType); - bool wrap = get(follow(ftv.retType)) && (hasTail ? retSize != 0 : retSize != 1); + size_t retSize = size(ftv.retTypes); + bool hasTail = !finite(ftv.retTypes); + bool wrap = get(follow(ftv.retTypes)) && (hasTail ? retSize != 0 : retSize != 1); if (wrap) state.emit("("); - tvs.stringify(ftv.retType); + tvs.stringify(ftv.retTypes); if (wrap) state.emit(")"); @@ -1210,6 +1350,24 @@ std::string dump(TypePackId ty) return s; } +std::string dump(const ScopePtr& scope, const char* name) +{ + auto binding = scope->linearSearchForBinding(name); + if (!binding) + { + printf("No binding %s\n", name); + return {}; + } + + TypeId ty = binding->typeId; + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; +} + std::string generateName(size_t i) { std::string n; @@ -1219,4 +1377,61 @@ std::string generateName(size_t i) return n; } +std::string toString(const Constraint& c, ToStringOptions& opts) +{ + if (const SubtypeConstraint* sc = Luau::get_if(&c.c)) + { + ToStringResult subStr = toStringDetailed(sc->subType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(sc->superType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " <: " + superStr.name; + } + else if (const PackSubtypeConstraint* psc = Luau::get_if(&c.c)) + { + ToStringResult subStr = toStringDetailed(psc->subPack, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(psc->superPack, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " <: " + superStr.name; + } + else if (const GeneralizationConstraint* gc = Luau::get_if(&c.c)) + { + ToStringResult subStr = toStringDetailed(gc->generalizedType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(gc->sourceType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " ~ gen " + superStr.name; + } + else if (const InstantiationConstraint* ic = Luau::get_if(&c.c)) + { + ToStringResult subStr = toStringDetailed(ic->subType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(ic->superType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " ~ inst " + superStr.name; + } + else if (const NameConstraint* nc = Luau::get(c)) + { + ToStringResult namedStr = toStringDetailed(nc->namedType, opts); + opts.nameMap = std::move(namedStr.nameMap); + return "@name(" + namedStr.name + ") = " + nc->name; + } + else + { + LUAU_ASSERT(false); + return ""; + } +} + +std::string dump(const Constraint& c) +{ + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string s = toString(c, opts); + printf("%s\n", s.c_str()); + return s; +} + } // namespace Luau diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 678001bf..1ea2e27d 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -215,6 +215,7 @@ struct ArcCollector : public AstVisitor } } + // Adds a dependency from the current node to the named node. void add(const Identifier& name) { Node** it = map.find(name); diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 92ed241e..1577bd63 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -1025,31 +1025,42 @@ struct Printer } else if (const auto& a = typeAnnotation.as()) { - CommaSeparatorInserter comma(writer); + AstTypeReference* indexType = a->indexer ? a->indexer->indexType->as() : nullptr; - writer.symbol("{"); - - for (std::size_t i = 0; i < a->props.size; ++i) + if (a->props.size == 0 && indexType && indexType->name == "number") { - comma(); - advance(a->props.data[i].location.begin); - writer.identifier(a->props.data[i].name.value); - if (a->props.data[i].type) - { - writer.symbol(":"); - visualizeTypeAnnotation(*a->props.data[i].type); - } - } - if (a->indexer) - { - comma(); - writer.symbol("["); - visualizeTypeAnnotation(*a->indexer->indexType); - writer.symbol("]"); - writer.symbol(":"); + writer.symbol("{"); visualizeTypeAnnotation(*a->indexer->resultType); + writer.symbol("}"); + } + else + { + CommaSeparatorInserter comma(writer); + + writer.symbol("{"); + + for (std::size_t i = 0; i < a->props.size; ++i) + { + comma(); + advance(a->props.data[i].location.begin); + writer.identifier(a->props.data[i].name.value); + if (a->props.data[i].type) + { + writer.symbol(":"); + visualizeTypeAnnotation(*a->props.data[i].type); + } + } + if (a->indexer) + { + comma(); + writer.symbol("["); + visualizeTypeAnnotation(*a->indexer->indexType); + writer.symbol("]"); + writer.symbol(":"); + visualizeTypeAnnotation(*a->indexer->resultType); + } + writer.symbol("}"); } - writer.symbol("}"); } else if (auto a = typeAnnotation.as()) { diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 876f5f05..4c6d54e0 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,6 +7,8 @@ #include #include +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) + namespace Luau { @@ -79,10 +81,34 @@ void TxnLog::concat(TxnLog rhs) void TxnLog::commit() { for (auto& [ty, rep] : typeVarChanges) - *asMutable(ty) = rep.get()->pending; + { + if (FFlag::LuauNonCopyableTypeVarFields) + { + asMutable(ty)->reassign(rep.get()->pending); + } + else + { + TypeArena* owningArena = ty->owningArena; + TypeVar* mtv = asMutable(ty); + *mtv = rep.get()->pending; + mtv->owningArena = owningArena; + } + } for (auto& [tp, rep] : typePackChanges) - *asMutable(tp) = rep.get()->pending; + { + if (FFlag::LuauNonCopyableTypeVarFields) + { + asMutable(tp)->reassign(rep.get()->pending); + } + else + { + TypeArena* owningArena = tp->owningArena; + TypePackVar* mpv = asMutable(tp); + *mpv = rep.get()->pending; + mpv->owningArena = owningArena; + } + } clear(); } @@ -144,11 +170,6 @@ bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const return true; } - if (parent) - { - return parent->haveSeen(lhs, rhs); - } - return false; } @@ -173,8 +194,13 @@ PendingType* TxnLog::queue(TypeId ty) // about this type, we don't want to mutate the parent's state. auto& pending = typeVarChanges[ty]; if (!pending) + { pending = std::make_unique(*ty); + if (FFlag::LuauNonCopyableTypeVarFields) + pending->pending.owningArena = nullptr; + } + return pending.get(); } @@ -186,8 +212,13 @@ PendingTypePack* TxnLog::queue(TypePackId tp) // about this type, we don't want to mutate the parent's state. auto& pending = typePackChanges[tp]; if (!pending) + { pending = std::make_unique(*tp); + if (FFlag::LuauNonCopyableTypeVarFields) + pending->pending.owningArena = nullptr; + } + return pending.get(); } @@ -199,8 +230,8 @@ PendingType* TxnLog::pending(TypeId ty) const for (const TxnLog* current = this; current; current = current->parent) { - if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end()) - return it->second.get(); + if (auto it = current->typeVarChanges.find(ty)) + return it->get(); } return nullptr; @@ -214,8 +245,8 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const for (const TxnLog* current = this; current; current = current->parent) { - if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end()) - return it->second.get(); + if (auto it = current->typePackChanges.find(tp)) + return it->get(); } return nullptr; @@ -224,14 +255,24 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) { PendingType* newTy = queue(ty); - newTy->pending = replacement; + + if (FFlag::LuauNonCopyableTypeVarFields) + newTy->pending.reassign(replacement); + else + newTy->pending = replacement; + return newTy; } PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) { PendingTypePack* newTp = queue(tp); - newTp->pending = replacement; + + if (FFlag::LuauNonCopyableTypeVarFields) + newTp->pending.reassign(replacement); + else + newTp->pending = replacement; + return newTp; } diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp new file mode 100644 index 00000000..0c89d130 --- /dev/null +++ b/Analysis/src/TypeArena.cpp @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeArena.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false); + +namespace Luau +{ + +void TypeArena::clear() +{ + typeVars.clear(); + typePacks.clear(); +} + +TypeId TypeArena::addTV(TypeVar&& tv) +{ + TypeId allocated = typeVars.allocate(std::move(tv)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(TypeLevel level) +{ + TypeId allocated = typeVars.allocate(FreeTypeVar{level}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(std::initializer_list types) +{ + TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(std::vector types) +{ + TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(TypePack tp) +{ + TypePackId allocated = typePacks.allocate(std::move(tp)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(TypePackVar tp) +{ + TypePackId allocated = typePacks.allocate(std::move(tp)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +void freeze(TypeArena& arena) +{ + if (!FFlag::DebugLuauFreezeArena) + return; + + arena.typeVars.freeze(); + arena.typePacks.freeze(); +} + +void unfreeze(TypeArena& arena) +{ + if (!FFlag::DebugLuauFreezeArena) + return; + + arena.typeVars.unfreeze(); + arena.typePacks.unfreeze(); +} + +} // namespace Luau diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index d575e023..6cca7127 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -94,6 +94,21 @@ public: } } + AstType* operator()(const BlockedTypeVar& btv) + { + return allocator->alloc(Location(), std::nullopt, AstName("*blocked*")); + } + + AstType* operator()(const ConstrainedTypeVar& ctv) + { + AstArray types; + types.size = ctv.parts.size(); + types.data = static_cast(allocator->allocate(sizeof(AstType*) * ctv.parts.size())); + for (size_t i = 0; i < ctv.parts.size(); ++i) + types.data[i] = Luau::visit(*this, ctv.parts[i]->ty); + return allocator->alloc(Location(), types); + } + AstType* operator()(const SingletonTypeVar& stv) { if (const BooleanSingleton* bs = get(&stv)) @@ -261,7 +276,7 @@ public: } AstArray returnTypes; - const auto& [retVector, retTail] = flatten(ftv.retType); + const auto& [retVector, retTail] = flatten(ftv.retTypes); returnTypes.size = retVector.size(); returnTypes.data = static_cast(allocator->allocate(sizeof(AstType*) * returnTypes.size)); for (size_t i = 0; i < returnTypes.size; ++i) @@ -364,6 +379,9 @@ public: AstTypePack* operator()(const VariadicTypePack& vtp) const { + if (vtp.hidden) + return nullptr; + return allocator->alloc(Location(), Luau::visit(*typeVisitor, vtp.ty->ty)); } @@ -466,6 +484,20 @@ public: { return visitLocal(al->local); } + + virtual bool visit(AstStatFor* stat) override + { + visitLocal(stat->var); + return true; + } + + virtual bool visit(AstStatForIn* stat) override + { + for (size_t i = 0; i < stat->vars.size; ++i) + visitLocal(stat->vars.data[i]); + return true; + } + virtual bool visit(AstExprFunction* fn) override { // TODO: add generics if the inferred type of the function is generic CLI-39908 diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp new file mode 100644 index 00000000..63e5800f --- /dev/null +++ b/Analysis/src/TypeChecker2.cpp @@ -0,0 +1,333 @@ + +#include "Luau/TypeChecker2.h" + +#include + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/Clone.h" +#include "Luau/Normalize.h" +#include "Luau/ConstraintGraphBuilder.h" // FIXME move Scope2 into its own header +#include "Luau/Unifier.h" +#include "Luau/ToString.h" + +namespace Luau +{ + +struct TypeChecker2 : public AstVisitor +{ + const SourceModule* sourceModule; + Module* module; + InternalErrorReporter ice; // FIXME accept a pointer from Frontend + + TypeChecker2(const SourceModule* sourceModule, Module* module) + : sourceModule(sourceModule) + , module(module) + { + } + + using AstVisitor::visit; + + TypePackId lookupPack(AstExpr* expr) + { + TypePackId* tp = module->astTypePacks.find(expr); + LUAU_ASSERT(tp); + return follow(*tp); + } + + TypeId lookupType(AstExpr* expr) + { + TypeId* ty = module->astTypes.find(expr); + LUAU_ASSERT(ty); + return follow(*ty); + } + + TypeId lookupAnnotation(AstType* annotation) + { + TypeId* ty = module->astResolvedTypes.find(annotation); + LUAU_ASSERT(ty); + return follow(*ty); + } + + TypePackId reconstructPack(AstArray exprs, TypeArena& arena) + { + std::vector head; + + for (size_t i = 0; i < exprs.size - 1; ++i) + { + head.push_back(lookupType(exprs.data[i])); + } + + TypePackId tail = lookupPack(exprs.data[exprs.size - 1]); + return arena.addTypePack(TypePack{head, tail}); + } + + Scope2* findInnermostScope(Location location) + { + Scope2* bestScope = module->getModuleScope2(); + Location bestLocation = module->scope2s[0].first; + + for (size_t i = 0; i < module->scope2s.size(); ++i) + { + auto& [scopeBounds, scope] = module->scope2s[i]; + if (scopeBounds.encloses(location)) + { + if (scopeBounds.begin > bestLocation.begin || scopeBounds.end < bestLocation.end) + { + bestScope = scope.get(); + bestLocation = scopeBounds; + } + } + else + { + // TODO: Is this sound? This relies on the fact that scopes are inserted + // into the scope list in the order that they appear in the AST. + break; + } + } + + return bestScope; + } + + bool visit(AstStatLocal* local) override + { + for (size_t i = 0; i < local->values.size; ++i) + { + AstExpr* value = local->values.data[i]; + if (i == local->values.size - 1) + { + if (i < local->values.size) + { + TypePackId valueTypes = lookupPack(value); + auto it = begin(valueTypes); + for (size_t j = i; j < local->vars.size; ++j) + { + if (it == end(valueTypes)) + { + break; + } + + AstLocal* var = local->vars.data[i]; + if (var->annotation) + { + TypeId varType = lookupAnnotation(var->annotation); + if (!isSubtype(*it, varType, ice)) + { + reportError(TypeMismatch{varType, *it}, value->location); + } + } + + ++it; + } + } + } + else + { + TypeId valueType = lookupType(value); + AstLocal* var = local->vars.data[i]; + + if (var->annotation) + { + TypeId varType = lookupAnnotation(var->annotation); + if (!isSubtype(varType, valueType, ice)) + { + reportError(TypeMismatch{varType, valueType}, value->location); + } + } + } + } + + return true; + } + + bool visit(AstStatAssign* assign) override + { + size_t count = std::min(assign->vars.size, assign->values.size); + + for (size_t i = 0; i < count; ++i) + { + AstExpr* lhs = assign->vars.data[i]; + TypeId* lhsType = module->astTypes.find(lhs); + LUAU_ASSERT(lhsType); + + AstExpr* rhs = assign->values.data[i]; + TypeId* rhsType = module->astTypes.find(rhs); + LUAU_ASSERT(rhsType); + + if (!isSubtype(*rhsType, *lhsType, ice)) + { + reportError(TypeMismatch{*lhsType, *rhsType}, rhs->location); + } + } + + return true; + } + + bool visit(AstStatReturn* ret) override + { + Scope2* scope = findInnermostScope(ret->location); + TypePackId expectedRetType = scope->returnType; + + TypeArena arena; + TypePackId actualRetType = reconstructPack(ret->list, arena); + + UnifierSharedState sharedState{&ice}; + Unifier u{&arena, Mode::Strict, ret->location, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(actualRetType, expectedRetType); + const bool ok = u.errors.empty() && u.log.empty(); + + if (!ok) + { + for (const TypeError& e : u.errors) + module->errors.push_back(e); + } + + return true; + } + + bool visit(AstExprCall* call) override + { + TypePackId expectedRetType = lookupPack(call); + TypeId functionType = lookupType(call->func); + + TypeArena arena; + TypePack args; + for (const auto& arg : call->args) + { + TypeId argTy = module->astTypes[arg]; + LUAU_ASSERT(argTy); + args.head.push_back(argTy); + } + + TypePackId argsTp = arena.addTypePack(args); + FunctionTypeVar ftv{argsTp, expectedRetType}; + TypeId expectedType = arena.addType(ftv); + if (!isSubtype(expectedType, functionType, ice)) + { + unfreeze(module->interfaceTypes); + CloneState cloneState; + expectedType = clone(expectedType, module->interfaceTypes, cloneState); + freeze(module->interfaceTypes); + reportError(TypeMismatch{expectedType, functionType}, call->location); + } + + return true; + } + + bool visit(AstExprFunction* fn) override + { + TypeId inferredFnTy = lookupType(fn); + const FunctionTypeVar* inferredFtv = get(inferredFnTy); + LUAU_ASSERT(inferredFtv); + + auto argIt = begin(inferredFtv->argTypes); + for (const auto& arg : fn->args) + { + if (argIt == end(inferredFtv->argTypes)) + break; + + if (arg->annotation) + { + TypeId inferredArgTy = *argIt; + TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + + if (!isSubtype(annotatedArgTy, inferredArgTy, ice)) + { + reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); + } + } + + ++argIt; + } + + return true; + } + + bool visit(AstExprIndexName* indexName) override + { + TypeId leftType = lookupType(indexName->expr); + TypeId resultType = lookupType(indexName); + + // leftType must have a property called indexName->index + + if (auto ttv = get(leftType)) + { + auto it = ttv->props.find(indexName->index.value); + if (it == ttv->props.end()) + { + reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location); + } + else if (!isSubtype(resultType, it->second.type, ice)) + { + reportError(TypeMismatch{resultType, it->second.type}, indexName->location); + } + } + else + { + reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location); + } + + return true; + } + + bool visit(AstExprConstantNumber* number) override + { + TypeId actualType = lookupType(number); + TypeId numberType = getSingletonTypes().numberType; + + if (!isSubtype(actualType, numberType, ice)) + { + reportError(TypeMismatch{actualType, numberType}, number->location); + } + + return true; + } + + bool visit(AstExprConstantString* string) override + { + TypeId actualType = lookupType(string); + TypeId stringType = getSingletonTypes().stringType; + + if (!isSubtype(actualType, stringType, ice)) + { + reportError(TypeMismatch{actualType, stringType}, string->location); + } + + return true; + } + + bool visit(AstType* ty) override + { + return true; + } + + bool visit(AstTypeReference* ty) override + { + Scope2* scope = findInnermostScope(ty->location); + + // TODO: Imported types + // TODO: Generic types + if (!scope->lookupTypeBinding(ty->name.value)) + { + reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); + } + + return true; + } + + void reportError(TypeErrorData&& data, const Location& location) + { + module->errors.emplace_back(location, sourceModule->name, std::move(data)); + } +}; + +void check(const SourceModule& sourceModule, Module* module) +{ + TypeChecker2 typeChecker{&sourceModule, module}; + + sourceModule.root->visit(&typeChecker); +} + +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 9965d5aa..44635e88 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1,57 +1,57 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" +#include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/Instantiation.h" #include "Luau/ModuleResolver.h" +#include "Luau/Normalize.h" +#include "Luau/Parser.h" #include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Substitution.h" +#include "Luau/TimeTrace.h" #include "Luau/TopoSortStatements.h" +#include "Luau/ToString.h" +#include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" -#include "Luau/ToString.h" #include "Luau/TypeVar.h" -#include "Luau/TimeTrace.h" +#include "Luau/VisitTypeVar.h" #include #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) -LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) +LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) -LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) +LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) +LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) -LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. +LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) -LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) -LUAU_FASTFLAGVARIABLE(LuauSealExports, false) -LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) -LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) -LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) -LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) -LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) -LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify2, false) -LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) -LUAU_FASTFLAG(LuauTypeMismatchModuleName) -LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) -LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) +LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) +LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. -LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree2) -LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) -LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) -LUAU_FASTFLAGVARIABLE(LuauFixArgumentCountMismatchAmountWithGenericTypes, false) -LUAU_FASTFLAGVARIABLE(LuauFixIncorrectLineNumberDuplicateType, false) -LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) -LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) -LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) +LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) +LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) +LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); +LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) +LUAU_FASTFLAG(LuauQuantifyConstrained) +LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) +LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) namespace Luau { +const char* TimeLimitError::what() const throw() +{ + return "Typeinfer failed to complete in allotted time"; +} + static bool typeCouldHaveMetatable(TypeId ty) { return get(follow(ty)) || get(follow(ty)) || get(follow(ty)); @@ -130,6 +130,34 @@ bool hasBreak(AstStat* node) } } +static bool hasReturn(const AstStat* node) +{ + struct Searcher : AstVisitor + { + bool result = false; + + bool visit(AstStat*) override + { + return !result; // if we've already found a return statement, don't bother to traverse inward anymore + } + + bool visit(AstStatReturn*) override + { + result = true; + return false; + } + + bool visit(AstExprFunction*) override + { + return false; // We don't care if the function uses a lambda that itself returns + } + }; + + Searcher searcher; + const_cast(node)->visit(&searcher); + return searcher.result; +} + // returns the last statement before the block exits, or nullptr if the block never exits const AstStat* getFallthrough(const AstStat* node) { @@ -228,7 +256,6 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(getSingletonTypes().booleanType) , threadType(getSingletonTypes().threadType) , anyType(getSingletonTypes().anyType) - , optionalNumberType(getSingletonTypes().optionalNumberType) , anyTypePack(getSingletonTypes().anyTypePack) , duplicateTypeAliases{{false, {}}} { @@ -243,15 +270,36 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) +{ + try + { + return checkWithoutRecursionCheck(module, mode, environmentScope); + } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(module.root->location); + return std::move(currentModule); + } +} + +ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope) { LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); currentModule.reset(new Module()); currentModule->type = module.type; + currentModule->allocator = module.allocator; + currentModule->names = module.names; iceHandler->moduleName = module.name; + if (FFlag::LuauAutocompleteDynamicLimits) + { + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; + } + ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); @@ -270,7 +318,14 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona if (prepareModuleScope) prepareModuleScope(module.name, currentModule->getModuleScope()); - checkBlock(moduleScope, *module.root); + try + { + checkBlock(moduleScope, *module.root); + } + catch (const TimeLimitError&) + { + currentModule->timeout = true; + } if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); @@ -282,12 +337,7 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona prepareErrorsForDisplay(currentModule->errors); - bool encounteredFreeType = currentModule->clonePublicInterface(); - if (encounteredFreeType) - { - reportError(TypeError{module.root->location, - GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); - } + currentModule->clonePublicInterface(*iceHandler); // Clear unifier cache since it's keyed off internal types that get deallocated // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. @@ -295,8 +345,7 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona unifierState.cachedUnifyError.clear(); unifierState.skipCacheForType.clear(); - if (FFlag::LuauTwoPassAliasDefinitionFix) - duplicateTypeAliases.clear(); + duplicateTypeAliases.clear(); return std::move(currentModule); } @@ -365,6 +414,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) } else ice("Unknown AstStat"); + + if (finishTime && TimeTrace::getClock() > *finishTime) + throw TimeLimitError(); } // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. @@ -382,7 +434,19 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) reportErrorCodeTooComplex(block.location); return; } + try + { + checkBlockWithoutRecursionCheck(scope, block); + } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(block.location); + return; + } +} +void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) +{ int subLevel = 0; std::vector sorted(block.body.data, block.body.data + block.body.size); @@ -402,6 +466,16 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) std::unordered_map> functionDecls; + auto isLocalLambda = [](AstStat* stat) -> AstStatLocal* { + AstStatLocal* local = stat->as(); + + if (FFlag::LuauLowerBoundsCalculation && local && local->vars.size == 1 && local->values.size == 1 && + local->values.data[0]->is()) + return local; + else + return nullptr; + }; + auto checkBody = [&](AstStat* stat) { if (auto fun = stat->as()) { @@ -449,7 +523,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (containsFunctionCallOrReturn(**protoIter)) + if (containsFunctionCallOrReturn(**protoIter) || (FFlag::LuauLowerBoundsCalculation && isLocalLambda(*protoIter))) { while (checkIter != protoIter) { @@ -463,13 +537,25 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } else if (auto fun = (*protoIter)->as()) { - auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt); + std::optional expectedType; + + if (!fun->func->self) + { + if (auto name = fun->name->as()) + { + TypeId exprTy = checkExpr(scope, *name->expr).type; + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + } + } + + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, expectedType); auto [funTy, funScope] = pair; functionDecls[*protoIter] = pair; ++subLevel; - TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level); + TypeId leftType = follow(checkFunctionName(scope, *fun->name, funScope->level)); + unify(funTy, leftType, fun->location); } else if (auto fun = (*protoIter)->as()) @@ -503,7 +589,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { - if (FFlag::LuauTwoPassAliasDefinitionFix && typealias->name == kParseNameError) + if (typealias->name == kParseNameError) continue; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; @@ -512,7 +598,16 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - *asMutable(type) = *errorRecoveryType(anyType); + if (FFlag::LuauNonCopyableTypeVarFields) + { + TypeVar* mty = asMutable(follow(type)); + mty->reassign(*errorRecoveryType(anyType)); + } + else + { + *asMutable(type) = *errorRecoveryType(anyType); + } + reportError(TypeError{typealias->location, OccursCheckFailed{}}); } } @@ -565,10 +660,10 @@ static std::optional tryGetTypeGuardPredicate(const AstExprBinary& ex void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) { - ExprResult result = checkExpr(scope, *statement.condition); + WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr ifScope = childScope(scope, statement.thenbody->location); - reportErrors(resolve(result.predicates, ifScope, true)); + resolve(result.predicates, ifScope, true); check(ifScope, *statement.thenbody); if (statement.elsebody) @@ -598,10 +693,10 @@ ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Locat void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) { - ExprResult result = checkExpr(scope, *statement.condition); + WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr whileScope = childScope(scope, statement.body->location); - reportErrors(resolve(result.predicates, whileScope, true)); + resolve(result.predicates, whileScope, true); check(whileScope, *statement.body); } @@ -614,6 +709,64 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } +void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location) +{ + Unifier state = mkUnifier(location); + state.unifyLowerBound(subTy, superTy, demotedLevel); + + state.log.commit(); + + reportErrors(state.errors); +} + +struct Demoter : Substitution +{ + Demoter(TypeArena* arena) + : Substitution(TxnLog::empty(), arena) + { + } + + bool isDirty(TypeId ty) override + { + return get(ty); + } + + bool isDirty(TypePackId tp) override + { + return get(tp); + } + + TypeId clean(TypeId ty) override + { + auto ftv = get(ty); + LUAU_ASSERT(ftv); + return addType(FreeTypeVar{demotedLevel(ftv->level)}); + } + + TypePackId clean(TypePackId tp) override + { + auto ftp = get(tp); + LUAU_ASSERT(ftp); + return addTypePack(TypePackVar{FreeTypePack{demotedLevel(ftp->level)}}); + } + + TypeLevel demotedLevel(TypeLevel level) + { + return TypeLevel{level.level + 5000, level.subLevel}; + } + + void demote(std::vector>& expectedTypes) + { + if (!FFlag::LuauQuantifyConstrained) + return; + for (std::optional& ty : expectedTypes) + { + if (ty) + ty = substitute(*ty); + } + } +}; + void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; @@ -636,8 +789,17 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) } } + Demoter demoter{¤tModule->internalTypes}; + demoter.demote(expectedTypes); + TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; + if (FFlag::LuauReturnTypeInferenceInNonstrict ? FFlag::LuauLowerBoundsCalculation : useConstrainedIntersections()) + { + unifyLowerBound(retPack, scope->returnType, demoter.demotedLevel(scope->level), return_.location); + return; + } + // HACK: Nonstrict mode gets a bit too smart and strict for us when we // start typechecking everything across module boundaries. if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) @@ -731,9 +893,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) TypeId right = nullptr; - Location loc = 0 == assign.values.size - ? assign.location - : i < assign.values.size ? assign.values.data[i]->location : assign.values.data[assign.values.size - 1]->location; + Location loc = 0 == assign.values.size ? assign.location + : i < assign.values.size ? assign.values.data[i]->location + : assign.values.data[assign.values.size - 1]->location; if (valueIter != valueEnd) { @@ -1005,7 +1167,46 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) } else { - iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); + iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); + } + + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) + { + // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions + // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments + // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types + for (TypeId var : varTypes) + unify(anyType, var, forin.location); + + return check(loopScope, *forin.body); + } + + if (const TableTypeVar* iterTable = get(iterTy)) + { + // TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer + // this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting + if (iterTable->indexer) + { + if (varTypes.size() > 0) + unify(iterTable->indexer->indexType, varTypes[0], forin.location); + + if (varTypes.size() > 1) + unify(iterTable->indexer->indexResultType, varTypes[1], forin.location); + + for (size_t i = 2; i < varTypes.size(); ++i) + unify(nilType, varTypes[i], forin.location); + } + else + { + TypeId varTy = errorRecoveryType(loopScope); + + for (TypeId var : varTypes) + unify(varTy, var, forin.location); + + reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); + } + + return check(loopScope, *forin.body); } const FunctionTypeVar* iterFunc = get(iterTy); @@ -1017,7 +1218,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) unify(varTy, var, forin.location); if (!get(iterTy) && !get(iterTy) && !get(iterTy)) - reportError(TypeError{firstValue->location, TypeMismatch{globalScope->bindings[AstName{"next"}].typeId, iterTy}}); + reportError(firstValue->location, CannotCallNonFunction{iterTy}); return check(loopScope, *forin.body); } @@ -1061,7 +1262,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) unify(retPack, varPack, forin.location); } else - unify(iterFunc->retType, varPack, forin.location); + unify(iterFunc->retTypes, varPack, forin.location); check(loopScope, *forin.body); } @@ -1088,7 +1289,12 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco // If in nonstrict mode and allowing redefinition of global function, restore the previous definition type // in case this function has a differing signature. The signature discrepancy will be caught in checkBlock. if (previouslyDefined) + { + if (FFlag::LuauReturnTypeInferenceInNonstrict && FFlag::LuauLowerBoundsCalculation) + quantify(funScope, ty, exprName->location); + globalBindings[name] = oldBinding; + } else globalBindings[name] = {quantify(funScope, ty, exprName->location), exprName->location}; @@ -1103,22 +1309,18 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } - else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify2) + else if (auto name = function.name->as()) { TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); - if (!ttv) + + if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false)) { - if (isTableIntersection(exprTy)) + if (ttv || isTableIntersection(exprTy)) reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); - else if (!get(exprTy) && !get(exprTy)) + else reportError(TypeError{function.location, OnlyTablesCanHaveMethods{exprTy}}); } - else if (ttv->state == TableState::Sealed) - { - if (!ttv->indexer || !isPrim(ttv->indexer->indexType, PrimitiveTypeVar::String)) - reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); - } ty = follow(ty); @@ -1141,7 +1343,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } - else if (FFlag::LuauStatFunctionSimplify2) + else { LUAU_ASSERT(function.name->is()); @@ -1149,69 +1351,6 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); } - else if (function.func->self) - { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify2); - - AstExprIndexName* indexName = function.name->as(); - if (!indexName) - ice("member function declaration has malformed name expression"); - - TypeId selfTy = checkExpr(scope, *indexName->expr).type; - TableTypeVar* tableSelf = getMutableTableType(selfTy); - if (!tableSelf) - { - if (isTableIntersection(selfTy)) - reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); - else if (!get(selfTy) && !get(selfTy)) - reportError(TypeError{function.location, OnlyTablesCanHaveMethods{selfTy}}); - } - else if (tableSelf->state == TableState::Sealed) - reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); - - ty = follow(ty); - - if (tableSelf && tableSelf->state != TableState::Sealed) - tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; - - const FunctionTypeVar* funTy = get(ty); - if (!funTy) - ice("Methods should be functions"); - - std::optional arg0 = first(funTy->argTypes); - if (!arg0) - ice("Methods should always have at least 1 argument (self)"); - - checkFunctionBody(funScope, ty, *function.func); - - if (tableSelf && tableSelf->state != TableState::Sealed) - tableSelf->props[indexName->index.value] = { - follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; - } - else - { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify2); - - TypeId leftType = checkLValueBinding(scope, *function.name); - - checkFunctionBody(funScope, ty, *function.func); - - unify(ty, leftType, function.location); - - LUAU_ASSERT(function.name->is() || function.name->is()); - - if (auto exprIndexName = function.name->as()) - { - if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) - { - if (auto ttv = getMutableTableType(*typeIt)) - { - if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) - it->second.type = follow(quantify(funScope, leftType, function.name->location)); - } - } - } - } } void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function) @@ -1232,7 +1371,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. - if (FFlag::LuauTwoPassAliasDefinitionFix && name == kParseNameError) + if (name == kParseNameError) return; std::optional binding; @@ -1251,8 +1390,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; - if (FFlag::LuauTwoPassAliasDefinitionFix) - duplicateTypeAliases.insert({typealias.exported, name}); + duplicateTypeAliases.insert({typealias.exported, name}); } else { @@ -1269,15 +1407,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ftv->forwardedTypeAlias = true; bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; - if (FFlag::LuauFixIncorrectLineNumberDuplicateType) - scope->typeAliasLocations[name] = typealias.location; + scope->typeAliasLocations[name] = typealias.location; } } else { // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be // interesting. - if (FFlag::LuauTwoPassAliasDefinitionFix && duplicateTypeAliases.find({typealias.exported, name})) + if (duplicateTypeAliases.find({typealias.exported, name})) return; if (!binding) @@ -1321,8 +1458,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = name; @@ -1332,7 +1467,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias for (auto param : binding->typePackParams) clone.instantiatedTypePackParams.push_back(param.tp); + bool isNormal = ty->normal; ty = addType(std::move(clone)); + + if (FFlag::LuauLowerBoundsCalculation) + asMutable(ty)->normal = isNormal; } } else @@ -1356,10 +1495,17 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } TypeId& bindingType = bindingsMap[name].type; - bool ok = unify(ty, bindingType, typealias.location); - if (FFlag::LuauTwoPassAliasDefinitionFix && ok) + if (unify(ty, bindingType, typealias.location)) bindingType = ty; + + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(bindingType, currentModule, *iceHandler); + bindingType = t; + if (!ok) + reportError(typealias.location, NormalizationTooComplex{}); + } } } @@ -1392,7 +1538,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar Name className(declaredClass.name.value); - TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {})); + TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); ClassTypeVar* ctv = getMutable(classTy); TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); @@ -1418,7 +1564,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) ftv->hasSelf = true; } } @@ -1497,7 +1643,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) @@ -1506,7 +1652,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& return {errorRecoveryType(scope)}; } - ExprResult result; + WithPredicate result; if (auto a = expr.as()) result = checkExpr(scope, *a->expr, expectedType); @@ -1568,7 +1714,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& return result; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLocal& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLocal& expr) { std::optional lvalue = tryGetLValue(expr); LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprLocal is an LValue. @@ -1582,7 +1728,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) { std::optional lvalue = tryGetLValue(expr); LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprGlobal is an LValue. @@ -1594,7 +1740,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) { TypePackId varargPack = checkExprPack(scope, expr).type; @@ -1624,19 +1770,20 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCall& expr) { - ExprResult result = checkExprPack(scope, expr); + WithPredicate result = checkExprPack(scope, expr); TypePackId retPack = follow(result.type); if (auto pack = get(retPack)) { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (get(retPack)) + else if (const FreeTypePack* ftp = get(retPack)) { - TypeId head = freshType(scope); - TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); + TypeLevel level = FFlag::LuauLowerBoundsCalculation ? ftp->level : scope->level; + TypeId head = freshType(level); + TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(level)}}); unify(pack, retPack, expr.location); return {head, std::move(result.predicates)}; } @@ -1655,7 +1802,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa ice("Unknown TypePack type!", expr.location); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) { Name name = expr.index.value; @@ -1691,7 +1838,7 @@ std::optional TypeChecker::findMetatableEntry(TypeId type, std::string e } std::optional TypeChecker::getIndexTypeFromType( - const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) + const ScopePtr& scope, TypeId type, const std::string& name, const Location& location, bool addErrors) { type = follow(type); @@ -1700,34 +1847,34 @@ std::optional TypeChecker::getIndexTypeFromType( tablify(type); - if (FFlag::LuauDiscriminableUnions2) + if (isString(type)) { - if (isString(type)) - { - std::optional mtIndex = findMetatableEntry(stringType, "__index", location); - LUAU_ASSERT(mtIndex); - type = *mtIndex; - } - } - else - { - const PrimitiveTypeVar* primitiveType = get(type); - if (primitiveType && primitiveType->type == PrimitiveTypeVar::String) - { - if (std::optional mtIndex = findMetatableEntry(type, "__index", location)) - type = *mtIndex; - } + std::optional mtIndex = findMetatableEntry(stringType, "__index", location); + LUAU_ASSERT(mtIndex); + type = *mtIndex; } if (TableTypeVar* tableType = getMutableTableType(type)) { - const auto& it = tableType->props.find(name); - if (it != tableType->props.end()) + if (auto it = tableType->props.find(name); it != tableType->props.end()) return it->second.type; else if (auto indexer = tableType->indexer) { - tryUnify(stringType, indexer->indexType, location); - return indexer->indexResultType; + // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. + ErrorVec errors = tryUnify(stringType, indexer->indexType, location); + + if (FFlag::LuauReportErrorsOnIndexerKeyMismatch) + { + if (errors.empty()) + return indexer->indexResultType; + + if (addErrors) + reportError(location, UnknownProperty{type, name}); + + return std::nullopt; + } + else + return indexer->indexResultType; } else if (tableType->state == TableState::Free) { @@ -1736,8 +1883,7 @@ std::optional TypeChecker::getIndexTypeFromType( return result; } - auto found = findTablePropertyRespectingMeta(type, name, location); - if (found) + if (auto found = findTablePropertyRespectingMeta(type, name, location)) return *found; } else if (const ClassTypeVar* cls = get(type)) @@ -1777,12 +1923,25 @@ std::optional TypeChecker::getIndexTypeFromType( return std::nullopt; } - std::vector result = reduceUnion(goodOptions); + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(addType(UnionTypeVar{std::move(goodOptions)}), currentModule, + *iceHandler); // FIXME Inefficient. We craft a UnionTypeVar and immediately throw it away. - if (result.size() == 1) - return result[0]; + if (!ok) + reportError(location, NormalizationTooComplex{}); - return addType(UnionTypeVar{std::move(result)}); + return t; + } + else + { + std::vector result = reduceUnion(goodOptions); + + if (result.size() == 1) + return result[0]; + + return addType(UnionTypeVar{std::move(result)}); + } } else if (const IntersectionTypeVar* itv = get(type)) { @@ -1804,23 +1963,10 @@ std::optional TypeChecker::getIndexTypeFromType( return std::nullopt; } - if (FFlag::LuauDoNotTryToReduce) - { - if (parts.size() == 1) - return parts[0]; + if (parts.size() == 1) + return parts[0]; - return addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. - } - else - { - // TODO(amccord): Write some logic to correctly handle intersections. CLI-34659 - std::vector result = reduceUnion(parts); - - if (result.size() == 1) - return result[0]; - - return addType(IntersectionTypeVar{result}); - } + return addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. } if (addErrors) @@ -1831,16 +1977,29 @@ std::optional TypeChecker::getIndexTypeFromType( std::vector TypeChecker::reduceUnion(const std::vector& types) { - if (FFlag::LuauDoNotAccidentallyDependOnPointerOrdering) + std::vector result; + for (TypeId t : types) { - std::vector result; - for (TypeId t : types) - { - t = follow(t); - if (get(t) || get(t)) - return {t}; + t = follow(t); + if (get(t) || get(t)) + return {t}; - if (const UnionTypeVar* utv = get(t)) + if (const UnionTypeVar* utv = get(t)) + { + if (FFlag::LuauReduceUnionRecursion) + { + for (TypeId ty : utv) + { + if (FFlag::LuauNormalizeFlagIsConservative) + ty = follow(ty); + if (get(ty) || get(ty)) + return {ty}; + + if (result.end() == std::find(result.begin(), result.end(), ty)) + result.push_back(ty); + } + } + else { std::vector r = reduceUnion(utv->options); for (TypeId ty : r) @@ -1853,67 +2012,20 @@ std::vector TypeChecker::reduceUnion(const std::vector& types) result.push_back(ty); } } - else if (std::find(result.begin(), result.end(), t) == result.end()) - result.push_back(t); } - - return result; + else if (std::find(result.begin(), result.end(), t) == result.end()) + result.push_back(t); } - else - { - std::set s; - for (TypeId t : types) - { - if (const UnionTypeVar* utv = get(follow(t))) - { - std::vector r = reduceUnion(utv->options); - for (TypeId ty : r) - s.insert(ty); - } - else - s.insert(t); - } - - // If any of them are ErrorTypeVars/AnyTypeVars, decay into them. - for (TypeId t : s) - { - t = follow(t); - if (get(t) || get(t)) - return {t}; - } - - std::vector r(s.begin(), s.end()); - std::sort(r.begin(), r.end()); - return r; - } + return result; } std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) { if (const UnionTypeVar* utv = get(ty)) { - if (FFlag::LuauAnyInIsOptionalIsOptional) - { - if (!std::any_of(begin(utv), end(utv), isNil)) - return ty; - } - else - { - bool hasNil = false; - - for (TypeId option : utv) - { - if (isNil(option)) - { - hasNil = true; - break; - } - } - - if (!hasNil) - return ty; - } + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; std::vector result; @@ -1934,39 +2046,24 @@ std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) { - if (FFlag::LuauAnyInIsOptionalIsOptional) + ty = follow(ty); + + if (auto utv = get(ty)) { - ty = follow(ty); - - if (auto utv = get(ty)) - { - if (!std::any_of(begin(utv), end(utv), isNil)) - return ty; - - } - - if (std::optional strippedUnion = tryStripUnionFromNil(ty)) - { - reportError(location, OptionalValueAccess{ty}); - return follow(*strippedUnion); - } + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; } - else + + if (std::optional strippedUnion = tryStripUnionFromNil(ty)) { - if (isOptional(ty)) - { - if (std::optional strippedUnion = tryStripUnionFromNil(follow(ty))) - { - reportError(location, OptionalValueAccess{ty}); - return follow(*strippedUnion); - } - } + reportError(location, OptionalValueAccess{ty}); + return follow(*strippedUnion); } return ty; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) { TypeId ty = checkLValue(scope, expr); @@ -1977,7 +2074,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn return {ty}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) { auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, expectedType); @@ -2018,7 +2115,10 @@ TypeId TypeChecker::checkExprTable( indexer = expectedTable->indexer; if (indexer) + { + unify(numberType, indexer->indexType, value->location); unify(valueType, indexer->indexResultType, value->location); + } else indexer = TableIndexer{numberType, anyIfNonstrict(valueType)}; } @@ -2030,7 +2130,7 @@ TypeId TypeChecker::checkExprTable( if (isNonstrictMode() && !getTableType(exprType) && !get(exprType)) exprType = anyType; - if (FFlag::LuauPropertiesGetExpectedType && expectedTable) + if (expectedTable) { auto it = expectedTable->props.find(key->value.data); if (it != expectedTable->props.end()) @@ -2040,7 +2140,7 @@ TypeId TypeChecker::checkExprTable( if (errors.empty()) exprType = expectedProp.type; } - else if (expectedTable->indexer && isString(expectedTable->indexer->indexType)) + else if (expectedTable->indexer && maybeString(expectedTable->indexer->indexType)) { ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, k->location); if (errors.empty()) @@ -2072,15 +2172,21 @@ TypeId TypeChecker::checkExprTable( } } - TableState state = (expr.items.size == 0 || isNonstrictMode() || FFlag::LuauUnsealedTableLiteral) ? TableState::Unsealed : TableState::Sealed; + TableState state = TableState::Unsealed; TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state}; table.definitionModuleName = currentModuleName; return addType(table); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionCounter _rc(&checkRecursionCount); + if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) + { + reportErrorCodeTooComplex(expr.location); + return {errorRecoveryType(scope)}; + } + std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; @@ -2103,9 +2209,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa } } } - else if (FFlag::LuauExpectedTypesOfProperties) - if (const UnionTypeVar* utv = get(follow(*expectedType))) - expectedUnion = utv; + else if (const UnionTypeVar* utv = get(follow(*expectedType))) + expectedUnion = utv; } for (size_t i = 0; i < expr.items.size; ++i) @@ -2127,8 +2232,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa { if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; + else if (expectedIndexType && maybeString(*expectedIndexType)) + expectedResultType = expectedIndexResultType; } - else if (FFlag::LuauExpectedTypesOfProperties && expectedUnion) + else if (expectedUnion) { std::vector expectedResultTypes; for (TypeId expectedOption : expectedUnion) @@ -2160,9 +2267,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa return {checkExprTable(scope, expr, fieldTypes, expectedType)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) { - ExprResult result = checkExpr(scope, *expr.expr); + WithPredicate result = checkExpr(scope, *expr.expr); TypeId operandType = follow(result.type); switch (expr.op) @@ -2332,7 +2439,7 @@ TypeId TypeChecker::checkRelationalOperation( if (expr.op == AstExprBinary::Or && subexp->op == AstExprBinary::And) { ScopePtr subScope = childScope(scope, subexp->location); - reportErrors(resolve(predicates, subScope, true)); + resolve(predicates, subScope, true); return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); } } @@ -2390,15 +2497,50 @@ TypeId TypeChecker::checkRelationalOperation( std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); - // TODO: this check seems odd, the second part is redundant - // is it meant to be if (leftMetatable && rightMetatable && leftMetatable != rightMetatable) - if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) + if (leftMetatable != rightMetatable) { - reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorRecoveryType(booleanType); + bool matches = false; + if (isEquality) + { + if (const UnionTypeVar* utv = get(leftType); utv && rightMetatable) + { + for (TypeId leftOption : utv) + { + if (getMetatable(follow(leftOption)) == rightMetatable) + { + matches = true; + break; + } + } + } + + if (!matches) + { + if (const UnionTypeVar* utv = get(rhsType); utv && leftMetatable) + { + for (TypeId rightOption : utv) + { + if (getMetatable(follow(rightOption)) == leftMetatable) + { + matches = true; + break; + } + } + } + } + } + + + if (!matches) + { + reportError( + expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + return errorRecoveryType(booleanType); + } } + if (leftMetatable) { std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); @@ -2409,7 +2551,7 @@ TypeId TypeChecker::checkRelationalOperation( if (isEquality) { Unifier state = mkUnifier(expr.location); - state.tryUnify(addTypePack({booleanType}), ftv->retType); + state.tryUnify(addTypePack({booleanType}), ftv->retTypes); if (!state.errors.empty()) { @@ -2492,24 +2634,11 @@ TypeId TypeChecker::checkBinaryOperation( lhsType = follow(lhsType); rhsType = follow(rhsType); - if (FFlag::LuauDecoupleOperatorInferenceFromUnifiedTypeInference) + if (!isNonstrictMode() && get(lhsType)) { - if (!isNonstrictMode() && get(lhsType)) - { - auto name = getIdentifierOfBaseVar(expr.left); - reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - // We will fall-through to the `return anyType` check below. - } - } - else - { - if (!isNonstrictMode() && get(lhsType)) - { - auto name = getIdentifierOfBaseVar(expr.left); - reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); - } + auto name = getIdentifierOfBaseVar(expr.left); + reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); + // We will fall-through to the `return anyType` check below. } // If we know nothing at all about the lhs type, we can usually say nothing about the result. @@ -2548,7 +2677,7 @@ TypeId TypeChecker::checkBinaryOperation( reportErrors(state.errors); bool hasErrors = !state.errors.empty(); - if (FFlag::LuauErrorRecoveryType && hasErrors) + if (hasErrors) { // If there are unification errors, the return type may still be unknown // so we loosen the argument types to see if that helps. @@ -2562,8 +2691,7 @@ TypeId TypeChecker::checkBinaryOperation( if (state.errors.empty()) state.log.commit(); } - - if (!hasErrors) + else { state.log.commit(); } @@ -2612,7 +2740,7 @@ TypeId TypeChecker::checkBinaryOperation( } } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) { if (expr.op == AstExprBinary::And) { @@ -2623,8 +2751,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); - return {checkBinaryOperation(FFlag::LuauDiscriminableUnions2 ? scope : innerScope, expr, lhsTy, rhsTy), - {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; + return {checkBinaryOperation(scope, expr, lhsTy, rhsTy), {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::Or) { @@ -2636,7 +2763,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); // Because of C++, I'm not sure if lhsPredicates was not moved out by the time we call checkBinaryOperation. - TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions2 ? scope : innerScope, expr, lhsTy, rhsTy, lhsPredicates); + TypeId result = checkBinaryOperation(scope, expr, lhsTy, rhsTy, lhsPredicates); return {result, {OrPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) @@ -2644,8 +2771,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions2); - ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions2); + WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); + WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); PredicateVec predicates; @@ -2662,18 +2789,18 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi } else { - ExprResult lhs = checkExpr(scope, *expr.left); - ExprResult rhs = checkExpr(scope, *expr.right); + WithPredicate lhs = checkExpr(scope, *expr.left); + WithPredicate rhs = checkExpr(scope, *expr.right); // Intentionally discarding predicates with other operators. return {checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; } } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) { TypeId annotationType = resolveType(scope, *expr.annotation); - ExprResult result = checkExpr(scope, *expr.expr, annotationType); + WithPredicate result = checkExpr(scope, *expr.expr, annotationType); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. if (canUnify(annotationType, result.type, expr.location).empty()) @@ -2686,7 +2813,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy return {errorRecoveryType(annotationType), std::move(result.predicates)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) { const size_t oldSize = currentModule->errors.size(); @@ -2700,17 +2827,17 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) { - ExprResult result = checkExpr(scope, *expr.condition); + WithPredicate result = checkExpr(scope, *expr.condition); + ScopePtr trueScope = childScope(scope, expr.trueExpr->location); - reportErrors(resolve(result.predicates, trueScope, true)); - ExprResult trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); + resolve(result.predicates, trueScope, true); + WithPredicate trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); ScopePtr falseScope = childScope(scope, expr.falseExpr->location); - // Don't report errors for this scope to avoid potentially duplicating errors reported for the first scope. resolve(result.predicates, falseScope, false); - ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); + WithPredicate falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); if (falseType.type == trueType.type) return {trueType.type}; @@ -2972,49 +3099,25 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T else if (auto indexName = funName.as()) { TypeId lhsType = checkExpr(scope, *indexName->expr).type; - if (get(lhsType) || get(lhsType)) - return lhsType; - TableTypeVar* ttv = getMutableTableType(lhsType); - if (!ttv) + + if (!ttv || ttv->state == TableState::Sealed) { - if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) - // This error now gets reported when we check the function body. - reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); + if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, false)) + return *ty; return errorRecoveryType(scope); } - if (FFlag::LuauStatFunctionSimplify2) - { - if (lhsType->persistent) - return errorRecoveryType(scope); - - // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check - if (ttv->state == TableState::Sealed) - { - if (ttv->indexer && isPrim(ttv->indexer->indexType, PrimitiveTypeVar::String)) - return ttv->indexer->indexResultType; - else - return errorRecoveryType(scope); - } - } - else - { - if (lhsType->persistent || ttv->state == TableState::Sealed) - return errorRecoveryType(scope); - } - Name name = indexName->index.value; if (ttv->props.count(name)) - return errorRecoveryType(scope); + return ttv->props[name].type; Property& property = ttv->props[name]; property.type = freshTy(); property.location = indexName->indexLocation; - ttv->methodDefinitionLocations[name] = funName.location; return property.type; } else if (funName.is()) @@ -3082,11 +3185,11 @@ std::pair TypeChecker::checkFunctionSignature( TypePackId retPack; if (expr.returnAnnotation) retPack = resolveTypePack(funScope, *expr.returnAnnotation); - else if (isNonstrictMode()) + else if (FFlag::LuauReturnTypeInferenceInNonstrict ? (!FFlag::LuauLowerBoundsCalculation && isNonstrictMode()) : isNonstrictMode()) retPack = anyTypePack; else if (expectedFunctionType) { - auto [head, tail] = flatten(expectedFunctionType->retType); + auto [head, tail] = flatten(expectedFunctionType->retTypes); // Do not infer 'nil' as function return type if (!tail && head.size() == 1 && isNil(head[0])) @@ -3129,6 +3232,10 @@ std::pair TypeChecker::checkFunctionSignature( funScope->varargPack = anyTypePack; } } + else if (FFlag::LuauLowerBoundsCalculation && !isNonstrictMode()) + { + funScope->varargPack = addTypePack(TypePackVar{VariadicTypePack{anyType, /*hidden*/ true}}); + } std::vector argTypes; @@ -3264,20 +3371,35 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE { check(scope, *function.body); - // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (get_if(&funTy->retType->ty)) - *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + if (useConstrainedIntersections()) + { + TypePackId retPack = follow(funTy->retTypes); + // It is possible for a function to have no annotation and no return statement, and yet still have an ascribed return type + // if it is expected to conform to some other interface. (eg the function may be a lambda passed as a callback) + if (!hasReturn(function.body) && !function.returnAnnotation.has_value() && get(retPack)) + { + auto level = getLevel(retPack); + if (level && scope->level.subsumes(*level)) + *asMutable(retPack) = TypePack{{}, std::nullopt}; + } + } + else + { + // We explicitly don't follow here to check if we have a 'true' free type instead of bound one + if (get_if(&funTy->retTypes->ty)) + *asMutable(funTy->retTypes) = TypePack{{}, std::nullopt}; + } bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; - if (reachesImplicitReturn && !allowsNoReturnValues(follow(funTy->retType))) + if (reachesImplicitReturn && !allowsNoReturnValues(follow(funTy->retTypes))) { // If we're in nonstrict mode we want to only report this missing return // statement if there are type annotations on the function. In strict mode // we report it regardless. if (!isNonstrictMode() || function.returnAnnotation) { - reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retType}); + reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retTypes}); } } } @@ -3285,7 +3407,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE ice("Checking non functional type"); } -ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) +WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) return checkExprPack(scope, *a); @@ -3304,32 +3426,6 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } // Returns the minimum number of arguments the argument list can accept. -static size_t getMinParameterCount_DEPRECATED(TypePackId tp) -{ - size_t minCount = 0; - size_t optionalCount = 0; - - auto it = begin(tp); - auto endIter = end(tp); - - while (it != endIter) - { - TypeId ty = *it; - if (isOptional(ty)) - ++optionalCount; - else - { - minCount += optionalCount; - optionalCount = 0; - minCount++; - } - - ++it; - } - - return minCount; -} - static size_t getMinParameterCount(TxnLog* log, TypePackId tp) { size_t minCount = 0; @@ -3369,7 +3465,15 @@ void TypeChecker::checkArgumentList( size_t paramIndex = 0; - size_t minParams = FFlag::LuauFixIncorrectLineNumberDuplicateType ? 0 : getMinParameterCount_DEPRECATED(paramPack); + auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack]() { + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + + size_t minParams = getMinParameterCount(&state.log, paramPack); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + }; while (true) { @@ -3425,6 +3529,8 @@ void TypeChecker::checkArgumentList( } else if (auto vtp = state.log.getMutable(tail)) { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. while (paramIter != endIter) { state.tryUnify(vtp->ty, *paramIter); @@ -3459,14 +3565,13 @@ void TypeChecker::checkArgumentList( else if (state.log.getMutable(t)) { } // ok - else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.getMutable(t)) - { - } // ok else { - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - minParams = getMinParameterCount(&state.log, paramPack); - bool isVariadic = FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic && !finite(paramPack, &state.log); + size_t minParams = getMinParameterCount(&state.log, paramPack); + + std::optional tail = flatten(paramPack, state.log).second; + bool isVariadic = tail && Luau::isVariadic(*tail); + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex, CountMismatch::Context::Arg, isVariadic}}); return; } @@ -3485,14 +3590,7 @@ void TypeChecker::checkArgumentList( unify(errorRecoveryType(scope), *argIter, state.location); ++argIter; } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - minParams = getMinParameterCount(&state.log, paramPack); - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + reportCountMismatchError(); return; } TypePackId tail = state.log.follow(*paramIter.tail()); @@ -3504,6 +3602,21 @@ void TypeChecker::checkArgumentList( } else if (auto vtp = state.log.getMutable(tail)) { + if (FFlag::LuauLowerBoundsCalculation && vtp->hidden) + { + // We know that this function can technically be oversaturated, but we have its definition and we + // know that it's useless. + + TypeId e = errorRecoveryType(scope); + while (argIter != endIter) + { + unify(e, *argIter, state.location); + ++argIter; + } + + reportCountMismatchError(); + return; + } // Function is variadic and requires that all subsequent parameters // be compatible with a type. size_t argIndex = paramIndex; @@ -3534,10 +3647,7 @@ void TypeChecker::checkArgumentList( } TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); - if (FFlag::LuauWidenIfSupertypeIsFree2) - state.tryUnify(varPack, tail); - else - state.tryUnify(tail, varPack); + state.tryUnify(varPack, tail); return; } @@ -3548,14 +3658,7 @@ void TypeChecker::checkArgumentList( } else if (state.log.getMutable(tail)) { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - minParams = getMinParameterCount(&state.log, paramPack); - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + reportCountMismatchError(); return; } } @@ -3570,7 +3673,7 @@ void TypeChecker::checkArgumentList( } } -ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) { // evaluate type of function // decompose an intersection into its component overloads @@ -3611,10 +3714,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A actualFunctionType = instantiate(scope, functionType, expr.func->location); } - actualFunctionType = follow(actualFunctionType); - TypePackId retPack; - if (!FFlag::LuauWidenIfSupertypeIsFree2) + if (FFlag::LuauLowerBoundsCalculation) { retPack = freshTypePack(scope->level); } @@ -3624,7 +3725,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A { retPack = freshTypePack(free->level); TypePackId freshArgPack = freshTypePack(free->level); - *asMutable(actualFunctionType) = FunctionTypeVar(free->level, freshArgPack, retPack); + asMutable(actualFunctionType)->ty.emplace(free->level, freshArgPack, retPack); } else retPack = freshTypePack(scope->level); @@ -3640,7 +3741,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); - ExprResult argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); + WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); TypePackId argPack = argListResult.type; if (get(argPack)) @@ -3678,16 +3779,13 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); - if (FFlag::LuauErrorRecoveryType) - { - const FunctionTypeVar* overload = nullptr; - if (!overloadsThatMatchArgCount.empty()) - overload = get(overloadsThatMatchArgCount[0]); - if (!overload && !overloadsThatDont.empty()) - overload = get(overloadsThatDont[0]); - if (overload) - return {errorRecoveryTypePack(overload->retType)}; - } + const FunctionTypeVar* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return {errorRecoveryTypePack(overload->retTypes)}; return {errorRecoveryTypePack(retPack)}; } @@ -3696,7 +3794,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st { std::vector> expectedTypes; - auto assignOption = [this, &expectedTypes](size_t index, std::optional ty) { + auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { if (index == expectedTypes.size()) { expectedTypes.push_back(ty); @@ -3711,7 +3809,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } else { - std::vector result = reduceUnion({*el, *ty}); + std::vector result = reduceUnion({*el, ty}); el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); } } @@ -3731,7 +3829,8 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st if (argsTail) { - if (const VariadicTypePack* vtp = get(follow(*argsTail))) + argsTail = follow(*argsTail); + if (const VariadicTypePack* vtp = get(*argsTail)) { while (index < argumentCount) assignOption(index++, vtp->ty); @@ -3740,11 +3839,14 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } } + Demoter demoter{¤tModule->internalTypes}; + demoter.demote(expectedTypes); + return expectedTypes; } -std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, +std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { LUAU_ASSERT(argLocations); @@ -3762,21 +3864,44 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {{errorRecoveryTypePack(scope)}}; } - if (get(fn)) + if (auto ftv = get(fn)) { // fn is one of the overloads of actualFunctionType, which // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. - TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - if (FFlag::LuauWidenIfSupertypeIsFree2) + if (useConstrainedIntersections()) { + // This ternary is phrased deliberately. We need ties between sibling scopes to bias toward ftv->level. + const TypeLevel level = scope->level.subsumes(ftv->level) ? scope->level : ftv->level; + + std::vector adjustedArgTypes; + auto it = begin(argPack); + auto endIt = end(argPack); + Widen widen{¤tModule->internalTypes}; + for (; it != endIt; ++it) + { + adjustedArgTypes.push_back(addType(ConstrainedTypeVar{level, {widen(*it)}})); + } + + TypePackId adjustedArgPack = addTypePack(TypePack{std::move(adjustedArgTypes), it.tail()}); + + TxnLog log; + promoteTypeLevels(log, ¤tModule->internalTypes, level, retPack); + log.commit(); + + *asMutable(fn) = FunctionTypeVar{level, adjustedArgPack, retPack}; + return {{retPack}}; + } + else + { + TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); + UnifierOptions options; options.isFunctionCall = true; unify(r, fn, expr.location, options); + + return {{retPack}}; } - else - unify(fn, r, expr.location); - return {{retPack}}; } std::vector metaArgLocations; @@ -3816,14 +3941,14 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (ftv->magicFunction) { // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 - if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) + if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) return *ret; } Unifier state = mkUnifier(expr.location); // Unify return types - checkArgumentList(scope, state, retPack, ftv->retType, /*argLocations*/ {}); + checkArgumentList(scope, state, retPack, ftv->retTypes, /*argLocations*/ {}); if (!state.errors.empty()) { return {}; @@ -3849,7 +3974,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!argMismatch) overloadsThatMatchArgCount.push_back(fn); - else if (FFlag::LuauErrorRecoveryType) + else overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); @@ -3858,32 +3983,6 @@ std::optional> TypeChecker::checkCallOverload(const Scope { state.log.commit(); - if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) - { - // If we are running in nonstrict mode, passing fewer arguments than the function is declared to take AND - // the function is declared with colon notation AND we use dot notation, warn. - auto [providedArgs, providedTail] = flatten(argPack); - - // If we have a variadic tail, we can't say how many arguments were actually provided - if (!providedTail) - { - std::vector actualArgs = flatten(ftv->argTypes).first; - - size_t providedCount = providedArgs.size(); - size_t requiredCount = actualArgs.size(); - - // Ignore optional arguments - while (providedCount < requiredCount && requiredCount != 0 && isOptional(actualArgs[requiredCount - 1])) - requiredCount--; - - if (providedCount < requiredCount) - { - int requiredExtraNils = int(requiredCount - providedCount); - reportError(TypeError{expr.func->location, FunctionRequiresSelf{requiredExtraNils}}); - } - } - } - currentModule->astOverloadResolvedTypes[&expr] = fn; // We select this overload @@ -3920,7 +4019,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // we eagerly assume that that's what you actually meant and we commit to it. // This could be incorrect if the function has an additional overload that // actually works. - // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + // checkArgumentList(scope, editedState, retPack, ftv->retTypes, retLocations, CountMismatch::Return); return true; } } @@ -3951,7 +4050,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // we eagerly assume that that's what you actually meant and we commit to it. // This could be incorrect if the function has an additional overload that // actually works. - // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + // checkArgumentList(scope, editedState, retPack, ftv->retTypes, retLocations, CountMismatch::Return); return true; } } @@ -4009,7 +4108,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast // Unify return types if (const FunctionTypeVar* ftv = get(overload)) { - checkArgumentList(scope, state, retPack, ftv->retType, {}); + checkArgumentList(scope, state, retPack, ftv->retTypes, {}); checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); } @@ -4034,7 +4133,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast return; } -ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, +WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil, const std::vector& instantiateGenerics, const std::vector>& expectedTypes) { TypePackId pack = addTypePack(TypePack{}); @@ -4143,6 +4242,15 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return anyType; } + // Types of requires that transitively refer to current module have to be replaced with 'any' + std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); + + for (const auto& [location, path] : requireCycles) + { + if (!path.empty() && path.front() == humanReadableName) + return anyType; + } + ModulePtr module = resolver->getModule(moduleInfo.name); if (!module) { @@ -4150,17 +4258,13 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // either the file does not exist or there's a cycle. If there's a cycle // we will already have reported the error. if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) - { - std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(TypeError{location, UnknownRequire{reportedModulePath}}); - } + reportError(TypeError{location, UnknownRequire{humanReadableName}}); return errorRecoveryType(scope); } if (module->type != SourceCode::Module) { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); return errorRecoveryType(scope); } @@ -4173,7 +4277,6 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module std::optional moduleType = first(modulePack); if (!moduleType) { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); return errorRecoveryType(scope); } @@ -4279,119 +4382,26 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s } } -bool Instantiation::isDirty(TypeId ty) -{ - if (log->getMutable(ty)) - return true; - else - return false; -} - -bool Instantiation::isDirty(TypePackId tp) -{ - return false; -} - -bool Instantiation::ignoreChildren(TypeId ty) -{ - if (log->getMutable(ty)) - return true; - else - return false; -} - -TypeId Instantiation::clean(TypeId ty) -{ - const FunctionTypeVar* ftv = log->getMutable(ty); - LUAU_ASSERT(ftv); - - FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; - clone.magicFunction = ftv->magicFunction; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - TypeId result = addType(std::move(clone)); - - // Annoyingly, we have to do this even if there are no generics, - // to replace any generic tables. - ReplaceGenerics replaceGenerics{log, arena, level, ftv->generics, ftv->genericPacks}; - - // TODO: What to do if this returns nullopt? - // We don't have access to the error-reporting machinery - result = replaceGenerics.substitute(result).value_or(result); - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; -} - -TypePackId Instantiation::clean(TypePackId tp) -{ - LUAU_ASSERT(false); - return tp; -} - -bool ReplaceGenerics::ignoreChildren(TypeId ty) -{ - if (const FunctionTypeVar* ftv = log->getMutable(ty)) - // We aren't recursing in the case of a generic function which - // binds the same generics. This can happen if, for example, there's recursive types. - // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. - // It's OK to use vector equality here, since we always generate fresh generics - // whenever we quantify, so the vectors overlap if and only if they are equal. - return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); - else - return false; -} - -bool ReplaceGenerics::isDirty(TypeId ty) -{ - if (const TableTypeVar* ttv = log->getMutable(ty)) - return ttv->state == TableState::Generic; - else if (log->getMutable(ty)) - return std::find(generics.begin(), generics.end(), ty) != generics.end(); - else - return false; -} - -bool ReplaceGenerics::isDirty(TypePackId tp) -{ - if (log->getMutable(tp)) - return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); - else - return false; -} - -TypeId ReplaceGenerics::clean(TypeId ty) -{ - LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = log->getMutable(ty)) - { - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - return addType(std::move(clone)); - } - else - return addType(FreeTypeVar{level}); -} - -TypePackId ReplaceGenerics::clean(TypePackId tp) -{ - LUAU_ASSERT(isDirty(tp)); - return addTypePack(TypePackVar(FreeTypePack{level})); -} - bool Anyification::isDirty(TypeId ty) { + if (ty->persistent) + return false; + if (const TableTypeVar* ttv = log->getMutable(ty)) - return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed)); + return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); else if (log->getMutable(ty)) return true; + else if (get(ty)) + return true; else return false; } bool Anyification::isDirty(TypePackId tp) { + if (tp->persistent) + return false; + if (log->getMutable(tp)) return true; else @@ -4404,15 +4414,34 @@ TypeId Anyification::clean(TypeId ty) if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; - if (FFlag::LuauSealExports) + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.tags = ttv->tags; + TypeId res = addType(std::move(clone)); + asMutable(res)->normal = ty->normal; + return res; + } + else if (auto ctv = get(ty)) + { + if (FFlag::LuauQuantifyConstrained) { - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.tags = ttv->tags; + std::vector copy = ctv->parts; + for (TypeId& ty : copy) + ty = replace(ty); + TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)}); + auto [t, ok] = normalize(res, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; + } + else + { + auto [t, ok] = normalize(ty, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; } - return addType(std::move(clone)); } else return anyType; @@ -4429,16 +4458,42 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location ty = follow(ty); const FunctionTypeVar* ftv = get(ty); - if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) - return ty; - Luau::quantify(ty, scope->level); + if (FFlag::LuauAlwaysQuantify) + { + if (ftv) + Luau::quantify(ty, scope->level); + } + else + { + if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) + Luau::quantify(ty, scope->level); + } + + if (FFlag::LuauLowerBoundsCalculation && ftv) + { + auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + return t; + } + return ty; } TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { + ty = follow(ty); + + const FunctionTypeVar* ftv = get(ty); + if (ftv && ftv->hasNoGenerics) + return ty; + Instantiation instantiation{log, ¤tModule->internalTypes, scope->level}; + + if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit) + instantiation.childLimit = *instantiationChildLimit; + std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; @@ -4451,8 +4506,18 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { - Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + ty = t; + } + + Anyification anyification{¤tModule->internalTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); + if (anyification.normalizationTooComplex) + reportError(location, NormalizationTooComplex{}); if (any.has_value()) return *any; else @@ -4464,7 +4529,15 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) { - Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + ty = t; + } + + Anyification anyification{¤tModule->internalTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; @@ -4617,8 +4690,7 @@ TypeId TypeChecker::freshType(TypeLevel level) TypeId TypeChecker::singletonType(bool value) { - // TODO: cache singleton types - return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BooleanSingleton{value}))); + return value ? getSingletonTypes().trueType : getSingletonTypes().falseType; } TypeId TypeChecker::singletonType(std::string value) @@ -4666,8 +4738,11 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) if (isNil(ty)) return sense ? std::nullopt : std::optional(ty); - // at this point, anything else is kept if sense is true, or eliminated otherwise - return sense ? std::optional(ty) : std::nullopt; + // at this point, anything else is kept if sense is true, or replaced by nil + if (FFlag::LuauFalsyPredicateReturnsNilInstead) + return sense ? ty : nilType; + else + return sense ? std::optional(ty) : std::nullopt; }; } @@ -4746,6 +4821,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation ToStringOptions opts; opts.exhaustive = true; opts.maxTableLength = 0; + opts.useLineBreaks = true; TypeId param = resolveType(scope, *lit->parameters.data[0].type); luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str())); @@ -4790,8 +4866,6 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); parameterCountErrorReported = true; - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); } } @@ -4909,33 +4983,25 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation reportError( TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); - if (FFlag::LuauErrorRecoveryType) - { - // Pad the types out with error recovery types - while (typeParams.size() < tf->typeParams.size()) - typeParams.push_back(errorRecoveryType(scope)); - while (typePackParams.size() < tf->typePackParams.size()) - typePackParams.push_back(errorRecoveryTypePack(scope)); - } - else - return errorRecoveryType(scope); + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) + typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); } - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { - return itp == tp.ty; + bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal( + typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; }); - bool sameTps = std::equal( - typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { - return itpp == tpp.tp; - }); - // If the generic parameters and the type arguments are the same, we are about to - // perform an identity substitution, which we can just short-circuit. - if (sameTys && sameTps) - return tf->type; - } + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + if (sameTys && sameTps) + return tf->type; return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); } @@ -4950,19 +5016,9 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (const auto& indexer = table->indexer) tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); - if (FFlag::LuauTypeMismatchModuleName) - { - TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed}; - ttv.definitionModuleName = currentModuleName; - return addType(std::move(ttv)); - } - else - { - return addType(TableTypeVar{ - props, tableIndexer, scope->level, - TableState::Sealed // FIXME: probably want a way to annotate other kinds of tables maybe - }); - } + TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed}; + ttv.definitionModuleName = currentModuleName; + return addType(std::move(ttv)); } else if (const auto& func = annotation.as()) { @@ -5101,14 +5157,11 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack bool ApplyTypeFunction::isDirty(TypeId ty) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics. - if (get(ty)) + if (typeArguments.count(ty)) return true; else if (const FreeTypeVar* ftv = get(ty)) { - if (FFlag::LuauRecursiveTypeParameterRestriction && ftv->forwardedTypeAlias) + if (ftv->forwardedTypeAlias) encounteredForwardedType = true; return false; } @@ -5118,10 +5171,7 @@ bool ApplyTypeFunction::isDirty(TypeId ty) bool ApplyTypeFunction::isDirty(TypePackId tp) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics. - if (get(tp)) + if (typePackArguments.count(tp)) return true; else return false; @@ -5145,26 +5195,16 @@ bool ApplyTypeFunction::ignoreChildren(TypePackId tp) TypeId ApplyTypeFunction::clean(TypeId ty) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics by free type variables. TypeId& arg = typeArguments[ty]; - if (arg) - return arg; - else - return addType(FreeTypeVar{level}); + LUAU_ASSERT(arg); + return arg; } TypePackId ApplyTypeFunction::clean(TypePackId tp) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics by free type variables. TypePackId& arg = typePackArguments[tp]; - if (arg) - return arg; - else - return addTypePack(FreeTypePack{level}); + LUAU_ASSERT(arg); + return arg; } TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, @@ -5187,7 +5227,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, reportError(location, UnificationTooComplex{}); return errorRecoveryType(scope); } - if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType) + if (applyTypeFunction.encounteredForwardedType) { reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}}); return errorRecoveryType(scope); @@ -5197,9 +5237,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId target = follow(instantiated); bool needsClone = follow(tf.type) == target; - bool shouldMutate = (!FFlag::LuauOnlyMutateInstantiatedTables || getTableType(tf.type)); + bool shouldMutate = getTableType(tf.type); TableTypeVar* ttv = getMutableTableType(target); - + if (shouldMutate && ttv && needsClone) { // Substitution::clone is a shallow clone. If this is a metatable type, we @@ -5224,9 +5264,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, { ttv->instantiatedTypeParams = typeParams; ttv->instantiatedTypePackParams = typePackParams; - - if (FFlag::LuauTypeMismatchModuleName) - ttv->definitionModuleName = currentModuleName; + ttv->definitionModuleName = currentModuleName; } return instantiated; @@ -5259,7 +5297,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st } TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction && (!FFlag::LuauGenericFunctionsDontCacheTypeParams || useCache)) + if (useCache) { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) @@ -5294,21 +5332,12 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypePackId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; - if (!cached) - cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); - g = cached; - } - else - { - g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); - } + TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; + if (!cached) + cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); - genericPacks.push_back({g, defaultValue}); - scope->privateTypePackBindings[n] = g; + genericPacks.push_back({cached, defaultValue}); + scope->privateTypePackBindings[n] = cached; } return {generics, genericPacks}; @@ -5316,8 +5345,6 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate) { - LUAU_ASSERT(FFlag::LuauDiscriminableUnions2 || FFlag::LuauAssertStripsFalsyTypes); - const LValue* target = &lvalue; std::optional key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type. @@ -5403,25 +5430,22 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV // We need to search in the provided Scope. Find t.x.y first. // We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x. // If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate. - const auto& [symbol, keys] = getFullName(lvalue); + const Symbol symbol = getBaseSymbol(lvalue); ScopePtr currentScope = scope; while (currentScope) { std::optional found; - std::vector childKeys; - const LValue* currentLValue = &lvalue; - while (currentLValue) + const LValue* topLValue = nullptr; + + for (topLValue = &lvalue; topLValue; topLValue = baseof(*topLValue)) { - if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end()) + if (auto it = currentScope->refinements.find(*topLValue); it != currentScope->refinements.end()) { found = it->second; break; } - - childKeys.push_back(*currentLValue); - currentLValue = baseof(*currentLValue); } if (!found) @@ -5437,9 +5461,15 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV } } + // We need to walk the l-value path in reverse, so we collect components into a vector + std::vector childKeys; + + for (const LValue* curr = &lvalue; curr != topLValue; curr = baseof(*curr)) + childKeys.push_back(curr); + for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) { - const LValue& key = *it; + const LValue& key = **it; // Symbol can happen. Skip. if (get(key)) @@ -5477,85 +5507,47 @@ static bool isUndecidable(TypeId ty) return get(ty) || get(ty) || get(ty); } -ErrorVec TypeChecker::resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense) { - ErrorVec errVec; - resolve(predicates, errVec, scope->refinements, scope, sense); - return errVec; + resolve(predicates, scope->refinements, scope, sense); } -void TypeChecker::resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const PredicateVec& predicates, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { for (const Predicate& c : predicates) - resolve(c, errVec, refis, scope, sense, fromOr); + resolve(c, refis, scope, sense, fromOr); } -void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const Predicate& predicate, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { if (auto truthyP = get(predicate)) - resolve(*truthyP, errVec, refis, scope, sense, fromOr); + resolve(*truthyP, refis, scope, sense, fromOr); else if (auto andP = get(predicate)) - resolve(*andP, errVec, refis, scope, sense); + resolve(*andP, refis, scope, sense); else if (auto orP = get(predicate)) - resolve(*orP, errVec, refis, scope, sense); + resolve(*orP, refis, scope, sense); else if (auto notP = get(predicate)) - resolve(notP->predicates, errVec, refis, scope, !sense, fromOr); + resolve(notP->predicates, refis, scope, !sense, fromOr); else if (auto isaP = get(predicate)) - resolve(*isaP, errVec, refis, scope, sense); + resolve(*isaP, refis, scope, sense); else if (auto typeguardP = get(predicate)) - resolve(*typeguardP, errVec, refis, scope, sense); + resolve(*typeguardP, refis, scope, sense); else if (auto eqP = get(predicate)) - resolve(*eqP, errVec, refis, scope, sense); + resolve(*eqP, refis, scope, sense); else ice("Unhandled predicate kind"); } -void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const TruthyPredicate& truthyP, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { - if (FFlag::LuauAssertStripsFalsyTypes) - { - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (ty && fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (ty && fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); - refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense)); - } - else - { - auto predicate = [sense](TypeId option) -> std::optional { - if (isUndecidable(option) || isBoolean(option) || isNil(option) != sense) - return option; - - return std::nullopt; - }; - - if (FFlag::LuauDiscriminableUnions2) - { - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (ty && fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); - - refineLValue(truthyP.lvalue, refis, scope, predicate); - } - else - { - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (!ty) - return; - - // This is a hack. :( - // Without this, the expression 'a or b' might refine 'b' to be falsy. - // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. - if (fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); - - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, truthyP.lvalue, *result); - } - } + refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense)); } -void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense) { if (!sense) { @@ -5564,14 +5556,14 @@ void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, Refinement {NotPredicate{std::move(andP.rhs)}}, }; - return resolve(orP, errVec, refis, scope, !sense); + return resolve(orP, refis, scope, !sense); } - resolve(andP.lhs, errVec, refis, scope, sense); - resolve(andP.rhs, errVec, refis, scope, sense); + resolve(andP.lhs, refis, scope, sense); + resolve(andP.rhs, refis, scope, sense); } -void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const OrPredicate& orP, RefinementMap& refis, const ScopePtr& scope, bool sense) { if (!sense) { @@ -5580,28 +5572,24 @@ void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMa {NotPredicate{std::move(orP.rhs)}}, }; - return resolve(andP, errVec, refis, scope, !sense); + return resolve(andP, refis, scope, !sense); } - ErrorVec discarded; - RefinementMap leftRefis; - resolve(orP.lhs, errVec, leftRefis, scope, sense); + resolve(orP.lhs, leftRefis, scope, sense); RefinementMap rightRefis; - resolve(orP.lhs, discarded, rightRefis, scope, !sense); - resolve(orP.rhs, errVec, rightRefis, scope, sense, true); // :( + resolve(orP.lhs, rightRefis, scope, !sense); + resolve(orP.rhs, rightRefis, scope, sense, true); // :( merge(refis, leftRefis); merge(refis, rightRefis); } -void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense) { auto predicate = [&](TypeId option) -> std::optional { // This by itself is not truly enough to determine that A is stronger than B or vice versa. - // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. - // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) bool optionIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); bool targetIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); @@ -5642,32 +5630,15 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return res; }; - if (FFlag::LuauDiscriminableUnions2) - { - refineLValue(isaP.lvalue, refis, scope, predicate); - } - else - { - std::optional ty = resolveLValue(refis, scope, isaP.lvalue); - if (!ty) - return; - - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, isaP.lvalue, *result); - else - { - addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); - errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); - } - } + refineLValue(isaP.lvalue, refis, scope, predicate); } -void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& refis, const ScopePtr& scope, bool sense) { // Rewrite the predicate 'type(foo) == "vector"' to be 'typeof(foo) == "Vector3"'. They're exactly identical. // This allows us to avoid writing in edge cases. if (!typeguardP.isTypeof && typeguardP.kind == "vector") - return resolve(TypeGuardPredicate{std::move(typeguardP.lvalue), typeguardP.location, "Vector3", true}, errVec, refis, scope, sense); + return resolve(TypeGuardPredicate{std::move(typeguardP.lvalue), typeguardP.location, "Vector3", true}, refis, scope, sense); std::optional ty = resolveLValue(refis, scope, typeguardP.lvalue); if (!ty) @@ -5717,52 +5688,29 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec if (auto it = primitives.find(typeguardP.kind); it != primitives.end()) { - if (FFlag::LuauDiscriminableUnions2) - { - refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); - return; - } - else - { - if (std::optional result = filterMap(*ty, it->second(sense))) - addRefinement(refis, typeguardP.lvalue, *result); - else - { - addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - if (sense) - errVec.push_back( - TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); - } - - return; - } + refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); + return; } - auto fail = [&](const TypeErrorData& err) { - if (!FFlag::LuauDiscriminableUnions2) - errVec.push_back(TypeError{typeguardP.location, err}); - addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - }; - if (!typeguardP.isTypeof) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); auto typeFun = globalScope->lookupType(typeguardP.kind); if (!typeFun || !typeFun->typeParams.empty() || !typeFun->typePackParams.empty()) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); TypeId type = follow(typeFun->type); // We're only interested in the root class of any classes. if (auto ctv = get(type); !ctv || ctv->parent) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. // Until then, we rewrite this to be the same as using IsA. - return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense); + return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, refis, scope, sense); } -void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. auto options = [](TypeId ty) -> std::vector { @@ -5771,82 +5719,30 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa return {ty}; }; - if (FFlag::LuauDiscriminableUnions2) - { - std::vector rhs = options(eqP.type); + std::vector rhs = options(eqP.type); - if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - auto predicate = [&](TypeId option) -> std::optional { - if (sense && isUndecidable(option)) - return FFlag::LuauWeakEqConstraint ? option : eqP.type; + auto predicate = [&](TypeId option) -> std::optional { + if (!sense && isNil(eqP.type)) + return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; - if (!sense && isNil(eqP.type)) - return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; - - if (maybeSingleton(eqP.type)) - { - // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. - if (!sense || canUnify(eqP.type, option, eqP.location).empty()) - return sense ? eqP.type : option; - - // local variable works around an odd gcc 9.3 warning: may be used uninitialized - std::optional res = std::nullopt; - return res; - } - - return option; - }; - - refineLValue(eqP.lvalue, refis, scope, predicate); - } - else - { - if (FFlag::LuauWeakEqConstraint) + if (maybeSingleton(eqP.type)) { - if (!sense && isNil(eqP.type)) - resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false); + // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. + if (!sense || canUnify(eqP.type, option, eqP.location).empty()) + return sense ? eqP.type : option; - return; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; } - if (FFlag::LuauEqConstraint) - { - std::optional ty = resolveLValue(refis, scope, eqP.lvalue); - if (!ty) - return; + return option; + }; - std::vector lhs = options(*ty); - std::vector rhs = options(eqP.type); - - if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) - { - addRefinement(refis, eqP.lvalue, eqP.type); - return; - } - else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - - std::unordered_set set; - for (TypeId left : lhs) - { - for (TypeId right : rhs) - { - // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) - set.insert(left); - } - } - - if (set.empty()) - return; - - std::vector viable(set.begin(), set.end()); - TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); - addRefinement(refis, eqP.lvalue, result); - } - } + refineLValue(eqP.lvalue, refis, scope, predicate); } bool TypeChecker::isNonstrictMode() const @@ -5854,6 +5750,11 @@ bool TypeChecker::isNonstrictMode() const return (currentModule->mode == Mode::Nonstrict) || (currentModule->mode == Mode::NoCheck); } +bool TypeChecker::useConstrainedIntersections() const +{ + return FFlag::LuauLowerBoundsCalculation && !isNonstrictMode(); +} + std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp, size_t expectedLength, const Location& location) { TypePackId expectedTypePack = addTypePack({}); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 5bb05234..82451bd1 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) + namespace Luau { @@ -36,6 +38,25 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) return *this; } +TypePackVar& TypePackVar::operator=(const TypePackVar& rhs) +{ + if (FFlag::LuauNonCopyableTypeVarFields) + { + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); + + reassign(rhs); + } + else + { + ty = rhs.ty; + persistent = rhs.persistent; + owningArena = rhs.owningArena; + } + + return *this; +} + TypePackIterator::TypePackIterator(TypePackId typePack) : TypePackIterator(typePack, TxnLog::empty()) { @@ -104,7 +125,7 @@ TypePackIterator begin(TypePackId tp) return TypePackIterator{tp}; } -TypePackIterator begin(TypePackId tp, TxnLog* log) +TypePackIterator begin(TypePackId tp, const TxnLog* log) { return TypePackIterator{tp, log}; } @@ -256,7 +277,7 @@ size_t size(const TypePack& tp, TxnLog* log) return result; } -std::optional first(TypePackId tp) +std::optional first(TypePackId tp, bool ignoreHiddenVariadics) { auto it = begin(tp); auto endIter = end(tp); @@ -266,7 +287,7 @@ std::optional first(TypePackId tp) if (auto tail = it.tail()) { - if (auto vtp = get(*tail)) + if (auto vtp = get(*tail); vtp && (!vtp->hidden || !ignoreHiddenVariadics)) return vtp->ty; } @@ -299,6 +320,46 @@ std::pair, std::optional> flatten(TypePackId tp) return {res, iter.tail()}; } +std::pair, std::optional> flatten(TypePackId tp, const TxnLog& log) +{ + tp = log.follow(tp); + + std::vector flattened; + std::optional tail = std::nullopt; + + TypePackIterator it(tp, &log); + + for (; it != end(tp); ++it) + { + flattened.push_back(*it); + } + + tail = it.tail(); + + return {flattened, tail}; +} + +bool isVariadic(TypePackId tp) +{ + return isVariadic(tp, *TxnLog::empty()); +} + +bool isVariadic(TypePackId tp, const TxnLog& log) +{ + std::optional tail = flatten(tp, log).second; + + if (!tail) + return false; + + if (log.get(*tail)) + return true; + + if (auto vtp = log.get(*tail); vtp && !vtp->hidden) + return true; + + return false; +} + TypePackVar* asMutable(TypePackId tp) { return const_cast(tp); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index c2435890..3d97e6eb 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -5,8 +5,6 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" -LUAU_FASTFLAGVARIABLE(LuauTerminateCyclicMetatableIndexLookup, false) - namespace Luau { @@ -55,13 +53,10 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t { TypeId index = follow(*mtIndex); - if (FFlag::LuauTerminateCyclicMetatableIndexLookup) - { - if (count >= 100) - return std::nullopt; + if (count >= 100) + return std::nullopt; - ++count; - } + ++count; if (const auto& itt = getTableType(index)) { @@ -71,7 +66,7 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t } else if (const auto& itf = get(index)) { - std::optional r = first(follow(itf->retType)); + std::optional r = first(follow(itf->retTypes)); if (!r) return getSingletonTypes().nilType; else diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 36545ad9..ade70d72 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,16 +23,13 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauErrorRecoveryType) -LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) -LUAU_FASTFLAG(LuauDiscriminableUnions2) -LUAU_FASTFLAGVARIABLE(LuauAnyInIsOptionalIsOptional, false) +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) namespace Luau { -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); TypeId follow(TypeId t) { @@ -174,22 +171,15 @@ bool isString(TypeId ty) // Returns true when ty is a supertype of string bool maybeString(TypeId ty) { - if (FFlag::LuauSubtypingAddOptPropsToUnsealedTables) - { - ty = follow(ty); - - if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) - return true; + ty = follow(ty); - if (auto utv = get(ty)) - return std::any_of(begin(utv), end(utv), maybeString); + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) + return true; - return false; - } - else - { - return isString(ty); - } + if (auto utv = get(ty)) + return std::any_of(begin(utv), end(utv), maybeString); + + return false; } bool isThread(TypeId ty) @@ -204,14 +194,14 @@ bool isOptional(TypeId ty) ty = follow(ty); - if (FFlag::LuauAnyInIsOptionalIsOptional && get(ty)) + if (get(ty)) return true; auto utv = get(ty); if (!utv) return false; - return std::any_of(begin(utv), end(utv), FFlag::LuauAnyInIsOptionalIsOptional ? isOptional : isNil); + return std::any_of(begin(utv), end(utv), isOptional); } bool isTableIntersection(TypeId ty) @@ -304,6 +294,11 @@ std::optional getDefinitionModuleName(TypeId type) if (ftv->definition) return ftv->definition->definitionModuleName; } + else if (auto ctv = get(type)) + { + if (!ctv->definitionModuleName.empty()) + return ctv->definitionModuleName; + } return std::nullopt; } @@ -373,8 +368,7 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) if (seen.contains(ty)) return true; - bool isStr = FFlag::LuauDiscriminableUnions2 ? isString(ty) : isPrim(ty, PrimitiveTypeVar::String); - if (isStr || get(ty) || get(ty) || get(ty)) + if (isString(ty) || get(ty) || get(ty) || get(ty)) return true; if (auto uty = get(ty)) @@ -406,41 +400,48 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } -FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) +BlockedTypeVar::BlockedTypeVar() + : index(++nextIndex) +{ +} + +int BlockedTypeVar::nextIndex = 0; + +FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } -FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) +FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : level(level) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } -FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, +FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } FunctionTypeVar::FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, - TypePackId retType, std::optional defn, bool hasSelf) + TypePackId retTypes, std::optional defn, bool hasSelf) : level(level) , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { @@ -486,7 +487,7 @@ bool areEqual(SeenSet& seen, const FunctionTypeVar& lhs, const FunctionTypeVar& if (!areEqual(seen, *lhs.argTypes, *rhs.argTypes)) return false; - if (!areEqual(seen, *lhs.retType, *rhs.retType)) + if (!areEqual(seen, *lhs.retTypes, *rhs.retTypes)) return false; return true; @@ -643,6 +644,26 @@ TypeVar& TypeVar::operator=(TypeVariant&& rhs) return *this; } +TypeVar& TypeVar::operator=(const TypeVar& rhs) +{ + if (FFlag::LuauNonCopyableTypeVarFields) + { + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); + + reassign(rhs); + } + else + { + ty = rhs.ty; + persistent = rhs.persistent; + normal = rhs.normal; + owningArena = rhs.owningArena; + } + + return *this; +} + TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); @@ -652,9 +673,10 @@ static TypeVar numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persist static TypeVar stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}; static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}; static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}; -static TypeVar anyType_{AnyTypeVar{}}; -static TypeVar errorType_{ErrorTypeVar{}}; -static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}}; +static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}; +static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; +static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true}; +static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true}; static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; static TypePackVar errorTypePack_{Unifiable::Error{}}; @@ -665,8 +687,9 @@ SingletonTypes::SingletonTypes() , stringType(&stringType_) , booleanType(&booleanType_) , threadType(&threadType_) + , trueType(&trueType_) + , falseType(&falseType_) , anyType(&anyType_) - , optionalNumberType(&optionalNumberType_) , anyTypePack(&anyTypePack_) , arena(new TypeArena) { @@ -694,7 +717,7 @@ TypeId SingletonTypes::makeStringMetatable() { const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}}); const TypeId optionalString = arena->addType(UnionTypeVar{{nilType, stringType}}); - const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, &booleanType_}}); + const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, booleanType}}); const TypePackId oneStringPack = arena->addTypePack({stringType}); const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); @@ -718,14 +741,16 @@ TypeId SingletonTypes::makeStringMetatable() TableTypeVar::Props stringLib = { {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, - {"char", {arena->addType(FunctionTypeVar{arena->addTypePack(TypePack{{numberType}, numberVariadicList}), arena->addTypePack({stringType})})}}, - {"find", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber, optionalBoolean}, {}, {optionalNumber, optionalNumber})}}, + {"char", {arena->addType(FunctionTypeVar{numberVariadicList, arena->addTypePack({stringType})})}}, + {"find", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})})}}, {"format", {formatFn}}, // FIXME {"gmatch", {gmatchFunc}}, {"gsub", {gsubFunc}}, {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, {"lower", {stringToStringType}}, - {"match", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber}, {}, {optionalString})}}, + {"match", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), + arena->addTypePack(TypePackVar{VariadicTypePack{optionalString}})})}}, {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, {"reverse", {stringToStringType}}, {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, @@ -765,18 +790,12 @@ TypePackId SingletonTypes::errorRecoveryTypePack() TypeId SingletonTypes::errorRecoveryType(TypeId guess) { - if (FFlag::LuauErrorRecoveryType) - return guess; - else - return &errorType_; + return guess; } TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) { - if (FFlag::LuauErrorRecoveryType) - return guess; - else - return &errorTypePack_; + return guess; } SingletonTypes& getSingletonTypes() @@ -798,13 +817,14 @@ void persist(TypeId ty) continue; asMutable(t)->persistent = true; + asMutable(t)->normal = true; // all persistent types are assumed to be normal if (auto btv = get(t)) queue.push_back(btv->boundTo); else if (auto ftv = get(t)) { persist(ftv->argTypes); - persist(ftv->retType); + persist(ftv->retTypes); } else if (auto ttv = get(t)) { @@ -834,6 +854,11 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } + else if (auto ctv = get(t)) + { + for (TypeId opt : ctv->parts) + queue.push_back(opt); + } else if (auto mtv = get(t)) { queue.push_back(mtv->table); @@ -895,6 +920,16 @@ TypeLevel* getMutableLevel(TypeId ty) return const_cast(getLevel(ty)); } +std::optional getLevel(TypePackId tp) +{ + tp = follow(tp); + + if (auto ftv = get(tp)) + return ftv->level; + else + return std::nullopt; +} + const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) { while (cls) @@ -1064,10 +1099,10 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha return result; } -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -1106,7 +1141,7 @@ std::optional> magicFunctionFormat( if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); - return ExprResult{arena.addTypePack({typechecker.stringType})}; + return WithPredicate{arena.addTypePack({typechecker.stringType})}; } std::vector filterMap(TypeId type, TypeIdPredicate predicate) diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index dc554664..8d23aa49 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -12,12 +12,16 @@ Free::Free(TypeLevel level) { } +Free::Free(Scope2* scope) + : scope(scope) +{ +} + int Free::nextIndex = 0; Generic::Generic() : index(++nextIndex) , name("g" + std::to_string(index)) - , explicitName(false) { } @@ -25,7 +29,6 @@ Generic::Generic(TypeLevel level) : index(++nextIndex) , level(level) , name("g" + std::to_string(index)) - , explicitName(false) { } @@ -36,6 +39,12 @@ Generic::Generic(const Name& name) { } +Generic::Generic(Scope2* scope) + : index(++nextIndex) + , scope(scope) +{ +} + Generic::Generic(TypeLevel level, const Name& name) : index(++nextIndex) , level(level) @@ -44,6 +53,14 @@ Generic::Generic(TypeLevel level, const Name& name) { } +Generic::Generic(Scope2* scope, const Name& name) + : index(++nextIndex) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + int Generic::nextIndex = 0; Error::Error() diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 398dc9e2..6147e118 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,24 +14,17 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); -LUAU_FASTFLAGVARIABLE(LuauExtendedIndexerError, false); -LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); +LUAU_FASTINT(LuauTypeInferIterationLimit); +LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); +LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); -LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) -LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) -LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogCheckForInvalidation, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) -LUAU_FASTFLAGVARIABLE(LuauUnifierCacheErrors, false) -LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) +LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau { -struct PromoteTypeLevels +struct PromoteTypeLevels final : TypeVarOnceVisitor { TxnLog& log; const TypeArena* typeArena = nullptr; @@ -54,13 +47,34 @@ struct PromoteTypeLevels } } + // TODO cycle and operator() need to be clipped when FFlagLuauUseVisitRecursionLimit is clipped template void cycle(TID) { } - template bool operator()(TID ty, const T&) + { + return visit(ty); + } + bool operator()(TypeId ty, const FreeTypeVar& ftv) + { + return visit(ty, ftv); + } + bool operator()(TypeId ty, const FunctionTypeVar& ftv) + { + return visit(ty, ftv); + } + bool operator()(TypeId ty, const TableTypeVar& ttv) + { + return visit(ty, ttv); + } + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + return visit(tp, ftp); + } + + bool visit(TypeId ty) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) @@ -69,7 +83,16 @@ struct PromoteTypeLevels return true; } - bool operator()(TypeId ty, const FreeTypeVar&) + bool visit(TypePackId tp) override + { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (tp->owningArena != typeArena) + return false; + + return true; + } + + bool visit(TypeId ty, const FreeTypeVar&) override { // Surprise, it's actually a BoundTypeVar that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. @@ -80,7 +103,7 @@ struct PromoteTypeLevels return true; } - bool operator()(TypeId ty, const FunctionTypeVar&) + bool visit(TypeId ty, const FunctionTypeVar&) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) @@ -90,7 +113,7 @@ struct PromoteTypeLevels return true; } - bool operator()(TypeId ty, const TableTypeVar& ttv) + bool visit(TypeId ty, const TableTypeVar& ttv) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) @@ -103,7 +126,7 @@ struct PromoteTypeLevels return true; } - bool operator()(TypePackId tp, const FreeTypePack&) + bool visit(TypePackId tp, const FreeTypePack&) override { // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. @@ -122,11 +145,9 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel return; PromoteTypeLevels ptl{log, typeArena, minLevel}; - DenseHashSet seen{nullptr}; - visitTypeVarOnce(ty, ptl, seen); + ptl.traverse(ty); } -// TODO: use this and make it static. void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -134,11 +155,10 @@ void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLev return; PromoteTypeLevels ptl{log, typeArena, minLevel}; - DenseHashSet seen{nullptr}; - visitTypeVarOnce(tp, ptl, seen); + ptl.traverse(tp); } -struct SkipCacheForType +struct SkipCacheForType final : TypeVarOnceVisitor { SkipCacheForType(const DenseHashMap& skipCacheForType, const TypeArena* typeArena) : skipCacheForType(skipCacheForType) @@ -146,28 +166,25 @@ struct SkipCacheForType { } - void cycle(TypeId) {} - void cycle(TypePackId) {} - - bool operator()(TypeId ty, const FreeTypeVar& ftv) + bool visit(TypeId, const FreeTypeVar&) override { result = true; return false; } - bool operator()(TypeId ty, const BoundTypeVar& btv) + bool visit(TypeId, const BoundTypeVar&) override { result = true; return false; } - bool operator()(TypeId ty, const GenericTypeVar& btv) + bool visit(TypeId, const GenericTypeVar&) override { result = true; return false; } - bool operator()(TypeId ty, const TableTypeVar&) + bool visit(TypeId ty, const TableTypeVar&) override { // Types from other modules don't contain mutable elements and are ok to cache if (ty->owningArena != typeArena) @@ -190,8 +207,7 @@ struct SkipCacheForType return true; } - template - bool operator()(TypeId ty, const T& t) + bool visit(TypeId ty) override { // Types from other modules don't contain mutable elements and are ok to cache if (ty->owningArena != typeArena) @@ -208,8 +224,7 @@ struct SkipCacheForType return true; } - template - bool operator()(TypePackId tp, const T&) + bool visit(TypePackId tp) override { // Types from other modules don't contain mutable elements and are ok to cache if (tp->owningArena != typeArena) @@ -218,19 +233,19 @@ struct SkipCacheForType return true; } - bool operator()(TypePackId tp, const FreeTypePack& ftp) + bool visit(TypePackId tp, const FreeTypePack&) override { result = true; return false; } - bool operator()(TypePackId tp, const BoundTypePack& ftp) + bool visit(TypePackId tp, const BoundTypePack&) override { result = true; return false; } - bool operator()(TypePackId tp, const GenericTypePack& ftp) + bool visit(TypePackId tp, const GenericTypePack&) override { result = true; return false; @@ -277,6 +292,16 @@ bool Widen::ignoreChildren(TypeId ty) return !log->is(ty); } +TypeId Widen::operator()(TypeId ty) +{ + return substitute(ty).value_or(ty); +} + +TypePackId Widen::operator()(TypePackId tp) +{ + return substitute(tp).value_or(tp); +} + static std::optional hasUnificationTooComplex(const ErrorVec& errors) { auto isUnificationTooComplex = [](const TypeError& te) { @@ -305,8 +330,7 @@ static std::optional> getTableMat return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, - TxnLog* parentLog) +Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) , log(parentLog) @@ -317,18 +341,6 @@ Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance LUAU_ASSERT(sharedState.iceHandler); } -Unifier::Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) - : types(types) - , mode(mode) - , log(parentLog, sharedSeen) - , location(location) - , variance(variance) - , sharedState(sharedState) -{ - LUAU_ASSERT(sharedState.iceHandler); -} - void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { sharedState.counters.iterationCount = 0; @@ -338,14 +350,26 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); ++sharedState.counters.iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + if (FFlag::LuauAutocompleteDynamicLimits) { - reportError(TypeError{location, UnificationTooComplex{}}); - return; + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } + } + else + { + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } } superTy = log.follow(superTy); @@ -354,6 +378,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; + if (log.get(superTy)) + return tryUnifyWithConstrainedSuperTypeVar(subTy, superTy); + auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); @@ -409,6 +436,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursFailed) { promoteTypeLevels(log, types, superLevel, subTy); + + Widen widen{types}; log.replace(superTy, BoundTypeVar(widen(subTy))); } @@ -442,35 +471,35 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (get(superTy) || get(superTy)) return tryUnifyWithAny(subTy, superTy); - if (get(subTy) || get(subTy)) + if (get(subTy)) + { + if (anyIsTop) + { + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + return; + } + else + return tryUnifyWithAny(superTy, subTy); + } + + if (get(subTy)) return tryUnifyWithAny(superTy, subTy); - bool cacheEnabled; auto& cache = sharedState.cachedUnify; // What if the types are immutable and we proved their relation before - if (FFlag::LuauUnifierCacheErrors) + bool cacheEnabled = !isFunctionCall && !isIntersection && variance == Invariant; + + if (cacheEnabled) { - cacheEnabled = !isFunctionCall && !isIntersection && variance == Invariant; - - if (cacheEnabled) - { - if (cache.contains({subTy, superTy})) - return; - - if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) - { - reportError(TypeError{location, *error}); - return; - } - } - } - else - { - cacheEnabled = !isFunctionCall && !isIntersection; - - if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy}))) + if (cache.contains({subTy, superTy})) return; + + if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) + { + reportError(TypeError{location, *error}); + return; + } } // If we have seen this pair of types before, we are currently recursing into cyclic types. @@ -484,7 +513,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (const UnionTypeVar* uv = log.getMutable(subTy)) + if (log.get(subTy)) + tryUnifyWithConstrainedSubTypeVar(subTy, superTy); + else if (const UnionTypeVar* uv = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, uv, superTy); } @@ -512,12 +543,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(superTy) && log.getMutable(subTy)) { tryUnifyTables(subTy, superTy, isIntersection); - - if (!FFlag::LuauUnifierCacheErrors) - { - if (cacheEnabled && errors.empty()) - cacheResult_DEPRECATED(subTy, superTy); - } } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. @@ -537,7 +562,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - if (FFlag::LuauUnifierCacheErrors && cacheEnabled) + if (cacheEnabled) cacheResult(subTy, superTy, errorCount); log.popSeen(superTy, subTy); @@ -550,9 +575,6 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId std::optional unificationTooComplex; std::optional firstFailedOption; - size_t count = uv->options.size(); - size_t i = 0; - for (TypeId type : uv->options) { Unifier innerState = makeChildUnifier(); @@ -568,52 +590,44 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId failed = true; } - - if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) - { - } - else - { - if (i == count - 1) - { - log.concat(std::move(innerState.log)); - } - - ++i; - } } // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. - if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) - { - auto tryBind = [this, subTy](TypeId superOption) { - superOption = log.follow(superOption); + auto tryBind = [this, subTy](TypeId superOption) { + superOption = log.follow(superOption); - // just skip if the superOption is not free-ish. - auto ttv = log.getMutable(superOption); - if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) - return; + // just skip if the superOption is not free-ish. + auto ttv = log.getMutable(superOption); + if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) + return; - // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. - // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. - if (log.haveSeen(subTy, superOption)) - { - // TODO: would it be nice for TxnLog::replace to do this? - if (log.is(superOption)) - log.bindTable(superOption, subTy); - else - log.replace(superOption, *subTy); - } - }; - - if (auto utv = log.getMutable(superTy)) + // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype + // test is successful. + if (auto subUnion = get(subTy)) { - for (TypeId ty : utv) - tryBind(ty); + if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) + return; } - else - tryBind(superTy); + + // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. + // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. + if (log.haveSeen(subTy, superOption)) + { + // TODO: would it be nice for TxnLog::replace to do this? + if (log.is(superOption)) + log.bindTable(superOption, subTy); + else + log.replace(superOption, *subTy); + } + }; + + if (auto utv = log.getMutable(superTy)) + { + for (TypeId ty : utv) + tryBind(ty); } + else + tryBind(superTy); if (unificationTooComplex) reportError(*unificationTooComplex); @@ -674,21 +688,10 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { TypeId type = uv->options[i]; - if (FFlag::LuauUnifierCacheErrors) + if (cache.contains({subTy, type})) { - if (cache.contains({subTy, type})) - { - startIndex = i; - break; - } - } - else - { - if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) - { - startIndex = i; - break; - } + startIndex = i; + break; } } } @@ -776,21 +779,10 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV { TypeId type = uv->parts[i]; - if (FFlag::LuauUnifierCacheErrors) + if (cache.contains({type, superTy})) { - if (cache.contains({type, superTy})) - { - startIndex = i; - break; - } - } - else - { - if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) - { - startIndex = i; - break; - } + startIndex = i; + break; } } } @@ -835,7 +827,7 @@ bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) auto skipCacheFor = [this](TypeId ty) { SkipCacheForType visitor{sharedState.skipCacheForType, types}; - visitTypeVarOnce(ty, visitor, sharedState.seenAny); + visitor.traverse(ty); sharedState.skipCacheForType[ty] = visitor.result; @@ -865,19 +857,6 @@ void Unifier::cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount) } } -void Unifier::cacheResult_DEPRECATED(TypeId subTy, TypeId superTy) -{ - LUAU_ASSERT(!FFlag::LuauUnifierCacheErrors); - - if (!canCacheResult(subTy, superTy)) - return; - - sharedState.cachedUnify.insert({superTy, subTy}); - - if (variance == Invariant) - sharedState.cachedUnify.insert({subTy, superTy}); -} - struct WeirdIter { TypePackId packId; @@ -946,7 +925,7 @@ struct WeirdIter LUAU_ASSERT(log.getMutable(newTail)); level = log.getMutable(packId)->level; - log.replace(packId, Unifiable::Bound(newTail)); + log.replace(packId, BoundTypePack(newTail)); packId = newTail; pack = log.getMutable(newTail); index = 0; @@ -994,39 +973,32 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall tryUnify_(subTp, superTp, isFunctionCall); } -static std::pair, std::optional> logAwareFlatten(TypePackId tp, const TxnLog& log) -{ - tp = log.follow(tp); - - std::vector flattened; - std::optional tail = std::nullopt; - - TypePackIterator it(tp, &log); - - for (; it != end(tp); ++it) - { - flattened.push_back(*it); - } - - tail = it.tail(); - - return {flattened, tail}; -} - /* * This is quite tricky: we are walking two rope-like structures and unifying corresponding elements. * If one is longer than the other, but the short end is free, we grow it to the required length. */ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); ++sharedState.counters.iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + if (FFlag::LuauAutocompleteDynamicLimits) { - reportError(TypeError{location, UnificationTooComplex{}}); - return; + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } + } + else + { + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } } superTp = log.follow(superTp); @@ -1051,7 +1023,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (superTp == subTp) return; - if (FFlag::LuauTxnLogSeesTypePacks2 && log.haveSeen(superTp, subTp)) + if (log.haveSeen(superTp, subTp)) return; if (log.getMutable(superTp)) @@ -1060,6 +1032,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (!log.getMutable(superTp)) { + Widen widen{types}; log.replace(superTp, Unifiable::Bound(widen(subTp))); } } @@ -1087,8 +1060,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If the size of two heads does not match, but both packs have free tail // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = logAwareFlatten(superTp, log); - auto [subTypes, subTail] = logAwareFlatten(subTp, log); + auto [superTypes, superTail] = flatten(superTp, log); + auto [subTypes, subTail] = flatten(subTp, log); bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable(*superTail)) && (subTail && log.getMutable(*subTail)); @@ -1165,24 +1138,17 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal else { // A union type including nil marks an optional argument - if (superIter.good() && isOptional(*superIter)) + if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && superIter.good() && isOptional(*superIter)) { superIter.advance(); continue; } - else if (subIter.good() && isOptional(*subIter)) + else if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && subIter.good() && isOptional(*subIter)) { subIter.advance(); continue; } - // In nonstrict mode, any also marks an optional argument. - else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) - { - superIter.advance(); - continue; - } - if (log.getMutable(superIter.packId)) { tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); @@ -1195,7 +1161,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal return; } - if (!isFunctionCall && subIter.good()) + if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && !isFunctionCall && subIter.good()) { // Sometimes it is ok to pass too many arguments return; @@ -1294,12 +1260,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal log.pushSeen(superFunction->generics[i], subFunction->generics[i]); } - if (FFlag::LuauTxnLogSeesTypePacks2) + for (size_t i = 0; i < numGenericPacks; i++) { - for (size_t i = 0; i < numGenericPacks; i++) - { - log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); - } + log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); } CountMismatch::Context context = ctx; @@ -1323,13 +1286,13 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); innerState.ctx = CountMismatch::Result; - innerState.tryUnify_(subFunction->retType, superFunction->retType); + innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); if (!reported) { if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); - else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) + else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) reportError( @@ -1347,24 +1310,18 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); ctx = CountMismatch::Result; - tryUnify_(subFunction->retType, superFunction->retType); + tryUnify_(subFunction->retTypes, superFunction->retTypes); } - if (FFlag::LuauTxnLogRefreshFunctionPointers) - { - // Updating the log may have invalidated the function pointers - superFunction = log.getMutable(superTy); - subFunction = log.getMutable(subTy); - } + // Updating the log may have invalidated the function pointers + superFunction = log.getMutable(superTy); + subFunction = log.getMutable(subTy); ctx = context; - if (FFlag::LuauTxnLogSeesTypePacks2) + for (int i = int(numGenericPacks) - 1; 0 <= i; i--) { - for (int i = int(numGenericPacks) - 1; 0 <= i; i--) - { - log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); - } + log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); } for (int i = int(numGenerics) - 1; 0 <= i; i--) @@ -1397,9 +1354,6 @@ struct Resetter void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { - if (!FFlag::LuauTableSubtypingVariance2) - return DEPRECATED_tryUnifyTables(subTy, superTy, isIntersection); - TableTypeVar* superTable = log.getMutable(superTy); TableTypeVar* subTable = log.getMutable(subTy); @@ -1416,18 +1370,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - if (FFlag::LuauAnyInIsOptionalIsOptional) - { - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type)) - missingProperties.push_back(propName); - } - else - { - bool isAny = log.getMutable(log.follow(superProp.type)); - - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) - missingProperties.push_back(propName); - } + if (subIter == subTable->props.end() && subTable->state == TableState::Unsealed && !isOptional(superProp.type)) + missingProperties.push_back(propName); } if (!missingProperties.empty()) @@ -1438,24 +1382,14 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } // And vice versa if we're invariant - if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && - superTable->state != TableState::Free) + if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && superTable->state != TableState::Free) { for (const auto& [propName, subProp] : subTable->props) { auto superIter = superTable->props.find(propName); - if (FFlag::LuauAnyInIsOptionalIsOptional) - { - if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || !isOptional(subProp.type))) - extraProperties.push_back(propName); - } - else - { - bool isAny = log.is(log.follow(subProp.type)); - if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) - extraProperties.push_back(propName); - } + if (superIter == superTable->props.end()) + extraProperties.push_back(propName); } if (!extraProperties.empty()) @@ -1499,19 +1433,12 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); } - else if (FFlag::LuauAnyInIsOptionalIsOptional && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) + else if (subTable->state == TableState::Unsealed && isOptional(prop.type)) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) { } - else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && (isOptional(prop.type) || get(follow(prop.type)))) - // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` - // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. - // TODO: should isOptional(anyType) be true? - // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) - { - } else if (subTable->state == TableState::Free) { PendingType* pendingSub = log.queue(subTy); @@ -1523,20 +1450,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else missingProperties.push_back(name); - if (FFlag::LuauTxnLogCheckForInvalidation) + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) { - // Recursive unification can change the txn log, and invalidate the old - // table. If we detect that this has happened, we start over, with the updated - // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; } } @@ -1578,12 +1502,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else if (variance == Covariant) { } - else if (FFlag::LuauAnyInIsOptionalIsOptional && !FFlag::LuauSubtypingAddOptPropsToUnsealedTables && isOptional(prop.type)) - { - } - else if (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables && (isOptional(prop.type) || get(follow(prop.type)))) - { - } else if (superTable->state == TableState::Free) { PendingType* pendingSuper = log.queue(superTy); @@ -1594,20 +1512,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else extraProperties.push_back(name); - if (FFlag::LuauTxnLogCheckForInvalidation) + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) { - // Recursive unification can change the txn log, and invalidate the old - // table. If we detect that this has happened, we start over, with the updated - // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; } } @@ -1620,24 +1535,16 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) Unifier innerState = makeChildUnifier(); - if (FFlag::LuauExtendedIndexerError) - { - innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); + innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); - bool reported = !innerState.errors.empty(); + bool reported = !innerState.errors.empty(); - checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); + checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); - innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); + innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); - if (!reported) - checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); - } - else - { - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - } + if (!reported) + checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1662,27 +1569,9 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } } - if (FFlag::LuauTxnLogDontRetryForIndexers) - { - // Changing the indexer can invalidate the table pointers. - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); - } - else if (FFlag::LuauTxnLogCheckForInvalidation) - { - // Recursive unification can change the txn log, and invalidate the old - // table. If we detect that this has happened, we start over, with the updated - // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } - } + // Changing the indexer can invalidate the table pointers. + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); if (!missingProperties.empty()) { @@ -1717,34 +1606,10 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } } -TypeId Unifier::widen(TypeId ty) -{ - if (!FFlag::LuauWidenIfSupertypeIsFree2) - return ty; - - Widen widen{types}; - std::optional result = widen.substitute(ty); - // TODO: what does it mean for substitution to fail to widen? - return result.value_or(ty); -} - -TypePackId Unifier::widen(TypePackId tp) -{ - if (!FFlag::LuauWidenIfSupertypeIsFree2) - return tp; - - Widen widen{types}; - std::optional result = widen.substitute(tp); - // TODO: what does it mean for substitution to fail to widen? - return result.value_or(tp); -} - TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) { ty = follow(ty); - if (!FFlag::LuauAnyInIsOptionalIsOptional && get(ty)) - return ty; - else if (isOptional(ty)) + if (isOptional(ty)) return ty; else if (const TableTypeVar* ttv = get(ty)) { @@ -1761,299 +1626,6 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see return types->addType(UnionTypeVar{{getSingletonTypes().nilType, ty}}); } -void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); - Resetter resetter{&variance}; - variance = Invariant; - - TableTypeVar* superTable = log.getMutable(superTy); - TableTypeVar* subTable = log.getMutable(subTy); - - if (!superTable || !subTable) - ice("passed non-table types to unifyTables"); - - if (superTable->state == TableState::Sealed && subTable->state == TableState::Sealed) - return tryUnifySealedTables(subTy, superTy, isIntersection); - else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Unsealed) || - (superTable->state == TableState::Unsealed && subTable->state == TableState::Sealed)) - return tryUnifySealedTables(subTy, superTy, isIntersection); - else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Generic) || - (superTable->state == TableState::Generic && subTable->state == TableState::Sealed)) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - else if ((superTable->state == TableState::Free) != (subTable->state == TableState::Free)) // one table is free and the other is not - { - TypeId freeTypeId = subTable->state == TableState::Free ? subTy : superTy; - TypeId otherTypeId = subTable->state == TableState::Free ? superTy : subTy; - - return tryUnifyFreeTable(otherTypeId, freeTypeId); - } - else if (superTable->state == TableState::Free && subTable->state == TableState::Free) - { - tryUnifyFreeTable(subTy, superTy); - - // avoid creating a cycle when the types are already pointing at each other - if (follow(superTy) != follow(subTy)) - { - log.bindTable(superTy, subTy); - } - return; - } - else if (superTable->state != TableState::Sealed && subTable->state != TableState::Sealed) - { - // All free tables are checked in one of the branches above - LUAU_ASSERT(superTable->state != TableState::Free); - LUAU_ASSERT(subTable->state != TableState::Free); - - // Tables must have exactly the same props and their types must all unify - // I honestly have no idea if this is remotely close to reasonable. - for (const auto& [name, prop] : superTable->props) - { - const auto& r = subTable->props.find(name); - if (r == subTable->props.end()) - reportError(TypeError{location, UnknownProperty{subTy, name}}); - else - tryUnify_(r->second.type, prop.type); - } - - if (superTable->indexer && subTable->indexer) - tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - else if (superTable->indexer) - { - // passing/assigning a table without an indexer to something that has one - // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. - if (subTable->state == TableState::Unsealed) - { - log.changeIndexer(subTy, superTable->indexer); - } - else - reportError(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); - } - } - else if (superTable->state == TableState::Sealed) - { - // lt is sealed and so it must be possible for rt to have precisely the same shape - // Verify that this is the case, then bind rt to lt. - ice("unsealed tables are not working yet", location); - } - else if (subTable->state == TableState::Sealed) - return tryUnifyTables(superTy, subTy, isIntersection); - else - ice("tryUnifyTables"); -} - -void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); - TableTypeVar* freeTable = log.getMutable(superTy); - TableTypeVar* subTable = log.getMutable(subTy); - - if (!freeTable || !subTable) - ice("passed non-table types to tryUnifyFreeTable"); - - // Any properties in freeTable must unify with those in otherTable. - // Then bind freeTable to otherTable. - for (const auto& [freeName, freeProp] : freeTable->props) - { - if (auto subProp = findTablePropertyRespectingMeta(subTy, freeName)) - { - if (FFlag::LuauWidenIfSupertypeIsFree2) - tryUnify_(*subProp, freeProp.type); - else - tryUnify_(freeProp.type, *subProp); - - /* - * TypeVars are commonly cyclic, so it is entirely possible - * for unifying a property of a table to change the table itself! - * We need to check for this and start over if we notice this occurring. - * - * I believe this is guaranteed to terminate eventually because this will - * only happen when a free table is bound to another table. - */ - if (!log.getMutable(superTy) || !log.getMutable(subTy)) - return tryUnify_(subTy, superTy); - - if (TableTypeVar* pendingFreeTtv = log.getMutable(superTy); pendingFreeTtv && pendingFreeTtv->boundTo) - return tryUnify_(subTy, superTy); - } - else - { - // If the other table is also free, then we are learning that it has more - // properties than we previously thought. Else, it is an error. - if (subTable->state == TableState::Free) - { - PendingType* pendingSub = log.queue(subTy); - TableTypeVar* pendingSubTtv = getMutable(pendingSub); - LUAU_ASSERT(pendingSubTtv); - pendingSubTtv->props.insert({freeName, freeProp}); - } - else - reportError(TypeError{location, UnknownProperty{subTy, freeName}}); - } - } - - if (freeTable->indexer && subTable->indexer) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnifyIndexer(*subTable->indexer, *freeTable->indexer); - - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - - log.concat(std::move(innerState.log)); - } - else if (subTable->state == TableState::Free && freeTable->indexer) - { - log.changeIndexer(superTy, subTable->indexer); - } - - if (!freeTable->boundTo && subTable->state != TableState::Free) - { - log.bindTable(superTy, subTy); - } -} - -void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); - TableTypeVar* superTable = log.getMutable(superTy); - TableTypeVar* subTable = log.getMutable(subTy); - - if (!superTable || !subTable) - ice("passed non-table types to unifySealedTables"); - - Unifier innerState = makeChildUnifier(); - - std::vector missingPropertiesInSuper; - bool isUnnamedTable = subTable->name == std::nullopt && subTable->syntheticName == std::nullopt; - bool errorReported = false; - - // Optimization: First test that the property sets are compatible without doing any recursive unification - if (!subTable->indexer) - { - for (const auto& [propName, superProp] : superTable->props) - { - auto subIter = subTable->props.find(propName); - if (subIter == subTable->props.end() && !isOptional(superProp.type)) - missingPropertiesInSuper.push_back(propName); - } - - if (!missingPropertiesInSuper.empty()) - { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); - return; - } - } - - // Tables must have exactly the same props and their types must all unify - for (const auto& it : superTable->props) - { - const auto& r = subTable->props.find(it.first); - if (r == subTable->props.end()) - { - if (isOptional(it.second.type)) - continue; - - missingPropertiesInSuper.push_back(it.first); - - innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - } - else - { - if (isUnnamedTable && r->second.location) - { - size_t oldErrorSize = innerState.errors.size(); - Location old = innerState.location; - innerState.location = *r->second.location; - innerState.tryUnify_(r->second.type, it.second.type); - innerState.location = old; - - if (oldErrorSize != innerState.errors.size() && !errorReported) - { - errorReported = true; - reportError(innerState.errors.back()); - } - } - else - { - innerState.tryUnify_(r->second.type, it.second.type); - } - } - } - - if (superTable->indexer || subTable->indexer) - { - if (superTable->indexer && subTable->indexer) - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - else if (subTable->state == TableState::Unsealed) - { - if (superTable->indexer && !subTable->indexer) - { - log.changeIndexer(subTy, superTable->indexer); - } - } - else if (superTable->state == TableState::Unsealed) - { - if (subTable->indexer && !superTable->indexer) - { - log.changeIndexer(superTy, subTable->indexer); - } - } - else if (superTable->indexer) - { - innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); - for (const auto& [name, type] : subTable->props) - { - const auto& it = superTable->props.find(name); - if (it == superTable->props.end()) - innerState.tryUnify_(type.type, superTable->indexer->indexResultType); - } - } - else - innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - } - - if (!errorReported) - log.concat(std::move(innerState.log)); - else - return; - - if (!missingPropertiesInSuper.empty()) - { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); - return; - } - - // If the superTy is an immediate part of an intersection type, do not do extra-property check. - // Otherwise, we would falsely generate an extra-property-error for 's' in this code: - // local a: {n: number} & {s: string} = {n=1, s=""} - // When checking against the table '{n: number}'. - if (!isIntersection && superTable->state != TableState::Unsealed && !superTable->indexer) - { - // Check for extra properties in the subTy - std::vector extraPropertiesInSub; - - for (const auto& [subKey, subProp] : subTable->props) - { - const auto& superIt = superTable->props.find(subKey); - if (superIt == superTable->props.end()) - { - if (isOptional(subProp.type)) - continue; - - extraPropertiesInSub.push_back(subKey); - } - } - - if (!extraPropertiesInSub.empty()) - { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); - return; - } - } - - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); -} - void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { const MetatableTypeVar* superMetatable = get(superTy); @@ -2193,14 +1765,6 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) return fail(); } -void Unifier::tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2 || !FFlag::LuauExtendedIndexerError); - - tryUnify_(subIndexer.indexType, superIndexer.indexType); - tryUnify_(subIndexer.indexResultType, superIndexer.indexResultType); -} - static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { while (true) @@ -2303,7 +1867,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas else if (auto fun = state.log.getMutable(ty)) { queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retTypes, anyTypePack); } else if (auto table = state.log.getMutable(ty)) { @@ -2376,6 +1940,180 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); } +void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy) +{ + const ConstrainedTypeVar* subConstrained = get(subTy); + if (!subConstrained) + ice("tryUnifyWithConstrainedSubTypeVar received non-ConstrainedTypeVar subTy!"); + + const std::vector& subTyParts = subConstrained->parts; + + // A | B <: T if A <: T and B <: T + bool failed = false; + std::optional unificationTooComplex; + + const size_t count = subTyParts.size(); + + for (size_t i = 0; i < count; ++i) + { + TypeId type = subTyParts[i]; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, superTy); + + if (i == count - 1) + log.concat(std::move(innerState.log)); + + ++i; + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + + if (!innerState.errors.empty()) + { + failed = true; + break; + } + } + + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (failed) + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + else + log.replace(subTy, BoundTypeVar{superTy}); +} + +void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy) +{ + ConstrainedTypeVar* superC = log.getMutable(superTy); + if (!superC) + ice("tryUnifyWithConstrainedSuperTypeVar received non-ConstrainedTypeVar superTy!"); + + // subTy could be a + // table + // metatable + // class + // function + // primitive + // free + // generic + // intersection + // union + // Do we really just tack it on? I think we might! + // We can certainly do some deduplication. + // Is there any point to deducing Player|Instance when we could just reduce to Instance? + // Is it actually ok to have multiple free types in a single intersection? What if they are later unified into the same type? + // Maybe we do a simplification step during quantification. + + auto it = std::find(superC->parts.begin(), superC->parts.end(), subTy); + if (it != superC->parts.end()) + return; + + superC->parts.push_back(subTy); +} + +void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel) +{ + // The duplication between this and regular typepack unification is tragic. + + auto superIter = begin(superTy, &log); + auto superEndIter = end(superTy); + + auto subIter = begin(subTy, &log); + auto subEndIter = end(subTy); + + int count = FInt::LuauTypeInferLowerBoundsIterationLimit; + + for (; subIter != subEndIter; ++subIter) + { + if (0 >= --count) + ice("Internal recursion counter limit exceeded in Unifier::unifyLowerBound"); + + if (superIter != superEndIter) + { + tryUnify_(*subIter, *superIter); + ++superIter; + continue; + } + + if (auto t = superIter.tail()) + { + TypePackId tailPack = follow(*t); + + if (log.get(tailPack)) + occursCheck(tailPack, subTy); + + FreeTypePack* freeTailPack = log.getMutable(tailPack); + if (!freeTailPack) + return; + + TypeLevel level = FFlag::LuauQuantifyConstrained ? demotedLevel : freeTailPack->level; + + TypePack* tp = getMutable(log.replace(tailPack, TypePack{})); + + for (; subIter != subEndIter; ++subIter) + { + tp->head.push_back(types->addType(ConstrainedTypeVar{level, {follow(*subIter)}})); + } + + tp->tail = subIter.tail(); + } + + return; + } + + if (superIter != superEndIter) + { + if (auto subTail = subIter.tail()) + { + TypePackId subTailPack = follow(*subTail); + if (get(subTailPack)) + { + TypePack* tp = getMutable(log.replace(subTailPack, TypePack{})); + + for (; superIter != superEndIter; ++superIter) + tp->head.push_back(*superIter); + } + } + else + { + while (superIter != superEndIter) + { + if (!isOptional(*superIter)) + { + errors.push_back(TypeError{location, CountMismatch{size(superTy), size(subTy), CountMismatch::Return}}); + return; + } + ++superIter; + } + } + + return; + } + + // Both iters are at their respective tails + auto subTail = subIter.tail(); + auto superTail = superIter.tail(); + if (subTail && superTail) + tryUnify(*subTail, *superTail); + else if (subTail) + { + const FreeTypePack* freeSubTail = log.getMutable(*subTail); + if (freeSubTail) + { + log.replace(*subTail, TypePack{}); + } + } + else if (superTail) + { + const FreeTypePack* freeSuperTail = log.getMutable(*superTail); + if (freeSuperTail) + { + log.replace(*superTail, TypePack{}); + } + } +} + void Unifier::occursCheck(TypeId needle, TypeId haystack) { sharedState.tempSeenTy.clear(); @@ -2385,7 +2123,8 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack) void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); auto check = [&](TypeId tv) { occursCheck(seen, needle, tv); @@ -2425,6 +2164,11 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays for (TypeId ty : a->parts) check(ty); } + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->parts) + check(ty); + } } void Unifier::occursCheck(TypePackId needle, TypePackId haystack) @@ -2450,7 +2194,8 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (!log.getMutable(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); while (!log.getMutable(haystack)) { @@ -2474,7 +2219,16 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; + Unifier u = Unifier{types, mode, location, variance, sharedState, &log}; + u.anyIsTop = anyIsTop; + return u; +} + +// A utility function that appends the given error to the unifier's error log. +// This allows setting a breakpoint wherever the unifier reports an error. +void Unifier::reportError(TypeError err) +{ + errors.push_back(std::move(err)); } bool Unifier::isNonstrictMode() const diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 31cd01cc..6f39e3fd 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -313,7 +313,7 @@ template struct AstArray { T* data; - std::size_t size; + size_t size; const T* begin() const { diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index 65939bee..f8543111 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -32,6 +32,7 @@ class DenseHashTable { public: class const_iterator; + class iterator; DenseHashTable(const Key& empty_key, size_t buckets = 0) : count(0) @@ -43,7 +44,7 @@ public: // don't move this to initializer list! this works around an MSVC codegen issue on AMD CPUs: // https://developercommunity.visualstudio.com/t/stdvector-constructor-from-size-t-is-25-times-slow/1546547 if (buckets) - data.resize(buckets, ItemInterface::create(empty_key)); + resize_data(buckets); } void clear() @@ -125,7 +126,7 @@ public: if (data.empty() && data.capacity() >= newsize) { LUAU_ASSERT(count == 0); - data.resize(newsize, ItemInterface::create(empty_key)); + resize_data(newsize); return; } @@ -169,6 +170,21 @@ public: return const_iterator(this, data.size()); } + iterator begin() + { + size_t start = 0; + + while (start < data.size() && eq(ItemInterface::getKey(data[start]), empty_key)) + start++; + + return iterator(this, start); + } + + iterator end() + { + return iterator(this, data.size()); + } + size_t size() const { return count; @@ -233,7 +249,82 @@ public: size_t index; }; + class iterator + { + public: + iterator() + : set(0) + , index(0) + { + } + + iterator(DenseHashTable* set, size_t index) + : set(set) + , index(index) + { + } + + MutableItem& operator*() const + { + return *reinterpret_cast(&set->data[index]); + } + + MutableItem* operator->() const + { + return reinterpret_cast(&set->data[index]); + } + + bool operator==(const iterator& other) const + { + return set == other.set && index == other.index; + } + + bool operator!=(const iterator& other) const + { + return set != other.set || index != other.index; + } + + iterator& operator++() + { + size_t size = set->data.size(); + + do + { + index++; + } while (index < size && set->eq(ItemInterface::getKey(set->data[index]), set->empty_key)); + + return *this; + } + + iterator operator++(int) + { + iterator res = *this; + ++*this; + return res; + } + + private: + DenseHashTable* set; + size_t index; + }; + private: + template + void resize_data(size_t count, typename std::enable_if_t>* dummy = nullptr) + { + data.resize(count, ItemInterface::create(empty_key)); + } + + template + void resize_data(size_t count, typename std::enable_if_t>* dummy = nullptr) + { + size_t size = data.size(); + data.resize(count); + + for (size_t i = size; i < count; i++) + data[i].first = empty_key; + } + std::vector data; size_t count; Key empty_key; @@ -290,6 +381,7 @@ class DenseHashSet public: typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::iterator iterator; DenseHashSet(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) @@ -336,6 +428,16 @@ public: { return impl.end(); } + + iterator begin() + { + return impl.begin(); + } + + iterator end() + { + return impl.end(); + } }; // This is a faster alternative of unordered_map, but it does not implement the same interface (i.e. it does not support erasing and has @@ -348,6 +450,7 @@ class DenseHashMap public: typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::iterator iterator; DenseHashMap(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) @@ -401,10 +504,21 @@ public: { return impl.begin(); } + const_iterator end() const { return impl.end(); } + + iterator begin() + { + return impl.begin(); + } + + iterator end() + { + return impl.end(); + } }; } // namespace Luau diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index d7d867f4..4f3dbbd5 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -173,7 +173,7 @@ public: } const Lexeme& next(); - const Lexeme& next(bool skipComments); + const Lexeme& next(bool skipComments, bool updatePrevLocation); void nextline(); Lexeme lookahead(); diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h index 6ecf0606..6ae9e977 100644 --- a/Ast/include/Luau/StringUtils.h +++ b/Ast/include/Luau/StringUtils.h @@ -19,6 +19,7 @@ std::string format(const char* fmt, ...) LUAU_PRINTF_ATTR(1, 2); std::string vformat(const char* fmt, va_list args); void formatAppend(std::string& str, const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); +void vformatAppend(std::string& ret, const char* fmt, va_list args); std::string join(const std::vector& segments, std::string_view delimiter); std::string join(const std::vector& segments, std::string_view delimiter); diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 503eca61..be282827 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Common.h" +#include "Luau/Common.h" #include @@ -9,14 +9,21 @@ LUAU_FASTFLAG(DebugLuauTimeTracing) +namespace Luau +{ +namespace TimeTrace +{ +double getClock(); +uint32_t getClockMicroseconds(); +} // namespace TimeTrace +} // namespace Luau + #if defined(LUAU_ENABLE_TIME_TRACE) namespace Luau { namespace TimeTrace { -uint32_t getClockMicroseconds(); - struct Token { const char* name; @@ -130,8 +137,8 @@ ThreadContext& getThreadContext(); struct Scope { - explicit Scope(ThreadContext& context, uint16_t token) - : context(context) + explicit Scope(uint16_t token) + : context(getThreadContext()) { if (!FFlag::DebugLuauTimeTracing) return; @@ -152,8 +159,8 @@ struct Scope struct OptionalTailScope { - explicit OptionalTailScope(ThreadContext& context, uint16_t token, uint32_t threshold) - : context(context) + explicit OptionalTailScope(uint16_t token, uint32_t threshold) + : context(getThreadContext()) , token(token) , threshold(threshold) { @@ -188,27 +195,27 @@ struct OptionalTailScope uint32_t pos; }; -LUAU_NOINLINE std::pair createScopeData(const char* name, const char* category); +LUAU_NOINLINE uint16_t createScopeData(const char* name, const char* category); } // namespace TimeTrace } // namespace Luau // Regular scope #define LUAU_TIMETRACE_SCOPE(name, category) \ - static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ - Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first) + static uint16_t lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::Scope lttScope(lttScopeStatic) // A scope without nested scopes that may be skipped if the time it took is less than the threshold #define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ - static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ - Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec) + static uint16_t lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail, microsec) // Extra key/value data can be added to regular scopes #define LUAU_TIMETRACE_ARGUMENT(name, value) \ do \ { \ if (FFlag::DebugLuauTimeTracing) \ - lttScopeStatic.second.eventArgument(name, value); \ + lttScope.context.eventArgument(name, value); \ } while (false) #else diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index d56c8860..a1f1d469 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -347,10 +347,10 @@ void Lexer::setReadNames(bool read) const Lexeme& Lexer::next() { - return next(this->skipComments); + return next(this->skipComments, true); } -const Lexeme& Lexer::next(bool skipComments) +const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation) { // in skipComments mode we reject valid comments do @@ -359,9 +359,11 @@ const Lexeme& Lexer::next(bool skipComments) while (isSpace(peekch())) consume(); - prevLocation = lexeme.location; + if (updatePrevLocation) + prevLocation = lexeme.location; lexeme = readNext(); + updatePrevLocation = false; } while (skipComments && (lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment)); return lexeme; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index f6dfd904..95bce3ee 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -11,6 +11,9 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) +LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) +LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false) + namespace Luau { @@ -165,6 +168,7 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc Function top; top.vararg = true; + functionStack.reserve(8); functionStack.push_back(top); nameSelf = names.addStatic("self"); @@ -184,6 +188,13 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc // all hot comments parsed after the first non-comment lexeme are special in that they don't affect type checking / linting mode hotcommentHeader = false; + + // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays + localStack.reserve(16); + scratchStat.reserve(16); + scratchExpr.reserve(16); + scratchLocal.reserve(16); + scratchBinding.reserve(16); } bool Parser::blockFollow(const Lexeme& l) @@ -1108,8 +1119,12 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() { - if (options.allowTypeAnnotations && lexer.current().type == ':') + if (options.allowTypeAnnotations && + (lexer.current().type == ':' || (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow))) { + if (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow) + report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); + nextLexeme(); unsigned int oldRecursionCount = recursionCounter; @@ -1340,8 +1355,12 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray paramTypes = copy(params); + bool returnTypeIntroducer = + FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false; + // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element - if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) + if (params.size() == 1 && !varargAnnotation && monomorphic && + (FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow)) { if (allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; @@ -1349,7 +1368,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) return {params[0], {}}; } - if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack) + if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && monomorphic && allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; AstArray> paramNames = copy(names); @@ -1363,8 +1382,13 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray' instead of ':'"); + lexer.next(); + } // Users occasionally write '()' as the 'unit' type when they actually want to use 'nil', here we'll try to give a more specific error - if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0) + else if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0) { report(Location(begin.location, lexer.previousLocation()), "Expected '->' after '()' when parsing function type; did you mean 'nil'?"); @@ -1420,6 +1444,11 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isIntersection = true; } + else if (c == Lexeme::Dot3) + { + report(lexer.current().location, "Unexpected '...' after type annotation"); + nextLexeme(); + } else break; } @@ -1536,6 +1565,11 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) prefix = name.name; name = parseIndexName("field name", pointPosition); } + else if (lexer.current().type == Lexeme::Dot3) + { + report(lexer.current().location, "Unexpected '...' after type name; type pack is not allowed in this context"); + nextLexeme(); + } else if (name.name == "typeof") { Lexeme typeofBegin = lexer.current(); @@ -1571,6 +1605,17 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { return parseFunctionTypeAnnotation(allowPack); } + else if (FFlag::LuauParserFunctionKeywordAsTypeHelp && lexer.current().type == Lexeme::ReservedFunction) + { + Location location = lexer.current().location; + + nextLexeme(); + + return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, + "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " + "...any'"), + {}}; + } else { Location location = lexer.current().location; @@ -2778,7 +2823,7 @@ void Parser::nextLexeme() { if (options.captureComments) { - Lexeme::Type type = lexer.next(/* skipComments= */ false).type; + Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type; while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) { @@ -2802,7 +2847,7 @@ void Parser::nextLexeme() hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); } - type = lexer.next(/* skipComments= */ false).type; + type = lexer.next(/* skipComments= */ false, /* updatePrevLocation= */ false).type; } } else diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 9c7fed31..0dc3f3f5 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -11,7 +11,7 @@ namespace Luau { -static void vformatAppend(std::string& ret, const char* fmt, va_list args) +void vformatAppend(std::string& ret, const char* fmt, va_list args) { va_list argscopy; va_copy(argscopy, args); diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index 8079830b..e3807683 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -26,9 +26,6 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false) - -#if defined(LUAU_ENABLE_TIME_TRACE) - namespace Luau { namespace TimeTrace @@ -67,6 +64,14 @@ static double getClockTimestamp() #endif } +double getClock() +{ + static double period = getClockPeriod(); + static double start = getClockTimestamp(); + + return (getClockTimestamp() - start) * period; +} + uint32_t getClockMicroseconds() { static double period = getClockPeriod() * 1e6; @@ -74,7 +79,15 @@ uint32_t getClockMicroseconds() return uint32_t((getClockTimestamp() - start) * period); } +} // namespace TimeTrace +} // namespace Luau +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ struct GlobalContext { GlobalContext() = default; @@ -246,10 +259,9 @@ ThreadContext& getThreadContext() return context; } -std::pair createScopeData(const char* name, const char* category) +uint16_t createScopeData(const char* name, const char* category) { - uint16_t token = createToken(Luau::TimeTrace::getGlobalContext(), name, category); - return {token, Luau::TimeTrace::getThreadContext()}; + return createToken(Luau::TimeTrace::getGlobalContext(), name, category); } } // namespace TimeTrace } // namespace Luau diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 8b03ea1a..81db7c35 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -9,6 +9,7 @@ #include "FileUtils.h" LUAU_FASTFLAG(DebugLuauTimeTracing) +LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution) enum class ReportFormat { @@ -49,6 +50,9 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); + else if (FFlag::LuauTypeMismatchModuleNameResolution) + report(format, humanReadableName.c_str(), error.location, "TypeError", + Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); else report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str()); } diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index fb6ac373..39a14ec7 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -240,7 +240,7 @@ std::optional getParentPath(const std::string& path) return std::nullopt; #endif - std::string::size_type slash = path.find_last_of("\\/", path.size() - 1); + size_t slash = path.find_last_of("\\/", path.size() - 1); if (slash == 0) return "/"; @@ -253,7 +253,7 @@ std::optional getParentPath(const std::string& path) static std::string getExtension(const std::string& path) { - std::string::size_type dot = path.find_last_of(".\\/"); + size_t dot = path.find_last_of(".\\/"); if (dot == std::string::npos || path[dot] != '.') return ""; diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 5fd6d341..83060f5b 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -21,6 +21,8 @@ #include #endif +#include + LUAU_FASTFLAG(DebugLuauTimeTracing) enum class CliMode @@ -34,7 +36,8 @@ enum class CliMode enum class CompileFormat { Text, - Binary + Binary, + Null }; constexpr int MaxTraversalLimit = 50; @@ -434,6 +437,9 @@ static void runReplImpl(lua_State* L) { ic_set_default_completer(completeRepl, L); + // Reset the locale to C + setlocale(LC_ALL, "C"); + // Make brace matching easier to see ic_style_def("ic-bracematch", "teal"); @@ -579,7 +585,8 @@ static bool compileFile(const char* name, CompileFormat format) if (format == CompileFormat::Text) { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals); + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpSource(*source); } @@ -593,6 +600,8 @@ static bool compileFile(const char* name, CompileFormat format) case CompileFormat::Binary: fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); break; + case CompileFormat::Null: + break; } return true; @@ -636,13 +645,60 @@ static int assertionHandler(const char* expr, const char* file, int line, const return 1; } +static void setLuauFlags(bool state) +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + { + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = state; + } +} + +static void setFlag(std::string_view name, bool state) +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + { + if (name == flag->name) + { + flag->value = state; + return; + } + } + + fprintf(stderr, "Warning: --fflag unrecognized flag '%.*s'.\n\n", int(name.length()), name.data()); +} + +static void applyFlagKeyValue(std::string_view element) +{ + if (size_t separator = element.find('='); separator != std::string_view::npos) + { + std::string_view key = element.substr(0, separator); + std::string_view value = element.substr(separator + 1); + + if (value == "true") + setFlag(key, true); + else if (value == "false") + setFlag(key, false); + else + fprintf(stderr, "Warning: --fflag unrecognized value '%.*s' for flag '%.*s'.\n\n", int(value.length()), value.data(), int(key.length()), + key.data()); + } + else + { + if (element == "true") + setLuauFlags(true); + else if (element == "false") + setLuauFlags(false); + else + setFlag(element, true); + } +} + int replMain(int argc, char** argv) { Luau::assertHandler() = assertionHandler; - for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) - if (strncmp(flag->name, "Luau", 4) == 0) - flag->value = true; + setLuauFlags(true); CliMode mode = CliMode::Unknown; CompileFormat compileFormat{}; @@ -668,6 +724,10 @@ int replMain(int argc, char** argv) { compileFormat = CompileFormat::Text; } + else if (strcmp(argv[1], "--compile=null") == 0) + { + compileFormat = CompileFormat::Null; + } else { fprintf(stderr, "Error: Unrecognized value for '--compile' specified.\n"); @@ -727,6 +787,22 @@ int replMain(int argc, char** argv) return 1; #endif } + else if (strncmp(argv[i], "--fflags=", 9) == 0) + { + std::string_view list = argv[i] + 9; + + while (!list.empty()) + { + size_t ending = list.find(","); + + applyFlagKeyValue(list.substr(0, ending)); + + if (ending != std::string_view::npos) + list.remove_prefix(ending + 1); + else + break; + } + } else if (argv[i][0] == '-') { fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); diff --git a/CMakeLists.txt b/CMakeLists.txt index c6ccebc5..e256e234 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,7 @@ option(LUAU_BUILD_TESTS "Build tests" ON) option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) +option(LUAU_EXTERN_C "Use extern C for all APIs" OFF) if(LUAU_STATIC_CRT) cmake_minimum_required(VERSION 3.15) @@ -19,9 +20,11 @@ if(LUAU_STATIC_CRT) endif() project(Luau LANGUAGES CXX C) +add_library(Luau.Common INTERFACE) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) add_library(Luau.Analysis STATIC) +add_library(Luau.CodeGen STATIC) add_library(Luau.VM STATIC) add_library(isocline STATIC) @@ -48,8 +51,11 @@ endif() include(Sources.cmake) +target_include_directories(Luau.Common INTERFACE Common/include) + target_compile_features(Luau.Ast PUBLIC cxx_std_17) target_include_directories(Luau.Ast PUBLIC Ast/include) +target_link_libraries(Luau.Ast PUBLIC Luau.Common) target_compile_features(Luau.Compiler PUBLIC cxx_std_17) target_include_directories(Luau.Compiler PUBLIC Compiler/include) @@ -59,8 +65,13 @@ target_compile_features(Luau.Analysis PUBLIC cxx_std_17) target_include_directories(Luau.Analysis PUBLIC Analysis/include) target_link_libraries(Luau.Analysis PUBLIC Luau.Ast) +target_compile_features(Luau.CodeGen PRIVATE cxx_std_17) +target_include_directories(Luau.CodeGen PUBLIC CodeGen/include) +target_link_libraries(Luau.CodeGen PUBLIC Luau.Common) + target_compile_features(Luau.VM PRIVATE cxx_std_11) target_include_directories(Luau.VM PUBLIC VM/include) +target_link_libraries(Luau.VM PUBLIC Luau.Common) target_include_directories(isocline PUBLIC extern/isocline/include) @@ -73,6 +84,12 @@ else() list(APPEND LUAU_OPTIONS -Wall) # All warnings endif() +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + # Some gcc versions treat var in `if (type var = val)` as unused + # Some gcc versions treat variables used in constexpr if blocks as unused + list(APPEND LUAU_OPTIONS -Wno-unused) +endif() + # Enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere if(LUAU_WERROR) if(MSVC) @@ -95,19 +112,35 @@ endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) +target_compile_options(Luau.CodeGen PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) target_compile_options(isocline PRIVATE ${LUAU_OPTIONS} ${ISOCLINE_OPTIONS}) +if(LUAU_EXTERN_C) + # enable extern "C" for VM (lua.h, lualib.h) and Compiler (luacode.h) to make Luau friendlier to use from non-C++ languages + # note that we enable LUA_USE_LONGJMP=1 as well; otherwise functions like luaL_error will throw C++ exceptions, which can't be done from extern "C" functions + target_compile_definitions(Luau.VM PUBLIC LUA_USE_LONGJMP=1) + target_compile_definitions(Luau.VM PUBLIC LUA_API=extern\"C\") + target_compile_definitions(Luau.Compiler PUBLIC LUACODE_API=extern\"C\") +endif() + if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: # https://developercommunity.visualstudio.com/t/performance-regression-on-a-complex-interpreter-lo/1631863 set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) endif() +if(MSVC AND LUAU_BUILD_CLI) + # the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger + set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) + set_target_properties(Luau.Repl.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) +endif() + # embed .natvis inside the library debug information if(MSVC) target_link_options(Luau.Ast INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Ast.natvis) target_link_options(Luau.Analysis INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Analysis.natvis) + target_link_options(Luau.CodeGen INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/CodeGen.natvis) target_link_options(Luau.VM INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/VM.natvis) endif() @@ -115,6 +148,7 @@ endif() if(MSVC_IDE) target_sources(Luau.Ast PRIVATE tools/natvis/Ast.natvis) target_sources(Luau.Analysis PRIVATE tools/natvis/Analysis.natvis) + target_sources(Luau.CodeGen PRIVATE tools/natvis/CodeGen.natvis) target_sources(Luau.VM PRIVATE tools/natvis/VM.natvis) endif() @@ -142,7 +176,7 @@ endif() if(LUAU_BUILD_TESTS) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.UnitTest PRIVATE extern) - target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler) + target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen) target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Conformance PRIVATE extern) diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h new file mode 100644 index 00000000..c5979d3c --- /dev/null +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -0,0 +1,169 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Condition.h" +#include "Luau/Label.h" +#include "Luau/OperandX64.h" +#include "Luau/RegisterX64.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +class AssemblyBuilderX64 +{ +public: + explicit AssemblyBuilderX64(bool logText); + ~AssemblyBuilderX64(); + + // Base two operand instructions with 9 opcode selection + void add(OperandX64 lhs, OperandX64 rhs); + void sub(OperandX64 lhs, OperandX64 rhs); + void cmp(OperandX64 lhs, OperandX64 rhs); + void and_(OperandX64 lhs, OperandX64 rhs); + void or_(OperandX64 lhs, OperandX64 rhs); + void xor_(OperandX64 lhs, OperandX64 rhs); + + // Binary shift instructions with special rhs handling + void sal(OperandX64 lhs, OperandX64 rhs); + void sar(OperandX64 lhs, OperandX64 rhs); + void shl(OperandX64 lhs, OperandX64 rhs); + void shr(OperandX64 lhs, OperandX64 rhs); + + // Two operand mov instruction has additional specialized encodings + void mov(OperandX64 lhs, OperandX64 rhs); + void mov64(RegisterX64 lhs, int64_t imm); + + // Base one operand instruction with 2 opcode selection + void div(OperandX64 op); + void idiv(OperandX64 op); + void mul(OperandX64 op); + void neg(OperandX64 op); + void not_(OperandX64 op); + + void test(OperandX64 lhs, OperandX64 rhs); + void lea(OperandX64 lhs, OperandX64 rhs); + + void push(OperandX64 op); + void pop(OperandX64 op); + void ret(); + + // Control flow + void jcc(Condition cond, Label& label); + void jmp(Label& label); + void jmp(OperandX64 op); + + // AVX + void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vaddsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vaddss(OperandX64 dst, OperandX64 src1, OperandX64 src2); + + void vsqrtpd(OperandX64 dst, OperandX64 src); + void vsqrtps(OperandX64 dst, OperandX64 src); + void vsqrtsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vsqrtss(OperandX64 dst, OperandX64 src1, OperandX64 src2); + + void vmovsd(OperandX64 dst, OperandX64 src); + void vmovsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vmovss(OperandX64 dst, OperandX64 src); + void vmovss(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vmovapd(OperandX64 dst, OperandX64 src); + void vmovaps(OperandX64 dst, OperandX64 src); + void vmovupd(OperandX64 dst, OperandX64 src); + void vmovups(OperandX64 dst, OperandX64 src); + + // Run final checks + void finalize(); + + // Places a label at current location and returns it + Label setLabel(); + + // Assigns label position to the current location + void setLabel(Label& label); + + // Constant allocation (uses rip-relative addressing) + OperandX64 i64(int64_t value); + OperandX64 f32(float value); + OperandX64 f64(double value); + OperandX64 f32x4(float x, float y, float z, float w); + + // Resulting data and code that need to be copied over one after the other + // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' + std::vector data; + std::vector code; + + std::string text; + +private: + // Instruction archetypes + void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev, + uint8_t coderev, uint8_t code8, uint8_t code, uint8_t opreg); + void placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code, uint8_t codeImm8, uint8_t opreg); + void placeBinaryRegAndRegMem(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code); + void placeBinaryRegMemAndReg(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code); + + void placeUnaryModRegMem(const char* name, OperandX64 op, uint8_t code8, uint8_t code, uint8_t opreg); + + void placeShift(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t opreg); + + void placeJcc(const char* name, Label& label, uint8_t cc); + + void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, bool setW, uint8_t mode, uint8_t prefix); + void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, uint8_t coderev, bool setW, uint8_t mode, uint8_t prefix); + void placeAvx(const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t code, bool setW, uint8_t mode, uint8_t prefix); + + // Instruction components + void placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs); + void placeModRegMem(OperandX64 rhs, uint8_t regop); + void placeRex(RegisterX64 op); + void placeRex(OperandX64 op); + void placeRex(RegisterX64 lhs, OperandX64 rhs); + void placeVex(OperandX64 dst, OperandX64 src1, OperandX64 src2, bool setW, uint8_t mode, uint8_t prefix); + void placeImm8Or32(int32_t imm); + void placeImm8(int32_t imm); + void placeImm32(int32_t imm); + void placeImm64(int64_t imm); + void placeLabel(Label& label); + void place(uint8_t byte); + + void commit(); + LUAU_NOINLINE void extend(); + uint32_t getCodeSize(); + + // Data + size_t allocateData(size_t size, size_t align); + + // Logging of assembly in text form (Intel asm with VS disassembly formatting) + LUAU_NOINLINE void log(const char* opcode); + LUAU_NOINLINE void log(const char* opcode, OperandX64 op); + LUAU_NOINLINE void log(const char* opcode, OperandX64 op1, OperandX64 op2); + LUAU_NOINLINE void log(const char* opcode, OperandX64 op1, OperandX64 op2, OperandX64 op3); + LUAU_NOINLINE void log(Label label); + LUAU_NOINLINE void log(const char* opcode, Label label); + void log(OperandX64 op); + void logAppend(const char* fmt, ...); + + const char* getSizeName(SizeX64 size); + const char* getRegisterName(RegisterX64 reg); + + uint32_t nextLabel = 1; + std::vector

=1 are all candidates to be kept in the array part. The actual size of the array is the + * largest n such that at least half the slots between 0 and n are in use. + * Hash uses a mix of chained scatter table with Brent's variation. + * + * A main invariant of these tables is that, if an element is not in its main position (i.e. the original + * position that its hash gives to it), then the colliding element is in its own main position. + * Hence even when the load factor reaches 100%, performance remains good. + * + * Table keys can be arbitrary values unless they contain NaN. Keys are hashed and compared using raw equality, + * so even if the key is a userdata with an overridden __eq, it's not used during hash lookups. + * + * Each table has a "boundary", defined as the index k where t[k] ~= nil and t[k+1] == nil. The boundary can be + * computed using a binary search and can be adjusted when the table is modified; crucially, Luau enforces an + * invariant where the boundary must be in the array part - this enforces a consistent iteration order through the + * prefix of the table when using pairs(), and allows to implement algorithms that access elements in 1..#t range + * more efficiently. + */ #include "ltable.h" @@ -24,8 +33,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauTableRehashRework, false) - // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -38,7 +45,7 @@ static_assert(TKey{{NULL}, {0}, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not static_assert(TKey{{NULL}, {0}, LUA_TNIL, -(MAXSIZE - 1)}.next == -(MAXSIZE - 1), "not enough bits for next"); // reset cache of absent metamethods, cache is updated in luaT_gettm -#define invalidateTMcache(t) t->flags = 0 +#define invalidateTMcache(t) t->tmcache = 0 // empty hash data points to dummynode so that we can always dereference it const LuaNode luaH_dummynode = { @@ -380,6 +387,8 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) setarrayvector(L, t, nasize); /* create new hash part with appropriate size */ setnodevector(L, t, nhsize); + /* used for the migration check at the end */ + LuaNode* nnew = t->node; if (nasize < oldasize) { /* array part must shrink? */ t->sizearray = nasize; @@ -388,57 +397,51 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) { if (!ttisnil(&t->array[i])) { - if (FFlag::LuauTableRehashRework) - { - TValue ok; - setnvalue(&ok, cast_num(i + 1)); - setobjt2t(L, newkey(L, t, &ok), &t->array[i]); - } - else - { - setobjt2t(L, luaH_setnum(L, t, i + 1), &t->array[i]); - } + TValue ok; + setnvalue(&ok, cast_num(i + 1)); + setobjt2t(L, newkey(L, t, &ok), &t->array[i]); } } /* shrink array */ luaM_reallocarray(L, t->array, oldasize, nasize, TValue, t->memcat); } + /* used for the migration check at the end */ + TValue* anew = t->array; /* re-insert elements from hash part */ - if (FFlag::LuauTableRehashRework) + for (int i = twoto(oldhsize) - 1; i >= 0; i--) { - for (int i = twoto(oldhsize) - 1; i >= 0; i--) + LuaNode* old = nold + i; + if (!ttisnil(gval(old))) { - LuaNode* old = nold + i; - if (!ttisnil(gval(old))) - { - TValue ok; - getnodekey(L, &ok, old); - setobjt2t(L, arrayornewkey(L, t, &ok), gval(old)); - } - } - } - else - { - for (int i = twoto(oldhsize) - 1; i >= 0; i--) - { - LuaNode* old = nold + i; - if (!ttisnil(gval(old))) - { - TValue ok; - getnodekey(L, &ok, old); - setobjt2t(L, luaH_set(L, t, &ok), gval(old)); - } + TValue ok; + getnodekey(L, &ok, old); + setobjt2t(L, arrayornewkey(L, t, &ok), gval(old)); } } + /* make sure we haven't recursively rehashed during element migration */ + LUAU_ASSERT(nnew == t->node); + LUAU_ASSERT(anew == t->array); + if (nold != dummynode) luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); /* free old array */ } +static int adjustasize(Table* t, int size, const TValue* ek) +{ + bool tbound = t->node != dummynode || size < t->sizearray; + int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; + /* move the array size up until the boundary is guaranteed to be inside the array part */ + while (size + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, size + 1)))) + size++; + return size; +} + void luaH_resizearray(lua_State* L, Table* t, int nasize) { int nsize = (t->node == dummynode) ? 0 : sizenode(t); - resize(L, t, nasize, nsize); + int asize = adjustasize(t, nasize, NULL); + resize(L, t, asize, nsize); } void luaH_resizehash(lua_State* L, Table* t, int nhsize) @@ -460,8 +463,11 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) totaluse++; /* compute new size for array part */ int na = computesizes(nums, &nasize); + int nh = totaluse - na; + /* enforce the boundary invariant; for performance, only do hash lookups if we must */ + nasize = adjustasize(t, nasize, ek); /* resize the table to new computed sizes */ - resize(L, t, nasize, totaluse - na); + resize(L, t, nasize, nh); } /* @@ -473,7 +479,7 @@ Table* luaH_new(lua_State* L, int narray, int nhash) Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = NULL; - t->flags = cast_byte(~0); + t->tmcache = cast_byte(~0); t->array = NULL; t->sizearray = 0; t->lastfree = 0; @@ -520,29 +526,30 @@ static LuaNode* getfreepos(Table* t) */ static TValue* newkey(lua_State* L, Table* t, const TValue* key) { + /* enforce boundary invariant */ + if (ttisnumber(key) && nvalue(key) == t->sizearray + 1) + { + rehash(L, t, key); /* grow table */ + + /* after rehash, numeric keys might be located in the new array part, but won't be found in the node part */ + return arrayornewkey(L, t, key); + } + LuaNode* mp = mainposition(t, key); if (!ttisnil(gval(mp)) || mp == dummynode) { - LuaNode* othern; LuaNode* n = getfreepos(t); /* get a free place */ if (n == NULL) { /* cannot find a free place? */ rehash(L, t, key); /* grow table */ - if (!FFlag::LuauTableRehashRework) - { - return luaH_set(L, t, key); /* re-insert key into grown table */ - } - else - { - // after rehash, numeric keys might be located in the new array part, but won't be found in the node part - return arrayornewkey(L, t, key); - } + /* after rehash, numeric keys might be located in the new array part, but won't be found in the node part */ + return arrayornewkey(L, t, key); } LUAU_ASSERT(n != dummynode); TValue mk; getnodekey(L, &mk, mp); - othern = mainposition(t, &mk); + LuaNode* othern = mainposition(t, &mk); if (othern != mp) { /* is colliding node out of its main position? */ /* yes; move colliding node into free position */ @@ -702,36 +709,6 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key) } } -static LUAU_NOINLINE int unbound_search(Table* t, unsigned int j) -{ - unsigned int i = j; /* i is zero or a present index */ - j++; - /* find `i' and `j' such that i is present and j is not */ - while (!ttisnil(luaH_getnum(t, j))) - { - i = j; - j *= 2; - if (j > cast_to(unsigned int, INT_MAX)) - { /* overflow? */ - /* table was built with bad purposes: resort to linear search */ - i = 1; - while (!ttisnil(luaH_getnum(t, i))) - i++; - return i - 1; - } - } - /* now do a binary search between them */ - while (j - i > 1) - { - unsigned int m = (i + j) / 2; - if (ttisnil(luaH_getnum(t, m))) - j = m; - else - i = m; - } - return i; -} - static int updateaboundary(Table* t, int boundary) { if (boundary < t->sizearray && ttisnil(&t->array[boundary - 1])) @@ -788,11 +765,12 @@ int luaH_getn(Table* t) maybesetaboundary(t, boundary); return boundary; } - /* else must find a boundary in hash part */ - else if (t->node == dummynode) /* hash part is empty? */ - return j; /* that is easy... */ else - return unbound_search(t, j); + { + /* validate boundary invariant */ + LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1))); + return j; + } } Table* luaH_clone(lua_State* L, Table* tt) @@ -800,7 +778,7 @@ Table* luaH_clone(lua_State* L, Table* tt) Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = tt->metatable; - t->flags = tt->flags; + t->tmcache = tt->tmcache; t->array = NULL; t->sizearray = 0; t->lsizenode = 0; @@ -857,5 +835,5 @@ void luaH_clear(Table* tt) } /* back to empty -> no tag methods present */ - tt->flags = cast_byte(~0); + tt->tmcache = cast_byte(~0); } diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 00753742..27187c61 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,8 +10,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauTableClone, false) - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -512,9 +510,6 @@ static int tisfrozen(lua_State* L) static int tclone(lua_State* L) { - if (!FFlag::LuauTableClone) - luaG_runerror(L, "table.clone is not available"); - luaL_checktype(L, 1, LUA_TTABLE); luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable"); diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 106efb2b..e7df4e53 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -37,6 +37,8 @@ const char* const luaT_eventname[] = { "__newindex", "__mode", "__namecall", + "__call", + "__iter", "__eq", @@ -54,13 +56,13 @@ const char* const luaT_eventname[] = { "__lt", "__le", "__concat", - "__call", "__type", }; // clang-format on static_assert(sizeof(luaT_typenames) / sizeof(luaT_typenames[0]) == LUA_T_COUNT, "luaT_typenames size mismatch"); static_assert(sizeof(luaT_eventname) / sizeof(luaT_eventname[0]) == TM_N, "luaT_eventname size mismatch"); +static_assert(TM_EQ < 8, "fasttm optimization stores a bitfield with metamethods in a byte"); void luaT_init(lua_State* L) { @@ -86,8 +88,8 @@ const TValue* luaT_gettm(Table* events, TMS event, TString* ename) const TValue* tm = luaH_getstr(events, ename); LUAU_ASSERT(event <= TM_EQ); if (ttisnil(tm)) - { /* no tag method? */ - events->flags |= cast_byte(1u << event); /* cache this fact */ + { /* no tag method? */ + events->tmcache |= cast_byte(1u << event); /* cache this fact */ return NULL; } else diff --git a/VM/src/ltm.h b/VM/src/ltm.h index 0e4e915d..a5223941 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -16,6 +16,8 @@ typedef enum TM_NEWINDEX, TM_MODE, TM_NAMECALL, + TM_CALL, + TM_ITER, TM_EQ, /* last tag method with `fast' access */ @@ -33,17 +35,16 @@ typedef enum TM_LT, TM_LE, TM_CONCAT, - TM_CALL, TM_TYPE, TM_N /* number of elements in the enum */ } TMS; // clang-format on -#define gfasttm(g, et, e) ((et) == NULL ? NULL : ((et)->flags & (1u << (e))) ? NULL : luaT_gettm(et, e, (g)->tmname[e])) +#define gfasttm(g, et, e) ((et) == NULL ? NULL : ((et)->tmcache & (1u << (e))) ? NULL : luaT_gettm(et, e, (g)->tmname[e])) #define fasttm(l, et, e) gfasttm(l->global, et, e) -#define fastnotm(et, e) ((et) == NULL || ((et)->flags & (1u << (e)))) +#define fastnotm(et, e) ((et) == NULL || ((et)->tmcache & (1u << (e)))) LUAI_DATA const char* const luaT_typenames[]; LUAI_DATA const char* const luaT_eventname[]; diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp index 819d1863..c2110cb3 100644 --- a/VM/src/ludata.cpp +++ b/VM/src/ludata.cpp @@ -22,14 +22,23 @@ Udata* luaU_newudata(lua_State* L, size_t s, int tag) void luaU_freeudata(lua_State* L, Udata* u, lua_Page* page) { - void (*dtor)(void*) = nullptr; if (u->tag < LUA_UTAG_LIMIT) + { + void (*dtor)(lua_State*, void*) = nullptr; dtor = L->global->udatagc[u->tag]; + // TODO: access to L here is highly unsafe since this is called during internal GC traversal + // certain operations such as lua_getthreaddata are okay, but by and large this risks crashes on improper use + if (dtor) + dtor(L, u->data); + } else if (u->tag == UTAG_IDTOR) + { + void (*dtor)(void*) = nullptr; memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); + if (dtor) + dtor(u->data); + } - if (dtor) - dtor(u->data); luaM_freegco(L, u, sizeudata(u->len), u->memcat, page); } diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 96a87b7e..e0a96474 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -108,7 +108,7 @@ VM_DISPATCH_OP(LOP_FORGLOOP_NEXT), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_JUMPIFEQK), VM_DISPATCH_OP(LOP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ - VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), + VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP), #if defined(__GNUC__) || defined(__clang__) #define VM_USE_CGOTO 1 @@ -148,8 +148,9 @@ LUAU_NOINLINE static void luau_prepareFORN(lua_State* L, StkId plimit, StkId pst LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c) { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) StkId ra = &L->base[a]; - LUAU_ASSERT(ra + 6 <= L->top); + LUAU_ASSERT(ra + 3 <= L->top); setobjs2s(L, ra + 3 + 2, ra + 2); setobjs2s(L, ra + 3 + 1, ra + 1); @@ -178,7 +179,7 @@ LUAU_NOINLINE static void luau_callTM(lua_State* L, int nparams, int res) ++L->nCcalls; if (L->nCcalls >= LUAI_MAXCCALLS) - luaG_runerror(L, "C stack overflow"); + luaD_checkCstack(L); luaD_checkstack(L, LUA_MINSTACK); @@ -691,7 +692,7 @@ static void luau_execute(lua_State* L) } else { - // slow-path, may invoke Lua calls via __index metamethod + // slow-path, may invoke Lua calls via __newindex metamethod L->cachedslot = slot; VM_PROTECT(luaV_settable(L, rb, kv, ra)); // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ @@ -701,7 +702,7 @@ static void luau_execute(lua_State* L) } else { - // fast-path: user data with C __index TM + // fast-path: user data with C __newindex TM const TValue* fn = 0; if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_NEWINDEX)) && ttisfunction(fn) && clvalue(fn)->isC) { @@ -722,7 +723,7 @@ static void luau_execute(lua_State* L) } else { - // slow-path, may invoke Lua calls via __index metamethod + // slow-path, may invoke Lua calls via __newindex metamethod VM_PROTECT(luaV_settable(L, rb, kv, ra)); VM_NEXT(); } @@ -2202,20 +2203,138 @@ static void luau_execute(lua_State* L) } } + VM_CASE(LOP_FORGPREP) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + if (ttisfunction(ra)) + { + /* will be called during FORGLOOP */ + } + else + { + Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); + + if (const TValue* fn = fasttm(L, mt, TM_ITER)) + { + setobj2s(L, ra + 1, ra); + setobj2s(L, ra, fn); + + L->top = ra + 2; /* func + self arg */ + LUAU_ASSERT(L->top <= L->stack_last); + + VM_PROTECT(luaD_call(L, ra, 3)); + L->top = L->ci->top; + } + else if (fasttm(L, mt, TM_CALL)) + { + /* table or userdata with __call, will be called during FORGLOOP */ + /* TODO: we might be able to stop supporting this depending on whether it's used in practice */ + } + else if (ttistable(ra)) + { + /* set up registers for builtin iteration */ + setobj2s(L, ra + 1, ra); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setnilvalue(ra); + } + else + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } + } + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + VM_CASE(LOP_FORGLOOP) { VM_INTERRUPT(); Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); uint32_t aux = *pc; - // note: this is a slow generic path, fast-path is FORGLOOP_INEXT/NEXT - bool stop; - VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), aux)); + // fast-path: builtin table iteration + if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) + { + Table* h = hvalue(ra + 1); + int index = int(reinterpret_cast(pvalue(ra + 2))); - // note that we need to increment pc by 1 to exit the loop since we need to skip over aux - pc += stop ? 1 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); + int sizearray = h->sizearray; + int sizenode = 1 << h->lsizenode; + + // clear extra variables since we might have more than two + if (LUAU_UNLIKELY(aux > 2)) + for (int i = 2; i < int(aux); ++i) + setnilvalue(ra + 3 + i); + + // first we advance index through the array portion + while (unsigned(index) < unsigned(sizearray)) + { + if (!ttisnil(&h->array[index])) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setnvalue(ra + 3, double(index + 1)); + setobj2s(L, ra + 4, &h->array[index]); + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + index++; + } + + // then we advance index through the hash portion + while (unsigned(index - sizearray) < unsigned(sizenode)) + { + LuaNode* n = &h->node[index - sizearray]; + + if (!ttisnil(gval(n))) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + getnodekey(L, ra + 3, n); + setobj2s(L, ra + 4, gval(n)); + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + index++; + } + + // fallthrough to exit + pc++; + VM_NEXT(); + } + else + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + setobjs2s(L, ra + 3 + 2, ra + 2); + setobjs2s(L, ra + 3 + 1, ra + 1); + setobjs2s(L, ra + 3, ra); + + L->top = ra + 3 + 3; /* func + 2 args (state and index) */ + LUAU_ASSERT(L->top <= L->stack_last); + + VM_PROTECT(luaD_call(L, ra + 3, aux)); + L->top = L->ci->top; + + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + + // copy first variable back into the iteration index + setobjs2s(L, ra + 2, ra + 3); + + // note that we need to increment pc by 1 to exit the loop since we need to skip over aux + pc += ttisnil(ra + 3) ? 1 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } } VM_CASE(LOP_FORGPREP_INEXT) @@ -2226,8 +2345,14 @@ static void luau_execute(lua_State* L) // fast-path: ipairs/inext if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) { + setnilvalue(ra); + /* ra+1 is already the table */ setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } + else if (!ttisfunction(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } pc += LUAU_INSN_D(insn); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -2241,7 +2366,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: ipairs/inext - if (ttistable(ra + 1) && ttislightuserdata(ra + 2)) + if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) { Table* h = hvalue(ra + 1); int index = int(reinterpret_cast(pvalue(ra + 2))); @@ -2266,23 +2391,9 @@ static void luau_execute(lua_State* L) VM_NEXT(); } } - else if (h->lsizenode == 0 && ttisnil(gval(h->node))) - { - // hash part is empty: fallthrough to exit - VM_NEXT(); - } else { - // the table has a hash part; index + 1 may appear in it in which case we need to iterate through the hash portion as well - const TValue* val = luaH_getnum(h, index + 1); - - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); - setnvalue(ra + 3, double(index + 1)); - setobj2s(L, ra + 4, val); - - // note that nil elements inside the array terminate the traversal - pc += ttisnil(ra + 4) ? 0 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + // fallthrough to exit VM_NEXT(); } } @@ -2306,8 +2417,14 @@ static void luau_execute(lua_State* L) // fast-path: pairs/next if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) { + setnilvalue(ra); + /* ra+1 is already the table */ setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } + else if (!ttisfunction(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } pc += LUAU_INSN_D(insn); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -2321,7 +2438,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: pairs/next - if (ttistable(ra + 1) && ttislightuserdata(ra + 2)) + if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) { Table* h = hvalue(ra + 1); int index = int(reinterpret_cast(pvalue(ra + 2))); @@ -2702,7 +2819,7 @@ static void luau_execute(lua_State* L) { VM_PROTECT_PC(); - int n = f(L, ra, arg, nresults, nullptr, nparams); + int n = f(L, ra, arg, nresults, NULL, nparams); if (n >= 0) { diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 4e5435b7..86afddd2 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Force, false) - // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template struct TempBuffer @@ -156,12 +154,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size return 1; } - if (FFlag::LuauBytecodeV2Force ? (version != LBC_VERSION_FUTURE) : (version != LBC_VERSION && version != LBC_VERSION_FUTURE)) + if (version < LBC_VERSION_MIN || version > LBC_VERSION_MAX) { char chunkid[LUA_IDSIZE]; luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); - lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, - FFlag::LuauBytecodeV2Force ? LBC_VERSION_FUTURE : LBC_VERSION, version); + lua_pushfstring(L, "%s: bytecode version mismatch (expected [%d..%d], got %d)", chunkid, LBC_VERSION_MIN, LBC_VERSION_MAX, version); return 1; } @@ -292,11 +289,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->p[j] = protos[fid]; } - if (FFlag::LuauBytecodeV2Force || version == LBC_VERSION_FUTURE) - p->linedefined = readVarInt(data, size, offset); - else - p->linedefined = -1; - + p->linedefined = readVarInt(data, size, offset); p->debugname = readString(strings, data, size, offset); uint8_t lineinfo = read(data, size, offset); diff --git a/bench/bench.py b/bench/bench.py index 39f219f3..67fc8cf7 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -814,13 +814,12 @@ def run(args, argsubcb): analyzeResult('', mainResult, compareResults) else: - for subdir, dirs, files in os.walk(arguments.folder): - for filename in files: - filepath = subdir + os.sep + filename - - if filename.endswith(".lua"): - if arguments.run_test == None or re.match(arguments.run_test, filename[:-4]): - runTest(subdir, filename, filepath) + all_files = [subdir + os.sep + filename for subdir, dirs, files in os.walk(arguments.folder) for filename in files] + for filepath in sorted(all_files): + subdir, filename = os.path.split(filepath) + if filename.endswith(".lua"): + if arguments.run_test == None or re.match(arguments.run_test, filename[:-4]): + runTest(subdir, filename, filepath) if arguments.sort and len(plotValueLists) > 1: rearrange(rearrangeSortKeyForComparison) diff --git a/bench/measure_time.py b/bench/measure_time.py new file mode 100644 index 00000000..c41c7d2c --- /dev/null +++ b/bench/measure_time.py @@ -0,0 +1,43 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +import os, sys, time, numpy + +try: + import scipy + from scipy import mean, stats +except ModuleNotFoundError: + print("Warning: scipy package is not installed, confidence values will not be available") + stats = None + +duration_list = [] + +DEFAULT_CYCLES_TO_RUN = 100 +cycles_to_run = DEFAULT_CYCLES_TO_RUN + +try: + cycles_to_run = sys.argv[3] if sys.argv[3] else DEFAULT_CYCLES_TO_RUN + cycles_to_run = int(cycles_to_run) +except IndexError: + pass +except (ValueError, TypeError): + cycles_to_run = DEFAULT_CYCLES_TO_RUN + print("Error: Cycles to run argument must be an integer. Using default value of {}".format(DEFAULT_CYCLES_TO_RUN)) + +# Numpy complains if we provide a cycle count of less than 3 ~ default to 3 whenever a lower value is provided +cycles_to_run = cycles_to_run if cycles_to_run > 2 else 3 + +for i in range(1,cycles_to_run): + start = time.perf_counter() + + # Run the code you want to measure here + os.system(sys.argv[1]) + + end = time.perf_counter() + + duration_ms = (end - start) * 1000 + duration_list.append(duration_ms) + +# Stats +mean = numpy.mean(duration_list) +std_err = stats.sem(duration_list) + +print("SUCCESS: {} : {:.2f}ms +/- {:.2f}% on luau ".format('duration', mean,std_err)) diff --git a/bench/micro_tests/test_LargeTableSum_loop_iter.lua b/bench/micro_tests/test_LargeTableSum_loop_iter.lua new file mode 100644 index 00000000..057420f6 --- /dev/null +++ b/bench/micro_tests/test_LargeTableSum_loop_iter.lua @@ -0,0 +1,17 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local t = {} + + for i=1,1000000 do t[i] = i end + + local ts0 = os.clock() + local sum = 0 + for k,v in t do sum = sum + v end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "LargeTableSum: for k,v in {}") diff --git a/bench/static_analysis/LuauPolyfillMap.lua b/bench/static_analysis/LuauPolyfillMap.lua new file mode 100644 index 00000000..1cfd0181 --- /dev/null +++ b/bench/static_analysis/LuauPolyfillMap.lua @@ -0,0 +1,962 @@ +-- This file is part of the Roblox luau-polyfill repository and is licensed under MIT License; see LICENSE.txt for details +--!nonstrict +-- #region Array +-- Array related +local Array = {} +local Object = {} +local Map = {} + +type Array = { [number]: T } +type callbackFn = (element: V, key: K, map: Map) -> () +type callbackFnWithThisArg = (thisArg: Object, value: V, key: K, map: Map) -> () +type Map = { + size: number, + -- method definitions + set: (self: Map, K, V) -> Map, + get: (self: Map, K) -> V | nil, + clear: (self: Map) -> (), + delete: (self: Map, K) -> boolean, + forEach: (self: Map, callback: callbackFn | callbackFnWithThisArg, thisArg: Object?) -> (), + has: (self: Map, K) -> boolean, + keys: (self: Map) -> Array, + values: (self: Map) -> Array, + entries: (self: Map) -> Array>, + ipairs: (self: Map) -> any, + [K]: V, + _map: { [K]: V }, + _array: { [number]: K }, +} +type mapFn = (element: T, index: number) -> U +type mapFnWithThisArg = (thisArg: any, element: T, index: number) -> U +type Object = { [string]: any } +type Table = { [T]: V } +type Tuple = Array + +local Set = {} + +-- #region Array +function Array.isArray(value: any): boolean + if typeof(value) ~= "table" then + return false + end + if next(value) == nil then + -- an empty table is an empty array + return true + end + + local length = #value + + if length == 0 then + return false + end + + local count = 0 + local sum = 0 + for key in pairs(value) do + if typeof(key) ~= "number" then + return false + end + if key % 1 ~= 0 or key < 1 then + return false + end + count += 1 + sum += key + end + + return sum == (count * (count + 1) / 2) +end + +function Array.from( + value: string | Array | Object, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? +): Array + if value == nil then + error("cannot create array from a nil value") + end + local valueType = typeof(value) + + local array = {} + + if valueType == "table" and Array.isArray(value) then + if mapFn then + for i = 1, #(value :: Array) do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, (value :: Array)[i], i) + else + array[i] = (mapFn :: mapFn)((value :: Array)[i], i) + end + end + else + for i = 1, #(value :: Array) do + array[i] = (value :: Array)[i] + end + end + elseif instanceOf(value, Set) then + if mapFn then + for i, v in (value :: any):ipairs() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, v, i) + else + array[i] = (mapFn :: mapFn)(v, i) + end + end + else + for i, v in (value :: any):ipairs() do + array[i] = v + end + end + elseif instanceOf(value, Map) then + if mapFn then + for i, v in (value :: any):ipairs() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, v, i) + else + array[i] = (mapFn :: mapFn)(v, i) + end + end + else + for i, v in (value :: any):ipairs() do + array[i] = v + end + end + elseif valueType == "string" then + if mapFn then + for i = 1, (value :: string):len() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, (value :: any):sub(i, i), i) + else + array[i] = (mapFn :: mapFn)((value :: any):sub(i, i), i) + end + end + else + for i = 1, (value :: string):len() do + array[i] = (value :: any):sub(i, i) + end + end + end + + return array +end + +type callbackFnArrayMap = (element: T, index: number, array: Array) -> U +type callbackFnWithThisArgArrayMap = (thisArg: V, element: T, index: number, array: Array) -> U + +-- Implements Javascript's `Array.prototype.map` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/map +function Array.map( + t: Array, + callback: callbackFnArrayMap | callbackFnWithThisArgArrayMap, + thisArg: V? +): Array + if typeof(t) ~= "table" then + error(string.format("Array.map called on %s", typeof(t))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local len = #t + local A = {} + local k = 1 + + while k <= len do + local kValue = t[k] + + if kValue ~= nil then + local mappedValue + + if thisArg ~= nil then + mappedValue = (callback :: callbackFnWithThisArgArrayMap)(thisArg, kValue, k, t) + else + mappedValue = (callback :: callbackFnArrayMap)(kValue, k, t) + end + + A[k] = mappedValue + end + k += 1 + end + + return A +end + +type Function = (any, any, number, any) -> any + +-- Implements Javascript's `Array.prototype.reduce` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/reduce +function Array.reduce(array: Array, callback: Function, initialValue: any?): any + if typeof(array) ~= "table" then + error(string.format("Array.reduce called on %s", typeof(array))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local length = #array + + local value + local initial = 1 + + if initialValue ~= nil then + value = initialValue + else + initial = 2 + if length == 0 then + error("reduce of empty array with no initial value") + end + value = array[1] + end + + for i = initial, length do + value = callback(value, array[i], i, array) + end + + return value +end + +type callbackFnArrayForEach = (element: T, index: number, array: Array) -> () +type callbackFnWithThisArgArrayForEach = (thisArg: U, element: T, index: number, array: Array) -> () + +-- Implements Javascript's `Array.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/forEach +function Array.forEach( + t: Array, + callback: callbackFnArrayForEach | callbackFnWithThisArgArrayForEach, + thisArg: U? +): () + if typeof(t) ~= "table" then + error(string.format("Array.forEach called on %s", typeof(t))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local len = #t + local k = 1 + + while k <= len do + local kValue = t[k] + + if thisArg ~= nil then + (callback :: callbackFnWithThisArgArrayForEach)(thisArg, kValue, k, t) + else + (callback :: callbackFnArrayForEach)(kValue, k, t) + end + + if #t < len then + -- don't iterate on removed items, don't iterate more than original length + len = #t + end + k += 1 + end +end +-- #endregion + +-- #region Set +Set.__index = Set + +type callbackFnSet = (value: T, key: T, set: Set) -> () +type callbackFnWithThisArgSet = (thisArg: Object, value: T, key: T, set: Set) -> () + +export type Set = { + size: number, + -- method definitions + add: (self: Set, T) -> Set, + clear: (self: Set) -> (), + delete: (self: Set, T) -> boolean, + forEach: (self: Set, callback: callbackFnSet | callbackFnWithThisArgSet, thisArg: Object?) -> (), + has: (self: Set, T) -> boolean, + ipairs: (self: Set) -> any, +} + +type Iterable = { ipairs: (any) -> any } + +function Set.new(iterable: Array | Set | Iterable | string | nil): Set + local array = {} + local map = {} + if iterable ~= nil then + local arrayIterable: Array + -- ROBLOX TODO: remove type casting from (iterable :: any).ipairs in next release + if typeof(iterable) == "table" then + if Array.isArray(iterable) then + arrayIterable = Array.from(iterable :: Array) + elseif typeof((iterable :: Iterable).ipairs) == "function" then + -- handle in loop below + elseif _G.__DEV__ then + error("cannot create array from an object-like table") + end + elseif typeof(iterable) == "string" then + arrayIterable = Array.from(iterable :: string) + else + error(("cannot create array from value of type `%s`"):format(typeof(iterable))) + end + + if arrayIterable then + for _, element in ipairs(arrayIterable) do + if not map[element] then + map[element] = true + table.insert(array, element) + end + end + elseif typeof(iterable) == "table" and typeof((iterable :: Iterable).ipairs) == "function" then + for _, element in (iterable :: Iterable):ipairs() do + if not map[element] then + map[element] = true + table.insert(array, element) + end + end + end + end + + return (setmetatable({ + size = #array, + _map = map, + _array = array, + }, Set) :: any) :: Set +end + +function Set:add(value) + if not self._map[value] then + -- Luau FIXME: analyze should know self is Set which includes size as a number + self.size = self.size :: number + 1 + self._map[value] = true + table.insert(self._array, value) + end + return self +end + +function Set:clear() + self.size = 0 + table.clear(self._map) + table.clear(self._array) +end + +function Set:delete(value): boolean + if not self._map[value] then + return false + end + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number - 1 + self._map[value] = nil + local index = table.find(self._array, value) + if index then + table.remove(self._array, index) + end + return true +end + +-- Implements Javascript's `Map.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Set/forEach +function Set:forEach(callback: callbackFnSet | callbackFnWithThisArgSet, thisArg: Object?): () + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + return Array.forEach(self._array, function(value: T) + if thisArg ~= nil then + (callback :: callbackFnWithThisArgSet)(thisArg, value, value, self) + else + (callback :: callbackFnSet)(value, value, self) + end + end) +end + +function Set:has(value): boolean + return self._map[value] ~= nil +end + +function Set:ipairs() + return ipairs(self._array) +end + +-- #endregion Set + +-- #region Object +function Object.entries(value: string | Object | Array): Array + assert(value :: any ~= nil, "cannot get entries from a nil value") + local valueType = typeof(value) + + local entries: Array> = {} + if valueType == "table" then + for key, keyValue in pairs(value :: Object) do + -- Luau FIXME: Luau should see entries as Array, given object is [string]: any, but it sees it as Array> despite all the manual annotation + table.insert(entries, { key :: string, keyValue :: any }) + end + elseif valueType == "string" then + for i = 1, string.len(value :: string) do + entries[i] = { tostring(i), string.sub(value :: string, i, i) } + end + end + + return entries +end + +-- #endregion + +-- #region instanceOf + +-- ROBLOX note: Typed tbl as any to work with strict type analyze +-- polyfill for https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/instanceof +function instanceOf(tbl: any, class) + assert(typeof(class) == "table", "Received a non-table as the second argument for instanceof") + + if typeof(tbl) ~= "table" then + return false + end + + local ok, hasNew = pcall(function() + return class.new ~= nil and tbl.new == class.new + end) + if ok and hasNew then + return true + end + + local seen = { tbl = true } + + while tbl and typeof(tbl) == "table" do + tbl = getmetatable(tbl) + if typeof(tbl) == "table" then + tbl = tbl.__index + + if tbl == class then + return true + end + end + + -- if we still have a valid table then check against seen + if typeof(tbl) == "table" then + if seen[tbl] then + return false + end + seen[tbl] = true + end + end + + return false +end +-- #endregion + +function Map.new(iterable: Array>?): Map + local array = {} + local map = {} + if iterable ~= nil then + local arrayFromIterable + local iterableType = typeof(iterable) + if iterableType == "table" then + if #iterable > 0 and typeof(iterable[1]) ~= "table" then + error("cannot create Map from {K, V} form, it must be { {K, V}... }") + end + + arrayFromIterable = Array.from(iterable) + else + error(("cannot create array from value of type `%s`"):format(iterableType)) + end + + for _, entry in ipairs(arrayFromIterable) do + local key = entry[1] + if _G.__DEV__ then + if key == nil then + error("cannot create Map from a table that isn't an array.") + end + end + local val = entry[2] + -- only add to array if new + if map[key] == nil then + table.insert(array, key) + end + -- always assign + map[key] = val + end + end + + return (setmetatable({ + size = #array, + _map = map, + _array = array, + }, Map) :: any) :: Map +end + +function Map:set(key: K, value: V): Map + -- preserve initial insertion order + if self._map[key] == nil then + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number + 1 + table.insert(self._array, key) + end + -- always update value + self._map[key] = value + return self +end + +function Map:get(key) + return self._map[key] +end + +function Map:clear() + local table_: any = table + self.size = 0 + table_.clear(self._map) + table_.clear(self._array) +end + +function Map:delete(key): boolean + if self._map[key] == nil then + return false + end + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number - 1 + self._map[key] = nil + local index = table.find(self._array, key) + if index then + table.remove(self._array, index) + end + return true +end + +-- Implements Javascript's `Map.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Map/forEach +function Map:forEach(callback: callbackFn | callbackFnWithThisArg, thisArg: Object?): () + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + return Array.forEach(self._array, function(key: K) + local value: V = self._map[key] :: V + + if thisArg ~= nil then + (callback :: callbackFnWithThisArg)(thisArg, value, key, self) + else + (callback :: callbackFn)(value, key, self) + end + end) +end + +function Map:has(key): boolean + return self._map[key] ~= nil +end + +function Map:keys() + return self._array +end + +function Map:values() + return Array.map(self._array, function(key) + return self._map[key] + end) +end + +function Map:entries() + return Array.map(self._array, function(key) + return { key, self._map[key] } + end) +end + +function Map:ipairs() + return ipairs(self:entries()) +end + +function Map.__index(self, key) + local mapProp = rawget(Map, key) + if mapProp ~= nil then + return mapProp + end + + return Map.get(self, key) +end + +function Map.__newindex(table_, key, value) + table_:set(key, value) +end + +local function coerceToMap(mapLike: Map | Table): Map + return instanceOf(mapLike, Map) and mapLike :: Map -- ROBLOX: order is preservered + or Map.new(Object.entries(mapLike)) -- ROBLOX: order is not preserved +end + +-- local function coerceToTable(mapLike: Map | Table): Table +-- if not instanceOf(mapLike, Map) then +-- return mapLike +-- end + +-- -- create table from map +-- return Array.reduce(mapLike:entries(), function(tbl, entry) +-- tbl[entry[1]] = entry[2] +-- return tbl +-- end, {}) +-- end + +-- #region Tests to verify it works as expected +local function it(description: string, fn: () -> ()) + local ok, result = pcall(fn) + + if not ok then + error("Failed test: " .. description .. "\n" .. result) + end +end + +local AN_ITEM = "bar" +local ANOTHER_ITEM = "baz" + +-- #region [Describe] "Map" +-- #region [Child Describe] "constructors" +it("creates an empty array", function() + local foo = Map.new() + assert(foo.size == 0) +end) + +it("creates a Map from an array", function() + local foo = Map.new({ + { AN_ITEM, "foo" }, + { ANOTHER_ITEM, "val" }, + }) + assert(foo.size == 2) + assert(foo:has(AN_ITEM) == true) + assert(foo:has(ANOTHER_ITEM) == true) +end) + +it("creates a Map from an array with duplicate keys", function() + local foo = Map.new({ + { AN_ITEM, "foo1" }, + { AN_ITEM, "foo2" }, + }) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == "foo2") + + assert(#foo:keys() == 1 and foo:keys()[1] == AN_ITEM) + assert(#foo:values() == 1 and foo:values()[1] == "foo2") + assert(#foo:entries() == 1) + assert(#foo:entries()[1] == 2) + + assert(foo:entries()[1][1] == AN_ITEM) + assert(foo:entries()[1][2] == "foo2") +end) + +it("preserves the order of keys first assignment", function() + local foo = Map.new({ + { AN_ITEM, "foo1" }, + { ANOTHER_ITEM, "bar" }, + { AN_ITEM, "foo2" }, + }) + assert(foo.size == 2) + assert(foo:get(AN_ITEM) == "foo2") + assert(foo:get(ANOTHER_ITEM) == "bar") + + assert(foo:keys()[1] == AN_ITEM) + assert(foo:keys()[2] == ANOTHER_ITEM) + assert(foo:values()[1] == "foo2") + assert(foo:values()[2] == "bar") + assert(foo:entries()[1][1] == AN_ITEM) + assert(foo:entries()[1][2] == "foo2") + assert(foo:entries()[2][1] == ANOTHER_ITEM) + assert(foo:entries()[2][2] == "bar") +end) +-- #endregion + +-- #region [Child Describe] "type" +it("instanceOf return true for an actual Map object", function() + local foo = Map.new() + assert(instanceOf(foo, Map) == true) +end) + +it("instanceOf return false for an regular plain object", function() + local foo = {} + assert(instanceOf(foo, Map) == false) +end) +-- #endregion + +-- #region [Child Describe] "set" +it("returns the Map object", function() + local foo = Map.new() + assert(foo:set(1, "baz") == foo) +end) + +it("increments the size if the element is added for the first time", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo.size == 1) +end) + +it("does not increment the size the second time an element is added", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(AN_ITEM, "val") + assert(foo.size == 1) +end) + +it("sets values correctly to true/false", function() + -- Luau FIXME: Luau insists that arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + foo:set(AN_ITEM, false) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == false) + + foo:set(AN_ITEM, true) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == true) + + foo:set(AN_ITEM, false) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == false) +end) + +-- #endregion + +-- #region [Child Describe] "get" +it("returns value of item from provided key", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:get(AN_ITEM) == "foo") +end) + +it("returns nil if the item is not in the Map", function() + local foo = Map.new() + assert(foo:get(AN_ITEM) == nil) +end) +-- #endregion + +-- #region [Child Describe] "clear" +it("sets the size to zero", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:clear() + assert(foo.size == 0) +end) + +it("removes the items from the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:clear() + assert(foo:has(AN_ITEM) == false) +end) +-- #endregion + +-- #region [Child Describe] "delete" +it("removes the items from the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(AN_ITEM) + assert(foo:has(AN_ITEM) == false) +end) + +it("returns true if the item was in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:delete(AN_ITEM) == true) +end) + +it("returns false if the item was not in the Map", function() + local foo = Map.new() + assert(foo:delete(AN_ITEM) == false) +end) + +it("decrements the size if the item was in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(AN_ITEM) + assert(foo.size == 0) +end) + +it("does not decrement the size if the item was not in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(ANOTHER_ITEM) + assert(foo.size == 1) +end) + +it("deletes value set to false", function() + -- Luau FIXME: Luau insists arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + + foo:delete(AN_ITEM) + + assert(foo.size == 0) + assert(foo:get(AN_ITEM) == nil) +end) +-- #endregion + +-- #region [Child Describe] "has" +it("returns true if the item is in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:has(AN_ITEM) == true) +end) + +it("returns false if the item is not in the Map", function() + local foo = Map.new() + assert(foo:has(AN_ITEM) == false) +end) + +it("returns correctly with value set to false", function() + -- Luau FIXME: Luau insists arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + + assert(foo:has(AN_ITEM) == true) +end) +-- #endregion + +-- #region [Child Describe] "keys / values / entries" +it("returns array of elements", function() + local myMap = Map.new() + myMap:set(AN_ITEM, "foo") + myMap:set(ANOTHER_ITEM, "val") + + assert(myMap:keys()[1] == AN_ITEM) + assert(myMap:keys()[2] == ANOTHER_ITEM) + + assert(myMap:values()[1] == "foo") + assert(myMap:values()[2] == "val") + + assert(myMap:entries()[1][1] == AN_ITEM) + assert(myMap:entries()[1][2] == "foo") + assert(myMap:entries()[2][1] == ANOTHER_ITEM) + assert(myMap:entries()[2][2] == "val") +end) +-- #endregion + +-- #region [Child Describe] "__index" +it("can access fields directly without using get", function() + local typeName = "size" + + local foo = Map.new({ + { AN_ITEM, "foo" }, + { ANOTHER_ITEM, "val" }, + { typeName, "buzz" }, + }) + + assert(foo.size == 3) + assert(foo[AN_ITEM] == "foo") + assert(foo[ANOTHER_ITEM] == "val") + assert(foo:get(typeName) == "buzz") +end) +-- #endregion + +-- #region [Child Describe] "__newindex" +it("can set fields directly without using set", function() + local foo = Map.new() + + assert(foo.size == 0) + + foo[AN_ITEM] = "foo" + foo[ANOTHER_ITEM] = "val" + foo.fizz = "buzz" + + assert(foo.size == 3) + assert(foo:get(AN_ITEM) == "foo") + assert(foo:get(ANOTHER_ITEM) == "val") + assert(foo:get("fizz") == "buzz") +end) +-- #endregion + +-- #region [Child Describe] "ipairs" +local function makeArray(...) + local array = {} + for _, item in ... do + table.insert(array, item) + end + return array +end + +it("iterates on the elements by their insertion order", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + assert(makeArray(foo:ipairs())[1][1] == AN_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "foo") + assert(makeArray(foo:ipairs())[2][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[2][2] == "val") +end) + +it("does not iterate on removed elements", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + foo:delete(AN_ITEM) + assert(makeArray(foo:ipairs())[1][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "val") +end) + +it("iterates on elements if the added back to the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + foo:delete(AN_ITEM) + foo:set(AN_ITEM, "food") + assert(makeArray(foo:ipairs())[1][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "val") + assert(makeArray(foo:ipairs())[2][1] == AN_ITEM) + assert(makeArray(foo:ipairs())[2][2] == "food") +end) +-- #endregion + +-- #region [Child Describe] "Integration Tests" +-- it("MDN Examples", function() +-- local myMap = Map.new() :: Map + +-- local keyString = "a string" +-- local keyObj = {} +-- local keyFunc = function() end + +-- -- setting the values +-- myMap:set(keyString, "value associated with 'a string'") +-- myMap:set(keyObj, "value associated with keyObj") +-- myMap:set(keyFunc, "value associated with keyFunc") + +-- assert(myMap.size == 3) + +-- -- getting the values +-- assert(myMap:get(keyString) == "value associated with 'a string'") +-- assert(myMap:get(keyObj) == "value associated with keyObj") +-- assert(myMap:get(keyFunc) == "value associated with keyFunc") + +-- assert(myMap:get("a string") == "value associated with 'a string'") + +-- assert(myMap:get({}) == nil) -- nil, because keyObj !== {} +-- assert(myMap:get(function() -- nil because keyFunc !== function () {} +-- end) == nil) +-- end) + +it("handles non-traditional keys", function() + local myMap = Map.new() :: Map + + local falseKey = false + local trueKey = true + local negativeKey = -1 + local emptyKey = "" + + myMap:set(falseKey, "apple") + myMap:set(trueKey, "bear") + myMap:set(negativeKey, "corgi") + myMap:set(emptyKey, "doge") + + assert(myMap.size == 4) + + assert(myMap:get(falseKey) == "apple") + assert(myMap:get(trueKey) == "bear") + assert(myMap:get(negativeKey) == "corgi") + assert(myMap:get(emptyKey) == "doge") + + myMap:delete(falseKey) + myMap:delete(trueKey) + myMap:delete(negativeKey) + myMap:delete(emptyKey) + + assert(myMap.size == 0) +end) +-- #endregion + +-- #endregion [Describe] "Map" + +-- #region [Describe] "coerceToMap" +it("returns the same object if instance of Map", function() + local map = Map.new() + assert(coerceToMap(map) == map) + + map = Map.new({}) + assert(coerceToMap(map) == map) + + map = Map.new({ { AN_ITEM, "foo" } }) + assert(coerceToMap(map) == map) +end) +-- #endregion [Describe] "coerceToMap" + +-- #endregion Tests to verify it works as expected diff --git a/bench/tests/sunspider/3d-cube.lua b/bench/tests/sunspider/3d-cube.lua index 5d162ab9..77fa0854 100644 --- a/bench/tests/sunspider/3d-cube.lua +++ b/bench/tests/sunspider/3d-cube.lua @@ -25,7 +25,7 @@ local DisplArea = {} DisplArea.Width = 300; DisplArea.Height = 300; -function DrawLine(From, To) +local function DrawLine(From, To) local x1 = From.V[1]; local x2 = To.V[1]; local y1 = From.V[2]; @@ -81,7 +81,7 @@ function DrawLine(From, To) Q.LastPx = NumPix; end -function CalcCross(V0, V1) +local function CalcCross(V0, V1) local Cross = {}; Cross[1] = V0[2]*V1[3] - V0[3]*V1[2]; Cross[2] = V0[3]*V1[1] - V0[1]*V1[3]; @@ -89,7 +89,7 @@ function CalcCross(V0, V1) return Cross; end -function CalcNormal(V0, V1, V2) +local function CalcNormal(V0, V1, V2) local A = {}; local B = {}; for i = 1,3 do A[i] = V0[i] - V1[i]; @@ -102,14 +102,14 @@ function CalcNormal(V0, V1, V2) return A; end -function CreateP(X,Y,Z) +local function CreateP(X,Y,Z) local result = {} result.V = {X,Y,Z,1}; return result end -- multiplies two matrices -function MMulti(M1, M2) +local function MMulti(M1, M2) local M = {{},{},{},{}}; for i = 1,4 do for j = 1,4 do @@ -120,7 +120,7 @@ function MMulti(M1, M2) end -- multiplies matrix with vector -function VMulti(M, V) +local function VMulti(M, V) local Vect = {}; for i = 1,4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; @@ -128,7 +128,7 @@ function VMulti(M, V) return Vect; end -function VMulti2(M, V) +local function VMulti2(M, V) local Vect = {}; for i = 1,3 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; @@ -137,7 +137,7 @@ function VMulti2(M, V) end -- add to matrices -function MAdd(M1, M2) +local function MAdd(M1, M2) local M = {{},{},{},{}}; for i = 1,4 do for j = 1,4 do @@ -147,7 +147,7 @@ function MAdd(M1, M2) return M; end -function Translate(M, Dx, Dy, Dz) +local function Translate(M, Dx, Dy, Dz) local T = { {1,0,0,Dx}, {0,1,0,Dy}, @@ -157,7 +157,7 @@ function Translate(M, Dx, Dy, Dz) return MMulti(T, M); end -function RotateX(M, Phi) +local function RotateX(M, Phi) local a = Phi; a = a * math.pi / 180; local Cos = math.cos(a); @@ -171,7 +171,7 @@ function RotateX(M, Phi) return MMulti(R, M); end -function RotateY(M, Phi) +local function RotateY(M, Phi) local a = Phi; a = a * math.pi / 180; local Cos = math.cos(a); @@ -185,7 +185,7 @@ function RotateY(M, Phi) return MMulti(R, M); end -function RotateZ(M, Phi) +local function RotateZ(M, Phi) local a = Phi; a = a * math.pi / 180; local Cos = math.cos(a); @@ -199,7 +199,7 @@ function RotateZ(M, Phi) return MMulti(R, M); end -function DrawQube() +local function DrawQube() -- calc current normals local CurN = {}; local i = 5; @@ -245,7 +245,7 @@ function DrawQube() Q.LastPx = 0; end -function Loop() +local function Loop() if (Testing.LoopCount > Testing.LoopMax) then return; end local TestingStr = tostring(Testing.LoopCount); while (#TestingStr < 3) do TestingStr = "0" .. TestingStr; end @@ -265,7 +265,7 @@ function Loop() Loop(); end -function Init(CubeSize) +local function Init(CubeSize) -- init/reset vars Origin.V = {150,150,20,1}; Testing.LoopCount = 0; diff --git a/bench/tests/sunspider/3d-morph.lua b/bench/tests/sunspider/3d-morph.lua index f73f173b..79e91419 100644 --- a/bench/tests/sunspider/3d-morph.lua +++ b/bench/tests/sunspider/3d-morph.lua @@ -31,7 +31,7 @@ local loops = 15 local nx = 120 local nz = 120 -function morph(a, f) +local function morph(a, f) local PI2nx = math.pi * 8/nx local sin = math.sin local f30 = -(50 * sin(f*math.pi*2)) diff --git a/bench/tests/sunspider/3d-raytrace.lua b/bench/tests/sunspider/3d-raytrace.lua index c8f6b5dc..3d5276c7 100644 --- a/bench/tests/sunspider/3d-raytrace.lua +++ b/bench/tests/sunspider/3d-raytrace.lua @@ -28,40 +28,40 @@ function test() local size = 30 -function createVector(x,y,z) +local function createVector(x,y,z) return { x,y,z }; end -function sqrLengthVector(self) +local function sqrLengthVector(self) return self[1] * self[1] + self[2] * self[2] + self[3] * self[3]; end -function lengthVector(self) +local function lengthVector(self) return math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); end -function addVector(self, v) +local function addVector(self, v) self[1] = self[1] + v[1]; self[2] = self[2] + v[2]; self[3] = self[3] + v[3]; return self; end -function subVector(self, v) +local function subVector(self, v) self[1] = self[1] - v[1]; self[2] = self[2] - v[2]; self[3] = self[3] - v[3]; return self; end -function scaleVector(self, scale) +local function scaleVector(self, scale) self[1] = self[1] * scale; self[2] = self[2] * scale; self[3] = self[3] * scale; return self; end -function normaliseVector(self) +local function normaliseVector(self) local len = math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); self[1] = self[1] / len; self[2] = self[2] / len; @@ -69,39 +69,39 @@ function normaliseVector(self) return self; end -function add(v1, v2) +local function add(v1, v2) return { v1[1] + v2[1], v1[2] + v2[2], v1[3] + v2[3] }; end -function sub(v1, v2) +local function sub(v1, v2) return { v1[1] - v2[1], v1[2] - v2[2], v1[3] - v2[3] }; end -function scalev(v1, v2) +local function scalev(v1, v2) return { v1[1] * v2[1], v1[2] * v2[2], v1[3] * v2[3] }; end -function dot(v1, v2) +local function dot(v1, v2) return v1[1] * v2[1] + v1[2] * v2[2] + v1[3] * v2[3]; end -function scale(v, scale) +local function scale(v, scale) return { v[1] * scale, v[2] * scale, v[3] * scale }; end -function cross(v1, v2) +local function cross(v1, v2) return { v1[2] * v2[3] - v1[3] * v2[2], v1[3] * v2[1] - v1[1] * v2[3], v1[1] * v2[2] - v1[2] * v2[1] }; end -function normalise(v) +local function normalise(v) local len = lengthVector(v); return { v[1] / len, v[2] / len, v[3] / len }; end -function transformMatrix(self, v) +local function transformMatrix(self, v) local vals = self; local x = vals[1] * v[1] + vals[2] * v[2] + vals[3] * v[3] + vals[4]; local y = vals[5] * v[1] + vals[6] * v[2] + vals[7] * v[3] + vals[8]; @@ -109,7 +109,7 @@ function transformMatrix(self, v) return { x, y, z }; end -function invertMatrix(self) +local function invertMatrix(self) local temp = {} local tx = -self[4]; local ty = -self[8]; @@ -131,7 +131,7 @@ function invertMatrix(self) end -- Triangle intersection using barycentric coord method -function Triangle(p1, p2, p3) +local function Triangle(p1, p2, p3) local this = {} local edge1 = sub(p3, p1); @@ -205,7 +205,7 @@ function Triangle(p1, p2, p3) return this end -function Scene(a_triangles) +local function Scene(a_triangles) local this = {} this.triangles = a_triangles; this.lights = {}; @@ -302,7 +302,7 @@ local zero = { 0,0,0 }; -- this camera code is from notes i made ages ago, it is from *somewhere* -- i cannot remember where -- that somewhere is -function Camera(origin, lookat, up) +local function Camera(origin, lookat, up) local this = {} local zaxis = normaliseVector(subVector(lookat, origin)); @@ -357,7 +357,7 @@ function Camera(origin, lookat, up) return this end -function raytraceScene() +local function raytraceScene() local startDate = 13154863; local numTriangles = 2 * 6; local triangles = {}; -- numTriangles); @@ -450,7 +450,7 @@ function raytraceScene() return pixels; end -function arrayToCanvasCommands(pixels) +local function arrayToCanvasCommands(pixels) local s = {}; table.insert(s, 'Test\nvar pixels = ['); for y = 0,size-1 do @@ -485,7 +485,7 @@ for (var y = 0; y < size; y++) {\n\ return table.concat(s); end -testOutput = arrayToCanvasCommands(raytraceScene()); +local testOutput = arrayToCanvasCommands(raytraceScene()); --local f = io.output("output.html") --f:write(testOutput) diff --git a/bench/tests/sunspider/access-binary-trees.lua b/bench/tests/sunspider/access-binary-trees.lua deleted file mode 100644 index 9eb93588..00000000 --- a/bench/tests/sunspider/access-binary-trees.lua +++ /dev/null @@ -1,69 +0,0 @@ ---[[ - The Great Computer Language Shootout - http://shootout.alioth.debian.org/ - contributed by Isaac Gouy -]] - -local bench = script and require(script.Parent.bench_support) or require("bench_support") - -function test() - -function TreeNode(left,right,item) - local this = {} - this.left = left; - this.right = right; - this.item = item; - - this.itemCheck = function(self) - if (self.left==nil) then return self.item; - else return self.item + self.left:itemCheck() - self.right:itemCheck(); end - end - - return this -end - -function bottomUpTree(item,depth) - if (depth>0) then - return TreeNode( - bottomUpTree(2*item-1, depth-1) - ,bottomUpTree(2*item, depth-1) - ,item - ); - else - return TreeNode(nil,nil,item); - end -end - -local ret = 0; - -for n = 4,7,1 do - local minDepth = 4; - local maxDepth = math.max(minDepth + 2, n); - local stretchDepth = maxDepth + 1; - - local check = bottomUpTree(0,stretchDepth):itemCheck(); - - local longLivedTree = bottomUpTree(0,maxDepth); - - for depth = minDepth,maxDepth,2 do - local iterations = 2.0 ^ (maxDepth - depth + minDepth - 1) -- 1 << (maxDepth - depth + minDepth); - - check = 0; - for i = 1,iterations do - check = check + bottomUpTree(i,depth):itemCheck(); - check = check + bottomUpTree(-i,depth):itemCheck(); - end - end - - ret = ret + longLivedTree:itemCheck(); -end - -local expected = -4; - -if (ret ~= expected) then - assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. ret); -end - -end - -bench.runCode(test, "access-binary-trees") diff --git a/bench/tests/sunspider/controlflow-recursive.lua b/bench/tests/sunspider/controlflow-recursive.lua index d0791626..a2591b2f 100644 --- a/bench/tests/sunspider/controlflow-recursive.lua +++ b/bench/tests/sunspider/controlflow-recursive.lua @@ -7,18 +7,18 @@ local bench = script and require(script.Parent.bench_support) or require("bench_ function test() -function ack(m,n) +local function ack(m,n) if (m==0) then return n+1; end if (n==0) then return ack(m-1,1); end return ack(m-1, ack(m,n-1) ); end -function fib(n) +local function fib(n) if (n < 2) then return 1; end return fib(n-2) + fib(n-1); end -function tak(x,y,z) +local function tak(x,y,z) if (y >= x) then return z; end return tak(tak(x-1,y,z), tak(y-1,z,x), tak(z-1,x,y)); end @@ -27,7 +27,7 @@ local result = 0; for i = 3,5 do result = result + ack(3,i); - result = result + fib(17.0+i); + result = result + fib(17+i); result = result + tak(3*i+3,2*i+2,i+1); end diff --git a/bench/tests/sunspider/crypto-aes.lua b/bench/tests/sunspider/crypto-aes.lua index 3b289729..8dd0cec6 100644 --- a/bench/tests/sunspider/crypto-aes.lua +++ b/bench/tests/sunspider/crypto-aes.lua @@ -42,7 +42,68 @@ local Rcon = { { 0x00, 0x00, 0x00, 0x00 }, {0x1b, 0x00, 0x00, 0x00}, {0x36, 0x00, 0x00, 0x00} }; -function Cipher(input, w) -- main Cipher function [§5.1] +local function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1] + for r = 0,3 do + for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end + end + return s; +end + + +local function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2] + local t = {}; + for r = 1,3 do + for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy + for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back + end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES): + return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf +end + + +local function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3] + for c = 0,3 do + local a = {}; -- 'a' is a copy of the current column from 's' + local b = {}; -- 'b' is a•{02} in GF(2^8) + for i = 0,3 do + a[i + 1] = s[i + 1][c + 1]; + + if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then + b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b); + else + b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1); + end + end + -- a[n] ^ b[n] is a•{03} in GF(2^8) + s[1][c + 1] = bit32.bxor(b[1], a[2], b[2], a[3], a[4]); -- 2*a0 + 3*a1 + a2 + a3 + s[2][c + 1] = bit32.bxor(a[1], b[2], a[3], b[3], a[4]); -- a0 * 2*a1 + 3*a2 + a3 + s[3][c + 1] = bit32.bxor(a[1], a[2], b[3], a[4], b[4]); -- a0 + a1 + 2*a2 + 3*a3 + s[4][c + 1] = bit32.bxor(a[1], b[1], a[2], a[3], b[4]); -- 3*a0 + a1 + a2 + 2*a3 +end + return s; +end + + +local function SubWord(w) -- apply SBox to 4-byte word w + for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end + return w; +end + +local function RotWord(w) -- rotate 4-byte word w left by one byte + w[5] = w[1]; + for i = 0,3 do w[i + 1] = w[i + 2]; end + return w; +end + + + +local function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4] + for r = 0,3 do + for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end + end + return state; +end + +local function Cipher(input, w) -- main Cipher function [§5.1] local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) local Nr = #w / Nb - 1; -- no of rounds: 10/12/14 for 128/192/256-bit keys @@ -69,56 +130,7 @@ function Cipher(input, w) -- main Cipher function [§5.1] end -function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1] - for r = 0,3 do - for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end - end - return s; -end - - -function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2] - local t = {}; - for r = 1,3 do - for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy - for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back - end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES): - return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf -end - - -function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3] - for c = 0,3 do - local a = {}; -- 'a' is a copy of the current column from 's' - local b = {}; -- 'b' is a•{02} in GF(2^8) - for i = 0,3 do - a[i + 1] = s[i + 1][c + 1]; - - if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then - b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b); - else - b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1); - end - end - -- a[n] ^ b[n] is a•{03} in GF(2^8) - s[1][c + 1] = bit32.bxor(b[1], a[2], b[2], a[3], a[4]); -- 2*a0 + 3*a1 + a2 + a3 - s[2][c + 1] = bit32.bxor(a[1], b[2], a[3], b[3], a[4]); -- a0 * 2*a1 + 3*a2 + a3 - s[3][c + 1] = bit32.bxor(a[1], a[2], b[3], a[4], b[4]); -- a0 + a1 + 2*a2 + 3*a3 - s[4][c + 1] = bit32.bxor(a[1], b[1], a[2], a[3], b[4]); -- 3*a0 + a1 + a2 + 2*a3 -end - return s; -end - - -function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4] - for r = 0,3 do - for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end - end - return state; -end - - -function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2] +local function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2] local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) local Nk = #key / 4 -- key length (in words): 4/6/8 for 128/192/256-bit keys local Nr = Nk + 6; -- no of rounds: 10/12/14 for 128/192/256-bit keys @@ -146,17 +158,17 @@ function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from return w; end -function SubWord(w) -- apply SBox to 4-byte word w - for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end - return w; +local function escCtrlChars(str) -- escape control chars which might cause problems handling ciphertext + return string.gsub(str, "[\0\t\n\v\f\r\'\"!-]", function(c) return '!' .. string.byte(c, 1) .. '!'; end); end -function RotWord(w) -- rotate 4-byte word w left by one byte - w[5] = w[1]; - for i = 0,3 do w[i + 1] = w[i + 2]; end - return w; -end +local function unescCtrlChars(str) -- unescape potentially problematic control characters + return string.gsub(str, "!%d%d?%d?!", function(c) + local sc = string.sub(c, 2,-2) + return string.char(tonumber(sc)); + end); +end --[[ * Use AES to encrypt 'plaintext' with 'password' using 'nBits' key, in 'Counter' mode of operation @@ -166,7 +178,7 @@ end * - cipherblock = plaintext xor outputblock ]] -function AESEncryptCtr(plaintext, password, nBits) +local function AESEncryptCtr(plaintext, password, nBits) if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys -- for this example script, generate the key by applying Cipher to 1st 16/24/32 chars of password; @@ -243,7 +255,7 @@ end * - cipherblock = plaintext xor outputblock ]] -function AESDecryptCtr(ciphertext, password, nBits) +local function AESDecryptCtr(ciphertext, password, nBits) if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys local nBytes = nBits/8; -- no bytes in key @@ -300,19 +312,7 @@ function AESDecryptCtr(ciphertext, password, nBits) return table.concat(plaintext) end -function escCtrlChars(str) -- escape control chars which might cause problems handling ciphertext - return string.gsub(str, "[\0\t\n\v\f\r\'\"!-]", function(c) return '!' .. string.byte(c, 1) .. '!'; end); -end - -function unescCtrlChars(str) -- unescape potentially problematic control characters - return string.gsub(str, "!%d%d?%d?!", function(c) - local sc = string.sub(c, 2,-2) - - return string.char(tonumber(sc)); - end); -end - -function test() +local function test() local plainText = "ROMEO: But, soft! what light through yonder window breaks?\n\ It is the east, and Juliet is the sun.\n\ diff --git a/bench/tests/sunspider/math-cordic.lua b/bench/tests/sunspider/math-cordic.lua index 94a64f45..cdb10fa2 100644 --- a/bench/tests/sunspider/math-cordic.lua +++ b/bench/tests/sunspider/math-cordic.lua @@ -31,15 +31,15 @@ function test() local AG_CONST = 0.6072529350; -function FIXED(X) +local function FIXED(X) return X * 65536.0; end -function FLOAT(X) +local function FLOAT(X) return X / 65536.0; end -function DEG2RAD(X) +local function DEG2RAD(X) return 0.017453 * (X); end @@ -52,7 +52,7 @@ local Angles = { local Target = 28.027; -function cordicsincos(Target) +local function cordicsincos(Target) local X; local Y; local TargetAngle; @@ -85,7 +85,7 @@ end local total = 0; -function cordic( runs ) +local function cordic( runs ) for i = 1,runs do total = total + cordicsincos(Target); end diff --git a/bench/tests/sunspider/math-partial-sums.lua b/bench/tests/sunspider/math-partial-sums.lua index 3c222876..9977ceff 100644 --- a/bench/tests/sunspider/math-partial-sums.lua +++ b/bench/tests/sunspider/math-partial-sums.lua @@ -7,7 +7,7 @@ local bench = script and require(script.Parent.bench_support) or require("bench_ function test() -function partial(n) +local function partial(n) local a1, a2, a3, a4, a5, a6, a7, a8, a9 = 0, 0, 0, 0, 0, 0, 0, 0, 0; local twothirds = 2.0/3.0; local alt = -1.0; diff --git a/bench/tests/sunspider/math-spectral-norm.lua b/bench/tests/sunspider/math-spectral-norm.lua deleted file mode 100644 index 7d7ec163..00000000 --- a/bench/tests/sunspider/math-spectral-norm.lua +++ /dev/null @@ -1,72 +0,0 @@ ---[[ -The Great Computer Language Shootout -http://shootout.alioth.debian.org/ - -contributed by Ian Osgood -]] -local bench = script and require(script.Parent.bench_support) or require("bench_support") - -function test() - -function A(i,j) - return 1/((i+j)*(i+j+1)/2+i+1); -end - -function Au(u,v) - for i = 0,#u-1 do - local t = 0; - for j = 0,#u-1 do - t = t + A(i,j) * u[j + 1]; - end - v[i + 1] = t; - end -end - -function Atu(u,v) - for i = 0,#u-1 do - local t = 0; - for j = 0,#u-1 do - t = t + A(j,i) * u[j + 1]; - end - v[i + 1] = t; - end -end - -function AtAu(u,v,w) - Au(u,w); - Atu(w,v); -end - -function spectralnorm(n) - local u, v, w, vv, vBv = {}, {}, {}, 0, 0; - for i = 1,n do - u[i] = 1; v[i] = 0; w[i] = 0; - end - for i = 0,9 do - AtAu(u,v,w); - AtAu(v,u,w); - end - for i = 1,n do - vBv = vBv + u[i]*v[i]; - vv = vv + v[i]*v[i]; - end - return math.sqrt(vBv/vv); -end - -local total = 0; -local i = 6 - -while i <= 48 do - total = total + spectralnorm(i); - i = i * 2 -end - -local expected = 5.086694231303284; - -if (total ~= expected) then - assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. total) -end - -end - -bench.runCode(test, "math-spectral-norm") diff --git a/docs/_config.yml b/docs/_config.yml index 71308686..33a85609 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -10,7 +10,7 @@ logo: /assets/images/luau-88.png plugins: ["jekyll-include-cache", "jekyll-feed"] include: ["_pages"] atom_feed: - path: feed.xml + path: "/feed.xml" defaults: # _docs diff --git a/docs/_pages/compatibility.md b/docs/_pages/compatibility.md index 00d883e2..d1686c2f 100644 --- a/docs/_pages/compatibility.md +++ b/docs/_pages/compatibility.md @@ -49,30 +49,31 @@ Sandboxing challenges are [covered in the dedicated section](sandbox). |---------|--------|------| | yieldable pcall/xpcall | ✔️ | | | yieldable metamethods | ❌ | significant performance implications | -| ephemeron tables | ❌ | this complicates the garbage collector esp. for large weak tables | -| emergency garbage collector | ❌ | Luau runs in environments where handling memory exhaustion in emergency situations is not tenable | +| ephemeron tables | ❌ | this complicates and slows down the garbage collector esp. for large weak tables | +| emergency garbage collector | 🤷‍ | Luau runs in environments where handling memory exhaustion in emergency situations is not tenable | | goto statement | ❌ | this complicates the compiler, makes control flow unstructured and doesn't address a significant need | | finalizers for tables | ❌ | no `__gc` support due to sandboxing and performance/complexity | | no more fenv for threads or functions | 😞 | we love this, but it breaks compatibility | -| tables honor the `__len` metamethod | ❌ | performance implications, no strong use cases +| tables honor the `__len` metamethod | 🤷‍♀️ | performance implications, no strong use cases | hex and `\z` escapes in strings | ✔️ | | | support for hexadecimal floats | 🤷‍♀️ | no strong use cases | -| order metamethods work for different types | ❌ | no strong use cases and more complicated semantics + compat | +| order metamethods work for different types | ❌ | no strong use cases and more complicated semantics, compatibility and performance implications | | empty statement | 🤷‍♀️ | less useful in Lua than in JS/C#/C/C++ | -| `break` statement may appear in the middle of a block | 🤷‍♀️ | we'd like to do it for return/continue as well but there be dragons | +| `break` statement may appear in the middle of a block | 🤷‍♀️ | we'd like to do it consistently for `break`/`return`/`continue` but there be dragons | | arguments for function called through `xpcall` | ✔️ | | | optional base in `math.log` | ✔️ | | -| optional separator in `string.rep` | 🤷‍♀️ | no real use cases | -| new metamethods `__pairs` and `__ipairs` | ❌ | would like to reevaluate iteration design long term | +| optional separator in `string.rep` | 🤷‍♀️ | no strong use cases | +| new metamethods `__pairs` and `__ipairs` | ❌ | superseded by `__iter` | | frontier patterns | ✔️ | | | `%g` in patterns | ✔️ | | | `\0` in patterns | ✔️ | | | `bit32` library | ✔️ | | | `string.gsub` is stricter about using `%` on special characters only | ✔️ | | +| light C functions | 😞 | this changes semantics of fenv on C functions and has complex implications wrt runtime performance | Two things that are important to call out here are various new metamethods for tables and yielding in metamethods. In both cases, there are performance implications to supporting this - our implementation is *very* highly tuned for performance, so any changes that affect the core fundamentals of how Lua works have a price. To support yielding in metamethods we'd need to make the core of the VM more involved, since almost every single "interesting" opcode would need to learn how to be resumable - which also complicates future JIT/AOT story. Metamethods in general are important for extensibility, but very challenging to deal with in implementation, so we err on the side of not supporting any new metamethods unless a strong need arises. -For `__pairs`/`__ipairs`, we aren't sure that this is the right design choice - self-iterating tables via `__iter` are very appealing, and if we can resolve some challenges with array iteration order, that would make the language more accessible so we may go that route instead. +For `__pairs`/`__ipairs`, we felt that extending library functions to enable custom containers wasn't the right choice. Instead we revisited iteration design to allow for self-iterating objects via `__iter` metamethod, which results in a cleaner iteration design that also makes it easier to iterate over tables. As such, we have no plans to support `__pairs`/`__ipairs` as all use cases for it can now be solved by `__iter`. Ephemeron tables may be implemented at some point since they do have valid uses and they make weak tables semantically cleaner, however the cleanup mechanism for these is expensive and complicated, and as such this can only be considered after the pending GC rework is complete. @@ -82,7 +83,7 @@ Ephemeron tables may be implemented at some point since they do have valid uses |---------|--------|------| | `\u` escapes in strings | ✔️ | | | integers (64-bit by default) | ❌ | backwards compatibility and performance implications | -| bitwise operators | ❌ | `bit32` library covers this | +| bitwise operators | ❌ | `bit32` library covers this in absence of 64-bit integers | | basic utf-8 support | ✔️ | we include `utf8` library and other UTF8 features | | functions for packing and unpacking values (string.pack/unpack/packsize) | ✔️ | | | floor division | ❌ | no strong use cases, syntax overlaps with C comments | @@ -94,16 +95,16 @@ Ephemeron tables may be implemented at some point since they do have valid uses It's important to highlight integer support and bitwise operators. For Luau, it's rare that a full 64-bit integer type is necessary - double-precision types support integers up to 2^53 (in Lua which is used in embedded space, integers may be more appealing in environments without a native 64-bit FPU). However, there's a *lot* of value in having a single number type, both from performance perspective and for consistency. Notably, Lua doesn't handle integer overflow properly, so using integers also carries compatibility implications. -If integers are taken out of the equation, bitwise operators make much less sense; additionally, `bit32` library is more fully featured (includes commonly used operations such as rotates and arithmetic shift; bit extraction/replacement is also more readable). Adding operators along with metamethods for all of them increases complexity, which means this feature isn't worth it on the balance. +If integers are taken out of the equation, bitwise operators make less sense, as integers aren't a first class feature; additionally, `bit32` library is more fully featured (includes commonly used operations such as rotates and arithmetic shift; bit extraction/replacement is also more readable). Adding operators along with metamethods for all of them increases complexity, which means this feature isn't worth it on the balance. Common arguments for this include a more familiar syntax, which, while true, gets more nuanced as `^` isn't available as a xor operator, and arithmetic right shift isn't expressible without yet another operator, and performance, which in Luau is substantially better than in Lua because `bit32` library uses VM builtins instead of expensive function calls. -Floor division is less harmful, but it's used rarely enough that `math.floor(a/b)` seems like an adequate replacement; additionally, `//` is a comment in C-derived languages and we may decide to adopt it in addition to `--` at some point. +Floor division is much less complex, but it's used rarely enough that `math.floor(a/b)` seems like an adequate replacement; additionally, `//` is a comment in C-derived languages and we may decide to adopt it in addition to `--` at some point. ## Lua 5.4 | feature | status | notes | |--|--|--| | new generational mode for garbage collection | 🔜 | we're working on gc optimizations and generational mode is on our radar -| to-be-closed variables | ❌ | the syntax is ugly and inconsistent with how we'd like to do attributes long-term; no strong use cases in our domain | +| to-be-closed variables | ❌ | the syntax is inconsistent with how we'd like to do attributes long-term; no strong use cases in our domain | | const variables | ❌ | while there's some demand for const variables, we'd never adopt this syntax | | new implementation for math.random | ✔️ | our RNG is based on PCG, unlike Lua 5.4 which uses Xoroshiro | | optional `init` argument to `string.gmatch` | 🤷‍♀️ | no strong use cases | @@ -111,14 +112,14 @@ Floor division is less harmful, but it's used rarely enough that `math.floor(a/b | coercions string-to-number moved to the string library | 😞 | we love this, but it breaks compatibility | | new format `%p` in `string.format` | 🤷‍♀️ | no strong use cases | | `utf8` library accepts codepoints up to 2^31 | 🤷‍♀️ | no strong use cases | -| The use of the `__lt` metamethod to emulate `__le` has been removed | 😞 | breaks compatibility and doesn't seem very interesting otherwise | +| The use of the `__lt` metamethod to emulate `__le` has been removed | ❌ | breaks compatibility and complicates comparison overloading story | | When finalizing objects, Lua will call `__gc` metamethods that are not functions | ❌ | no `__gc` support due to sandboxing and performance/complexity | | The function print calls `__tostring` instead of tostring to format its arguments. | ✔️ | | | By default, the decoding functions in the utf8 library do not accept surrogates. | 😞 | breaks compatibility and doesn't seem very interesting otherwise | -Lua has a beautiful syntax and frankly we're disappointed in the ``/`` which takes away from that beauty. Taking syntax aside, `` isn't very useful in Luau - its dominant use case is for code that works with external resources like files or sockets, but we don't provide such APIs - and has a very large complexity cost, evidences by a lot of bug fixes since the initial implementation in 5.4 work versions. `` in Luau doesn't matter for performance - our multi-pass compiler is already able to analyze the usage of the variable to know if it's modified or not and extract all performance gains from it - so the only use here is for code readability, where the `` syntax is... suboptimal. +Taking syntax aside (which doesn't feel idiomatic or beautiful), `` isn't very useful in Luau - its dominant use case is for code that works with external resources like files or sockets, but we don't provide such APIs - and has a very large complexity cost, evidences by a lot of bug fixes since the initial implementation in 5.4 work versions. `` in Luau doesn't matter for performance - our multi-pass compiler is already able to analyze the usage of the variable to know if it's modified or not and extract all performance gains from it - so the only use here is for code readability, where the `` syntax is... suboptimal. -If we do end up introducing const variables, it would be through a `const var = value` syntax, which is backwards compatible through a context-sensitive keyword similar to `type`. +If we do end up introducing const variables, it would be through a `const var = value` syntax, which is backwards compatible through a context-sensitive keyword similar to `type`. That said, there's ambiguity wrt whether `const` should simply behave like a read-only variable, ala JavaScript, or if it should represent a stronger contract, for example by limiting the expressions on the right hand side to ones compiler can evaluate ahead of time, or by freezing table values and thus guaranteeing immutability. ## Differences from Lua diff --git a/docs/_pages/library.md b/docs/_pages/library.md index eeada336..f419d2bf 100644 --- a/docs/_pages/library.md +++ b/docs/_pages/library.md @@ -488,7 +488,7 @@ function string.char(args: ...number): string Returns the string that contains a byte for every input number; all inputs must be integers in `[0..255]` range. ``` -function string.find(s: string, p: string, init: number?, plain: boolean?): (number?, number?) +function string.find(s: string, p: string, init: number?, plain: boolean?): (number?, number?, ...string) ``` Tries to find an instance of pattern `p` in the string `s`, starting from position `init` (defaults to 1). When `plain` is true, the search is using raw case-insensitive string equality, otherwise `p` should be a [string pattern](https://www.lua.org/manual/5.3/manual.html#6.4.1). If a match is found, returns the position of the match and the length of the match, followed by the pattern captures; otherwise returns `nil`. @@ -536,7 +536,7 @@ function string.lower(s: string): string Returns a string where each byte corresponds to the lower-case ASCII version of the input byte in the source string. ``` -function string.match(s: string, p: string, init: number?): (number?, number?) +function string.match(s: string, p: string, init: number?): ...string? ``` Tries to find an instance of pattern `p` in the string `s`, starting from position `init` (defaults to 1). `p` should be a [string pattern](https://www.lua.org/manual/5.3/manual.html#6.4.1). If a match is found, returns all pattern captures, or entire matching substring if no captures are present, otherwise returns `nil`. @@ -647,7 +647,7 @@ All functions in the `bit32` library treat input numbers as 32-bit unsigned inte function bit32.arshift(n: number, i: number): number ``` -Shifts `n` by `i` bits to the right (if `i` is negative, a left shift is performed instead). The most significant bit of `n` is propagated during the shift. +Shifts `n` by `i` bits to the right (if `i` is negative, a left shift is performed instead). The most significant bit of `n` is propagated during the shift. When `i` is larger than 31, returns an integer with all bits set to the sign bit of `n`. When `i` is smaller than `-31`, 0 is returned. ``` function bit32.band(args: ...number): number @@ -695,7 +695,7 @@ Rotates `n` to the left by `i` bits (if `i` is negative, a right rotate is perfo function bit32.lshift(n: number, i: number): number ``` -Shifts `n` to the left by `i` bits (if `i` is negative, a right shift is performed instead). +Shifts `n` to the left by `i` bits (if `i` is negative, a right shift is performed instead). When `i` is outside of `[-31..31]` range, returns 0. ``` function bit32.replace(n: number, r: number, f: number, w: number?): number @@ -713,7 +713,7 @@ Rotates `n` to the right by `i` bits (if `i` is negative, a left rotate is perfo function bit32.rshift(n: number, i: number): number ``` -Shifts `n` to the right by `i` bits (if `i` is negative, a left shift is performed instead). +Shifts `n` to the right by `i` bits (if `i` is negative, a left shift is performed instead). When `i` is outside of `[-31..31]` range, returns 0. ``` function bit32.countlz(n: number): number diff --git a/docs/_pages/performance.md b/docs/_pages/performance.md index b4fd3a7b..34b24b03 100644 --- a/docs/_pages/performance.md +++ b/docs/_pages/performance.md @@ -92,12 +92,22 @@ As a result, builtin calls are very fast in Luau - they are still slightly slowe ## Optimized table iteration -Luau implements a fully generic iteration protocol; however, for iteration through tables it recognizes three common idioms (`for .. in ipairs(t)`, `for .. in pairs(t)` and `for .. in next, t`) and emits specialized bytecode that is carefully optimized using custom internal iterators. +Luau implements a fully generic iteration protocol; however, for iteration through tables in addition to generalized iteration (`for .. in t`) it recognizes three common idioms (`for .. in ipairs(t)`, `for .. in pairs(t)` and `for .. in next, t`) and emits specialized bytecode that is carefully optimized using custom internal iterators. -As a result, iteration through tables typically doesn't result in function calls for every iteration; the performance of iteration using `pairs` and `ipairs` is comparable, so it's recommended to pick the iteration style based on readability instead of performance. +As a result, iteration through tables typically doesn't result in function calls for every iteration; the performance of iteration using generalized iteration, `pairs` and `ipairs` is comparable, so generalized iteration (without the use of `pairs`/`ipairs`) is recommended unless the code needs to be compatible with vanilla Lua or the specific semantics of `ipairs` (which stops at the first `nil` element) is required. Additionally, using generalized iteration avoids calling `pairs` when the loop starts which can be noticeable when the table is very short. Iterating through array-like tables using `for i=1,#t` tends to be slightly slower because of extra cost incurred when reading elements from the table. +## Optimized table length + +Luau tables use a hybrid array/hash storage, like in Lua; in some sense "arrays" don't truly exist and are an internal optimization, but some operations, notably `#t` and functions that depend on it, like `table.insert`, are defined by the Luau/Lua language to allow internal optimizations. Luau takes advantage of that fact. + +Unlike Lua, Luau guarantees that the element at index `#t` is stored in the array part of the table. This can accelerate various table operations that use indices limited by `#t`, and this makes `#t` worst-case complexity O(logN), unlike Lua where the worst case complexity is O(N). This also accelerates computation of this value for small tables like `{ [1] = 1 }` since we never need to look at the hash part. + +The "default" implementation of `#t` in both Lua and Luau is a binary search. Luau uses a special branch-free (depending on the compiler...) implementation of the binary search which results in 50+% faster computation of table length when it needs to be computed from scratch. + +Additionally, Luau can cache the length of the table and adjust it following operations like `table.insert`/`table.remove`; this means that in practice, `#t` is almost always a constant time operation. + ## Creating and modifying tables Luau implements several optimizations for table creation. When creating object-like tables, it's recommended to use table literals (`{ ... }`) and to specify all table fields in the literal in one go instead of assigning fields later; this triggers an optimization inspired by LuaJIT's "table templates" and results in higher performance when creating objects. When creating array-like tables, if the maximum size of the table is known up front, it's recommended to use `table.create` function which can create an empty table with preallocated storage, and optionally fill it with a given value. @@ -112,7 +122,7 @@ v.z = 3 return v ``` -When appending elements to tables, it's recommended to use `table.insert` (which is the fastest method to append an element to a table if the table size is not known). In cases when a table is filled sequentially, however, it's much more efficient to use a known index for insertion - together with preallocating tables using `table.create` this can result in much faster code, for example this is the fastest way to build a table of squares: +When appending elements to tables, it's recommended to use `table.insert` (which is the fastest method to append an element to a table if the table size is not known). In cases when a table is filled sequentially, however, it can be more efficient to use a known index for insertion - together with preallocating tables using `table.create` this can result in much faster code, for example this is the fastest way to build a table of squares: ```lua local t = table.create(N) @@ -175,3 +185,13 @@ While large tables can be a problem for incremental GC in general since currentl The incremental garbage collector in Luau runs three phases for each cycle: mark, atomic and sweep. Mark incrementally traverses all live objects, atomic finishes various operations that need to happen without mutator intervention (see previous section), and sweep traverses all objects in the heap, reclaiming memory used by dead objects and performing minor fixup for live objects. While objects allocated during the mark phase are traversed in the same cycle and thus may get reclaimed, objects allocated during the sweep phase are considered live. Because of this, the faster the sweep phase completes, the less garbage will accumulate; and, of course, the less time sweeping takes the less overhead there is from this phase of garbage collection on the process. Since sweeping traverses the whole heap, we maximize the efficiency of this traversal by allocating garbage-collected objects of the same size in 16 KB pages, and traversing each page at a time, which is otherwise known as a paged sweeper. This ensures good locality of reference as consecutively swept objects are contiugous in memory, and allows us to spend no memory for each object on sweep-related data or allocation metadata, since paged sweeper doesn't need to be able to free objects without knowing which page they are in. Compared to linked list based sweeping that Lua/LuaJIT implement, paged sweeper is 2-3x faster, and saves 16 bytes per object on 64-bit platforms. + +## Function inlining and loop unrolling + +By default, the bytecode compiler performs a series of optimizations that result in faster execution of the code, but they preserve both execution semantics and debuggability. For example, a function call is compiled as a function call, which may be observable via `debug.traceback`; a loop is compiled as a loop, which may be observable via `lua_getlocal`. To help improve performance in cases where these restrictions can be relaxed, the bytecode compiler implements additional optimizations when optimization level 2 is enabled (which requires using `-O2` switch when using Luau CLI), namely function inlining and loop unrolling. + +Only loops with loop bounds known at compile time, such as `for i=1,4 do`, can be unrolled. The loop body must be simple enough for the optimization to be profitable; compiler uses heuristics to estimate the performance benefit and automatically decide if unrolling should be performed. + +Only local functions (defined either as `local function foo` or `local foo = function`) can be inlined. The function body must be simple enough for the optimization to be profitable; compiler uses heuristics to estimate the performance benefit and automatically decide if each call to the function should be inlined instead. Additionally recursive invocations of a function can’t be inlined at this time, and inlining is completely disabled for modules that use `getfenv`/`setfenv` functions. + +In both cases, in addition to removing the overhead associated with function calls or loop iteration, these optimizations can additionally benefit by enabling additional optimizations, such as constant folding of expressions dependent on loop iteration variable or constant function arguments, or using more efficient instructions for certain expressions when the inputs to these instructions are constants. diff --git a/docs/_pages/sandbox.md b/docs/_pages/sandbox.md index 409a0929..a7ed7476 100644 --- a/docs/_pages/sandbox.md +++ b/docs/_pages/sandbox.md @@ -4,11 +4,11 @@ title: Sandboxing toc: true --- -Luau is safe to embed. Broadly speaking, this means that even in the face of untrusted (and in Roblox case, actively malicious) code, the language and the standard library don't allow any unsafe access to the underlying system, and don't have any bugs that allow escaping out of the sandbox (e.g. to gain native code execution through ROP gadgets et al). Additionally, the VM provides extra features to implement isolation of privileged code from unprivileged code and protect one from the other; this is important if the embedding environment (Roblox) decides to expose some APIs that may not be safe to call from untrusted code, for example because they do provide controlled access to the underlying system or risk PII exposure through fingerprinting etc. +Luau is safe to embed. Broadly speaking, this means that even in the face of untrusted (and in Roblox case, actively malicious) code, the language and the standard library don't allow unsafe access to the underlying system, and don't have known bugs that allow escaping out of the sandbox (e.g. to gain native code execution through ROP gadgets et al). Additionally, the VM provides extra features to implement isolation of privileged code from unprivileged code and protect one from the other; this is important if the embedding environment decides to expose some APIs that may not be safe to call from untrusted code, for example because they do provide controlled access to the underlying system or risk PII exposure through fingerprinting etc. This safety is achieved through a combination of removing features from the standard library that are unsafe, adding features to the VM that make it possible to implement sandboxing and isolation, and making sure the implementation is safe from memory safety issues using fuzzing. -Of course, since the entire stack is implemented in C++, the sandboxing isn't formally proven - in theory, compiler or the standard library can have exploitable vulnerabilities. In practice these are usually found and fixed quickly. While implementing the stack in a safer language such as Rust would make it easier to provide these guarantees, to our knowledge (based on prior art) this would make it difficult to reach the level of performance required. +Of course, since the entire stack is implemented in C++, the sandboxing isn't formally proven - in theory, compiler or the standard library can have exploitable vulnerabilities. In practice these are very rare and usually found and fixed quickly. While implementing the stack in a safer language such as Rust would make it easier to provide these guarantees, to our knowledge (based on prior art) this would make it difficult to reach the level of performance required. ## Library @@ -19,7 +19,7 @@ The following libraries and global functions have been removed as a result: - `io.` library has been removed entirely, as it gives access to files and allows running processes - `package.` library has been removed entirely, as it gives access to files and allows loading native modules - `os.` library has been cleaned up from file and environment access functions (`execute`, `exit`, etc.). The only supported functions in the library are `clock`, `date`, `difftime` and `time`. -- `debug.` library has been removed to a large extent, as it has functions that aren't memory safe and other functions break isolation; the only supported functions are `traceback` ~~and `getinfo` (with reduced functionality)~~. +- `debug.` library has been removed to a large extent, as it has functions that aren't memory safe and other functions break isolation; the only supported functions are `traceback` and `info` (which is similar to `debug.getinfo` but has a slightly different interface). - `dofile` and `loadfile` allowed access to file system and have been removed. To achieve memory safety, access to function bytecode has been removed. Bytecode is hard to validate and using untrusted bytecode may lead to exploits. Thus, `loadstring` doesn't work with bytecode inputs, and `string.dump`/`load` have been removed as they aren't necessary anymore. When embedding Luau, bytecode should be encrypted/signed to prevent MITM attacks as well, as the VM assumes that the bytecode was generated by the Luau compiler (which never produces invalid/unsafe bytecode). @@ -54,7 +54,7 @@ This mechanism is bad for performance, memory safety and isolation: - In Lua 5.1, `__gc` support requires traversing userdata lists redundantly during garbage collection to filter out finalizable objects - In later versions of Lua, userdata that implement `__gc` are split into separate lists; however, finalization prolongs the lifetime of the finalized objects which results in less prompt memory reclamation, and two-step destruction results in extra cache misses for userdata -- `__gc` runs during garbage collection in context of an arbitrary thread which makes the thread identity mechanism described above invalid +- `__gc` runs during garbage collection in context of an arbitrary thread which makes the thread identity mechanism used in Roblox to support trusted Luau code invalid - Objects can be removed from weak tables *after* being finalized, which means that accessing these objects can result in memory safety bugs, unless all exposed userdata methods guard against use-after-gc. - If `__gc` method ever leaks to scripts, they can call it directly on an object and use any method exposed by that object after that. This means that `__gc` and all other exposed methods must support memory safety when called on a destroyed object. diff --git a/docs/_pages/syntax.md b/docs/_pages/syntax.md index 4d39e462..fe825fda 100644 --- a/docs/_pages/syntax.md +++ b/docs/_pages/syntax.md @@ -196,3 +196,26 @@ local sign = if x < 0 then -1 elseif x > 0 then 1 else 0 ``` **Note:** In Luau, the `if-then-else` expression is preferred vs the standard Lua idiom of writing `a and b or c` (which roughly simulates a ternary operator). However, the Lua idiom may return an unexpected result if `b` evaluates to false. The `if-then-else` expression will behave as expected in all situations. + +## Generalized iteration + +Luau uses the standard Lua syntax for iterating through containers, `for vars in values`, but extends the semantics with support for generalized iteration. In Lua, to iterate over a table you need to use an iterator like `next` or a function that returns one like `pairs` or `ipairs`. In Luau, you can simply iterate over a table: + +```lua +for k, v in {1, 4, 9} do + assert(k * k == v) +end +``` + +This works for tables but can also be extended for tables or userdata by implementing `__iter` metamethod that is called before the iteration begins, and should return an iterator function like `next` (or a custom one): + +```lua +local obj = { items = {1, 4, 9} } +setmetatable(obj, { __iter = function(o) return next, o.items end }) + +for k, v in obj do + assert(k * k == v) +end +``` + +The default iteration order for tables is specified to be consecutive for elements `1..#t` and unordered after that, visiting every element; similarly to iteration using `pairs`, modifying the table entries for keys other than the current one results in unspecified behavior. diff --git a/docs/_pages/typecheck.md b/docs/_pages/typecheck.md index 3580d66e..63e4c8bb 100644 --- a/docs/_pages/typecheck.md +++ b/docs/_pages/typecheck.md @@ -31,20 +31,6 @@ foo = 1 However, given the second snippet in strict mode, the type checker would be able to infer `number` for `foo`. -## Unknown symbols - -Consider how often you're likely to assign a new value to a local variable. What if you accidentally misspelled it? Oops, it's now assigned globally and your local variable is still using the old value. - -```lua -local someLocal = 1 - -soeLocal = 2 -- the bug - -print(someLocal) -``` - -Because of this, Luau type checker currently emits an error in strict mode; use local variables instead. - ## Structural type system Luau's type system is structural by default, which is to say that we inspect the shape of two tables to see if they are similar enough. This was the obvious choice because Lua 5.1 is inherently structural. @@ -151,6 +137,15 @@ local t = {x = 1} -- {x: number} t.y = 2 -- not ok ``` +Sealed tables support *width subtyping*, which allows a table with more properties to be used as a table with fewer + +```lua +type Point1D = { x : number } +type Point2D = { x : number, y : number } +local p : Point2D = { x = 5, y = 37 } +local q : Point1D = p -- ok because Point2D has more properties than Point1D +``` + ### Generic tables This typically occurs when the symbol does not have any annotated types or were not inferred anything concrete. In this case, when you index on a parameter, you're requesting that there is a table with a matching interface. @@ -267,6 +262,23 @@ Note: it's impossible to create an intersection type of some primitive types, e. Note: Luau still does not support user-defined overloaded functions. Some of Roblox and Lua 5.1 functions have different function signature, so inherently requires overloaded functions. +## Singleton types (aka literal types) + +Luau's type system also supports singleton types, which means it's a type that represents one single value at runtime. At this time, both string and booleans are representable in types. + +> We do not currently support numbers as types. For now, this is intentional. + +```lua +local foo: "Foo" = "Foo" -- ok +local bar: "Bar" = foo -- not ok +local baz: string = foo -- ok + +local t: true = true -- ok +local f: false = false -- ok +``` + +This happens all the time, especially through [type refinements](#type-refinements) and is also incredibly useful when you want to enforce program invariants in the type system! See [tagged unions](#tagged-unions) for more information. + ## Variadic types Luau permits assigning a type to the `...` variadic symbol like any other parameter: @@ -375,22 +387,42 @@ local account: Account = Account.new("Alexander", 500) --^^^^^^^ not ok, 'Account' does not exist ``` +## Tagged unions + +Tagged unions are just union types! In particular, they're union types of tables where they have at least _some_ common properties but the structure of the tables are different enough. Here's one example: + +```lua +type Ok = { type: "ok", value: T } +type Err = { type: "err", error: E } +type Result = Ok | Err +``` + +This `Result` type can be discriminated by using type refinements on the property `type`, like so: + +```lua +if result.type == "ok" then + -- result is known to be Ok + -- and attempting to index for error here will fail + print(result.value) +elseif result.type == "err" then + -- result is known to be Err + -- and attempting to index for value here will fail + print(result.error) +end +``` + +Which works out because `value: T` exists only when `type` is in actual fact `"ok"`, and `error: E` exists only when `type` is in actual fact `"err"`. + ## Type refinements -When we check the type of a value, what we're doing is we're refining the type, hence "type refinement." Currently, the support for this is somewhat basic. +When we check the type of any lvalue (a global, a local, or a property), what we're doing is we're refining the type, hence "type refinement." The support for this is arbitrarily complex, so go crazy! -Using `type` comparison: -```lua -local stringOrNumber: string | number = "foo" +Here are all the ways you can refine: +1. Truthy test: `if x then` will refine `x` to be truthy. +2. Type guards: `if type(x) == "number" then` will refine `x` to be `number`. +3. Equality: `x == "hello"` will refine `x` to be a singleton type `"hello"`. -if type(x) == "string" then - local onlyString: string = stringOrNumber -- ok - local onlyNumber: number = stringOrNumber -- not ok -end - -local onlyString: string = stringOrNumber -- not ok -local onlyNumber: number = stringOrNumber -- not ok -``` +And they can be composed with many of `and`/`or`/`not`. `not`, just like `~=`, will flip the resulting refinements, that is `not x` will refine `x` to be falsy. Using truthy test: ```lua @@ -398,10 +430,55 @@ local maybeString: string? = nil if maybeString then local onlyString: string = maybeString -- ok + local onlyNil: nil = maybeString -- not ok +end + +if not maybeString then + local onlyString: string = maybeString -- not ok + local onlyNil: nil = maybeString -- ok end ``` -And using `assert` will work with the above type guards: +Using `type` test: +```lua +local stringOrNumber: string | number = "foo" + +if type(stringOrNumber) == "string" then + local onlyString: string = stringOrNumber -- ok + local onlyNumber: number = stringOrNumber -- not ok +end + +if type(stringOrNumber) ~= "string" then + local onlyString: string = stringOrNumber -- not ok + local onlyNumber: number = stringOrNumber -- ok +end +``` + +Using equality test: +```lua +local myString: string = f() + +if myString == "hello" then + local hello: "hello" = myString -- ok because it is absolutely "hello"! + local copy: string = myString -- ok +end +``` + +And as said earlier, we can compose as many of `and`/`or`/`not` as we wish with these refinements: +```lua +local function f(x: any, y: any) + if (x == "hello" or x == "bye") and type(y) == "string" then + -- x is of type "hello" | "bye" + -- y is of type string + end + + if not (x ~= "hi") then + -- x is of type "hi" + end +end +``` + +`assert` can also be used to refine in all the same ways: ```lua local stringOrNumber: string | number = "foo" @@ -411,7 +488,7 @@ local onlyString: string = stringOrNumber -- ok local onlyNumber: number = stringOrNumber -- not ok ``` -## Typecasts +## Type casts Expressions may be typecast using `::`. Typecasting is useful for specifying the type of an expression when the automatically inferred type is too generic. @@ -487,4 +564,4 @@ There are some caveats here though. For instance, the require path must be resol Cyclic module dependencies can cause problems for the type checker. In order to break a module dependency cycle a typecast of the module to `any` may be used: ```lua local myModule = require(MyModule) :: any -``` \ No newline at end of file +``` diff --git a/docs/_posts/2022-03-31-luau-recap-march-2022.md b/docs/_posts/2022-03-31-luau-recap-march-2022.md new file mode 100644 index 00000000..ff3a4d0f --- /dev/null +++ b/docs/_posts/2022-03-31-luau-recap-march-2022.md @@ -0,0 +1,109 @@ +--- +layout: single +title: "Luau Recap: March 2022" +--- + +Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). + +[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-march-2022/).] + +## Singleton types + +We added support for singleton types! These allow you to use string or +boolean literals in types. These types are only inhabited by the +literal, for example if a variable `x` has type `"foo"`, then `x == +"foo"` is guaranteed to be true. + +Singleton types are particularly useful when combined with union types, +for example: + +```lua +type Animals = "Dog" | "Cat" | "Bird" +``` + +or: + +```lua +type Falsey = false | nil +``` + +In particular, singleton types play well with unions of tables, +allowing tagged unions (also known as discriminated unions): + +```lua +type Ok = { type: "ok", value: T } +type Err = { type: "error", error: E } +type Result = Ok | Err + +local result: Result = ... +if result.type == "ok" then + -- result :: Ok + print(result.value) +elseif result.type == "error" then + -- result :: Err + error(result.error) +end +``` + +The RFC for singleton types is https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md + +## Width subtyping + +A common idiom for programming with tables is to provide a public interface type, but to keep some of the concrete implementation private, for example: + +```lua +type Interface = { + name: string, +} + +type Concrete = { + name: string, + id: number, +} +``` + +Within a module, a developer might use the concrete type, but export functions using the interface type: + +```lua +local x: Concrete = { + name = "foo", + id = 123, +} + +local function get(): Interface + return x +end +``` + +Previously examples like this did not typecheck but now they do! + +This language feature is called *width subtyping* (it allows tables to get *wider*, that is to have more properties). + +The RFC for width subtyping is https://github.com/Roblox/luau/blob/master/rfcs/sealed-table-subtyping.md + +## Typechecking improvements + + * Generic function type inference now works the same for generic types and generic type packs. + * We improved some error messages. + * There are now fewer crashes (hopefully none!) due to mutating types inside the Luau typechecker. + * We fixed a bug that could cause two incompatible copies of the same class to be created. + * Luau now copes better with cyclic metatable types (it gives a type error rather than hanging). + * Fixed a case where types are not properly bound to all of the subtype when the subtype is a union. + * We fixed a bug that confused union and intersection types of table properties. + * Functions declared as `function f(x : any)` can now be called as `f()` without a type error. + +## API improvements + + * Implement `table.clone` which takes a table and returns a new table that has the same keys/values/metatable. The cloning is shallow - if some keys refer to tables that need to be cloned, that can be done manually by modifying the resulting table. + +## Debugger improvements + + * Use the property name as the name of methods in the debugger. + +## Performance improvements + + * Optimize table rehashing (~15% faster dictionary table resize on average) + * Improve performance of freeing tables (~5% lift on some GC benchmarks) + * Improve gathering performance metrics for GC. + * Reduce stack memory reallocation. + diff --git a/docs/_posts/2022-05-02-luau-recap-april-2022.md b/docs/_posts/2022-05-02-luau-recap-april-2022.md new file mode 100644 index 00000000..dd6b2c0c --- /dev/null +++ b/docs/_posts/2022-05-02-luau-recap-april-2022.md @@ -0,0 +1,51 @@ +--- +layout: single +title: "Luau Recap: April 2022" +--- + +Luau is our new language that you can read more about at [https://luau-lang.org](https://luau-lang.org). + +[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-april-2022/).] + +It's been a bit of a quiet month. We mostly have small optimizations and bugfixes for you. + +It is now allowed to define functions on sealed tables that have string indexers. These functions will be typechecked against the indexer type. For example, the following is now valid: + +```lua +local a : {[string]: () -> number} = {} + +function a.y() return 4 end -- OK +``` + +Autocomplete will now provide string literal suggestions for singleton types. eg + +```lua +local function f(x: "a" | "b") end +f("_") -- suggest "a" and "b" +``` + +Improve error recovery in the case where we encounter a type pack variable in a place where one is not allowed. eg `type Foo = { value: A... }` + +When code does not pass enough arguments to a variadic function, the error feedback is now better. + +For example, the following script now produces a much nicer error message: +```lua +type A = { [number]: number } +type B = { [number]: string } + +local a: A = { 1, 2, 3 } + +-- ERROR: Type 'A' could not be converted into 'B' +-- caused by: +-- Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string' +local b: B = a +``` + +If the following code were to error because `Hello` was undefined, we would erroneously include the comment in the span of the error. This is now fixed. +```lua +type Foo = Hello -- some comment over here +``` + +Fix a crash that could occur when strict scripts have cyclic require() dependencies. + +Add an option to autocomplete to cause it to abort processing after a certain amount of time has elapsed. diff --git a/docs/_posts/2022-06-01-luau-recap-may-2022.md b/docs/_posts/2022-06-01-luau-recap-may-2022.md new file mode 100644 index 00000000..500e6e4a --- /dev/null +++ b/docs/_posts/2022-06-01-luau-recap-may-2022.md @@ -0,0 +1,97 @@ +--- +layout: single +title: "Luau Recap: May 2022" +--- + +This month Luau team has worked to bring you a new language feature together with more typechecking improvements and bugfixes! + +[Cross-posted to the [Roblox Developer Forum](https://devforum.roblox.com/t/luau-recap-may-2022/).] + +## Generalized iteration + +We have extended the semantics of standard Lua syntax for iterating through containers, `for vars in values` with support for generalized iteration. +In Lua, to iterate over a table you need to use an iterator like `next` or a function that returns one like `pairs` or `ipairs`. In Luau, you can now simply iterate over a table: + +```lua +for k, v in {1, 4, 9} do + assert(k * k == v) +end +``` + +This works for tables but can also be customized for tables or userdata by implementing `__iter` metamethod. It is called before the iteration begins, and should return an iterator function like `next` (or a custom one): + +```lua +local obj = { items = {1, 4, 9} } +setmetatable(obj, { __iter = function(o) return next, o.items end }) + +for k, v in obj do + assert(k * k == v) +end +``` + +The default iteration order for tables is specified to be consecutive for elements `1..#t` and unordered after that, visiting every element. +Similar to iteration using `pairs`, modifying the table entries for keys other than the current one results in unspecified behavior. + +## Typechecking improvements + +We have added a missing check to compare implicit table keys against the key type of the table indexer: + +```lua +-- error is correctly reported, implicit keys (1,2,3) are not compatible with [string] +local t: { [string]: boolean } = { true, true, false } +``` + +Rules for `==` and `~=` have been relaxed for union types, if any of the union parts can be compared, operation succeeds: + +```lua +--!strict +local function compare(v1: Vector3, v2: Vector3?) + return v1 == v2 -- no longer an error +end +``` + +Table value type propagation now correctly works with `[any]` key type: + +```lua +--!strict +type X = {[any]: string | boolean} +local x: X = { key = "str" } -- no longer gives an incorrect error +``` + +If a generic function doesn't provide type annotations for all arguments and the return value, additional generic type parameters might be added automatically: + +```lua +-- previously it was foo, now it's foo, because second argument is also generic +function foo(x: T, y) end +``` + +We have also fixed various issues that have caused crashes, with many of them coming from your bug reports. + +## Linter improvements + +`GlobalUsedAsLocal` lint warning has been extended to notice when global variable writes always happen before their use in a local scope, suggesting that they can be replaced with a local variable: + +```lua +function bar() + foo = 6 -- Global 'foo' is never read before being written. Consider changing it to local + return foo +end +function baz() + foo = 10 + return foo +end +``` + +## Performance improvements + +Garbage collection CPU utilization has been tuned to further reduce frame time spikes of individual collection steps and to bring different GC stages to the same level of CPU utilization. + +Returning a type-cast local (`return a :: type`) as well as returning multiple local variables (`return a, b, c`) is now a little bit more efficient. + +### Function inlining and loop unrolling + +In the open-source release of Luau, when optimization level 2 is enabled, the compiler will now perform function inlining and loop unrolling. + +Only loops with loop bounds known at compile time, such as `for i=1,4 do`, can be unrolled. The loop body must be simple enough for the optimization to be profitable; compiler uses heuristics to estimate the performance benefit and automatically decide if unrolling should be performed. + +Only local functions (defined either as `local function foo` or `local foo = function`) can be inlined. The function body must be simple enough for the optimization to be profitable; compiler uses heuristics to estimate the performance benefit and automatically decide if each call to the function should be inlined instead. Additionally recursive invocations of a function can't be inlined at this time, and inlining is completely disabled for modules that use `getfenv`/`setfenv` functions. diff --git a/extern/isocline/src/bbcode.c b/extern/isocline/src/bbcode.c index 4d11ac38..8722cbd6 100644 --- a/extern/isocline/src/bbcode.c +++ b/extern/isocline/src/bbcode.c @@ -575,6 +575,7 @@ ic_private const char* parse_tag_value( tag_t* tag, char* idbuf, const char* s, } // limit name and attr to 128 bytes char valbuf[128]; + valbuf[0] = 0; // fixes gcc uninitialized warning ic_strncpy( idbuf, 128, id, idend - id); ic_strncpy( valbuf, 128, val, valend - val); ic_str_tolower(idbuf); diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 1022831b..22483f9e 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -103,7 +103,7 @@ int registerTypes(Luau::TypeChecker& env) // Vector3 stub TypeId vector3MetaType = arena.addType(TableTypeVar{}); - TypeId vector3InstanceType = arena.addType(ClassTypeVar{"Vector3", {}, nullopt, vector3MetaType, {}, {}}); + TypeId vector3InstanceType = arena.addType(ClassTypeVar{"Vector3", {}, nullopt, vector3MetaType, {}, {}, "Test"}); getMutable(vector3InstanceType)->props = { {"X", {env.numberType}}, {"Y", {env.numberType}}, @@ -117,7 +117,7 @@ int registerTypes(Luau::TypeChecker& env) env.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; // Instance stub - TypeId instanceType = arena.addType(ClassTypeVar{"Instance", {}, nullopt, nullopt, {}, {}}); + TypeId instanceType = arena.addType(ClassTypeVar{"Instance", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(instanceType)->props = { {"Name", {env.stringType}}, }; @@ -125,7 +125,7 @@ int registerTypes(Luau::TypeChecker& env) env.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; // Part stub - TypeId partType = arena.addType(ClassTypeVar{"Part", {}, instanceType, nullopt, {}, {}}); + TypeId partType = arena.addType(ClassTypeVar{"Part", {}, instanceType, nullopt, {}, {}, "Test"}); getMutable(partType)->props = { {"Position", {vector3InstanceType}}, }; @@ -137,6 +137,21 @@ int registerTypes(Luau::TypeChecker& env) return 0; } + +static void setupFrontend(Luau::Frontend& frontend) +{ + registerTypes(frontend.typeChecker); + Luau::freeze(frontend.typeChecker.globalTypes); + + registerTypes(frontend.typeCheckerForAutocomplete); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + + frontend.iceHandler.onInternalError = [](const char* error) { + printf("ICE: %s\n", error); + LUAU_ASSERT(!"ICE"); + }; +} + struct FuzzFileResolver : Luau::FileResolver { std::optional readSource(const Luau::ModuleName& name) override @@ -173,7 +188,7 @@ struct FuzzConfigResolver : Luau::ConfigResolver { FuzzConfigResolver() { - defaultConfig.mode = Luau::Mode::Nonstrict; // typecheckTwice option will cover Strict mode + defaultConfig.mode = Luau::Mode::Nonstrict; defaultConfig.enabledLint.warningMask = ~0ull; defaultConfig.parseOptions.captureComments = true; } @@ -238,19 +253,11 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) if (kFuzzTypeck) { static FuzzFileResolver fileResolver; - static Luau::NullConfigResolver configResolver; + static FuzzConfigResolver configResolver; static Luau::FrontendOptions options{true, true}; static Luau::Frontend frontend(&fileResolver, &configResolver, options); - static int once = registerTypes(frontend.typeChecker); - (void)once; - static int once2 = (Luau::freeze(frontend.typeChecker.globalTypes), 0); - (void)once2; - - frontend.iceHandler.onInternalError = [](const char* error) { - printf("ICE: %s\n", error); - LUAU_ASSERT(!"ICE"); - }; + static int once = (setupFrontend(frontend), 0); // restart frontend.clear(); @@ -275,6 +282,11 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) // lint (note that we need access to types so we need to do this with typeck in scope) if (kFuzzLinter && result.errors.empty()) frontend.lint(name, std::nullopt); + + // Second pass in strict mode (forced by auto-complete) + Luau::FrontendOptions opts; + opts.forAutocomplete = true; + frontend.check(name, opts); } catch (std::exception&) { diff --git a/prototyping/Luau/ResolveOverloads.agda b/prototyping/Luau/ResolveOverloads.agda new file mode 100644 index 00000000..67175176 --- /dev/null +++ b/prototyping/Luau/ResolveOverloads.agda @@ -0,0 +1,98 @@ +{-# OPTIONS --rewriting #-} + +module Luau.ResolveOverloads where + +open import FFI.Data.Either using (Left; Right) +open import Luau.Subtyping using (_<:_; _≮:_; Language; witness; scalar; unknown; never; function-ok) +open import Luau.Type using (Type ; _⇒_; _∩_; _∪_; unknown; never) +open import Luau.TypeSaturation using (saturate) +open import Luau.TypeNormalization using (normalize) +open import Properties.Contradiction using (CONTRADICTION) +open import Properties.DecSubtyping using (dec-subtyping; dec-subtypingⁿ; <:-impl-<:ᵒ) +open import Properties.Functions using (_∘_) +open import Properties.Subtyping using (<:-refl; <:-trans; <:-trans-≮:; ≮:-trans-<:; <:-∩-left; <:-∩-right; <:-∩-glb; <:-impl-¬≮:; <:-unknown; <:-function; function-≮:-never; <:-never; unknown-≮:-function; scalar-≮:-function; ≮:-∪-right; scalar-≮:-never; <:-∪-left; <:-∪-right) +open import Properties.TypeNormalization using (Normal; FunType; normal; _⇒_; _∩_; _∪_; never; unknown; <:-normalize; normalize-<:; fun-≮:-never; unknown-≮:-fun; scalar-≮:-fun) +open import Properties.TypeSaturation using (Overloads; Saturated; _⊆ᵒ_; _<:ᵒ_; normal-saturate; saturated; <:-saturate; saturate-<:; defn; here; left; right) + +-- The domain of a normalized type +srcⁿ : Type → Type +srcⁿ (S ⇒ T) = S +srcⁿ (S ∩ T) = srcⁿ S ∪ srcⁿ T +srcⁿ never = unknown +srcⁿ T = never + +-- To get the domain of a type, we normalize it first We need to do +-- this, since if we try to use it on non-normalized types, we get +-- +-- src(number ∩ string) = src(number) ∪ src(string) = never ∪ never +-- src(never) = unknown +-- +-- so src doesn't respect type equivalence. +src : Type → Type +src (S ⇒ T) = S +src T = srcⁿ(normalize T) + +-- Calculate the result of applying a function type `F` to an argument type `V`. +-- We do this by finding an overload of `F` that has the most precise type, +-- that is an overload `(Sʳ ⇒ Tʳ)` where `V <: Sʳ` and moreover +-- for any other such overload `(S ⇒ T)` we have that `Tʳ <: T`. + +-- For example if `F` is `(number -> number) & (nil -> nil) & (number? -> number?)` +-- then to resolve `F` with argument type `number`, we pick the `number -> number` +-- overload, but if the argument is `number?`, we pick `number? -> number?`./ + +-- Not all types have such a most precise overload, but saturated ones do. + +data ResolvedTo F G V : Set where + + yes : ∀ Sʳ Tʳ → + + Overloads F (Sʳ ⇒ Tʳ) → + (V <: Sʳ) → + (∀ {S T} → Overloads G (S ⇒ T) → (V <: S) → (Tʳ <: T)) → + -------------------------------------------- + ResolvedTo F G V + + no : + + (∀ {S T} → Overloads G (S ⇒ T) → (V ≮: S)) → + -------------------------------------------- + ResolvedTo F G V + +Resolved : Type → Type → Set +Resolved F V = ResolvedTo F F V + +target : ∀ {F V} → Resolved F V → Type +target (yes _ T _ _ _) = T +target (no _) = unknown + +-- We can resolve any saturated function type +resolveˢ : ∀ {F G V} → FunType G → Saturated F → Normal V → (G ⊆ᵒ F) → ResolvedTo F G V +resolveˢ (Sⁿ ⇒ Tⁿ) (defn sat-∩ sat-∪) Vⁿ G⊆F with dec-subtypingⁿ Vⁿ Sⁿ +resolveˢ (Sⁿ ⇒ Tⁿ) (defn sat-∩ sat-∪) Vⁿ G⊆F | Left V≮:S = no (λ { here → V≮:S }) +resolveˢ (Sⁿ ⇒ Tⁿ) (defn sat-∩ sat-∪) Vⁿ G⊆F | Right V<:S = yes _ _ (G⊆F here) V<:S (λ { here _ → <:-refl }) +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F with resolveˢ Gᶠ (defn sat-∩ sat-∪) Vⁿ (G⊆F ∘ left) | resolveˢ Hᶠ (defn sat-∩ sat-∪) Vⁿ (G⊆F ∘ right) +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F | yes S₁ T₁ o₁ V<:S₁ tgt₁ | yes S₂ T₂ o₂ V<:S₂ tgt₂ with sat-∩ o₁ o₂ +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F | yes S₁ T₁ o₁ V<:S₁ tgt₁ | yes S₂ T₂ o₂ V<:S₂ tgt₂ | defn o p₁ p₂ = + yes _ _ o (<:-trans (<:-∩-glb V<:S₁ V<:S₂) p₁) (λ { (left o) p → <:-trans p₂ (<:-trans <:-∩-left (tgt₁ o p)) ; (right o) p → <:-trans p₂ (<:-trans <:-∩-right (tgt₂ o p)) }) +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F | yes S₁ T₁ o₁ V<:S₁ tgt₁ | no src₂ = + yes _ _ o₁ V<:S₁ (λ { (left o) p → tgt₁ o p ; (right o) p → CONTRADICTION (<:-impl-¬≮: p (src₂ o)) }) +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F | no src₁ | yes S₂ T₂ o₂ V<:S₂ tgt₂ = + yes _ _ o₂ V<:S₂ (λ { (left o) p → CONTRADICTION (<:-impl-¬≮: p (src₁ o)) ; (right o) p → tgt₂ o p }) +resolveˢ (Gᶠ ∩ Hᶠ) (defn sat-∩ sat-∪) Vⁿ G⊆F | no src₁ | no src₂ = + no (λ { (left o) → src₁ o ; (right o) → src₂ o }) + +-- Which means we can resolve any normalized type, by saturating it first +resolveᶠ : ∀ {F V} → FunType F → Normal V → Type +resolveᶠ Fᶠ Vⁿ = target (resolveˢ (normal-saturate Fᶠ) (saturated Fᶠ) Vⁿ (λ o → o)) + +resolveⁿ : ∀ {F V} → Normal F → Normal V → Type +resolveⁿ (Sⁿ ⇒ Tⁿ) Vⁿ = resolveᶠ (Sⁿ ⇒ Tⁿ) Vⁿ +resolveⁿ (Fᶠ ∩ Gᶠ) Vⁿ = resolveᶠ (Fᶠ ∩ Gᶠ) Vⁿ +resolveⁿ (Sⁿ ∪ Tˢ) Vⁿ = unknown +resolveⁿ unknown Vⁿ = unknown +resolveⁿ never Vⁿ = never + +-- Which means we can resolve any type, by normalizing it first +resolve : Type → Type → Type +resolve F V = resolveⁿ (normal F) (normal V) diff --git a/prototyping/Luau/StrictMode.agda b/prototyping/Luau/StrictMode.agda index 0b5fe0da..0628951b 100644 --- a/prototyping/Luau/StrictMode.agda +++ b/prototyping/Luau/StrictMode.agda @@ -5,18 +5,16 @@ module Luau.StrictMode where open import Agda.Builtin.Equality using (_≡_) open import FFI.Data.Maybe using (just; nothing) open import Luau.Syntax using (Expr; Stat; Block; BinaryOperator; yes; nil; addr; var; binexp; var_∈_; _⟨_⟩∈_; function_is_end; _$_; block_is_end; local_←_; _∙_; done; return; name; +; -; *; /; <; >; <=; >=; ··) -open import Luau.Type using (Type; strict; nil; number; string; boolean; none; any; _⇒_; _∪_; _∩_; tgt) +open import Luau.Type using (Type; nil; number; string; boolean; _⇒_; _∪_; _∩_) +open import Luau.ResolveOverloads using (src; resolve) open import Luau.Subtyping using (_≮:_) open import Luau.Heap using (Heap; function_is_end) renaming (_[_] to _[_]ᴴ) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_) renaming (_[_] to _[_]ⱽ) -open import Luau.TypeCheck(strict) using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; ⊢ᴴ_; ⊢ᴼ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; var; addr; app; binexp; block; return; local; function; srcBinOp) +open import Luau.TypeCheck using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; ⊢ᴴ_; ⊢ᴼ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; var; addr; app; binexp; block; return; local; function; srcBinOp) open import Properties.Contradiction using (¬) -open import Properties.TypeCheck(strict) using (typeCheckᴮ) +open import Properties.TypeCheck using (typeCheckᴮ) open import Properties.Product using (_,_) -src : Type → Type -src = Luau.Type.src strict - data Warningᴱ (H : Heap yes) {Γ} : ∀ {M T} → (Γ ⊢ᴱ M ∈ T) → Set data Warningᴮ (H : Heap yes) {Γ} : ∀ {B T} → (Γ ⊢ᴮ B ∈ T) → Set diff --git a/prototyping/Luau/StrictMode/ToString.agda b/prototyping/Luau/StrictMode/ToString.agda index 08ee13b8..7c5f0253 100644 --- a/prototyping/Luau/StrictMode/ToString.agda +++ b/prototyping/Luau/StrictMode/ToString.agda @@ -4,11 +4,11 @@ module Luau.StrictMode.ToString where open import Agda.Builtin.Nat using (Nat; suc) open import FFI.Data.String using (String; _++_) -open import Luau.Subtyping using (_≮:_; Tree; witness; scalar; function; function-ok; function-err) +open import Luau.Subtyping using (_≮:_; Tree; witness; scalar; function; function-ok; function-err; function-tgt) open import Luau.StrictMode using (Warningᴱ; Warningᴮ; UnallocatedAddress; UnboundVariable; FunctionCallMismatch; FunctionDefnMismatch; BlockMismatch; app₁; app₂; BinOpMismatch₁; BinOpMismatch₂; bin₁; bin₂; block₁; return; LocalVarMismatch; local₁; local₂; function₁; function₂; heap; expr; block; addr) open import Luau.Syntax using (Expr; val; yes; var; var_∈_; _⟨_⟩∈_; _$_; addr; number; binexp; nil; function_is_end; block_is_end; done; return; local_←_; _∙_; fun; arg; name) -open import Luau.Type using (strict; number; boolean; string; nil) -open import Luau.TypeCheck(strict) using (_⊢ᴮ_∈_; _⊢ᴱ_∈_) +open import Luau.Type using (number; boolean; string; nil) +open import Luau.TypeCheck using (_⊢ᴮ_∈_; _⊢ᴱ_∈_) open import Luau.Addr.ToString using (addrToString) open import Luau.Var.ToString using (varToString) open import Luau.Type.ToString using (typeToString) @@ -27,8 +27,9 @@ treeToString (scalar boolean) n v = v ++ " is a boolean" treeToString (scalar string) n v = v ++ " is a string" treeToString (scalar nil) n v = v ++ " is nil" treeToString function n v = v ++ " is a function" -treeToString (function-ok t) n v = treeToString t n (v ++ "()") +treeToString (function-ok s t) n v = treeToString t (suc n) (v ++ "(" ++ w ++ ")") ++ " when\n " ++ treeToString s (suc n) w where w = tmp n treeToString (function-err t) n v = v ++ "(" ++ w ++ ") can error when\n " ++ treeToString t (suc n) w where w = tmp n +treeToString (function-tgt t) n v = treeToString t n (v ++ "()") subtypeWarningToString : ∀ {T U} → (T ≮: U) → String subtypeWarningToString (witness t p q) = "\n because provided type contains v, where " ++ treeToString t 0 "v" diff --git a/prototyping/Luau/Subtyping.agda b/prototyping/Luau/Subtyping.agda index 7d67eb43..dc2abed0 100644 --- a/prototyping/Luau/Subtyping.agda +++ b/prototyping/Luau/Subtyping.agda @@ -1,6 +1,6 @@ {-# OPTIONS --rewriting #-} -open import Luau.Type using (Type; Scalar; nil; number; string; boolean; none; any; _⇒_; _∪_; _∩_) +open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) open import Properties.Equality using (_≢_) module Luau.Subtyping where @@ -13,8 +13,9 @@ data Tree : Set where scalar : ∀ {T} → Scalar T → Tree function : Tree - function-ok : Tree → Tree + function-ok : Tree → Tree → Tree function-err : Tree → Tree + function-tgt : Tree → Tree data Language : Type → Tree → Set data ¬Language : Type → Tree → Set @@ -23,26 +24,30 @@ data Language where scalar : ∀ {T} → (s : Scalar T) → Language T (scalar s) function : ∀ {T U} → Language (T ⇒ U) function - function-ok : ∀ {T U u} → (Language U u) → Language (T ⇒ U) (function-ok u) + function-ok₁ : ∀ {T U t u} → (¬Language T t) → Language (T ⇒ U) (function-ok t u) + function-ok₂ : ∀ {T U t u} → (Language U u) → Language (T ⇒ U) (function-ok t u) function-err : ∀ {T U t} → (¬Language T t) → Language (T ⇒ U) (function-err t) - scalar-function-err : ∀ {S t} → (Scalar S) → Language S (function-err t) + function-tgt : ∀ {T U t} → (Language U t) → Language (T ⇒ U) (function-tgt t) left : ∀ {T U t} → Language T t → Language (T ∪ U) t right : ∀ {T U u} → Language U u → Language (T ∪ U) u _,_ : ∀ {T U t} → Language T t → Language U t → Language (T ∩ U) t - any : ∀ {t} → Language any t + unknown : ∀ {t} → Language unknown t data ¬Language where scalar-scalar : ∀ {S T} → (s : Scalar S) → (Scalar T) → (S ≢ T) → ¬Language T (scalar s) scalar-function : ∀ {S} → (Scalar S) → ¬Language S function - scalar-function-ok : ∀ {S u} → (Scalar S) → ¬Language S (function-ok u) + scalar-function-ok : ∀ {S t u} → (Scalar S) → ¬Language S (function-ok t u) + scalar-function-err : ∀ {S t} → (Scalar S) → ¬Language S (function-err t) + scalar-function-tgt : ∀ {S t} → (Scalar S) → ¬Language S (function-tgt t) function-scalar : ∀ {S T U} (s : Scalar S) → ¬Language (T ⇒ U) (scalar s) - function-ok : ∀ {T U u} → (¬Language U u) → ¬Language (T ⇒ U) (function-ok u) + function-ok : ∀ {T U t u} → (Language T t) → (¬Language U u) → ¬Language (T ⇒ U) (function-ok t u) function-err : ∀ {T U t} → (Language T t) → ¬Language (T ⇒ U) (function-err t) + function-tgt : ∀ {T U t} → (¬Language U t) → ¬Language (T ⇒ U) (function-tgt t) _,_ : ∀ {T U t} → ¬Language T t → ¬Language U t → ¬Language (T ∪ U) t left : ∀ {T U t} → ¬Language T t → ¬Language (T ∩ U) t right : ∀ {T U u} → ¬Language U u → ¬Language (T ∩ U) u - none : ∀ {t} → ¬Language none t + never : ∀ {t} → ¬Language never t -- Subtyping as language inclusion diff --git a/prototyping/Luau/Type.agda b/prototyping/Luau/Type.agda index 87815ddb..1d0ec9e5 100644 --- a/prototyping/Luau/Type.agda +++ b/prototyping/Luau/Type.agda @@ -9,8 +9,8 @@ open import FFI.Data.Maybe using (Maybe; just; nothing) data Type : Set where nil : Type _⇒_ : Type → Type → Type - none : Type - any : Type + never : Type + unknown : Type boolean : Type number : Type string : Type @@ -24,13 +24,15 @@ data Scalar : Type → Set where string : Scalar string nil : Scalar nil +skalar = number ∪ (string ∪ (nil ∪ boolean)) + lhs : Type → Type lhs (T ⇒ _) = T lhs (T ∪ _) = T lhs (T ∩ _) = T lhs nil = nil -lhs none = none -lhs any = any +lhs never = never +lhs unknown = unknown lhs number = number lhs boolean = boolean lhs string = string @@ -40,8 +42,8 @@ rhs (_ ⇒ T) = T rhs (_ ∪ T) = T rhs (_ ∩ T) = T rhs nil = nil -rhs none = none -rhs any = any +rhs never = never +rhs unknown = unknown rhs number = number rhs boolean = boolean rhs string = string @@ -49,16 +51,16 @@ rhs string = string _≡ᵀ_ : ∀ (T U : Type) → Dec(T ≡ U) nil ≡ᵀ nil = yes refl nil ≡ᵀ (S ⇒ T) = no (λ ()) -nil ≡ᵀ none = no (λ ()) -nil ≡ᵀ any = no (λ ()) +nil ≡ᵀ never = no (λ ()) +nil ≡ᵀ unknown = no (λ ()) nil ≡ᵀ number = no (λ ()) nil ≡ᵀ boolean = no (λ ()) nil ≡ᵀ (S ∪ T) = no (λ ()) nil ≡ᵀ (S ∩ T) = no (λ ()) nil ≡ᵀ string = no (λ ()) (S ⇒ T) ≡ᵀ string = no (λ ()) -none ≡ᵀ string = no (λ ()) -any ≡ᵀ string = no (λ ()) +never ≡ᵀ string = no (λ ()) +unknown ≡ᵀ string = no (λ ()) boolean ≡ᵀ string = no (λ ()) number ≡ᵀ string = no (λ ()) (S ∪ T) ≡ᵀ string = no (λ ()) @@ -68,48 +70,48 @@ number ≡ᵀ string = no (λ ()) (S ⇒ T) ≡ᵀ (S ⇒ T) | yes refl | yes refl = yes refl (S ⇒ T) ≡ᵀ (U ⇒ V) | _ | no p = no (λ q → p (cong rhs q)) (S ⇒ T) ≡ᵀ (U ⇒ V) | no p | _ = no (λ q → p (cong lhs q)) -(S ⇒ T) ≡ᵀ none = no (λ ()) -(S ⇒ T) ≡ᵀ any = no (λ ()) +(S ⇒ T) ≡ᵀ never = no (λ ()) +(S ⇒ T) ≡ᵀ unknown = no (λ ()) (S ⇒ T) ≡ᵀ number = no (λ ()) (S ⇒ T) ≡ᵀ boolean = no (λ ()) (S ⇒ T) ≡ᵀ (U ∪ V) = no (λ ()) (S ⇒ T) ≡ᵀ (U ∩ V) = no (λ ()) -none ≡ᵀ nil = no (λ ()) -none ≡ᵀ (U ⇒ V) = no (λ ()) -none ≡ᵀ none = yes refl -none ≡ᵀ any = no (λ ()) -none ≡ᵀ number = no (λ ()) -none ≡ᵀ boolean = no (λ ()) -none ≡ᵀ (U ∪ V) = no (λ ()) -none ≡ᵀ (U ∩ V) = no (λ ()) -any ≡ᵀ nil = no (λ ()) -any ≡ᵀ (U ⇒ V) = no (λ ()) -any ≡ᵀ none = no (λ ()) -any ≡ᵀ any = yes refl -any ≡ᵀ number = no (λ ()) -any ≡ᵀ boolean = no (λ ()) -any ≡ᵀ (U ∪ V) = no (λ ()) -any ≡ᵀ (U ∩ V) = no (λ ()) +never ≡ᵀ nil = no (λ ()) +never ≡ᵀ (U ⇒ V) = no (λ ()) +never ≡ᵀ never = yes refl +never ≡ᵀ unknown = no (λ ()) +never ≡ᵀ number = no (λ ()) +never ≡ᵀ boolean = no (λ ()) +never ≡ᵀ (U ∪ V) = no (λ ()) +never ≡ᵀ (U ∩ V) = no (λ ()) +unknown ≡ᵀ nil = no (λ ()) +unknown ≡ᵀ (U ⇒ V) = no (λ ()) +unknown ≡ᵀ never = no (λ ()) +unknown ≡ᵀ unknown = yes refl +unknown ≡ᵀ number = no (λ ()) +unknown ≡ᵀ boolean = no (λ ()) +unknown ≡ᵀ (U ∪ V) = no (λ ()) +unknown ≡ᵀ (U ∩ V) = no (λ ()) number ≡ᵀ nil = no (λ ()) number ≡ᵀ (T ⇒ U) = no (λ ()) -number ≡ᵀ none = no (λ ()) -number ≡ᵀ any = no (λ ()) +number ≡ᵀ never = no (λ ()) +number ≡ᵀ unknown = no (λ ()) number ≡ᵀ number = yes refl number ≡ᵀ boolean = no (λ ()) number ≡ᵀ (T ∪ U) = no (λ ()) number ≡ᵀ (T ∩ U) = no (λ ()) boolean ≡ᵀ nil = no (λ ()) boolean ≡ᵀ (T ⇒ U) = no (λ ()) -boolean ≡ᵀ none = no (λ ()) -boolean ≡ᵀ any = no (λ ()) +boolean ≡ᵀ never = no (λ ()) +boolean ≡ᵀ unknown = no (λ ()) boolean ≡ᵀ boolean = yes refl boolean ≡ᵀ number = no (λ ()) boolean ≡ᵀ (T ∪ U) = no (λ ()) boolean ≡ᵀ (T ∩ U) = no (λ ()) string ≡ᵀ nil = no (λ ()) string ≡ᵀ (x ⇒ x₁) = no (λ ()) -string ≡ᵀ none = no (λ ()) -string ≡ᵀ any = no (λ ()) +string ≡ᵀ never = no (λ ()) +string ≡ᵀ unknown = no (λ ()) string ≡ᵀ boolean = no (λ ()) string ≡ᵀ number = no (λ ()) string ≡ᵀ string = yes refl @@ -117,8 +119,8 @@ string ≡ᵀ (U ∪ V) = no (λ ()) string ≡ᵀ (U ∩ V) = no (λ ()) (S ∪ T) ≡ᵀ nil = no (λ ()) (S ∪ T) ≡ᵀ (U ⇒ V) = no (λ ()) -(S ∪ T) ≡ᵀ none = no (λ ()) -(S ∪ T) ≡ᵀ any = no (λ ()) +(S ∪ T) ≡ᵀ never = no (λ ()) +(S ∪ T) ≡ᵀ unknown = no (λ ()) (S ∪ T) ≡ᵀ number = no (λ ()) (S ∪ T) ≡ᵀ boolean = no (λ ()) (S ∪ T) ≡ᵀ (U ∪ V) with (S ≡ᵀ U) | (T ≡ᵀ V) @@ -128,8 +130,8 @@ string ≡ᵀ (U ∩ V) = no (λ ()) (S ∪ T) ≡ᵀ (U ∩ V) = no (λ ()) (S ∩ T) ≡ᵀ nil = no (λ ()) (S ∩ T) ≡ᵀ (U ⇒ V) = no (λ ()) -(S ∩ T) ≡ᵀ none = no (λ ()) -(S ∩ T) ≡ᵀ any = no (λ ()) +(S ∩ T) ≡ᵀ never = no (λ ()) +(S ∩ T) ≡ᵀ unknown = no (λ ()) (S ∩ T) ≡ᵀ number = no (λ ()) (S ∩ T) ≡ᵀ boolean = no (λ ()) (S ∩ T) ≡ᵀ (U ∪ V) = no (λ ()) @@ -146,37 +148,6 @@ just T ≡ᴹᵀ just U with T ≡ᵀ U (just T ≡ᴹᵀ just T) | yes refl = yes refl (just T ≡ᴹᵀ just U) | no p = no (λ q → p (just-inv q)) -data Mode : Set where - strict : Mode - nonstrict : Mode - -src : Mode → Type → Type -src m nil = none -src m number = none -src m boolean = none -src m string = none -src m (S ⇒ T) = S --- In nonstrict mode, functions are covaraiant, in strict mode they're contravariant -src strict (S ∪ T) = (src strict S) ∩ (src strict T) -src nonstrict (S ∪ T) = (src nonstrict S) ∪ (src nonstrict T) -src strict (S ∩ T) = (src strict S) ∪ (src strict T) -src nonstrict (S ∩ T) = (src nonstrict S) ∩ (src nonstrict T) -src strict none = any -src nonstrict none = none -src strict any = none -src nonstrict any = any - -tgt : Type → Type -tgt nil = none -tgt (S ⇒ T) = T -tgt none = none -tgt any = any -tgt number = none -tgt boolean = none -tgt string = none -tgt (S ∪ T) = (tgt S) ∪ (tgt T) -tgt (S ∩ T) = (tgt S) ∩ (tgt T) - optional : Type → Type optional nil = nil optional (T ∪ nil) = (T ∪ nil) diff --git a/prototyping/Luau/Type/FromJSON.agda b/prototyping/Luau/Type/FromJSON.agda index 2d6ba689..e3d1e8e7 100644 --- a/prototyping/Luau/Type/FromJSON.agda +++ b/prototyping/Luau/Type/FromJSON.agda @@ -2,7 +2,7 @@ module Luau.Type.FromJSON where -open import Luau.Type using (Type; nil; _⇒_; _∪_; _∩_; any; number; string) +open import Luau.Type using (Type; nil; _⇒_; _∪_; _∩_; unknown; never; number; string) open import Agda.Builtin.List using (List; _∷_; []) open import Agda.Builtin.Bool using (true; false) @@ -42,7 +42,9 @@ typeFromJSON (object o) | just (string "AstTypeFunction") | nothing | nothing = typeFromJSON (object o) | just (string "AstTypeReference") with lookup name o typeFromJSON (object o) | just (string "AstTypeReference") | just (string "nil") = Right nil -typeFromJSON (object o) | just (string "AstTypeReference") | just (string "any") = Right any +typeFromJSON (object o) | just (string "AstTypeReference") | just (string "any") = Right unknown -- not quite right +typeFromJSON (object o) | just (string "AstTypeReference") | just (string "unknown") = Right unknown +typeFromJSON (object o) | just (string "AstTypeReference") | just (string "never") = Right never typeFromJSON (object o) | just (string "AstTypeReference") | just (string "number") = Right number typeFromJSON (object o) | just (string "AstTypeReference") | just (string "string") = Right string typeFromJSON (object o) | just (string "AstTypeReference") | _ = Left "Unknown referenced type" diff --git a/prototyping/Luau/Type/ToString.agda b/prototyping/Luau/Type/ToString.agda index 2efe6632..a41ecec2 100644 --- a/prototyping/Luau/Type/ToString.agda +++ b/prototyping/Luau/Type/ToString.agda @@ -1,7 +1,7 @@ module Luau.Type.ToString where open import FFI.Data.String using (String; _++_) -open import Luau.Type using (Type; nil; _⇒_; none; any; number; boolean; string; _∪_; _∩_; normalizeOptional) +open import Luau.Type using (Type; nil; _⇒_; never; unknown; number; boolean; string; _∪_; _∩_; normalizeOptional) {-# TERMINATING #-} typeToString : Type → String @@ -10,8 +10,8 @@ typeToStringᴵ : Type → String typeToString nil = "nil" typeToString (S ⇒ T) = "(" ++ (typeToString S) ++ ") -> " ++ (typeToString T) -typeToString none = "none" -typeToString any = "any" +typeToString never = "never" +typeToString unknown = "unknown" typeToString number = "number" typeToString boolean = "boolean" typeToString string = "string" diff --git a/prototyping/Luau/TypeCheck.agda b/prototyping/Luau/TypeCheck.agda index c22618bc..1abc1eda 100644 --- a/prototyping/Luau/TypeCheck.agda +++ b/prototyping/Luau/TypeCheck.agda @@ -1,27 +1,25 @@ {-# OPTIONS --rewriting #-} -open import Luau.Type using (Mode) - -module Luau.TypeCheck (m : Mode) where +module Luau.TypeCheck where open import Agda.Builtin.Equality using (_≡_) +open import FFI.Data.Either using (Either; Left; Right) open import FFI.Data.Maybe using (Maybe; just) +open import Luau.ResolveOverloads using (resolve) open import Luau.Syntax using (Expr; Stat; Block; BinaryOperator; yes; nil; addr; number; bool; string; val; var; var_∈_; _⟨_⟩∈_; function_is_end; _$_; block_is_end; binexp; local_←_; _∙_; done; return; name; +; -; *; /; <; >; ==; ~=; <=; >=; ··) open import Luau.Var using (Var) open import Luau.Addr using (Addr) open import Luau.Heap using (Heap; Object; function_is_end) renaming (_[_] to _[_]ᴴ) -open import Luau.Type using (Type; Mode; nil; any; number; boolean; string; _⇒_; tgt) +open import Luau.Type using (Type; nil; unknown; number; boolean; string; _⇒_) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_) renaming (_[_] to _[_]ⱽ) open import FFI.Data.Vector using (Vector) open import FFI.Data.Maybe using (Maybe; just; nothing) +open import Properties.DecSubtyping using (dec-subtyping) open import Properties.Product using (_×_; _,_) -src : Type → Type -src = Luau.Type.src m - -orAny : Maybe Type → Type -orAny nothing = any -orAny (just T) = T +orUnknown : Maybe Type → Type +orUnknown nothing = unknown +orUnknown (just T) = T srcBinOp : BinaryOperator → Type srcBinOp + = number @@ -30,8 +28,8 @@ srcBinOp * = number srcBinOp / = number srcBinOp < = number srcBinOp > = number -srcBinOp == = any -srcBinOp ~= = any +srcBinOp == = unknown +srcBinOp ~= = unknown srcBinOp <= = number srcBinOp >= = number srcBinOp ·· = string @@ -89,7 +87,7 @@ data _⊢ᴱ_∈_ where var : ∀ {x T Γ} → - T ≡ orAny(Γ [ x ]ⱽ) → + T ≡ orUnknown(Γ [ x ]ⱽ) → ---------------- Γ ⊢ᴱ (var x) ∈ T @@ -117,8 +115,8 @@ data _⊢ᴱ_∈_ where Γ ⊢ᴱ M ∈ T → Γ ⊢ᴱ N ∈ U → - ---------------------- - Γ ⊢ᴱ (M $ N) ∈ (tgt T) + ---------------------------- + Γ ⊢ᴱ (M $ N) ∈ (resolve T U) function : ∀ {f x B T U V Γ} → diff --git a/prototyping/Luau/TypeNormalization.agda b/prototyping/Luau/TypeNormalization.agda new file mode 100644 index 00000000..08f14474 --- /dev/null +++ b/prototyping/Luau/TypeNormalization.agda @@ -0,0 +1,65 @@ +module Luau.TypeNormalization where + +open import Luau.Type using (Type; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) + +-- Operations on normalized types +_∪ᶠ_ : Type → Type → Type +_∪ⁿˢ_ : Type → Type → Type +_∩ⁿˢ_ : Type → Type → Type +_∪ⁿ_ : Type → Type → Type +_∩ⁿ_ : Type → Type → Type + +-- Union of function types +(F₁ ∩ F₂) ∪ᶠ G = (F₁ ∪ᶠ G) ∩ (F₂ ∪ᶠ G) +F ∪ᶠ (G₁ ∩ G₂) = (F ∪ᶠ G₁) ∩ (F ∪ᶠ G₂) +(R ⇒ S) ∪ᶠ (T ⇒ U) = (R ∩ⁿ T) ⇒ (S ∪ⁿ U) +F ∪ᶠ G = F ∪ G + +-- Union of normalized types +S ∪ⁿ (T₁ ∪ T₂) = (S ∪ⁿ T₁) ∪ T₂ +S ∪ⁿ unknown = unknown +S ∪ⁿ never = S +never ∪ⁿ T = T +unknown ∪ⁿ T = unknown +(S₁ ∪ S₂) ∪ⁿ G = (S₁ ∪ⁿ G) ∪ S₂ +F ∪ⁿ G = F ∪ᶠ G + +-- Intersection of normalized types +S ∩ⁿ (T₁ ∪ T₂) = (S ∩ⁿ T₁) ∪ⁿˢ (S ∩ⁿˢ T₂) +S ∩ⁿ unknown = S +S ∩ⁿ never = never +(S₁ ∪ S₂) ∩ⁿ G = (S₁ ∩ⁿ G) +unknown ∩ⁿ G = G +never ∩ⁿ G = never +F ∩ⁿ G = F ∩ G + +-- Intersection of normalized types with a scalar +(S₁ ∪ nil) ∩ⁿˢ nil = nil +(S₁ ∪ boolean) ∩ⁿˢ boolean = boolean +(S₁ ∪ number) ∩ⁿˢ number = number +(S₁ ∪ string) ∩ⁿˢ string = string +(S₁ ∪ S₂) ∩ⁿˢ T = S₁ ∩ⁿˢ T +unknown ∩ⁿˢ T = T +F ∩ⁿˢ T = never + +-- Union of normalized types with an optional scalar +S ∪ⁿˢ never = S +unknown ∪ⁿˢ T = unknown +(S₁ ∪ nil) ∪ⁿˢ nil = S₁ ∪ nil +(S₁ ∪ boolean) ∪ⁿˢ boolean = S₁ ∪ boolean +(S₁ ∪ number) ∪ⁿˢ number = S₁ ∪ number +(S₁ ∪ string) ∪ⁿˢ string = S₁ ∪ string +(S₁ ∪ S₂) ∪ⁿˢ T = (S₁ ∪ⁿˢ T) ∪ S₂ +F ∪ⁿˢ T = F ∪ T + +-- Normalize! +normalize : Type → Type +normalize nil = never ∪ nil +normalize (S ⇒ T) = (normalize S ⇒ normalize T) +normalize never = never +normalize unknown = unknown +normalize boolean = never ∪ boolean +normalize number = never ∪ number +normalize string = never ∪ string +normalize (S ∪ T) = normalize S ∪ⁿ normalize T +normalize (S ∩ T) = normalize S ∩ⁿ normalize T diff --git a/prototyping/Luau/TypeSaturation.agda b/prototyping/Luau/TypeSaturation.agda new file mode 100644 index 00000000..fa24ff73 --- /dev/null +++ b/prototyping/Luau/TypeSaturation.agda @@ -0,0 +1,66 @@ +module Luau.TypeSaturation where + +open import Luau.Type using (Type; _⇒_; _∩_; _∪_) +open import Luau.TypeNormalization using (_∪ⁿ_; _∩ⁿ_) + +-- So, there's a problem with overloaded functions +-- (of the form (S_1 ⇒ T_1) ∩⋯∩ (S_n ⇒ T_n)) +-- which is that it's not good enough to compare them +-- for subtyping by comparing all of their overloads. + +-- For example (nil → nil) is a subtype of (number? → number?) ∩ (string? → string?) +-- but not a subtype of any of its overloads. + +-- To fix this, we adapt the semantic subtyping algorithm for +-- function types, given in +-- https://www.irif.fr/~gc/papers/covcon-again.pdf and +-- https://pnwamk.github.io/sst-tutorial/ + +-- A function type is *intersection-saturated* if for any overloads +-- (S₁ ⇒ T₁) and (S₂ ⇒ T₂), there exists an overload which is a subtype +-- of ((S₁ ∩ S₂) ⇒ (T₁ ∩ T₂)). + +-- A function type is *union-saturated* if for any overloads +-- (S₁ ⇒ T₁) and (S₂ ⇒ T₂), there exists an overload which is a subtype +-- of ((S₁ ∪ S₂) ⇒ (T₁ ∪ T₂)). + +-- A function type is *saturated* if it is both intersection- and +-- union-saturated. + +-- For example (number? → number?) ∩ (string? → string?) +-- is not saturated, but (number? → number?) ∩ (string? → string?) ∩ (nil → nil) ∩ ((number ∪ string)? → (number ∪ string)?) +-- is. + +-- Saturated function types have the nice property that they can ber +-- compared by just comparing their overloads: F <: G whenever for any +-- overload of G, there is an overload os F which is a subtype of it. + +-- Forunately every function type can be saturated! +_⋓_ : Type → Type → Type +(S₁ ⇒ T₁) ⋓ (S₂ ⇒ T₂) = (S₁ ∪ⁿ S₂) ⇒ (T₁ ∪ⁿ T₂) +(F₁ ∩ G₁) ⋓ F₂ = (F₁ ⋓ F₂) ∩ (G₁ ⋓ F₂) +F₁ ⋓ (F₂ ∩ G₂) = (F₁ ⋓ F₂) ∩ (F₁ ⋓ G₂) +F ⋓ G = F ∩ G + +_⋒_ : Type → Type → Type +(S₁ ⇒ T₁) ⋒ (S₂ ⇒ T₂) = (S₁ ∩ⁿ S₂) ⇒ (T₁ ∩ⁿ T₂) +(F₁ ∩ G₁) ⋒ F₂ = (F₁ ⋒ F₂) ∩ (G₁ ⋒ F₂) +F₁ ⋒ (F₂ ∩ G₂) = (F₁ ⋒ F₂) ∩ (F₁ ⋒ G₂) +F ⋒ G = F ∩ G + +_∩ᵘ_ : Type → Type → Type +F ∩ᵘ G = (F ∩ G) ∩ (F ⋓ G) + +_∩ⁱ_ : Type → Type → Type +F ∩ⁱ G = (F ∩ G) ∩ (F ⋒ G) + +∪-saturate : Type → Type +∪-saturate (F ∩ G) = (∪-saturate F ∩ᵘ ∪-saturate G) +∪-saturate F = F + +∩-saturate : Type → Type +∩-saturate (F ∩ G) = (∩-saturate F ∩ⁱ ∩-saturate G) +∩-saturate F = F + +saturate : Type → Type +saturate F = ∪-saturate (∩-saturate F) diff --git a/prototyping/Properties.agda b/prototyping/Properties.agda index 5594812e..f883a3ea 100644 --- a/prototyping/Properties.agda +++ b/prototyping/Properties.agda @@ -4,6 +4,7 @@ module Properties where import Properties.Contradiction import Properties.Dec +import Properties.DecSubtyping import Properties.Equality import Properties.Functions import Properties.Remember @@ -11,3 +12,4 @@ import Properties.Step import Properties.StrictMode import Properties.Subtyping import Properties.TypeCheck +import Properties.TypeNormalization diff --git a/prototyping/Properties/DecSubtyping.agda b/prototyping/Properties/DecSubtyping.agda new file mode 100644 index 00000000..8dc7a446 --- /dev/null +++ b/prototyping/Properties/DecSubtyping.agda @@ -0,0 +1,174 @@ +{-# OPTIONS --rewriting #-} + +module Properties.DecSubtyping where + +open import Agda.Builtin.Equality using (_≡_; refl) +open import FFI.Data.Either using (Either; Left; Right; mapLR; swapLR; cond) +open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-function-tgt; scalar-scalar; function-scalar; function-ok; function-ok₁; function-ok₂; function-err; function-tgt; left; right; _,_) +open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) +open import Luau.TypeNormalization using (_∪ⁿ_; _∩ⁿ_) +open import Luau.TypeSaturation using (saturate) +open import Properties.Contradiction using (CONTRADICTION; ¬) +open import Properties.Functions using (_∘_) +open import Properties.Subtyping using (<:-refl; <:-trans; ≮:-trans-<:; <:-trans-≮:; <:-never; <:-unknown; <:-∪-left; <:-∪-right; <:-∪-lub; ≮:-∪-left; ≮:-∪-right; <:-∩-left; <:-∩-right; <:-∩-glb; ≮:-∩-left; ≮:-∩-right; dec-language; scalar-<:; <:-everything; <:-function; ≮:-function-left; ≮:-function-right; <:-impl-¬≮:; <:-intersect; <:-function-∩-∪; <:-function-∩; <:-union; ≮:-left-∪; ≮:-right-∪; <:-∩-distr-∪; <:-impl-⊇; language-comp) +open import Properties.TypeNormalization using (FunType; Normal; never; unknown; _∩_; _∪_; _⇒_; normal; <:-normalize; normalize-<:; normal-∩ⁿ; normal-∪ⁿ; ∪-<:-∪ⁿ; ∪ⁿ-<:-∪; ∩ⁿ-<:-∩; ∩-<:-∩ⁿ; normalᶠ; fun-top; fun-function; fun-¬scalar) +open import Properties.TypeSaturation using (Overloads; Saturated; _⊆ᵒ_; _<:ᵒ_; defn; here; left; right; ov-language; ov-<:; saturated; normal-saturate; normal-overload-src; normal-overload-tgt; saturate-<:; <:-saturate; <:ᵒ-impl-<:; _>>=ˡ_; _>>=ʳ_) +open import Properties.Equality using (_≢_) + +-- Honest this terminates, since saturation maintains the depth of nested arrows +{-# TERMINATING #-} +dec-subtypingˢⁿ : ∀ {T U} → Scalar T → Normal U → Either (T ≮: U) (T <: U) +dec-subtypingˢᶠ : ∀ {F G} → FunType F → Saturated F → FunType G → Either (F ≮: G) (F <:ᵒ G) +dec-subtypingᶠ : ∀ {F G} → FunType F → FunType G → Either (F ≮: G) (F <: G) +dec-subtypingᶠⁿ : ∀ {F U} → FunType F → Normal U → Either (F ≮: U) (F <: U) +dec-subtypingⁿ : ∀ {T U} → Normal T → Normal U → Either (T ≮: U) (T <: U) +dec-subtyping : ∀ T U → Either (T ≮: U) (T <: U) + +dec-subtypingˢⁿ T U with dec-language _ (scalar T) +dec-subtypingˢⁿ T U | Left p = Left (witness (scalar T) (scalar T) p) +dec-subtypingˢⁿ T U | Right p = Right (scalar-<: T p) + +dec-subtypingˢᶠ {F} {S ⇒ T} Fᶠ (defn sat-∩ sat-∪) (Sⁿ ⇒ Tⁿ) = result (top Fᶠ (λ o → o)) where + + data Top G : Set where + + defn : ∀ Sᵗ Tᵗ → + + Overloads F (Sᵗ ⇒ Tᵗ) → + (∀ {S′ T′} → Overloads G (S′ ⇒ T′) → (S′ <: Sᵗ)) → + ------------- + Top G + + top : ∀ {G} → (FunType G) → (G ⊆ᵒ F) → Top G + top {S′ ⇒ T′} _ G⊆F = defn S′ T′ (G⊆F here) (λ { here → <:-refl }) + top (Gᶠ ∩ Hᶠ) G⊆F with top Gᶠ (G⊆F ∘ left) | top Hᶠ (G⊆F ∘ right) + top (Gᶠ ∩ Hᶠ) G⊆F | defn Rᵗ Sᵗ p p₁ | defn Tᵗ Uᵗ q q₁ with sat-∪ p q + top (Gᶠ ∩ Hᶠ) G⊆F | defn Rᵗ Sᵗ p p₁ | defn Tᵗ Uᵗ q q₁ | defn n r r₁ = defn _ _ n + (λ { (left o) → <:-trans (<:-trans (p₁ o) <:-∪-left) r ; (right o) → <:-trans (<:-trans (q₁ o) <:-∪-right) r }) + + result : Top F → Either (F ≮: (S ⇒ T)) (F <:ᵒ (S ⇒ T)) + result (defn Sᵗ Tᵗ oᵗ srcᵗ) with dec-subtypingⁿ Sⁿ (normal-overload-src Fᶠ oᵗ) + result (defn Sᵗ Tᵗ oᵗ srcᵗ) | Left (witness s Ss ¬Sᵗs) = Left (witness (function-err s) (ov-language Fᶠ (λ o → function-err (<:-impl-⊇ (srcᵗ o) s ¬Sᵗs))) (function-err Ss)) + result (defn Sᵗ Tᵗ oᵗ srcᵗ) | Right S<:Sᵗ = result₀ (largest Fᶠ (λ o → o)) where + + data LargestSrc (G : Type) : Set where + + yes : ∀ S₀ T₀ → + + Overloads F (S₀ ⇒ T₀) → + T₀ <: T → + (∀ {S′ T′} → Overloads G (S′ ⇒ T′) → T′ <: T → (S′ <: S₀)) → + ----------------------- + LargestSrc G + + no : ∀ S₀ T₀ → + + Overloads F (S₀ ⇒ T₀) → + T₀ ≮: T → + (∀ {S′ T′} → Overloads G (S′ ⇒ T′) → T₀ <: T′) → + ----------------------- + LargestSrc G + + largest : ∀ {G} → (FunType G) → (G ⊆ᵒ F) → LargestSrc G + largest {S′ ⇒ T′} (S′ⁿ ⇒ T′ⁿ) G⊆F with dec-subtypingⁿ T′ⁿ Tⁿ + largest {S′ ⇒ T′} (S′ⁿ ⇒ T′ⁿ) G⊆F | Left T′≮:T = no S′ T′ (G⊆F here) T′≮:T λ { here → <:-refl } + largest {S′ ⇒ T′} (S′ⁿ ⇒ T′ⁿ) G⊆F | Right T′<:T = yes S′ T′ (G⊆F here) T′<:T (λ { here _ → <:-refl }) + largest (Gᶠ ∩ Hᶠ) GH⊆F with largest Gᶠ (GH⊆F ∘ left) | largest Hᶠ (GH⊆F ∘ right) + largest (Gᶠ ∩ Hᶠ) GH⊆F | no S₁ T₁ o₁ T₁≮:T tgt₁ | no S₂ T₂ o₂ T₂≮:T tgt₂ with sat-∩ o₁ o₂ + largest (Gᶠ ∩ Hᶠ) GH⊆F | no S₁ T₁ o₁ T₁≮:T tgt₁ | no S₂ T₂ o₂ T₂≮:T tgt₂ | defn o src tgt with dec-subtypingⁿ (normal-overload-tgt Fᶠ o) Tⁿ + largest (Gᶠ ∩ Hᶠ) GH⊆F | no S₁ T₁ o₁ T₁≮:T tgt₁ | no S₂ T₂ o₂ T₂≮:T tgt₂ | defn o src tgt | Left T₀≮:T = no _ _ o T₀≮:T (λ { (left o) → <:-trans tgt (<:-trans <:-∩-left (tgt₁ o)) ; (right o) → <:-trans tgt (<:-trans <:-∩-right (tgt₂ o)) }) + largest (Gᶠ ∩ Hᶠ) GH⊆F | no S₁ T₁ o₁ T₁≮:T tgt₁ | no S₂ T₂ o₂ T₂≮:T tgt₂ | defn o src tgt | Right T₀<:T = yes _ _ o T₀<:T (λ { (left o) p → CONTRADICTION (<:-impl-¬≮: p (<:-trans-≮: (tgt₁ o) T₁≮:T)) ; (right o) p → CONTRADICTION (<:-impl-¬≮: p (<:-trans-≮: (tgt₂ o) T₂≮:T)) }) + largest (Gᶠ ∩ Hᶠ) GH⊆F | no S₁ T₁ o₁ T₁≮:T tgt₁ | yes S₂ T₂ o₂ T₂<:T src₂ = yes S₂ T₂ o₂ T₂<:T (λ { (left o) p → CONTRADICTION (<:-impl-¬≮: p (<:-trans-≮: (tgt₁ o) T₁≮:T)) ; (right o) p → src₂ o p }) + largest (Gᶠ ∩ Hᶠ) GH⊆F | yes S₁ T₁ o₁ T₁<:T src₁ | no S₂ T₂ o₂ T₂≮:T tgt₂ = yes S₁ T₁ o₁ T₁<:T (λ { (left o) p → src₁ o p ; (right o) p → CONTRADICTION (<:-impl-¬≮: p (<:-trans-≮: (tgt₂ o) T₂≮:T)) }) + largest (Gᶠ ∩ Hᶠ) GH⊆F | yes S₁ T₁ o₁ T₁<:T src₁ | yes S₂ T₂ o₂ T₂<:T src₂ with sat-∪ o₁ o₂ + largest (Gᶠ ∩ Hᶠ) GH⊆F | yes S₁ T₁ o₁ T₁<:T src₁ | yes S₂ T₂ o₂ T₂<:T src₂ | defn o src tgt = yes _ _ o (<:-trans tgt (<:-∪-lub T₁<:T T₂<:T)) + (λ { (left o) T′<:T → <:-trans (src₁ o T′<:T) (<:-trans <:-∪-left src) + ; (right o) T′<:T → <:-trans (src₂ o T′<:T) (<:-trans <:-∪-right src) + }) + + result₀ : LargestSrc F → Either (F ≮: (S ⇒ T)) (F <:ᵒ (S ⇒ T)) + result₀ (no S₀ T₀ o₀ (witness t T₀t ¬Tt) tgt₀) = Left (witness (function-tgt t) (ov-language Fᶠ (λ o → function-tgt (tgt₀ o t T₀t))) (function-tgt ¬Tt)) + result₀ (yes S₀ T₀ o₀ T₀<:T src₀) with dec-subtypingⁿ Sⁿ (normal-overload-src Fᶠ o₀) + result₀ (yes S₀ T₀ o₀ T₀<:T src₀) | Right S<:S₀ = Right λ { here → defn o₀ S<:S₀ T₀<:T } + result₀ (yes S₀ T₀ o₀ T₀<:T src₀) | Left (witness s Ss ¬S₀s) = Left (result₁ (smallest Fᶠ (λ o → o))) where + + data SmallestTgt (G : Type) : Set where + + defn : ∀ S₁ T₁ → + + Overloads F (S₁ ⇒ T₁) → + Language S₁ s → + (∀ {S′ T′} → Overloads G (S′ ⇒ T′) → Language S′ s → (T₁ <: T′)) → + ----------------------- + SmallestTgt G + + smallest : ∀ {G} → (FunType G) → (G ⊆ᵒ F) → SmallestTgt G + smallest {S′ ⇒ T′} _ G⊆F with dec-language S′ s + smallest {S′ ⇒ T′} _ G⊆F | Left ¬S′s = defn Sᵗ Tᵗ oᵗ (S<:Sᵗ s Ss) λ { here S′s → CONTRADICTION (language-comp s ¬S′s S′s) } + smallest {S′ ⇒ T′} _ G⊆F | Right S′s = defn S′ T′ (G⊆F here) S′s (λ { here _ → <:-refl }) + smallest (Gᶠ ∩ Hᶠ) GH⊆F with smallest Gᶠ (GH⊆F ∘ left) | smallest Hᶠ (GH⊆F ∘ right) + smallest (Gᶠ ∩ Hᶠ) GH⊆F | defn S₁ T₁ o₁ R₁s tgt₁ | defn S₂ T₂ o₂ R₂s tgt₂ with sat-∩ o₁ o₂ + smallest (Gᶠ ∩ Hᶠ) GH⊆F | defn S₁ T₁ o₁ R₁s tgt₁ | defn S₂ T₂ o₂ R₂s tgt₂ | defn o src tgt = defn _ _ o (src s (R₁s , R₂s)) + (λ { (left o) S′s → <:-trans (<:-trans tgt <:-∩-left) (tgt₁ o S′s) + ; (right o) S′s → <:-trans (<:-trans tgt <:-∩-right) (tgt₂ o S′s) + }) + + result₁ : SmallestTgt F → (F ≮: (S ⇒ T)) + result₁ (defn S₁ T₁ o₁ S₁s tgt₁) with dec-subtypingⁿ (normal-overload-tgt Fᶠ o₁) Tⁿ + result₁ (defn S₁ T₁ o₁ S₁s tgt₁) | Right T₁<:T = CONTRADICTION (language-comp s ¬S₀s (src₀ o₁ T₁<:T s S₁s)) + result₁ (defn S₁ T₁ o₁ S₁s tgt₁) | Left (witness t T₁t ¬Tt) = witness (function-ok s t) (ov-language Fᶠ lemma) (function-ok Ss ¬Tt) where + + lemma : ∀ {S′ T′} → Overloads F (S′ ⇒ T′) → Language (S′ ⇒ T′) (function-ok s t) + lemma {S′} o with dec-language S′ s + lemma {S′} o | Left ¬S′s = function-ok₁ ¬S′s + lemma {S′} o | Right S′s = function-ok₂ (tgt₁ o S′s t T₁t) + +dec-subtypingˢᶠ F Fˢ (G ∩ H) with dec-subtypingˢᶠ F Fˢ G | dec-subtypingˢᶠ F Fˢ H +dec-subtypingˢᶠ F Fˢ (G ∩ H) | Left F≮:G | _ = Left (≮:-∩-left F≮:G) +dec-subtypingˢᶠ F Fˢ (G ∩ H) | _ | Left F≮:H = Left (≮:-∩-right F≮:H) +dec-subtypingˢᶠ F Fˢ (G ∩ H) | Right F<:G | Right F<:H = Right (λ { (left o) → F<:G o ; (right o) → F<:H o }) + +dec-subtypingᶠ F G with dec-subtypingˢᶠ (normal-saturate F) (saturated F) G +dec-subtypingᶠ F G | Left H≮:G = Left (<:-trans-≮: (saturate-<: F) H≮:G) +dec-subtypingᶠ F G | Right H<:G = Right (<:-trans (<:-saturate F) (<:ᵒ-impl-<: (normal-saturate F) G H<:G)) + +dec-subtypingᶠⁿ T never = Left (witness function (fun-function T) never) +dec-subtypingᶠⁿ T unknown = Right <:-unknown +dec-subtypingᶠⁿ T (U ⇒ V) = dec-subtypingᶠ T (U ⇒ V) +dec-subtypingᶠⁿ T (U ∩ V) = dec-subtypingᶠ T (U ∩ V) +dec-subtypingᶠⁿ T (U ∪ V) with dec-subtypingᶠⁿ T U +dec-subtypingᶠⁿ T (U ∪ V) | Left (witness t p q) = Left (witness t p (q , fun-¬scalar V T p)) +dec-subtypingᶠⁿ T (U ∪ V) | Right p = Right (<:-trans p <:-∪-left) + +dec-subtypingⁿ never U = Right <:-never +dec-subtypingⁿ unknown unknown = Right <:-refl +dec-subtypingⁿ unknown U with dec-subtypingᶠⁿ (never ⇒ unknown) U +dec-subtypingⁿ unknown U | Left p = Left (<:-trans-≮: <:-unknown p) +dec-subtypingⁿ unknown U | Right p₁ with dec-subtypingˢⁿ number U +dec-subtypingⁿ unknown U | Right p₁ | Left p = Left (<:-trans-≮: <:-unknown p) +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ with dec-subtypingˢⁿ string U +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Left p = Left (<:-trans-≮: <:-unknown p) +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Right p₃ with dec-subtypingˢⁿ nil U +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Right p₃ | Left p = Left (<:-trans-≮: <:-unknown p) +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Right p₃ | Right p₄ with dec-subtypingˢⁿ boolean U +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Right p₃ | Right p₄ | Left p = Left (<:-trans-≮: <:-unknown p) +dec-subtypingⁿ unknown U | Right p₁ | Right p₂ | Right p₃ | Right p₄ | Right p₅ = Right (<:-trans <:-everything (<:-∪-lub p₁ (<:-∪-lub p₂ (<:-∪-lub p₃ (<:-∪-lub p₄ p₅))))) +dec-subtypingⁿ (S ⇒ T) U = dec-subtypingᶠⁿ (S ⇒ T) U +dec-subtypingⁿ (S ∩ T) U = dec-subtypingᶠⁿ (S ∩ T) U +dec-subtypingⁿ (S ∪ T) U with dec-subtypingⁿ S U | dec-subtypingˢⁿ T U +dec-subtypingⁿ (S ∪ T) U | Left p | q = Left (≮:-∪-left p) +dec-subtypingⁿ (S ∪ T) U | Right p | Left q = Left (≮:-∪-right q) +dec-subtypingⁿ (S ∪ T) U | Right p | Right q = Right (<:-∪-lub p q) + +dec-subtyping T U with dec-subtypingⁿ (normal T) (normal U) +dec-subtyping T U | Left p = Left (<:-trans-≮: (normalize-<: T) (≮:-trans-<: p (<:-normalize U))) +dec-subtyping T U | Right p = Right (<:-trans (<:-normalize T) (<:-trans p (normalize-<: U))) + +-- As a corollary, for saturated functions +-- <:ᵒ coincides with <:, that is F is a subtype of (S ⇒ T) precisely +-- when one of its overloads is. + +<:-impl-<:ᵒ : ∀ {F G} → FunType F → Saturated F → FunType G → (F <: G) → (F <:ᵒ G) +<:-impl-<:ᵒ {F} {G} Fᶠ Fˢ Gᶠ F<:G with dec-subtypingˢᶠ Fᶠ Fˢ Gᶠ +<:-impl-<:ᵒ {F} {G} Fᶠ Fˢ Gᶠ F<:G | Left F≮:G = CONTRADICTION (<:-impl-¬≮: F<:G F≮:G) +<:-impl-<:ᵒ {F} {G} Fᶠ Fˢ Gᶠ F<:G | Right F<:ᵒG = F<:ᵒG diff --git a/prototyping/Properties/ResolveOverloads.agda b/prototyping/Properties/ResolveOverloads.agda new file mode 100644 index 00000000..8de4a875 --- /dev/null +++ b/prototyping/Properties/ResolveOverloads.agda @@ -0,0 +1,189 @@ +{-# OPTIONS --rewriting #-} + +module Properties.ResolveOverloads where + +open import FFI.Data.Either using (Left; Right) +open import Luau.ResolveOverloads using (Resolved; src; srcⁿ; resolve; resolveⁿ; resolveᶠ; resolveˢ; target; yes; no) +open import Luau.Subtyping using (_<:_; _≮:_; Language; ¬Language; witness; scalar; unknown; never; function; function-ok; function-err; function-tgt; function-scalar; function-ok₁; function-ok₂; scalar-scalar; scalar-function; scalar-function-ok; scalar-function-err; scalar-function-tgt; _,_; left; right) +open import Luau.Type using (Type ; Scalar; _⇒_; _∩_; _∪_; nil; boolean; number; string; unknown; never) +open import Luau.TypeSaturation using (saturate) +open import Luau.TypeNormalization using (normalize) +open import Properties.Contradiction using (CONTRADICTION) +open import Properties.DecSubtyping using (dec-subtyping; dec-subtypingⁿ; <:-impl-<:ᵒ) +open import Properties.Functions using (_∘_) +open import Properties.Subtyping using (<:-refl; <:-trans; <:-trans-≮:; ≮:-trans-<:; <:-∩-left; <:-∩-right; <:-∩-glb; <:-impl-¬≮:; <:-unknown; <:-function; function-≮:-never; <:-never; unknown-≮:-function; scalar-≮:-function; ≮:-∪-right; scalar-≮:-never; <:-∪-left; <:-∪-right; <:-impl-⊇; language-comp) +open import Properties.TypeNormalization using (Normal; FunType; normal; _⇒_; _∩_; _∪_; never; unknown; <:-normalize; normalize-<:; fun-≮:-never; unknown-≮:-fun; scalar-≮:-fun) +open import Properties.TypeSaturation using (Overloads; Saturated; _⊆ᵒ_; _<:ᵒ_; normal-saturate; saturated; <:-saturate; saturate-<:; defn; here; left; right) + +-- Properties of src +function-err-srcⁿ : ∀ {T t} → (FunType T) → (¬Language (srcⁿ T) t) → Language T (function-err t) +function-err-srcⁿ (S ⇒ T) p = function-err p +function-err-srcⁿ (S ∩ T) (p₁ , p₂) = (function-err-srcⁿ S p₁ , function-err-srcⁿ T p₂) + +¬function-err-srcᶠ : ∀ {T t} → (FunType T) → (Language (srcⁿ T) t) → ¬Language T (function-err t) +¬function-err-srcᶠ (S ⇒ T) p = function-err p +¬function-err-srcᶠ (S ∩ T) (left p) = left (¬function-err-srcᶠ S p) +¬function-err-srcᶠ (S ∩ T) (right p) = right (¬function-err-srcᶠ T p) + +¬function-err-srcⁿ : ∀ {T t} → (Normal T) → (Language (srcⁿ T) t) → ¬Language T (function-err t) +¬function-err-srcⁿ never p = never +¬function-err-srcⁿ unknown (scalar ()) +¬function-err-srcⁿ (S ⇒ T) p = function-err p +¬function-err-srcⁿ (S ∩ T) (left p) = left (¬function-err-srcᶠ S p) +¬function-err-srcⁿ (S ∩ T) (right p) = right (¬function-err-srcᶠ T p) +¬function-err-srcⁿ (S ∪ T) (scalar ()) + +¬function-err-src : ∀ {T t} → (Language (src T) t) → ¬Language T (function-err t) +¬function-err-src {T = S ⇒ T} p = function-err p +¬function-err-src {T = nil} p = scalar-function-err nil +¬function-err-src {T = never} p = never +¬function-err-src {T = unknown} (scalar ()) +¬function-err-src {T = boolean} p = scalar-function-err boolean +¬function-err-src {T = number} p = scalar-function-err number +¬function-err-src {T = string} p = scalar-function-err string +¬function-err-src {T = S ∪ T} p = <:-impl-⊇ (<:-normalize (S ∪ T)) _ (¬function-err-srcⁿ (normal (S ∪ T)) p) +¬function-err-src {T = S ∩ T} p = <:-impl-⊇ (<:-normalize (S ∩ T)) _ (¬function-err-srcⁿ (normal (S ∩ T)) p) + +src-¬function-errᶠ : ∀ {T t} → (FunType T) → Language T (function-err t) → (¬Language (srcⁿ T) t) +src-¬function-errᶠ (S ⇒ T) (function-err p) = p +src-¬function-errᶠ (S ∩ T) (p₁ , p₂) = (src-¬function-errᶠ S p₁ , src-¬function-errᶠ T p₂) + +src-¬function-errⁿ : ∀ {T t} → (Normal T) → Language T (function-err t) → (¬Language (srcⁿ T) t) +src-¬function-errⁿ unknown p = never +src-¬function-errⁿ (S ⇒ T) (function-err p) = p +src-¬function-errⁿ (S ∩ T) (p₁ , p₂) = (src-¬function-errᶠ S p₁ , src-¬function-errᶠ T p₂) +src-¬function-errⁿ (S ∪ T) p = never + +src-¬function-err : ∀ {T t} → Language T (function-err t) → (¬Language (src T) t) +src-¬function-err {T = S ⇒ T} (function-err p) = p +src-¬function-err {T = unknown} p = never +src-¬function-err {T = S ∪ T} p = src-¬function-errⁿ (normal (S ∪ T)) (<:-normalize (S ∪ T) _ p) +src-¬function-err {T = S ∩ T} p = src-¬function-errⁿ (normal (S ∩ T)) (<:-normalize (S ∩ T) _ p) + +fun-¬scalar : ∀ {S T} (s : Scalar S) → FunType T → ¬Language T (scalar s) +fun-¬scalar s (S ⇒ T) = function-scalar s +fun-¬scalar s (S ∩ T) = left (fun-¬scalar s S) + +¬fun-scalar : ∀ {S T t} (s : Scalar S) → FunType T → Language T t → ¬Language S t +¬fun-scalar s (S ⇒ T) function = scalar-function s +¬fun-scalar s (S ⇒ T) (function-ok₁ p) = scalar-function-ok s +¬fun-scalar s (S ⇒ T) (function-ok₂ p) = scalar-function-ok s +¬fun-scalar s (S ⇒ T) (function-err p) = scalar-function-err s +¬fun-scalar s (S ⇒ T) (function-tgt p) = scalar-function-tgt s +¬fun-scalar s (S ∩ T) (p₁ , p₂) = ¬fun-scalar s T p₂ + +fun-function : ∀ {T} → FunType T → Language T function +fun-function (S ⇒ T) = function +fun-function (S ∩ T) = (fun-function S , fun-function T) + +srcⁿ-¬scalar : ∀ {S T t} (s : Scalar S) → Normal T → Language T (scalar s) → (¬Language (srcⁿ T) t) +srcⁿ-¬scalar s never (scalar ()) +srcⁿ-¬scalar s unknown p = never +srcⁿ-¬scalar s (S ⇒ T) (scalar ()) +srcⁿ-¬scalar s (S ∩ T) (p₁ , p₂) = CONTRADICTION (language-comp (scalar s) (fun-¬scalar s S) p₁) +srcⁿ-¬scalar s (S ∪ T) p = never + +src-¬scalar : ∀ {S T t} (s : Scalar S) → Language T (scalar s) → (¬Language (src T) t) +src-¬scalar {T = nil} s p = never +src-¬scalar {T = T ⇒ U} s (scalar ()) +src-¬scalar {T = never} s (scalar ()) +src-¬scalar {T = unknown} s p = never +src-¬scalar {T = boolean} s p = never +src-¬scalar {T = number} s p = never +src-¬scalar {T = string} s p = never +src-¬scalar {T = T ∪ U} s p = srcⁿ-¬scalar s (normal (T ∪ U)) (<:-normalize (T ∪ U) (scalar s) p) +src-¬scalar {T = T ∩ U} s p = srcⁿ-¬scalar s (normal (T ∩ U)) (<:-normalize (T ∩ U) (scalar s) p) + +srcⁿ-unknown-≮: : ∀ {T U} → (Normal U) → (T ≮: srcⁿ U) → (U ≮: (T ⇒ unknown)) +srcⁿ-unknown-≮: never (witness t p q) = CONTRADICTION (language-comp t q unknown) +srcⁿ-unknown-≮: unknown (witness t p q) = witness (function-err t) unknown (function-err p) +srcⁿ-unknown-≮: (U ⇒ V) (witness t p q) = witness (function-err t) (function-err q) (function-err p) +srcⁿ-unknown-≮: (U ∩ V) (witness t p q) = witness (function-err t) (function-err-srcⁿ (U ∩ V) q) (function-err p) +srcⁿ-unknown-≮: (U ∪ V) (witness t p q) = witness (scalar V) (right (scalar V)) (function-scalar V) + +src-unknown-≮: : ∀ {T U} → (T ≮: src U) → (U ≮: (T ⇒ unknown)) +src-unknown-≮: {U = nil} (witness t p q) = witness (scalar nil) (scalar nil) (function-scalar nil) +src-unknown-≮: {U = T ⇒ U} (witness t p q) = witness (function-err t) (function-err q) (function-err p) +src-unknown-≮: {U = never} (witness t p q) = CONTRADICTION (language-comp t q unknown) +src-unknown-≮: {U = unknown} (witness t p q) = witness (function-err t) unknown (function-err p) +src-unknown-≮: {U = boolean} (witness t p q) = witness (scalar boolean) (scalar boolean) (function-scalar boolean) +src-unknown-≮: {U = number} (witness t p q) = witness (scalar number) (scalar number) (function-scalar number) +src-unknown-≮: {U = string} (witness t p q) = witness (scalar string) (scalar string) (function-scalar string) +src-unknown-≮: {U = T ∪ U} p = <:-trans-≮: (normalize-<: (T ∪ U)) (srcⁿ-unknown-≮: (normal (T ∪ U)) p) +src-unknown-≮: {U = T ∩ U} p = <:-trans-≮: (normalize-<: (T ∩ U)) (srcⁿ-unknown-≮: (normal (T ∩ U)) p) + +unknown-src-≮: : ∀ {S T U} → (U ≮: S) → (T ≮: (U ⇒ unknown)) → (U ≮: src T) +unknown-src-≮: (witness t x x₁) (witness (scalar s) p (function-scalar s)) = witness t x (src-¬scalar s p) +unknown-src-≮: r (witness (function-ok s .(scalar s₁)) p (function-ok x (scalar-scalar s₁ () x₂))) +unknown-src-≮: r (witness (function-ok s .function) p (function-ok x (scalar-function ()))) +unknown-src-≮: r (witness (function-ok s .(function-ok _ _)) p (function-ok x (scalar-function-ok ()))) +unknown-src-≮: r (witness (function-ok s .(function-err _)) p (function-ok x (scalar-function-err ()))) +unknown-src-≮: r (witness (function-err t) p (function-err q)) = witness t q (src-¬function-err p) +unknown-src-≮: r (witness (function-tgt t) p (function-tgt (scalar-function-tgt ()))) + +-- Properties of resolve +resolveˢ-<:-⇒ : ∀ {F V U} → (FunType F) → (Saturated F) → (FunType (V ⇒ U)) → (r : Resolved F V) → (F <: (V ⇒ U)) → (target r <: U) +resolveˢ-<:-⇒ Fᶠ Fˢ V⇒Uᶠ r F<:V⇒U with <:-impl-<:ᵒ Fᶠ Fˢ V⇒Uᶠ F<:V⇒U here +resolveˢ-<:-⇒ Fᶠ Fˢ V⇒Uᶠ (yes Sʳ Tʳ oʳ V<:Sʳ tgtʳ) F<:V⇒U | defn o o₁ o₂ = <:-trans (tgtʳ o o₁) o₂ +resolveˢ-<:-⇒ Fᶠ Fˢ V⇒Uᶠ (no tgtʳ) F<:V⇒U | defn o o₁ o₂ = CONTRADICTION (<:-impl-¬≮: o₁ (tgtʳ o)) + +resolveⁿ-<:-⇒ : ∀ {F V U} → (Fⁿ : Normal F) → (Vⁿ : Normal V) → (Uⁿ : Normal U) → (F <: (V ⇒ U)) → (resolveⁿ Fⁿ Vⁿ <: U) +resolveⁿ-<:-⇒ (Sⁿ ⇒ Tⁿ) Vⁿ Uⁿ F<:V⇒U = resolveˢ-<:-⇒ (normal-saturate (Sⁿ ⇒ Tⁿ)) (saturated (Sⁿ ⇒ Tⁿ)) (Vⁿ ⇒ Uⁿ) (resolveˢ (normal-saturate (Sⁿ ⇒ Tⁿ)) (saturated (Sⁿ ⇒ Tⁿ)) Vⁿ (λ o → o)) F<:V⇒U +resolveⁿ-<:-⇒ (Fⁿ ∩ Gⁿ) Vⁿ Uⁿ F<:V⇒U = resolveˢ-<:-⇒ (normal-saturate (Fⁿ ∩ Gⁿ)) (saturated (Fⁿ ∩ Gⁿ)) (Vⁿ ⇒ Uⁿ) (resolveˢ (normal-saturate (Fⁿ ∩ Gⁿ)) (saturated (Fⁿ ∩ Gⁿ)) Vⁿ (λ o → o)) (<:-trans (saturate-<: (Fⁿ ∩ Gⁿ)) F<:V⇒U) +resolveⁿ-<:-⇒ (Sⁿ ∪ Tˢ) Vⁿ Uⁿ F<:V⇒U = CONTRADICTION (<:-impl-¬≮: F<:V⇒U (<:-trans-≮: <:-∪-right (scalar-≮:-function Tˢ))) +resolveⁿ-<:-⇒ never Vⁿ Uⁿ F<:V⇒U = <:-never +resolveⁿ-<:-⇒ unknown Vⁿ Uⁿ F<:V⇒U = CONTRADICTION (<:-impl-¬≮: F<:V⇒U unknown-≮:-function) + +resolve-<:-⇒ : ∀ {F V U} → (F <: (V ⇒ U)) → (resolve F V <: U) +resolve-<:-⇒ {F} {V} {U} F<:V⇒U = <:-trans (resolveⁿ-<:-⇒ (normal F) (normal V) (normal U) (<:-trans (normalize-<: F) (<:-trans F<:V⇒U (<:-normalize (V ⇒ U))))) (normalize-<: U) + +resolve-≮:-⇒ : ∀ {F V U} → (resolve F V ≮: U) → (F ≮: (V ⇒ U)) +resolve-≮:-⇒ {F} {V} {U} FV≮:U with dec-subtyping F (V ⇒ U) +resolve-≮:-⇒ {F} {V} {U} FV≮:U | Left F≮:V⇒U = F≮:V⇒U +resolve-≮:-⇒ {F} {V} {U} FV≮:U | Right F<:V⇒U = CONTRADICTION (<:-impl-¬≮: (resolve-<:-⇒ F<:V⇒U) FV≮:U) + +<:-resolveˢ-⇒ : ∀ {S T V} → (r : Resolved (S ⇒ T) V) → (V <: S) → T <: target r +<:-resolveˢ-⇒ (yes S T here _ _) V<:S = <:-refl +<:-resolveˢ-⇒ (no _) V<:S = <:-unknown + +<:-resolveⁿ-⇒ : ∀ {S T V} → (Sⁿ : Normal S) → (Tⁿ : Normal T) → (Vⁿ : Normal V) → (V <: S) → T <: resolveⁿ (Sⁿ ⇒ Tⁿ) Vⁿ +<:-resolveⁿ-⇒ Sⁿ Tⁿ Vⁿ V<:S = <:-resolveˢ-⇒ (resolveˢ (Sⁿ ⇒ Tⁿ) (saturated (Sⁿ ⇒ Tⁿ)) Vⁿ (λ o → o)) V<:S + +<:-resolve-⇒ : ∀ {S T V} → (V <: S) → T <: resolve (S ⇒ T) V +<:-resolve-⇒ {S} {T} {V} V<:S = <:-trans (<:-normalize T) (<:-resolveⁿ-⇒ (normal S) (normal T) (normal V) (<:-trans (normalize-<: V) (<:-trans V<:S (<:-normalize S)))) + +<:-resolveˢ : ∀ {F G V W} → (r : Resolved F V) → (s : Resolved G W) → (F <:ᵒ G) → (V <: W) → target r <: target s +<:-resolveˢ (yes Sʳ Tʳ oʳ V<:Sʳ tgtʳ) (yes Sˢ Tˢ oˢ W<:Sˢ tgtˢ) F<:G V<:W with F<:G oˢ +<:-resolveˢ (yes Sʳ Tʳ oʳ V<:Sʳ tgtʳ) (yes Sˢ Tˢ oˢ W<:Sˢ tgtˢ) F<:G V<:W | defn o o₁ o₂ = <:-trans (tgtʳ o (<:-trans (<:-trans V<:W W<:Sˢ) o₁)) o₂ +<:-resolveˢ (no r) (yes Sˢ Tˢ oˢ W<:Sˢ tgtˢ) F<:G V<:W with F<:G oˢ +<:-resolveˢ (no r) (yes Sˢ Tˢ oˢ W<:Sˢ tgtˢ) F<:G V<:W | defn o o₁ o₂ = CONTRADICTION (<:-impl-¬≮: (<:-trans V<:W (<:-trans W<:Sˢ o₁)) (r o)) +<:-resolveˢ r (no s) F<:G V<:W = <:-unknown + +<:-resolveᶠ : ∀ {F G V W} → (Fᶠ : FunType F) → (Gᶠ : FunType G) → (Vⁿ : Normal V) → (Wⁿ : Normal W) → (F <: G) → (V <: W) → resolveᶠ Fᶠ Vⁿ <: resolveᶠ Gᶠ Wⁿ +<:-resolveᶠ Fᶠ Gᶠ Vⁿ Wⁿ F<:G V<:W = <:-resolveˢ + (resolveˢ (normal-saturate Fᶠ) (saturated Fᶠ) Vⁿ (λ o → o)) + (resolveˢ (normal-saturate Gᶠ) (saturated Gᶠ) Wⁿ (λ o → o)) + (<:-impl-<:ᵒ (normal-saturate Fᶠ) (saturated Fᶠ) (normal-saturate Gᶠ) (<:-trans (saturate-<: Fᶠ) (<:-trans F<:G (<:-saturate Gᶠ)))) + V<:W + +<:-resolveⁿ : ∀ {F G V W} → (Fⁿ : Normal F) → (Gⁿ : Normal G) → (Vⁿ : Normal V) → (Wⁿ : Normal W) → (F <: G) → (V <: W) → resolveⁿ Fⁿ Vⁿ <: resolveⁿ Gⁿ Wⁿ +<:-resolveⁿ (Rⁿ ⇒ Sⁿ) (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W = <:-resolveᶠ (Rⁿ ⇒ Sⁿ) (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W +<:-resolveⁿ (Rⁿ ⇒ Sⁿ) (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W = <:-resolveᶠ (Rⁿ ⇒ Sⁿ) (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W +<:-resolveⁿ (Eⁿ ∩ Fⁿ) (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W = <:-resolveᶠ (Eⁿ ∩ Fⁿ) (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W +<:-resolveⁿ (Eⁿ ∩ Fⁿ) (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W = <:-resolveᶠ (Eⁿ ∩ Fⁿ) (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W +<:-resolveⁿ (Fⁿ ∪ Sˢ) (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (≮:-∪-right (scalar-≮:-function Sˢ))) +<:-resolveⁿ unknown (Tⁿ ⇒ Uⁿ) Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G unknown-≮:-function) +<:-resolveⁿ (Fⁿ ∪ Sˢ) (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (≮:-∪-right (scalar-≮:-fun (Gⁿ ∩ Hⁿ) Sˢ))) +<:-resolveⁿ unknown (Gⁿ ∩ Hⁿ) Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (unknown-≮:-fun (Gⁿ ∩ Hⁿ))) +<:-resolveⁿ (Rⁿ ⇒ Sⁿ) never Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (fun-≮:-never (Rⁿ ⇒ Sⁿ))) +<:-resolveⁿ (Eⁿ ∩ Fⁿ) never Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (fun-≮:-never (Eⁿ ∩ Fⁿ))) +<:-resolveⁿ (Fⁿ ∪ Sˢ) never Vⁿ Wⁿ F<:G V<:W = CONTRADICTION (<:-impl-¬≮: F<:G (≮:-∪-right (scalar-≮:-never Sˢ))) +<:-resolveⁿ unknown never Vⁿ Wⁿ F<:G V<:W = F<:G +<:-resolveⁿ never Gⁿ Vⁿ Wⁿ F<:G V<:W = <:-never +<:-resolveⁿ Fⁿ (Gⁿ ∪ Uˢ) Vⁿ Wⁿ F<:G V<:W = <:-unknown +<:-resolveⁿ Fⁿ unknown Vⁿ Wⁿ F<:G V<:W = <:-unknown + +<:-resolve : ∀ {F G V W} → (F <: G) → (V <: W) → resolve F V <: resolve G W +<:-resolve {F} {G} {V} {W} F<:G V<:W = <:-resolveⁿ (normal F) (normal G) (normal V) (normal W) + (<:-trans (normalize-<: F) (<:-trans F<:G (<:-normalize G))) + (<:-trans (normalize-<: V) (<:-trans V<:W (<:-normalize W))) diff --git a/prototyping/Properties/StrictMode.agda b/prototyping/Properties/StrictMode.agda index 2ff2b153..948674b9 100644 --- a/prototyping/Properties/StrictMode.agda +++ b/prototyping/Properties/StrictMode.agda @@ -7,12 +7,13 @@ open import Agda.Builtin.Equality using (_≡_; refl) open import FFI.Data.Either using (Either; Left; Right; mapL; mapR; mapLR; swapLR; cond) open import FFI.Data.Maybe using (Maybe; just; nothing) open import Luau.Heap using (Heap; Object; function_is_end; defn; alloc; ok; next; lookup-not-allocated) renaming (_≡_⊕_↦_ to _≡ᴴ_⊕_↦_; _[_] to _[_]ᴴ; ∅ to ∅ᴴ) +open import Luau.ResolveOverloads using (src; resolve) open import Luau.StrictMode using (Warningᴱ; Warningᴮ; Warningᴼ; Warningᴴ; UnallocatedAddress; UnboundVariable; FunctionCallMismatch; app₁; app₂; BinOpMismatch₁; BinOpMismatch₂; bin₁; bin₂; BlockMismatch; block₁; return; LocalVarMismatch; local₁; local₂; FunctionDefnMismatch; function₁; function₂; heap; expr; block; addr) open import Luau.Substitution using (_[_/_]ᴮ; _[_/_]ᴱ; _[_/_]ᴮunless_; var_[_/_]ᴱwhenever_) -open import Luau.Subtyping using (_≮:_; witness; any; none; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_; Tree; Language; ¬Language) +open import Luau.Subtyping using (_<:_; _≮:_; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_; Tree; Language; ¬Language) open import Luau.Syntax using (Expr; yes; var; val; var_∈_; _⟨_⟩∈_; _$_; addr; number; bool; string; binexp; nil; function_is_end; block_is_end; done; return; local_←_; _∙_; fun; arg; name; ==; ~=) -open import Luau.Type using (Type; strict; nil; number; boolean; string; _⇒_; none; any; _∩_; _∪_; tgt; _≡ᵀ_; _≡ᴹᵀ_) -open import Luau.TypeCheck(strict) using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; _⊢ᴴᴮ_▷_∈_; _⊢ᴴᴱ_▷_∈_; nil; var; addr; app; function; block; done; return; local; orAny; srcBinOp; tgtBinOp) +open import Luau.Type using (Type; nil; number; boolean; string; _⇒_; never; unknown; _∩_; _∪_; _≡ᵀ_; _≡ᴹᵀ_) +open import Luau.TypeCheck using (_⊢ᴮ_∈_; _⊢ᴱ_∈_; _⊢ᴴᴮ_▷_∈_; _⊢ᴴᴱ_▷_∈_; nil; var; addr; app; function; block; done; return; local; orUnknown; srcBinOp; tgtBinOp) open import Luau.Var using (_≡ⱽ_) open import Luau.Addr using (_≡ᴬ_) open import Luau.VarCtxt using (VarCtxt; ∅; _⋒_; _↦_; _⊕_↦_; _⊝_; ⊕-lookup-miss; ⊕-swap; ⊕-over) renaming (_[_] to _[_]ⱽ) @@ -22,14 +23,15 @@ open import Properties.Equality using (_≢_; sym; cong; trans; subst₁) open import Properties.Dec using (Dec; yes; no) open import Properties.Contradiction using (CONTRADICTION; ¬) open import Properties.Functions using (_∘_) -open import Properties.Subtyping using (any-≮:; ≡-trans-≮:; ≮:-trans-≡; none-tgt-≮:; tgt-none-≮:; src-any-≮:; any-src-≮:; ≮:-trans; ≮:-refl; scalar-≢-impl-≮:; function-≮:-scalar; scalar-≮:-function; function-≮:-none; any-≮:-scalar; scalar-≮:-none; any-≮:-none) -open import Properties.TypeCheck(strict) using (typeOfᴼ; typeOfᴹᴼ; typeOfⱽ; typeOfᴱ; typeOfᴮ; typeCheckᴱ; typeCheckᴮ; typeCheckᴼ; typeCheckᴴ) +open import Properties.DecSubtyping using (dec-subtyping) +open import Properties.Subtyping using (unknown-≮:; ≡-trans-≮:; ≮:-trans-≡; ≮:-trans; ≮:-refl; scalar-≢-impl-≮:; function-≮:-scalar; scalar-≮:-function; function-≮:-never; unknown-≮:-scalar; scalar-≮:-never; unknown-≮:-never; <:-refl; <:-unknown; <:-impl-¬≮:) +open import Properties.ResolveOverloads using (src-unknown-≮:; unknown-src-≮:; <:-resolve; resolve-<:-⇒; <:-resolve-⇒) +open import Properties.Subtyping using (unknown-≮:; ≡-trans-≮:; ≮:-trans-≡; ≮:-trans; <:-trans-≮:; ≮:-refl; scalar-≢-impl-≮:; function-≮:-scalar; scalar-≮:-function; function-≮:-never; unknown-≮:-scalar; scalar-≮:-never; unknown-≮:-never; ≡-impl-<:; ≡-trans-<:; <:-trans-≡; ≮:-trans-<:; <:-trans) +open import Properties.TypeCheck using (typeOfᴼ; typeOfᴹᴼ; typeOfⱽ; typeOfᴱ; typeOfᴮ; typeCheckᴱ; typeCheckᴮ; typeCheckᴼ; typeCheckᴴ) open import Luau.OpSem using (_⟦_⟧_⟶_; _⊢_⟶*_⊣_; _⊢_⟶ᴮ_⊣_; _⊢_⟶ᴱ_⊣_; app₁; app₂; function; beta; return; block; done; local; subst; binOp₀; binOp₁; binOp₂; refl; step; +; -; *; /; <; >; ==; ~=; <=; >=; ··) open import Luau.RuntimeError using (BinOpError; RuntimeErrorᴱ; RuntimeErrorᴮ; FunctionMismatch; BinOpMismatch₁; BinOpMismatch₂; UnboundVariable; SEGV; app₁; app₂; bin₁; bin₂; block; local; return; +; -; *; /; <; >; <=; >=; ··) open import Luau.RuntimeType using (RuntimeType; valueType; number; string; boolean; nil; function) -src = Luau.Type.src strict - data _⊑_ (H : Heap yes) : Heap yes → Set where refl : (H ⊑ H) snoc : ∀ {H′ a O} → (H′ ≡ᴴ H ⊕ a ↦ O) → (H ⊑ H′) @@ -63,51 +65,32 @@ lookup-⊑-nothing {H} a (snoc defn) p with a ≡ᴬ next H lookup-⊑-nothing {H} a (snoc defn) p | yes refl = refl lookup-⊑-nothing {H} a (snoc o) p | no q = trans (lookup-not-allocated o q) p -heap-weakeningᴱ : ∀ Γ H M {H′ U} → (H ⊑ H′) → (typeOfᴱ H′ Γ M ≮: U) → (typeOfᴱ H Γ M ≮: U) -heap-weakeningᴱ Γ H (var x) h p = p -heap-weakeningᴱ Γ H (val nil) h p = p -heap-weakeningᴱ Γ H (val (addr a)) refl p = p -heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = b} q) p with a ≡ᴬ b -heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = a} defn) p | yes refl = any-≮: p -heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = b} q) p | no r = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ (lookup-not-allocated q r))) p -heap-weakeningᴱ Γ H (val (number x)) h p = p -heap-weakeningᴱ Γ H (val (bool x)) h p = p -heap-weakeningᴱ Γ H (val (string x)) h p = p -heap-weakeningᴱ Γ H (M $ N) h p = none-tgt-≮: (heap-weakeningᴱ Γ H M h (tgt-none-≮: p)) -heap-weakeningᴱ Γ H (function f ⟨ var x ∈ T ⟩∈ U is B end) h p = p -heap-weakeningᴱ Γ H (block var b ∈ T is B end) h p = p -heap-weakeningᴱ Γ H (binexp M op N) h p = p +<:-heap-weakeningᴱ : ∀ Γ H M {H′} → (H ⊑ H′) → (typeOfᴱ H′ Γ M <: typeOfᴱ H Γ M) +<:-heap-weakeningᴱ Γ H (var x) h = <:-refl +<:-heap-weakeningᴱ Γ H (val nil) h = <:-refl +<:-heap-weakeningᴱ Γ H (val (addr a)) refl = <:-refl +<:-heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = b} q) with a ≡ᴬ b +<:-heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = a} defn) | yes refl = <:-unknown +<:-heap-weakeningᴱ Γ H (val (addr a)) (snoc {a = b} q) | no r = ≡-impl-<: (sym (cong orUnknown (cong typeOfᴹᴼ (lookup-not-allocated q r)))) +<:-heap-weakeningᴱ Γ H (val (number n)) h = <:-refl +<:-heap-weakeningᴱ Γ H (val (bool b)) h = <:-refl +<:-heap-weakeningᴱ Γ H (val (string s)) h = <:-refl +<:-heap-weakeningᴱ Γ H (M $ N) h = <:-resolve (<:-heap-weakeningᴱ Γ H M h) (<:-heap-weakeningᴱ Γ H N h) +<:-heap-weakeningᴱ Γ H (function f ⟨ var x ∈ S ⟩∈ T is B end) h = <:-refl +<:-heap-weakeningᴱ Γ H (block var b ∈ T is N end) h = <:-refl +<:-heap-weakeningᴱ Γ H (binexp M op N) h = <:-refl -heap-weakeningᴮ : ∀ Γ H B {H′ U} → (H ⊑ H′) → (typeOfᴮ H′ Γ B ≮: U) → (typeOfᴮ H Γ B ≮: U) -heap-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h p = heap-weakeningᴮ (Γ ⊕ f ↦ (T ⇒ U)) H B h p -heap-weakeningᴮ Γ H (local var x ∈ T ← M ∙ B) h p = heap-weakeningᴮ (Γ ⊕ x ↦ T) H B h p -heap-weakeningᴮ Γ H (return M ∙ B) h p = heap-weakeningᴱ Γ H M h p -heap-weakeningᴮ Γ H done h p = p +<:-heap-weakeningᴮ : ∀ Γ H B {H′} → (H ⊑ H′) → (typeOfᴮ H′ Γ B <: typeOfᴮ H Γ B) +<:-heap-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h = <:-heap-weakeningᴮ (Γ ⊕ f ↦ (T ⇒ U)) H B h +<:-heap-weakeningᴮ Γ H (local var x ∈ T ← M ∙ B) h = <:-heap-weakeningᴮ (Γ ⊕ x ↦ T) H B h +<:-heap-weakeningᴮ Γ H (return M ∙ B) h = <:-heap-weakeningᴱ Γ H M h +<:-heap-weakeningᴮ Γ H done h = <:-refl -substitutivityᴱ : ∀ {Γ T U} H M v x → (typeOfᴱ H Γ (M [ v / x ]ᴱ) ≮: U) → Either (typeOfᴱ H (Γ ⊕ x ↦ T) M ≮: U) (typeOfᴱ H ∅ (val v) ≮: T) -substitutivityᴱ-whenever : ∀ {Γ T U} H v x y (r : Dec(x ≡ y)) → (typeOfᴱ H Γ (var y [ v / x ]ᴱwhenever r) ≮: U) → Either (typeOfᴱ H (Γ ⊕ x ↦ T) (var y) ≮: U) (typeOfᴱ H ∅ (val v) ≮: T) -substitutivityᴮ : ∀ {Γ T U} H B v x → (typeOfᴮ H Γ (B [ v / x ]ᴮ) ≮: U) → Either (typeOfᴮ H (Γ ⊕ x ↦ T) B ≮: U) (typeOfᴱ H ∅ (val v) ≮: T) -substitutivityᴮ-unless : ∀ {Γ T U V} H B v x y (r : Dec(x ≡ y)) → (typeOfᴮ H (Γ ⊕ y ↦ U) (B [ v / x ]ᴮunless r) ≮: V) → Either (typeOfᴮ H ((Γ ⊕ x ↦ T) ⊕ y ↦ U) B ≮: V) (typeOfᴱ H ∅ (val v) ≮: T) -substitutivityᴮ-unless-yes : ∀ {Γ Γ′ T V} H B v x y (r : x ≡ y) → (Γ′ ≡ Γ) → (typeOfᴮ H Γ (B [ v / x ]ᴮunless yes r) ≮: V) → Either (typeOfᴮ H Γ′ B ≮: V) (typeOfᴱ H ∅ (val v) ≮: T) -substitutivityᴮ-unless-no : ∀ {Γ Γ′ T V} H B v x y (r : x ≢ y) → (Γ′ ≡ Γ ⊕ x ↦ T) → (typeOfᴮ H Γ (B [ v / x ]ᴮunless no r) ≮: V) → Either (typeOfᴮ H Γ′ B ≮: V) (typeOfᴱ H ∅ (val v) ≮: T) +≮:-heap-weakeningᴱ : ∀ Γ H M {H′ U} → (H ⊑ H′) → (typeOfᴱ H′ Γ M ≮: U) → (typeOfᴱ H Γ M ≮: U) +≮:-heap-weakeningᴱ Γ H M h p = <:-trans-≮: (<:-heap-weakeningᴱ Γ H M h) p -substitutivityᴱ H (var y) v x p = substitutivityᴱ-whenever H v x y (x ≡ⱽ y) p -substitutivityᴱ H (val w) v x p = Left p -substitutivityᴱ H (binexp M op N) v x p = Left p -substitutivityᴱ H (M $ N) v x p = mapL none-tgt-≮: (substitutivityᴱ H M v x (tgt-none-≮: p)) -substitutivityᴱ H (function f ⟨ var y ∈ T ⟩∈ U is B end) v x p = Left p -substitutivityᴱ H (block var b ∈ T is B end) v x p = Left p -substitutivityᴱ-whenever H v x x (yes refl) q = swapLR (≮:-trans q) -substitutivityᴱ-whenever H v x y (no p) q = Left (≡-trans-≮: (cong orAny (sym (⊕-lookup-miss x y _ _ p))) q) - -substitutivityᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x p = substitutivityᴮ-unless H B v x f (x ≡ⱽ f) p -substitutivityᴮ H (local var y ∈ T ← M ∙ B) v x p = substitutivityᴮ-unless H B v x y (x ≡ⱽ y) p -substitutivityᴮ H (return M ∙ B) v x p = substitutivityᴱ H M v x p -substitutivityᴮ H done v x p = Left p -substitutivityᴮ-unless H B v x y (yes p) q = substitutivityᴮ-unless-yes H B v x y p (⊕-over p) q -substitutivityᴮ-unless H B v x y (no p) q = substitutivityᴮ-unless-no H B v x y p (⊕-swap p) q -substitutivityᴮ-unless-yes H B v x y refl refl p = Left p -substitutivityᴮ-unless-no H B v x y p refl q = substitutivityᴮ H B v x q +≮:-heap-weakeningᴮ : ∀ Γ H B {H′ U} → (H ⊑ H′) → (typeOfᴮ H′ Γ B ≮: U) → (typeOfᴮ H Γ B ≮: U) +≮:-heap-weakeningᴮ Γ H B h p = <:-trans-≮: (<:-heap-weakeningᴮ Γ H B h) p binOpPreservation : ∀ H {op v w x} → (v ⟦ op ⟧ w ⟶ x) → (tgtBinOp op ≡ typeOfᴱ H ∅ (val x)) binOpPreservation H (+ m n) = refl @@ -122,24 +105,78 @@ binOpPreservation H (== v w) = refl binOpPreservation H (~= v w) = refl binOpPreservation H (·· v w) = refl -reflect-subtypingᴱ : ∀ H M {H′ M′ T} → (H ⊢ M ⟶ᴱ M′ ⊣ H′) → (typeOfᴱ H′ ∅ M′ ≮: T) → Either (typeOfᴱ H ∅ M ≮: T) (Warningᴱ H (typeCheckᴱ H ∅ M)) -reflect-subtypingᴮ : ∀ H B {H′ B′ T} → (H ⊢ B ⟶ᴮ B′ ⊣ H′) → (typeOfᴮ H′ ∅ B′ ≮: T) → Either (typeOfᴮ H ∅ B ≮: T) (Warningᴮ H (typeCheckᴮ H ∅ B)) +<:-substitutivityᴱ : ∀ {Γ T} H M v x → (typeOfᴱ H ∅ (val v) <: T) → (typeOfᴱ H Γ (M [ v / x ]ᴱ) <: typeOfᴱ H (Γ ⊕ x ↦ T) M) +<:-substitutivityᴱ-whenever : ∀ {Γ T} H v x y (r : Dec(x ≡ y)) → (typeOfᴱ H ∅ (val v) <: T) → (typeOfᴱ H Γ (var y [ v / x ]ᴱwhenever r) <: typeOfᴱ H (Γ ⊕ x ↦ T) (var y)) +<:-substitutivityᴮ : ∀ {Γ T} H B v x → (typeOfᴱ H ∅ (val v) <: T) → (typeOfᴮ H Γ (B [ v / x ]ᴮ) <: typeOfᴮ H (Γ ⊕ x ↦ T) B) +<:-substitutivityᴮ-unless : ∀ {Γ T U} H B v x y (r : Dec(x ≡ y)) → (typeOfᴱ H ∅ (val v) <: T) → (typeOfᴮ H (Γ ⊕ y ↦ U) (B [ v / x ]ᴮunless r) <: typeOfᴮ H ((Γ ⊕ x ↦ T) ⊕ y ↦ U) B) +<:-substitutivityᴮ-unless-yes : ∀ {Γ Γ′} H B v x y (r : x ≡ y) → (Γ′ ≡ Γ) → (typeOfᴮ H Γ (B [ v / x ]ᴮunless yes r) <: typeOfᴮ H Γ′ B) +<:-substitutivityᴮ-unless-no : ∀ {Γ Γ′ T} H B v x y (r : x ≢ y) → (Γ′ ≡ Γ ⊕ x ↦ T) → (typeOfᴱ H ∅ (val v) <: T) → (typeOfᴮ H Γ (B [ v / x ]ᴮunless no r) <: typeOfᴮ H Γ′ B) -reflect-subtypingᴱ H (M $ N) (app₁ s) p = mapLR none-tgt-≮: app₁ (reflect-subtypingᴱ H M s (tgt-none-≮: p)) -reflect-subtypingᴱ H (M $ N) (app₂ v s) p = Left (none-tgt-≮: (heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) (tgt-none-≮: p))) -reflect-subtypingᴱ H (M $ N) (beta (function f ⟨ var y ∈ T ⟩∈ U is B end) v refl q) p = Left (≡-trans-≮: (cong tgt (cong orAny (cong typeOfᴹᴼ q))) p) -reflect-subtypingᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) p = Left p -reflect-subtypingᴱ H (block var b ∈ T is B end) (block s) p = Left p -reflect-subtypingᴱ H (block var b ∈ T is return (val v) ∙ B end) (return v) p = mapR BlockMismatch (swapLR (≮:-trans p)) -reflect-subtypingᴱ H (block var b ∈ T is done end) done p = mapR BlockMismatch (swapLR (≮:-trans p)) -reflect-subtypingᴱ H (binexp M op N) (binOp₀ s) p = Left (≡-trans-≮: (binOpPreservation H s) p) -reflect-subtypingᴱ H (binexp M op N) (binOp₁ s) p = Left p -reflect-subtypingᴱ H (binexp M op N) (binOp₂ s) p = Left p +<:-substitutivityᴱ H (var y) v x p = <:-substitutivityᴱ-whenever H v x y (x ≡ⱽ y) p +<:-substitutivityᴱ H (val w) v x p = <:-refl +<:-substitutivityᴱ H (binexp M op N) v x p = <:-refl +<:-substitutivityᴱ H (M $ N) v x p = <:-resolve (<:-substitutivityᴱ H M v x p) (<:-substitutivityᴱ H N v x p) +<:-substitutivityᴱ H (function f ⟨ var y ∈ T ⟩∈ U is B end) v x p = <:-refl +<:-substitutivityᴱ H (block var b ∈ T is B end) v x p = <:-refl +<:-substitutivityᴱ-whenever H v x x (yes refl) p = p +<:-substitutivityᴱ-whenever H v x y (no o) p = (≡-impl-<: (cong orUnknown (⊕-lookup-miss x y _ _ o))) -reflect-subtypingᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a defn) p = mapLR (heap-weakeningᴮ _ _ B (snoc defn)) (CONTRADICTION ∘ ≮:-refl) (substitutivityᴮ _ B (addr a) f p) -reflect-subtypingᴮ H (local var x ∈ T ← M ∙ B) (local s) p = Left (heap-weakeningᴮ (x ↦ T) H B (rednᴱ⊑ s) p) -reflect-subtypingᴮ H (local var x ∈ T ← M ∙ B) (subst v) p = mapR LocalVarMismatch (substitutivityᴮ H B v x p) -reflect-subtypingᴮ H (return M ∙ B) (return s) p = mapR return (reflect-subtypingᴱ H M s p) +<:-substitutivityᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x p = <:-substitutivityᴮ-unless H B v x f (x ≡ⱽ f) p +<:-substitutivityᴮ H (local var y ∈ T ← M ∙ B) v x p = <:-substitutivityᴮ-unless H B v x y (x ≡ⱽ y) p +<:-substitutivityᴮ H (return M ∙ B) v x p = <:-substitutivityᴱ H M v x p +<:-substitutivityᴮ H done v x p = <:-refl +<:-substitutivityᴮ-unless H B v x y (yes r) p = <:-substitutivityᴮ-unless-yes H B v x y r (⊕-over r) +<:-substitutivityᴮ-unless H B v x y (no r) p = <:-substitutivityᴮ-unless-no H B v x y r (⊕-swap r) p +<:-substitutivityᴮ-unless-yes H B v x y refl refl = <:-refl +<:-substitutivityᴮ-unless-no H B v x y r refl p = <:-substitutivityᴮ H B v x p + +≮:-substitutivityᴱ : ∀ {Γ T U} H M v x → (typeOfᴱ H Γ (M [ v / x ]ᴱ) ≮: U) → Either (typeOfᴱ H (Γ ⊕ x ↦ T) M ≮: U) (typeOfᴱ H ∅ (val v) ≮: T) +≮:-substitutivityᴱ {T = T} H M v x p with dec-subtyping (typeOfᴱ H ∅ (val v)) T +≮:-substitutivityᴱ H M v x p | Left q = Right q +≮:-substitutivityᴱ H M v x p | Right q = Left (<:-trans-≮: (<:-substitutivityᴱ H M v x q) p) + +≮:-substitutivityᴮ : ∀ {Γ T U} H B v x → (typeOfᴮ H Γ (B [ v / x ]ᴮ) ≮: U) → Either (typeOfᴮ H (Γ ⊕ x ↦ T) B ≮: U) (typeOfᴱ H ∅ (val v) ≮: T) +≮:-substitutivityᴮ {T = T} H M v x p with dec-subtyping (typeOfᴱ H ∅ (val v)) T +≮:-substitutivityᴮ H M v x p | Left q = Right q +≮:-substitutivityᴮ H M v x p | Right q = Left (<:-trans-≮: (<:-substitutivityᴮ H M v x q) p) + +≮:-substitutivityᴮ-unless : ∀ {Γ T U V} H B v x y (r : Dec(x ≡ y)) → (typeOfᴮ H (Γ ⊕ y ↦ U) (B [ v / x ]ᴮunless r) ≮: V) → Either (typeOfᴮ H ((Γ ⊕ x ↦ T) ⊕ y ↦ U) B ≮: V) (typeOfᴱ H ∅ (val v) ≮: T) +≮:-substitutivityᴮ-unless {T = T} H B v x y r p with dec-subtyping (typeOfᴱ H ∅ (val v)) T +≮:-substitutivityᴮ-unless H B v x y r p | Left q = Right q +≮:-substitutivityᴮ-unless H B v x y r p | Right q = Left (<:-trans-≮: (<:-substitutivityᴮ-unless H B v x y r q) p) + +<:-reductionᴱ : ∀ H M {H′ M′} → (H ⊢ M ⟶ᴱ M′ ⊣ H′) → Either (typeOfᴱ H′ ∅ M′ <: typeOfᴱ H ∅ M) (Warningᴱ H (typeCheckᴱ H ∅ M)) +<:-reductionᴮ : ∀ H B {H′ B′} → (H ⊢ B ⟶ᴮ B′ ⊣ H′) → Either (typeOfᴮ H′ ∅ B′ <: typeOfᴮ H ∅ B) (Warningᴮ H (typeCheckᴮ H ∅ B)) + +<:-reductionᴱ H (M $ N) (app₁ s) = mapLR (λ p → <:-resolve p (<:-heap-weakeningᴱ ∅ H N (rednᴱ⊑ s))) app₁ (<:-reductionᴱ H M s) +<:-reductionᴱ H (M $ N) (app₂ q s) = mapLR (λ p → <:-resolve (<:-heap-weakeningᴱ ∅ H M (rednᴱ⊑ s)) p) app₂ (<:-reductionᴱ H N s) +<:-reductionᴱ H (M $ N) (beta (function f ⟨ var y ∈ S ⟩∈ U is B end) v refl q) with dec-subtyping (typeOfᴱ H ∅ (val v)) S +<:-reductionᴱ H (M $ N) (beta (function f ⟨ var y ∈ S ⟩∈ U is B end) v refl q) | Left r = Right (FunctionCallMismatch (≮:-trans-≡ r (cong src (cong orUnknown (cong typeOfᴹᴼ (sym q)))))) +<:-reductionᴱ H (M $ N) (beta (function f ⟨ var y ∈ S ⟩∈ U is B end) v refl q) | Right r = Left (<:-trans-≡ (<:-resolve-⇒ r) (cong (λ F → resolve F (typeOfᴱ H ∅ N)) (cong orUnknown (cong typeOfᴹᴼ (sym q))))) +<:-reductionᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) = Left <:-refl +<:-reductionᴱ H (block var b ∈ T is B end) (block s) = Left <:-refl +<:-reductionᴱ H (block var b ∈ T is return (val v) ∙ B end) (return v) with dec-subtyping (typeOfᴱ H ∅ (val v)) T +<:-reductionᴱ H (block var b ∈ T is return (val v) ∙ B end) (return v) | Left p = Right (BlockMismatch p) +<:-reductionᴱ H (block var b ∈ T is return (val v) ∙ B end) (return v) | Right p = Left p +<:-reductionᴱ H (block var b ∈ T is done end) done with dec-subtyping nil T +<:-reductionᴱ H (block var b ∈ T is done end) done | Left p = Right (BlockMismatch p) +<:-reductionᴱ H (block var b ∈ T is done end) done | Right p = Left p +<:-reductionᴱ H (binexp M op N) (binOp₀ s) = Left (≡-impl-<: (sym (binOpPreservation H s))) +<:-reductionᴱ H (binexp M op N) (binOp₁ s) = Left <:-refl +<:-reductionᴱ H (binexp M op N) (binOp₂ s) = Left <:-refl + +<:-reductionᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a defn) = Left (<:-trans (<:-substitutivityᴮ _ B (addr a) f <:-refl) (<:-heap-weakeningᴮ (f ↦ (T ⇒ U)) H B (snoc defn))) +<:-reductionᴮ H (local var x ∈ T ← M ∙ B) (local s) = Left (<:-heap-weakeningᴮ (x ↦ T) H B (rednᴱ⊑ s)) +<:-reductionᴮ H (local var x ∈ T ← M ∙ B) (subst v) with dec-subtyping (typeOfᴱ H ∅ (val v)) T +<:-reductionᴮ H (local var x ∈ T ← M ∙ B) (subst v) | Left p = Right (LocalVarMismatch p) +<:-reductionᴮ H (local var x ∈ T ← M ∙ B) (subst v) | Right p = Left (<:-substitutivityᴮ H B v x p) +<:-reductionᴮ H (return M ∙ B) (return s) = mapR return (<:-reductionᴱ H M s) + +≮:-reductionᴱ : ∀ H M {H′ M′ T} → (H ⊢ M ⟶ᴱ M′ ⊣ H′) → (typeOfᴱ H′ ∅ M′ ≮: T) → Either (typeOfᴱ H ∅ M ≮: T) (Warningᴱ H (typeCheckᴱ H ∅ M)) +≮:-reductionᴱ H M s p = mapL (λ q → <:-trans-≮: q p) (<:-reductionᴱ H M s) + +≮:-reductionᴮ : ∀ H B {H′ B′ T} → (H ⊢ B ⟶ᴮ B′ ⊣ H′) → (typeOfᴮ H′ ∅ B′ ≮: T) → Either (typeOfᴮ H ∅ B ≮: T) (Warningᴮ H (typeCheckᴮ H ∅ B)) +≮:-reductionᴮ H B s p = mapL (λ q → <:-trans-≮: q p) (<:-reductionᴮ H B s) reflect-substitutionᴱ : ∀ {Γ T} H M v x → Warningᴱ H (typeCheckᴱ H Γ (M [ v / x ]ᴱ)) → Either (Warningᴱ H (typeCheckᴱ H (Γ ⊕ x ↦ T) M)) (Either (Warningᴱ H (typeCheckᴱ H ∅ (val v))) (typeOfᴱ H ∅ (val v) ≮: T)) reflect-substitutionᴱ-whenever : ∀ {Γ T} H v x y (p : Dec(x ≡ y)) → Warningᴱ H (typeCheckᴱ H Γ (var y [ v / x ]ᴱwhenever p)) → Either (Warningᴱ H (typeCheckᴱ H (Γ ⊕ x ↦ T) (var y))) (Either (Warningᴱ H (typeCheckᴱ H ∅ (val v))) (typeOfᴱ H ∅ (val v) ≮: T)) @@ -150,29 +187,29 @@ reflect-substitutionᴮ-unless-no : ∀ {Γ Γ′ T} H B v x y (r : x ≢ y) → reflect-substitutionᴱ H (var y) v x W = reflect-substitutionᴱ-whenever H v x y (x ≡ⱽ y) W reflect-substitutionᴱ H (val (addr a)) v x (UnallocatedAddress r) = Left (UnallocatedAddress r) -reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) with substitutivityᴱ H N v x p +reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) with ≮:-substitutivityᴱ H N v x p reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Right W = Right (Right W) -reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q with substitutivityᴱ H M v x (src-any-≮: q) -reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q | Left r = Left ((FunctionCallMismatch ∘ any-src-≮: q) r) +reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q with ≮:-substitutivityᴱ H M v x (src-unknown-≮: q) +reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q | Left r = Left ((FunctionCallMismatch ∘ unknown-src-≮: q) r) reflect-substitutionᴱ H (M $ N) v x (FunctionCallMismatch p) | Left q | Right W = Right (Right W) reflect-substitutionᴱ H (M $ N) v x (app₁ W) = mapL app₁ (reflect-substitutionᴱ H M v x W) reflect-substitutionᴱ H (M $ N) v x (app₂ W) = mapL app₂ (reflect-substitutionᴱ H N v x W) -reflect-substitutionᴱ H (function f ⟨ var y ∈ T ⟩∈ U is B end) v x (FunctionDefnMismatch q) = mapLR FunctionDefnMismatch Right (substitutivityᴮ-unless H B v x y (x ≡ⱽ y) q) +reflect-substitutionᴱ H (function f ⟨ var y ∈ T ⟩∈ U is B end) v x (FunctionDefnMismatch q) = mapLR FunctionDefnMismatch Right (≮:-substitutivityᴮ-unless H B v x y (x ≡ⱽ y) q) reflect-substitutionᴱ H (function f ⟨ var y ∈ T ⟩∈ U is B end) v x (function₁ W) = mapL function₁ (reflect-substitutionᴮ-unless H B v x y (x ≡ⱽ y) W) -reflect-substitutionᴱ H (block var b ∈ T is B end) v x (BlockMismatch q) = mapLR BlockMismatch Right (substitutivityᴮ H B v x q) +reflect-substitutionᴱ H (block var b ∈ T is B end) v x (BlockMismatch q) = mapLR BlockMismatch Right (≮:-substitutivityᴮ H B v x q) reflect-substitutionᴱ H (block var b ∈ T is B end) v x (block₁ W′) = mapL block₁ (reflect-substitutionᴮ H B v x W′) -reflect-substitutionᴱ H (binexp M op N) v x (BinOpMismatch₁ q) = mapLR BinOpMismatch₁ Right (substitutivityᴱ H M v x q) -reflect-substitutionᴱ H (binexp M op N) v x (BinOpMismatch₂ q) = mapLR BinOpMismatch₂ Right (substitutivityᴱ H N v x q) +reflect-substitutionᴱ H (binexp M op N) v x (BinOpMismatch₁ q) = mapLR BinOpMismatch₁ Right (≮:-substitutivityᴱ H M v x q) +reflect-substitutionᴱ H (binexp M op N) v x (BinOpMismatch₂ q) = mapLR BinOpMismatch₂ Right (≮:-substitutivityᴱ H N v x q) reflect-substitutionᴱ H (binexp M op N) v x (bin₁ W) = mapL bin₁ (reflect-substitutionᴱ H M v x W) reflect-substitutionᴱ H (binexp M op N) v x (bin₂ W) = mapL bin₂ (reflect-substitutionᴱ H N v x W) reflect-substitutionᴱ-whenever H a x x (yes refl) (UnallocatedAddress p) = Right (Left (UnallocatedAddress p)) reflect-substitutionᴱ-whenever H v x y (no p) (UnboundVariable q) = Left (UnboundVariable (trans (sym (⊕-lookup-miss x y _ _ p)) q)) -reflect-substitutionᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x (FunctionDefnMismatch q) = mapLR FunctionDefnMismatch Right (substitutivityᴮ-unless H C v x y (x ≡ⱽ y) q) +reflect-substitutionᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x (FunctionDefnMismatch q) = mapLR FunctionDefnMismatch Right (≮:-substitutivityᴮ-unless H C v x y (x ≡ⱽ y) q) reflect-substitutionᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x (function₁ W) = mapL function₁ (reflect-substitutionᴮ-unless H C v x y (x ≡ⱽ y) W) reflect-substitutionᴮ H (function f ⟨ var y ∈ T ⟩∈ U is C end ∙ B) v x (function₂ W) = mapL function₂ (reflect-substitutionᴮ-unless H B v x f (x ≡ⱽ f) W) -reflect-substitutionᴮ H (local var y ∈ T ← M ∙ B) v x (LocalVarMismatch q) = mapLR LocalVarMismatch Right (substitutivityᴱ H M v x q) +reflect-substitutionᴮ H (local var y ∈ T ← M ∙ B) v x (LocalVarMismatch q) = mapLR LocalVarMismatch Right (≮:-substitutivityᴱ H M v x q) reflect-substitutionᴮ H (local var y ∈ T ← M ∙ B) v x (local₁ W) = mapL local₁ (reflect-substitutionᴱ H M v x W) reflect-substitutionᴮ H (local var y ∈ T ← M ∙ B) v x (local₂ W) = mapL local₂ (reflect-substitutionᴮ-unless H B v x y (x ≡ⱽ y) W) reflect-substitutionᴮ H (return M ∙ B) v x (return W) = mapL return (reflect-substitutionᴱ H M v x W) @@ -187,61 +224,61 @@ reflect-weakeningᴮ : ∀ Γ H B {H′} → (H ⊑ H′) → Warningᴮ H′ (t reflect-weakeningᴱ Γ H (var x) h (UnboundVariable p) = (UnboundVariable p) reflect-weakeningᴱ Γ H (val (addr a)) h (UnallocatedAddress p) = UnallocatedAddress (lookup-⊑-nothing a h p) -reflect-weakeningᴱ Γ H (M $ N) h (FunctionCallMismatch p) = FunctionCallMismatch (heap-weakeningᴱ Γ H N h (any-src-≮: p (heap-weakeningᴱ Γ H M h (src-any-≮: p)))) +reflect-weakeningᴱ Γ H (M $ N) h (FunctionCallMismatch p) = FunctionCallMismatch (≮:-heap-weakeningᴱ Γ H N h (unknown-src-≮: p (≮:-heap-weakeningᴱ Γ H M h (src-unknown-≮: p)))) reflect-weakeningᴱ Γ H (M $ N) h (app₁ W) = app₁ (reflect-weakeningᴱ Γ H M h W) reflect-weakeningᴱ Γ H (M $ N) h (app₂ W) = app₂ (reflect-weakeningᴱ Γ H N h W) -reflect-weakeningᴱ Γ H (binexp M op N) h (BinOpMismatch₁ p) = BinOpMismatch₁ (heap-weakeningᴱ Γ H M h p) -reflect-weakeningᴱ Γ H (binexp M op N) h (BinOpMismatch₂ p) = BinOpMismatch₂ (heap-weakeningᴱ Γ H N h p) +reflect-weakeningᴱ Γ H (binexp M op N) h (BinOpMismatch₁ p) = BinOpMismatch₁ (≮:-heap-weakeningᴱ Γ H M h p) +reflect-weakeningᴱ Γ H (binexp M op N) h (BinOpMismatch₂ p) = BinOpMismatch₂ (≮:-heap-weakeningᴱ Γ H N h p) reflect-weakeningᴱ Γ H (binexp M op N) h (bin₁ W′) = bin₁ (reflect-weakeningᴱ Γ H M h W′) reflect-weakeningᴱ Γ H (binexp M op N) h (bin₂ W′) = bin₂ (reflect-weakeningᴱ Γ H N h W′) -reflect-weakeningᴱ Γ H (function f ⟨ var y ∈ T ⟩∈ U is B end) h (FunctionDefnMismatch p) = FunctionDefnMismatch (heap-weakeningᴮ (Γ ⊕ y ↦ T) H B h p) +reflect-weakeningᴱ Γ H (function f ⟨ var y ∈ T ⟩∈ U is B end) h (FunctionDefnMismatch p) = FunctionDefnMismatch (≮:-heap-weakeningᴮ (Γ ⊕ y ↦ T) H B h p) reflect-weakeningᴱ Γ H (function f ⟨ var y ∈ T ⟩∈ U is B end) h (function₁ W) = function₁ (reflect-weakeningᴮ (Γ ⊕ y ↦ T) H B h W) -reflect-weakeningᴱ Γ H (block var b ∈ T is B end) h (BlockMismatch p) = BlockMismatch (heap-weakeningᴮ Γ H B h p) +reflect-weakeningᴱ Γ H (block var b ∈ T is B end) h (BlockMismatch p) = BlockMismatch (≮:-heap-weakeningᴮ Γ H B h p) reflect-weakeningᴱ Γ H (block var b ∈ T is B end) h (block₁ W) = block₁ (reflect-weakeningᴮ Γ H B h W) reflect-weakeningᴮ Γ H (return M ∙ B) h (return W) = return (reflect-weakeningᴱ Γ H M h W) -reflect-weakeningᴮ Γ H (local var y ∈ T ← M ∙ B) h (LocalVarMismatch p) = LocalVarMismatch (heap-weakeningᴱ Γ H M h p) +reflect-weakeningᴮ Γ H (local var y ∈ T ← M ∙ B) h (LocalVarMismatch p) = LocalVarMismatch (≮:-heap-weakeningᴱ Γ H M h p) reflect-weakeningᴮ Γ H (local var y ∈ T ← M ∙ B) h (local₁ W) = local₁ (reflect-weakeningᴱ Γ H M h W) reflect-weakeningᴮ Γ H (local var y ∈ T ← M ∙ B) h (local₂ W) = local₂ (reflect-weakeningᴮ (Γ ⊕ y ↦ T) H B h W) -reflect-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h (FunctionDefnMismatch p) = FunctionDefnMismatch (heap-weakeningᴮ (Γ ⊕ x ↦ T) H C h p) +reflect-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h (FunctionDefnMismatch p) = FunctionDefnMismatch (≮:-heap-weakeningᴮ (Γ ⊕ x ↦ T) H C h p) reflect-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h (function₁ W) = function₁ (reflect-weakeningᴮ (Γ ⊕ x ↦ T) H C h W) reflect-weakeningᴮ Γ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) h (function₂ W) = function₂ (reflect-weakeningᴮ (Γ ⊕ f ↦ (T ⇒ U)) H B h W) reflect-weakeningᴼ : ∀ H O {H′} → (H ⊑ H′) → Warningᴼ H′ (typeCheckᴼ H′ O) → Warningᴼ H (typeCheckᴼ H O) -reflect-weakeningᴼ H (just function f ⟨ var x ∈ T ⟩∈ U is B end) h (FunctionDefnMismatch p) = FunctionDefnMismatch (heap-weakeningᴮ (x ↦ T) H B h p) +reflect-weakeningᴼ H (just function f ⟨ var x ∈ T ⟩∈ U is B end) h (FunctionDefnMismatch p) = FunctionDefnMismatch (≮:-heap-weakeningᴮ (x ↦ T) H B h p) reflect-weakeningᴼ H (just function f ⟨ var x ∈ T ⟩∈ U is B end) h (function₁ W) = function₁ (reflect-weakeningᴮ (x ↦ T) H B h W) reflectᴱ : ∀ H M {H′ M′} → (H ⊢ M ⟶ᴱ M′ ⊣ H′) → Warningᴱ H′ (typeCheckᴱ H′ ∅ M′) → Either (Warningᴱ H (typeCheckᴱ H ∅ M)) (Warningᴴ H (typeCheckᴴ H)) reflectᴮ : ∀ H B {H′ B′} → (H ⊢ B ⟶ᴮ B′ ⊣ H′) → Warningᴮ H′ (typeCheckᴮ H′ ∅ B′) → Either (Warningᴮ H (typeCheckᴮ H ∅ B)) (Warningᴴ H (typeCheckᴴ H)) -reflectᴱ H (M $ N) (app₁ s) (FunctionCallMismatch p) = cond (Left ∘ FunctionCallMismatch ∘ heap-weakeningᴱ ∅ H N (rednᴱ⊑ s) ∘ any-src-≮: p) (Left ∘ app₁) (reflect-subtypingᴱ H M s (src-any-≮: p)) +reflectᴱ H (M $ N) (app₁ s) (FunctionCallMismatch p) = cond (Left ∘ FunctionCallMismatch ∘ ≮:-heap-weakeningᴱ ∅ H N (rednᴱ⊑ s) ∘ unknown-src-≮: p) (Left ∘ app₁) (≮:-reductionᴱ H M s (src-unknown-≮: p)) reflectᴱ H (M $ N) (app₁ s) (app₁ W′) = mapL app₁ (reflectᴱ H M s W′) reflectᴱ H (M $ N) (app₁ s) (app₂ W′) = Left (app₂ (reflect-weakeningᴱ ∅ H N (rednᴱ⊑ s) W′)) -reflectᴱ H (M $ N) (app₂ p s) (FunctionCallMismatch q) = cond (λ r → Left (FunctionCallMismatch (any-src-≮: r (heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) (src-any-≮: r))))) (Left ∘ app₂) (reflect-subtypingᴱ H N s q) +reflectᴱ H (M $ N) (app₂ p s) (FunctionCallMismatch q) = cond (λ r → Left (FunctionCallMismatch (unknown-src-≮: r (≮:-heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) (src-unknown-≮: r))))) (Left ∘ app₂) (≮:-reductionᴱ H N s q) reflectᴱ H (M $ N) (app₂ p s) (app₁ W′) = Left (app₁ (reflect-weakeningᴱ ∅ H M (rednᴱ⊑ s) W′)) reflectᴱ H (M $ N) (app₂ p s) (app₂ W′) = mapL app₂ (reflectᴱ H N s W′) -reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) with substitutivityᴮ H B v x q +reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) with ≮:-substitutivityᴮ H B v x q reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) | Left r = Right (addr a p (FunctionDefnMismatch r)) -reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) | Right r = Left (FunctionCallMismatch (≮:-trans-≡ r ((cong src (cong orAny (cong typeOfᴹᴼ (sym p))))))) +reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (BlockMismatch q) | Right r = Left (FunctionCallMismatch (≮:-trans-≡ r ((cong src (cong orUnknown (cong typeOfᴹᴼ (sym p))))))) reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) with reflect-substitutionᴮ _ B v x W′ reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) | Left W = Right (addr a p (function₁ W)) reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) | Right (Left W) = Left (app₂ W) -reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) | Right (Right q) = Left (FunctionCallMismatch (≮:-trans-≡ q (cong src (cong orAny (cong typeOfᴹᴼ (sym p)))))) -reflectᴱ H (block var b ∈ T is B end) (block s) (BlockMismatch p) = Left (cond BlockMismatch block₁ (reflect-subtypingᴮ H B s p)) +reflectᴱ H (val (addr a) $ N) (beta (function f ⟨ var x ∈ T ⟩∈ U is B end) v refl p) (block₁ W′) | Right (Right q) = Left (FunctionCallMismatch (≮:-trans-≡ q (cong src (cong orUnknown (cong typeOfᴹᴼ (sym p)))))) +reflectᴱ H (block var b ∈ T is B end) (block s) (BlockMismatch p) = Left (cond BlockMismatch block₁ (≮:-reductionᴮ H B s p)) reflectᴱ H (block var b ∈ T is B end) (block s) (block₁ W′) = mapL block₁ (reflectᴮ H B s W′) reflectᴱ H (block var b ∈ T is B end) (return v) W′ = Left (block₁ (return W′)) reflectᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) (UnallocatedAddress ()) reflectᴱ H (binexp M op N) (binOp₀ ()) (UnallocatedAddress p) -reflectᴱ H (binexp M op N) (binOp₁ s) (BinOpMismatch₁ p) = Left (cond BinOpMismatch₁ bin₁ (reflect-subtypingᴱ H M s p)) -reflectᴱ H (binexp M op N) (binOp₁ s) (BinOpMismatch₂ p) = Left (BinOpMismatch₂ (heap-weakeningᴱ ∅ H N (rednᴱ⊑ s) p)) +reflectᴱ H (binexp M op N) (binOp₁ s) (BinOpMismatch₁ p) = Left (cond BinOpMismatch₁ bin₁ (≮:-reductionᴱ H M s p)) +reflectᴱ H (binexp M op N) (binOp₁ s) (BinOpMismatch₂ p) = Left (BinOpMismatch₂ (≮:-heap-weakeningᴱ ∅ H N (rednᴱ⊑ s) p)) reflectᴱ H (binexp M op N) (binOp₁ s) (bin₁ W′) = mapL bin₁ (reflectᴱ H M s W′) reflectᴱ H (binexp M op N) (binOp₁ s) (bin₂ W′) = Left (bin₂ (reflect-weakeningᴱ ∅ H N (rednᴱ⊑ s) W′)) -reflectᴱ H (binexp M op N) (binOp₂ s) (BinOpMismatch₁ p) = Left (BinOpMismatch₁ (heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) p)) -reflectᴱ H (binexp M op N) (binOp₂ s) (BinOpMismatch₂ p) = Left (cond BinOpMismatch₂ bin₂ (reflect-subtypingᴱ H N s p)) +reflectᴱ H (binexp M op N) (binOp₂ s) (BinOpMismatch₁ p) = Left (BinOpMismatch₁ (≮:-heap-weakeningᴱ ∅ H M (rednᴱ⊑ s) p)) +reflectᴱ H (binexp M op N) (binOp₂ s) (BinOpMismatch₂ p) = Left (cond BinOpMismatch₂ bin₂ (≮:-reductionᴱ H N s p)) reflectᴱ H (binexp M op N) (binOp₂ s) (bin₁ W′) = Left (bin₁ (reflect-weakeningᴱ ∅ H M (rednᴱ⊑ s) W′)) reflectᴱ H (binexp M op N) (binOp₂ s) (bin₂ W′) = mapL bin₂ (reflectᴱ H N s W′) -reflectᴮ H (local var x ∈ T ← M ∙ B) (local s) (LocalVarMismatch p) = Left (cond LocalVarMismatch local₁ (reflect-subtypingᴱ H M s p)) +reflectᴮ H (local var x ∈ T ← M ∙ B) (local s) (LocalVarMismatch p) = Left (cond LocalVarMismatch local₁ (≮:-reductionᴱ H M s p)) reflectᴮ H (local var x ∈ T ← M ∙ B) (local s) (local₁ W′) = mapL local₁ (reflectᴱ H M s W′) reflectᴮ H (local var x ∈ T ← M ∙ B) (local s) (local₂ W′) = Left (local₂ (reflect-weakeningᴮ (x ↦ T) H B (rednᴱ⊑ s) W′)) reflectᴮ H (local var x ∈ T ← M ∙ B) (subst v) W′ = Left (cond local₂ (cond local₁ LocalVarMismatch) (reflect-substitutionᴮ H B v x W′)) @@ -258,7 +295,7 @@ reflectᴴᴱ H (M $ N) (app₁ s) W = mapL app₁ (reflectᴴᴱ H M s W) reflectᴴᴱ H (M $ N) (app₂ v s) W = mapL app₂ (reflectᴴᴱ H N s W) reflectᴴᴱ H (M $ N) (beta O v refl p) W = Right W reflectᴴᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a p) (addr b refl W) with b ≡ᴬ a -reflectᴴᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) (addr b refl (FunctionDefnMismatch p)) | yes refl = Left (FunctionDefnMismatch (heap-weakeningᴮ (x ↦ T) H B (snoc defn) p)) +reflectᴴᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) (addr b refl (FunctionDefnMismatch p)) | yes refl = Left (FunctionDefnMismatch (≮:-heap-weakeningᴮ (x ↦ T) H B (snoc defn) p)) reflectᴴᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a defn) (addr b refl (function₁ W)) | yes refl = Left (function₁ (reflect-weakeningᴮ (x ↦ T) H B (snoc defn) W)) reflectᴴᴱ H (function f ⟨ var x ∈ T ⟩∈ U is B end) (function a p) (addr b refl W) | no q = Right (addr b (lookup-not-allocated p q) (reflect-weakeningᴼ H _ (snoc p) W)) reflectᴴᴱ H (block var b ∈ T is B end) (block s) W = mapL block₁ (reflectᴴᴮ H B s W) @@ -269,7 +306,7 @@ reflectᴴᴱ H (binexp M op N) (binOp₁ s) W = mapL bin₁ (reflectᴴᴱ H M reflectᴴᴱ H (binexp M op N) (binOp₂ s) W = mapL bin₂ (reflectᴴᴱ H N s W) reflectᴴᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a p) (addr b refl W) with b ≡ᴬ a -reflectᴴᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a defn) (addr b refl (FunctionDefnMismatch p)) | yes refl = Left (FunctionDefnMismatch (heap-weakeningᴮ (x ↦ T) H C (snoc defn) p)) +reflectᴴᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a defn) (addr b refl (FunctionDefnMismatch p)) | yes refl = Left (FunctionDefnMismatch (≮:-heap-weakeningᴮ (x ↦ T) H C (snoc defn) p)) reflectᴴᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a defn) (addr b refl (function₁ W)) | yes refl = Left (function₁ (reflect-weakeningᴮ (x ↦ T) H C (snoc defn) W)) reflectᴴᴮ H (function f ⟨ var x ∈ T ⟩∈ U is C end ∙ B) (function a p) (addr b refl W) | no q = Right (addr b (lookup-not-allocated p q) (reflect-weakeningᴼ H _ (snoc p) W)) reflectᴴᴮ H (local var x ∈ T ← M ∙ B) (local s) W = mapL local₁ (reflectᴴᴱ H M s W) @@ -283,8 +320,8 @@ reflect* H B (step s t) W = cond (reflectᴮ H B s) (reflectᴴᴮ H B s) (refle isntNumber : ∀ H v → (valueType v ≢ number) → (typeOfᴱ H ∅ (val v) ≮: number) isntNumber H nil p = scalar-≢-impl-≮: nil number (λ ()) isntNumber H (addr a) p with remember (H [ a ]ᴴ) -isntNumber H (addr a) p | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , q) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ q)) (function-≮:-scalar number) -isntNumber H (addr a) p | (nothing , q) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ q)) (any-≮:-scalar number) +isntNumber H (addr a) p | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , q) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ q)) (function-≮:-scalar number) +isntNumber H (addr a) p | (nothing , q) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ q)) (unknown-≮:-scalar number) isntNumber H (number x) p = CONTRADICTION (p refl) isntNumber H (bool x) p = scalar-≢-impl-≮: boolean number (λ ()) isntNumber H (string x) p = scalar-≢-impl-≮: string number (λ ()) @@ -292,8 +329,8 @@ isntNumber H (string x) p = scalar-≢-impl-≮: string number (λ ()) isntString : ∀ H v → (valueType v ≢ string) → (typeOfᴱ H ∅ (val v) ≮: string) isntString H nil p = scalar-≢-impl-≮: nil string (λ ()) isntString H (addr a) p with remember (H [ a ]ᴴ) -isntString H (addr a) p | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , q) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ q)) (function-≮:-scalar string) -isntString H (addr a) p | (nothing , q) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ q)) (any-≮:-scalar string) +isntString H (addr a) p | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , q) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ q)) (function-≮:-scalar string) +isntString H (addr a) p | (nothing , q) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ q)) (unknown-≮:-scalar string) isntString H (number x) p = scalar-≢-impl-≮: number string (λ ()) isntString H (bool x) p = scalar-≢-impl-≮: boolean string (λ ()) isntString H (string x) p = CONTRADICTION (p refl) @@ -305,14 +342,14 @@ isntFunction H (number x) p = scalar-≮:-function number isntFunction H (bool x) p = scalar-≮:-function boolean isntFunction H (string x) p = scalar-≮:-function string -isntEmpty : ∀ H v → (typeOfᴱ H ∅ (val v) ≮: none) -isntEmpty H nil = scalar-≮:-none nil +isntEmpty : ∀ H v → (typeOfᴱ H ∅ (val v) ≮: never) +isntEmpty H nil = scalar-≮:-never nil isntEmpty H (addr a) with remember (H [ a ]ᴴ) -isntEmpty H (addr a) | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , p) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ p)) function-≮:-none -isntEmpty H (addr a) | (nothing , p) = ≡-trans-≮: (cong orAny (cong typeOfᴹᴼ p)) any-≮:-none -isntEmpty H (number x) = scalar-≮:-none number -isntEmpty H (bool x) = scalar-≮:-none boolean -isntEmpty H (string x) = scalar-≮:-none string +isntEmpty H (addr a) | (just (function f ⟨ var x ∈ T ⟩∈ U is B end) , p) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ p)) function-≮:-never +isntEmpty H (addr a) | (nothing , p) = ≡-trans-≮: (cong orUnknown (cong typeOfᴹᴼ p)) unknown-≮:-never +isntEmpty H (number x) = scalar-≮:-never number +isntEmpty H (bool x) = scalar-≮:-never boolean +isntEmpty H (string x) = scalar-≮:-never string runtimeBinOpWarning : ∀ H {op} v → BinOpError op (valueType v) → (typeOfᴱ H ∅ (val v) ≮: srcBinOp op) runtimeBinOpWarning H v (+ p) = isntNumber H v p @@ -330,7 +367,7 @@ runtimeWarningᴮ : ∀ H B → RuntimeErrorᴮ H B → Warningᴮ H (typeCheck runtimeWarningᴱ H (var x) UnboundVariable = UnboundVariable refl runtimeWarningᴱ H (val (addr a)) (SEGV p) = UnallocatedAddress p -runtimeWarningᴱ H (M $ N) (FunctionMismatch v w p) = FunctionCallMismatch (any-src-≮: (isntEmpty H w) (isntFunction H v p)) +runtimeWarningᴱ H (M $ N) (FunctionMismatch v w p) = FunctionCallMismatch (unknown-src-≮: (isntEmpty H w) (isntFunction H v p)) runtimeWarningᴱ H (M $ N) (app₁ err) = app₁ (runtimeWarningᴱ H M err) runtimeWarningᴱ H (M $ N) (app₂ err) = app₂ (runtimeWarningᴱ H N err) runtimeWarningᴱ H (block var b ∈ T is B end) (block err) = block₁ (runtimeWarningᴮ H B err) diff --git a/prototyping/Properties/Subtyping.agda b/prototyping/Properties/Subtyping.agda index 6a0b4203..73bf0e9a 100644 --- a/prototyping/Properties/Subtyping.agda +++ b/prototyping/Properties/Subtyping.agda @@ -4,13 +4,13 @@ module Properties.Subtyping where open import Agda.Builtin.Equality using (_≡_; refl) open import FFI.Data.Either using (Either; Left; Right; mapLR; swapLR; cond) -open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; any; none; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-scalar; function-scalar; function-ok; function-err; left; right; _,_) -open import Luau.Type using (Type; Scalar; strict; nil; number; string; boolean; none; any; _⇒_; _∪_; _∩_; tgt) -open import Properties.Contradiction using (CONTRADICTION; ¬) +open import FFI.Data.Maybe using (Maybe; just; nothing) +open import Luau.Subtyping using (_<:_; _≮:_; Tree; Language; ¬Language; witness; unknown; never; scalar; function; scalar-function; scalar-function-ok; scalar-function-err; scalar-function-tgt; scalar-scalar; function-scalar; function-ok; function-ok₁; function-ok₂; function-err; function-tgt; left; right; _,_) +open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_; skalar) +open import Properties.Contradiction using (CONTRADICTION; ¬; ⊥) open import Properties.Equality using (_≢_) open import Properties.Functions using (_∘_) - -src = Luau.Type.src strict +open import Properties.Product using (_×_; _,_) -- Language membership is decidable dec-language : ∀ T t → Either (¬Language T t) (Language T t) @@ -19,37 +19,42 @@ dec-language nil (scalar boolean) = Left (scalar-scalar boolean nil (λ ())) dec-language nil (scalar string) = Left (scalar-scalar string nil (λ ())) dec-language nil (scalar nil) = Right (scalar nil) dec-language nil function = Left (scalar-function nil) -dec-language nil (function-ok t) = Left (scalar-function-ok nil) -dec-language nil (function-err t) = Right (scalar-function-err nil) +dec-language nil (function-ok s t) = Left (scalar-function-ok nil) +dec-language nil (function-err t) = Left (scalar-function-err nil) dec-language boolean (scalar number) = Left (scalar-scalar number boolean (λ ())) dec-language boolean (scalar boolean) = Right (scalar boolean) dec-language boolean (scalar string) = Left (scalar-scalar string boolean (λ ())) dec-language boolean (scalar nil) = Left (scalar-scalar nil boolean (λ ())) dec-language boolean function = Left (scalar-function boolean) -dec-language boolean (function-ok t) = Left (scalar-function-ok boolean) -dec-language boolean (function-err t) = Right (scalar-function-err boolean) +dec-language boolean (function-ok s t) = Left (scalar-function-ok boolean) +dec-language boolean (function-err t) = Left (scalar-function-err boolean) dec-language number (scalar number) = Right (scalar number) dec-language number (scalar boolean) = Left (scalar-scalar boolean number (λ ())) dec-language number (scalar string) = Left (scalar-scalar string number (λ ())) dec-language number (scalar nil) = Left (scalar-scalar nil number (λ ())) dec-language number function = Left (scalar-function number) -dec-language number (function-ok t) = Left (scalar-function-ok number) -dec-language number (function-err t) = Right (scalar-function-err number) +dec-language number (function-ok s t) = Left (scalar-function-ok number) +dec-language number (function-err t) = Left (scalar-function-err number) dec-language string (scalar number) = Left (scalar-scalar number string (λ ())) dec-language string (scalar boolean) = Left (scalar-scalar boolean string (λ ())) dec-language string (scalar string) = Right (scalar string) dec-language string (scalar nil) = Left (scalar-scalar nil string (λ ())) dec-language string function = Left (scalar-function string) -dec-language string (function-ok t) = Left (scalar-function-ok string) -dec-language string (function-err t) = Right (scalar-function-err string) +dec-language string (function-ok s t) = Left (scalar-function-ok string) +dec-language string (function-err t) = Left (scalar-function-err string) dec-language (T₁ ⇒ T₂) (scalar s) = Left (function-scalar s) dec-language (T₁ ⇒ T₂) function = Right function -dec-language (T₁ ⇒ T₂) (function-ok t) = mapLR function-ok function-ok (dec-language T₂ t) +dec-language (T₁ ⇒ T₂) (function-ok s t) = cond (Right ∘ function-ok₁) (λ p → mapLR (function-ok p) function-ok₂ (dec-language T₂ t)) (dec-language T₁ s) dec-language (T₁ ⇒ T₂) (function-err t) = mapLR function-err function-err (swapLR (dec-language T₁ t)) -dec-language none t = Left none -dec-language any t = Right any +dec-language never t = Left never +dec-language unknown t = Right unknown dec-language (T₁ ∪ T₂) t = cond (λ p → cond (Left ∘ _,_ p) (Right ∘ right) (dec-language T₂ t)) (Right ∘ left) (dec-language T₁ t) dec-language (T₁ ∩ T₂) t = cond (Left ∘ left) (λ p → cond (Left ∘ right) (Right ∘ _,_ p) (dec-language T₂ t)) (dec-language T₁ t) +dec-language nil (function-tgt t) = Left (scalar-function-tgt nil) +dec-language (T₁ ⇒ T₂) (function-tgt t) = mapLR function-tgt function-tgt (dec-language T₂ t) +dec-language boolean (function-tgt t) = Left (scalar-function-tgt boolean) +dec-language number (function-tgt t) = Left (scalar-function-tgt number) +dec-language string (function-tgt t) = Left (scalar-function-tgt string) -- ¬Language T is the complement of Language T language-comp : ∀ {T} t → ¬Language T t → ¬(Language T t) @@ -59,11 +64,14 @@ language-comp t (left p) (q₁ , q₂) = language-comp t p q₁ language-comp t (right p) (q₁ , q₂) = language-comp t p q₂ language-comp (scalar s) (scalar-scalar s p₁ p₂) (scalar s) = p₂ refl language-comp (scalar s) (function-scalar s) (scalar s) = language-comp function (scalar-function s) function -language-comp (scalar s) none (scalar ()) +language-comp (scalar s) never (scalar ()) language-comp function (scalar-function ()) function -language-comp (function-ok t) (scalar-function-ok ()) (function-ok q) -language-comp (function-ok t) (function-ok p) (function-ok q) = language-comp t p q -language-comp (function-err t) (function-err p) (function-err q) = language-comp t q p +language-comp (function-ok s t) (scalar-function-ok ()) (function-ok₁ p) +language-comp (function-ok s t) (function-ok p₁ p₂) (function-ok₁ q) = language-comp s q p₁ +language-comp (function-ok s t) (function-ok p₁ p₂) (function-ok₂ q) = language-comp t p₂ q +language-comp (function-err t) (function-err p) (function-err q) = language-comp t q p +language-comp (function-tgt t) (scalar-function-tgt ()) (function-tgt q) +language-comp (function-tgt t) (function-tgt p) (function-tgt q) = language-comp t p q -- ≮: is the complement of <: ¬≮:-impl-<: : ∀ {T U} → ¬(T ≮: U) → (T <: U) @@ -74,6 +82,11 @@ language-comp (function-err t) (function-err p) (function-err q) = language-comp <:-impl-¬≮: : ∀ {T U} → (T <: U) → ¬(T ≮: U) <:-impl-¬≮: p (witness t q r) = language-comp t r (p t q) +<:-impl-⊇ : ∀ {T U} → (T <: U) → ∀ t → ¬Language U t → ¬Language T t +<:-impl-⊇ {T} p t q with dec-language T t +<:-impl-⊇ {_} p t q | Left r = r +<:-impl-⊇ {_} p t q | Right r = CONTRADICTION (language-comp t q (p t r)) + -- reflexivity ≮:-refl : ∀ {T} → ¬(T ≮: T) ≮:-refl (witness t p q) = language-comp t q p @@ -85,17 +98,227 @@ language-comp (function-err t) (function-err p) (function-err q) = language-comp ≮:-trans-≡ : ∀ {S T U} → (S ≮: T) → (T ≡ U) → (S ≮: U) ≮:-trans-≡ p refl = p +<:-trans-≡ : ∀ {S T U} → (S <: T) → (T ≡ U) → (S <: U) +<:-trans-≡ p refl = p + +≡-impl-<: : ∀ {T U} → (T ≡ U) → (T <: U) +≡-impl-<: refl = <:-refl + ≡-trans-≮: : ∀ {S T U} → (S ≡ T) → (T ≮: U) → (S ≮: U) ≡-trans-≮: refl p = p +≡-trans-<: : ∀ {S T U} → (S ≡ T) → (T <: U) → (S <: U) +≡-trans-<: refl p = p + ≮:-trans : ∀ {S T U} → (S ≮: U) → Either (S ≮: T) (T ≮: U) ≮:-trans {T = T} (witness t p q) = mapLR (witness t p) (λ z → witness t z q) (dec-language T t) <:-trans : ∀ {S T U} → (S <: T) → (T <: U) → (S <: U) -<:-trans p q = ¬≮:-impl-<: (cond (<:-impl-¬≮: p) (<:-impl-¬≮: q) ∘ ≮:-trans) +<:-trans p q t r = q t (p t r) + +<:-trans-≮: : ∀ {S T U} → (S <: T) → (S ≮: U) → (T ≮: U) +<:-trans-≮: p (witness t q r) = witness t (p t q) r + +≮:-trans-<: : ∀ {S T U} → (S ≮: U) → (T <: U) → (S ≮: T) +≮:-trans-<: (witness t p q) r = witness t p (<:-impl-⊇ r t q) + +-- Properties of union + +<:-union : ∀ {R S T U} → (R <: T) → (S <: U) → ((R ∪ S) <: (T ∪ U)) +<:-union p q t (left r) = left (p t r) +<:-union p q t (right r) = right (q t r) + +<:-∪-left : ∀ {S T} → S <: (S ∪ T) +<:-∪-left t p = left p + +<:-∪-right : ∀ {S T} → T <: (S ∪ T) +<:-∪-right t p = right p + +<:-∪-lub : ∀ {S T U} → (S <: U) → (T <: U) → ((S ∪ T) <: U) +<:-∪-lub p q t (left r) = p t r +<:-∪-lub p q t (right r) = q t r + +<:-∪-symm : ∀ {T U} → (T ∪ U) <: (U ∪ T) +<:-∪-symm t (left p) = right p +<:-∪-symm t (right p) = left p + +<:-∪-assocl : ∀ {S T U} → (S ∪ (T ∪ U)) <: ((S ∪ T) ∪ U) +<:-∪-assocl t (left p) = left (left p) +<:-∪-assocl t (right (left p)) = left (right p) +<:-∪-assocl t (right (right p)) = right p + +<:-∪-assocr : ∀ {S T U} → ((S ∪ T) ∪ U) <: (S ∪ (T ∪ U)) +<:-∪-assocr t (left (left p)) = left p +<:-∪-assocr t (left (right p)) = right (left p) +<:-∪-assocr t (right p) = right (right p) + +≮:-∪-left : ∀ {S T U} → (S ≮: U) → ((S ∪ T) ≮: U) +≮:-∪-left (witness t p q) = witness t (left p) q + +≮:-∪-right : ∀ {S T U} → (T ≮: U) → ((S ∪ T) ≮: U) +≮:-∪-right (witness t p q) = witness t (right p) q + +≮:-left-∪ : ∀ {S T U} → (S ≮: (T ∪ U)) → (S ≮: T) +≮:-left-∪ (witness t p (q₁ , q₂)) = witness t p q₁ + +≮:-right-∪ : ∀ {S T U} → (S ≮: (T ∪ U)) → (S ≮: U) +≮:-right-∪ (witness t p (q₁ , q₂)) = witness t p q₂ + +-- Properties of intersection + +<:-intersect : ∀ {R S T U} → (R <: T) → (S <: U) → ((R ∩ S) <: (T ∩ U)) +<:-intersect p q t (r₁ , r₂) = (p t r₁ , q t r₂) + +<:-∩-left : ∀ {S T} → (S ∩ T) <: S +<:-∩-left t (p , _) = p + +<:-∩-right : ∀ {S T} → (S ∩ T) <: T +<:-∩-right t (_ , p) = p + +<:-∩-glb : ∀ {S T U} → (S <: T) → (S <: U) → (S <: (T ∩ U)) +<:-∩-glb p q t r = (p t r , q t r) + +<:-∩-symm : ∀ {T U} → (T ∩ U) <: (U ∩ T) +<:-∩-symm t (p₁ , p₂) = (p₂ , p₁) + +<:-∩-assocl : ∀ {S T U} → (S ∩ (T ∩ U)) <: ((S ∩ T) ∩ U) +<:-∩-assocl t (p , (p₁ , p₂)) = (p , p₁) , p₂ + +<:-∩-assocr : ∀ {S T U} → ((S ∩ T) ∩ U) <: (S ∩ (T ∩ U)) +<:-∩-assocr t ((p , p₁) , p₂) = p , (p₁ , p₂) + +≮:-∩-left : ∀ {S T U} → (S ≮: T) → (S ≮: (T ∩ U)) +≮:-∩-left (witness t p q) = witness t p (left q) + +≮:-∩-right : ∀ {S T U} → (S ≮: U) → (S ≮: (T ∩ U)) +≮:-∩-right (witness t p q) = witness t p (right q) + +-- Distribution properties +<:-∩-distl-∪ : ∀ {S T U} → (S ∩ (T ∪ U)) <: ((S ∩ T) ∪ (S ∩ U)) +<:-∩-distl-∪ t (p₁ , left p₂) = left (p₁ , p₂) +<:-∩-distl-∪ t (p₁ , right p₂) = right (p₁ , p₂) + +∩-distl-∪-<: : ∀ {S T U} → ((S ∩ T) ∪ (S ∩ U)) <: (S ∩ (T ∪ U)) +∩-distl-∪-<: t (left (p₁ , p₂)) = (p₁ , left p₂) +∩-distl-∪-<: t (right (p₁ , p₂)) = (p₁ , right p₂) + +<:-∩-distr-∪ : ∀ {S T U} → ((S ∪ T) ∩ U) <: ((S ∩ U) ∪ (T ∩ U)) +<:-∩-distr-∪ t (left p₁ , p₂) = left (p₁ , p₂) +<:-∩-distr-∪ t (right p₁ , p₂) = right (p₁ , p₂) + +∩-distr-∪-<: : ∀ {S T U} → ((S ∩ U) ∪ (T ∩ U)) <: ((S ∪ T) ∩ U) +∩-distr-∪-<: t (left (p₁ , p₂)) = (left p₁ , p₂) +∩-distr-∪-<: t (right (p₁ , p₂)) = (right p₁ , p₂) + +<:-∪-distl-∩ : ∀ {S T U} → (S ∪ (T ∩ U)) <: ((S ∪ T) ∩ (S ∪ U)) +<:-∪-distl-∩ t (left p) = (left p , left p) +<:-∪-distl-∩ t (right (p₁ , p₂)) = (right p₁ , right p₂) + +∪-distl-∩-<: : ∀ {S T U} → ((S ∪ T) ∩ (S ∪ U)) <: (S ∪ (T ∩ U)) +∪-distl-∩-<: t (left p₁ , p₂) = left p₁ +∪-distl-∩-<: t (right p₁ , left p₂) = left p₂ +∪-distl-∩-<: t (right p₁ , right p₂) = right (p₁ , p₂) + +<:-∪-distr-∩ : ∀ {S T U} → ((S ∩ T) ∪ U) <: ((S ∪ U) ∩ (T ∪ U)) +<:-∪-distr-∩ t (left (p₁ , p₂)) = left p₁ , left p₂ +<:-∪-distr-∩ t (right p) = (right p , right p) + +∪-distr-∩-<: : ∀ {S T U} → ((S ∪ U) ∩ (T ∪ U)) <: ((S ∩ T) ∪ U) +∪-distr-∩-<: t (left p₁ , left p₂) = left (p₁ , p₂) +∪-distr-∩-<: t (left p₁ , right p₂) = right p₂ +∪-distr-∩-<: t (right p₁ , p₂) = right p₁ + +∩-<:-∪ : ∀ {S T} → (S ∩ T) <: (S ∪ T) +∩-<:-∪ t (p , _) = left p + +-- Properties of functions +<:-function : ∀ {R S T U} → (R <: S) → (T <: U) → (S ⇒ T) <: (R ⇒ U) +<:-function p q function function = function +<:-function p q (function-ok s t) (function-ok₁ r) = function-ok₁ (<:-impl-⊇ p s r) +<:-function p q (function-ok s t) (function-ok₂ r) = function-ok₂ (q t r) +<:-function p q (function-err s) (function-err r) = function-err (<:-impl-⊇ p s r) +<:-function p q (function-tgt t) (function-tgt r) = function-tgt (q t r) + +<:-function-∩-∩ : ∀ {R S T U} → ((R ⇒ T) ∩ (S ⇒ U)) <: ((R ∩ S) ⇒ (T ∩ U)) +<:-function-∩-∩ function (function , function) = function +<:-function-∩-∩ (function-ok s t) (function-ok₁ p , q) = function-ok₁ (left p) +<:-function-∩-∩ (function-ok s t) (function-ok₂ p , function-ok₁ q) = function-ok₁ (right q) +<:-function-∩-∩ (function-ok s t) (function-ok₂ p , function-ok₂ q) = function-ok₂ (p , q) +<:-function-∩-∩ (function-err s) (function-err p , q) = function-err (left p) +<:-function-∩-∩ (function-tgt s) (function-tgt p , function-tgt q) = function-tgt (p , q) + +<:-function-∩-∪ : ∀ {R S T U} → ((R ⇒ T) ∩ (S ⇒ U)) <: ((R ∪ S) ⇒ (T ∪ U)) +<:-function-∩-∪ function (function , function) = function +<:-function-∩-∪ (function-ok s t) (function-ok₁ p₁ , function-ok₁ p₂) = function-ok₁ (p₁ , p₂) +<:-function-∩-∪ (function-ok s t) (p₁ , function-ok₂ p₂) = function-ok₂ (right p₂) +<:-function-∩-∪ (function-ok s t) (function-ok₂ p₁ , p₂) = function-ok₂ (left p₁) +<:-function-∩-∪ (function-err s) (function-err p₁ , function-err q₂) = function-err (p₁ , q₂) +<:-function-∩-∪ (function-tgt t) (function-tgt p , q) = function-tgt (left p) + +<:-function-∩ : ∀ {S T U} → ((S ⇒ T) ∩ (S ⇒ U)) <: (S ⇒ (T ∩ U)) +<:-function-∩ function (function , function) = function +<:-function-∩ (function-ok s t) (p₁ , function-ok₁ p₂) = function-ok₁ p₂ +<:-function-∩ (function-ok s t) (function-ok₁ p₁ , p₂) = function-ok₁ p₁ +<:-function-∩ (function-ok s t) (function-ok₂ p₁ , function-ok₂ p₂) = function-ok₂ (p₁ , p₂) +<:-function-∩ (function-err s) (function-err p₁ , function-err p₂) = function-err p₂ +<:-function-∩ (function-tgt t) (function-tgt p₁ , function-tgt p₂) = function-tgt (p₁ , p₂) + +<:-function-∪ : ∀ {R S T U} → ((R ⇒ S) ∪ (T ⇒ U)) <: ((R ∩ T) ⇒ (S ∪ U)) +<:-function-∪ function (left function) = function +<:-function-∪ (function-ok s t) (left (function-ok₁ p)) = function-ok₁ (left p) +<:-function-∪ (function-ok s t) (left (function-ok₂ p)) = function-ok₂ (left p) +<:-function-∪ (function-err s) (left (function-err p)) = function-err (left p) +<:-function-∪ (scalar s) (left (scalar ())) +<:-function-∪ function (right function) = function +<:-function-∪ (function-ok s t) (right (function-ok₁ p)) = function-ok₁ (right p) +<:-function-∪ (function-ok s t) (right (function-ok₂ p)) = function-ok₂ (right p) +<:-function-∪ (function-err s) (right (function-err x)) = function-err (right x) +<:-function-∪ (scalar s) (right (scalar ())) +<:-function-∪ (function-tgt t) (left (function-tgt p)) = function-tgt (left p) +<:-function-∪ (function-tgt t) (right (function-tgt p)) = function-tgt (right p) + +<:-function-∪-∩ : ∀ {R S T U} → ((R ∩ S) ⇒ (T ∪ U)) <: ((R ⇒ T) ∪ (S ⇒ U)) +<:-function-∪-∩ function function = left function +<:-function-∪-∩ (function-ok s t) (function-ok₁ (left p)) = left (function-ok₁ p) +<:-function-∪-∩ (function-ok s t) (function-ok₂ (left p)) = left (function-ok₂ p) +<:-function-∪-∩ (function-ok s t) (function-ok₁ (right p)) = right (function-ok₁ p) +<:-function-∪-∩ (function-ok s t) (function-ok₂ (right p)) = right (function-ok₂ p) +<:-function-∪-∩ (function-err s) (function-err (left p)) = left (function-err p) +<:-function-∪-∩ (function-err s) (function-err (right p)) = right (function-err p) +<:-function-∪-∩ (function-tgt t) (function-tgt (left p)) = left (function-tgt p) +<:-function-∪-∩ (function-tgt t) (function-tgt (right p)) = right (function-tgt p) + +<:-function-left : ∀ {R S T U} → (S ⇒ T) <: (R ⇒ U) → (R <: S) +<:-function-left {R} {S} p s Rs with dec-language S s +<:-function-left p s Rs | Right Ss = Ss +<:-function-left p s Rs | Left ¬Ss with p (function-err s) (function-err ¬Ss) +<:-function-left p s Rs | Left ¬Ss | function-err ¬Rs = CONTRADICTION (language-comp s ¬Rs Rs) + +<:-function-right : ∀ {R S T U} → (S ⇒ T) <: (R ⇒ U) → (T <: U) +<:-function-right p t Tt with p (function-tgt t) (function-tgt Tt) +<:-function-right p t Tt | function-tgt St = St + +≮:-function-left : ∀ {R S T U} → (R ≮: S) → (S ⇒ T) ≮: (R ⇒ U) +≮:-function-left (witness t p q) = witness (function-err t) (function-err q) (function-err p) + +≮:-function-right : ∀ {R S T U} → (T ≮: U) → (S ⇒ T) ≮: (R ⇒ U) +≮:-function-right (witness t p q) = witness (function-tgt t) (function-tgt p) (function-tgt q) -- Properties of scalars -skalar = number ∪ (string ∪ (nil ∪ boolean)) +skalar-function-ok : ∀ {s t} → (¬Language skalar (function-ok s t)) +skalar-function-ok = (scalar-function-ok number , (scalar-function-ok string , (scalar-function-ok nil , scalar-function-ok boolean))) + +scalar-<: : ∀ {S T} → (s : Scalar S) → Language T (scalar s) → (S <: T) +scalar-<: number p (scalar number) (scalar number) = p +scalar-<: boolean p (scalar boolean) (scalar boolean) = p +scalar-<: string p (scalar string) (scalar string) = p +scalar-<: nil p (scalar nil) (scalar nil) = p + +scalar-∩-function-<:-never : ∀ {S T U} → (Scalar S) → ((T ⇒ U) ∩ S) <: never +scalar-∩-function-<:-never number .(scalar number) (() , scalar number) +scalar-∩-function-<:-never boolean .(scalar boolean) (() , scalar boolean) +scalar-∩-function-<:-never string .(scalar string) (() , scalar string) +scalar-∩-function-<:-never nil .(scalar nil) (() , scalar nil) function-≮:-scalar : ∀ {S T U} → (Scalar U) → ((S ⇒ T) ≮: U) function-≮:-scalar s = witness function function (scalar-function s) @@ -103,37 +326,17 @@ function-≮:-scalar s = witness function function (scalar-function s) scalar-≮:-function : ∀ {S T U} → (Scalar U) → (U ≮: (S ⇒ T)) scalar-≮:-function s = witness (scalar s) (scalar s) (function-scalar s) -any-≮:-scalar : ∀ {U} → (Scalar U) → (any ≮: U) -any-≮:-scalar s = witness (function-ok (scalar s)) any (scalar-function-ok s) +unknown-≮:-scalar : ∀ {U} → (Scalar U) → (unknown ≮: U) +unknown-≮:-scalar s = witness function unknown (scalar-function s) -scalar-≮:-none : ∀ {U} → (Scalar U) → (U ≮: none) -scalar-≮:-none s = witness (scalar s) (scalar s) none +scalar-≮:-never : ∀ {U} → (Scalar U) → (U ≮: never) +scalar-≮:-never s = witness (scalar s) (scalar s) never scalar-≢-impl-≮: : ∀ {T U} → (Scalar T) → (Scalar U) → (T ≢ U) → (T ≮: U) scalar-≢-impl-≮: s₁ s₂ p = witness (scalar s₁) (scalar s₁) (scalar-scalar s₁ s₂ p) --- Properties of tgt -tgt-function-ok : ∀ {T t} → (Language (tgt T) t) → Language T (function-ok t) -tgt-function-ok {T = nil} (scalar ()) -tgt-function-ok {T = T₁ ⇒ T₂} p = function-ok p -tgt-function-ok {T = none} (scalar ()) -tgt-function-ok {T = any} p = any -tgt-function-ok {T = boolean} (scalar ()) -tgt-function-ok {T = number} (scalar ()) -tgt-function-ok {T = string} (scalar ()) -tgt-function-ok {T = T₁ ∪ T₂} (left p) = left (tgt-function-ok p) -tgt-function-ok {T = T₁ ∪ T₂} (right p) = right (tgt-function-ok p) -tgt-function-ok {T = T₁ ∩ T₂} (p₁ , p₂) = (tgt-function-ok p₁ , tgt-function-ok p₂) - -function-ok-tgt : ∀ {T t} → Language T (function-ok t) → (Language (tgt T) t) -function-ok-tgt (function-ok p) = p -function-ok-tgt (left p) = left (function-ok-tgt p) -function-ok-tgt (right p) = right (function-ok-tgt p) -function-ok-tgt (p₁ , p₂) = (function-ok-tgt p₁ , function-ok-tgt p₂) -function-ok-tgt any = any - -skalar-function-ok : ∀ {t} → (¬Language skalar (function-ok t)) -skalar-function-ok = (scalar-function-ok number , (scalar-function-ok string , (scalar-function-ok nil , scalar-function-ok boolean))) +scalar-≢-∩-<:-never : ∀ {T U V} → (Scalar T) → (Scalar U) → (T ≢ U) → (T ∩ U) <: V +scalar-≢-∩-<:-never s t p u (scalar s₁ , scalar s₂) = CONTRADICTION (p refl) skalar-scalar : ∀ {T} (s : Scalar T) → (Language skalar (scalar s)) skalar-scalar number = left (scalar number) @@ -141,81 +344,138 @@ skalar-scalar boolean = right (right (right (scalar boolean))) skalar-scalar string = right (left (scalar string)) skalar-scalar nil = right (right (left (scalar nil))) -tgt-none-≮: : ∀ {T U} → (tgt T ≮: U) → (T ≮: (skalar ∪ (none ⇒ U))) -tgt-none-≮: (witness t p q) = witness (function-ok t) (tgt-function-ok p) (skalar-function-ok , function-ok q) +-- Properties of unknown and never +unknown-≮: : ∀ {T U} → (T ≮: U) → (unknown ≮: U) +unknown-≮: (witness t p q) = witness t unknown q -none-tgt-≮: : ∀ {T U} → (T ≮: (skalar ∪ (none ⇒ U))) → (tgt T ≮: U) -none-tgt-≮: (witness (scalar s) p (q₁ , q₂)) = CONTRADICTION (≮:-refl (witness (scalar s) (skalar-scalar s) q₁)) -none-tgt-≮: (witness function p (q₁ , scalar-function ())) -none-tgt-≮: (witness (function-ok t) p (q₁ , function-ok q₂)) = witness t (function-ok-tgt p) q₂ -none-tgt-≮: (witness (function-err (scalar s)) p (q₁ , function-err (scalar ()))) +never-≮: : ∀ {T U} → (T ≮: U) → (T ≮: never) +never-≮: (witness t p q) = witness t p never --- Properties of src -function-err-src : ∀ {T t} → (¬Language (src T) t) → Language T (function-err t) -function-err-src {T = nil} none = scalar-function-err nil -function-err-src {T = T₁ ⇒ T₂} p = function-err p -function-err-src {T = none} (scalar-scalar number () p) -function-err-src {T = none} (scalar-function-ok ()) -function-err-src {T = any} none = any -function-err-src {T = boolean} p = scalar-function-err boolean -function-err-src {T = number} p = scalar-function-err number -function-err-src {T = string} p = scalar-function-err string -function-err-src {T = T₁ ∪ T₂} (left p) = left (function-err-src p) -function-err-src {T = T₁ ∪ T₂} (right p) = right (function-err-src p) -function-err-src {T = T₁ ∩ T₂} (p₁ , p₂) = function-err-src p₁ , function-err-src p₂ +unknown-≮:-never : (unknown ≮: never) +unknown-≮:-never = witness (scalar nil) unknown never -¬function-err-src : ∀ {T t} → (Language (src T) t) → ¬Language T (function-err t) -¬function-err-src {T = nil} (scalar ()) -¬function-err-src {T = T₁ ⇒ T₂} p = function-err p -¬function-err-src {T = none} any = none -¬function-err-src {T = any} (scalar ()) -¬function-err-src {T = boolean} (scalar ()) -¬function-err-src {T = number} (scalar ()) -¬function-err-src {T = string} (scalar ()) -¬function-err-src {T = T₁ ∪ T₂} (p₁ , p₂) = (¬function-err-src p₁ , ¬function-err-src p₂) -¬function-err-src {T = T₁ ∩ T₂} (left p) = left (¬function-err-src p) -¬function-err-src {T = T₁ ∩ T₂} (right p) = right (¬function-err-src p) +unknown-≮:-function : ∀ {S T} → (unknown ≮: (S ⇒ T)) +unknown-≮:-function = witness (scalar nil) unknown (function-scalar nil) -src-¬function-err : ∀ {T t} → Language T (function-err t) → (¬Language (src T) t) -src-¬function-err {T = nil} p = none -src-¬function-err {T = T₁ ⇒ T₂} (function-err p) = p -src-¬function-err {T = none} (scalar-function-err ()) -src-¬function-err {T = any} p = none -src-¬function-err {T = boolean} p = none -src-¬function-err {T = number} p = none -src-¬function-err {T = string} p = none -src-¬function-err {T = T₁ ∪ T₂} (left p) = left (src-¬function-err p) -src-¬function-err {T = T₁ ∪ T₂} (right p) = right (src-¬function-err p) -src-¬function-err {T = T₁ ∩ T₂} (p₁ , p₂) = (src-¬function-err p₁ , src-¬function-err p₂) +function-≮:-never : ∀ {T U} → ((T ⇒ U) ≮: never) +function-≮:-never = witness function function never -src-¬scalar : ∀ {S T t} (s : Scalar S) → Language T (scalar s) → (¬Language (src T) t) -src-¬scalar number (scalar number) = none -src-¬scalar boolean (scalar boolean) = none -src-¬scalar string (scalar string) = none -src-¬scalar nil (scalar nil) = none -src-¬scalar s (left p) = left (src-¬scalar s p) -src-¬scalar s (right p) = right (src-¬scalar s p) -src-¬scalar s (p₁ , p₂) = (src-¬scalar s p₁ , src-¬scalar s p₂) -src-¬scalar s any = none +<:-never : ∀ {T} → (never <: T) +<:-never t (scalar ()) -src-any-≮: : ∀ {T U} → (T ≮: src U) → (U ≮: (T ⇒ any)) -src-any-≮: (witness t p q) = witness (function-err t) (function-err-src q) (¬function-err-src p) +≮:-never-left : ∀ {S T U} → (S <: (T ∪ U)) → (S ≮: T) → (S ∩ U) ≮: never +≮:-never-left p (witness t q₁ q₂) with p t q₁ +≮:-never-left p (witness t q₁ q₂) | left r = CONTRADICTION (language-comp t q₂ r) +≮:-never-left p (witness t q₁ q₂) | right r = witness t (q₁ , r) never -any-src-≮: : ∀ {S T U} → (U ≮: S) → (T ≮: (U ⇒ any)) → (U ≮: src T) -any-src-≮: (witness t x x₁) (witness (scalar s) p (function-scalar s)) = witness t x (src-¬scalar s p) -any-src-≮: r (witness (function-ok (scalar s)) p (function-ok (scalar-scalar s () q))) -any-src-≮: r (witness (function-ok (function-ok _)) p (function-ok (scalar-function-ok ()))) -any-src-≮: r (witness (function-err t) p (function-err q)) = witness t q (src-¬function-err p) +≮:-never-right : ∀ {S T U} → (S <: (T ∪ U)) → (S ≮: U) → (S ∩ T) ≮: never +≮:-never-right p (witness t q₁ q₂) with p t q₁ +≮:-never-right p (witness t q₁ q₂) | left r = witness t (q₁ , r) never +≮:-never-right p (witness t q₁ q₂) | right r = CONTRADICTION (language-comp t q₂ r) --- Properties of any and none -any-≮: : ∀ {T U} → (T ≮: U) → (any ≮: U) -any-≮: (witness t p q) = witness t any q +<:-unknown : ∀ {T} → (T <: unknown) +<:-unknown t p = unknown -none-≮: : ∀ {T U} → (T ≮: U) → (T ≮: none) -none-≮: (witness t p q) = witness t p none +<:-everything : unknown <: ((never ⇒ unknown) ∪ skalar) +<:-everything (scalar s) p = right (skalar-scalar s) +<:-everything function p = left function +<:-everything (function-ok s t) p = left (function-ok₁ never) +<:-everything (function-err s) p = left (function-err never) +<:-everything (function-tgt t) p = left (function-tgt unknown) -any-≮:-none : (any ≮: none) -any-≮:-none = witness (scalar nil) any none +-- A Gentle Introduction To Semantic Subtyping (https://www.cduce.org/papers/gentle.pdf) +-- defines a "set-theoretic" model (sec 2.5) +-- Unfortunately we don't quite have this property, due to uninhabited types, +-- for example (never -> T) is equivalent to (never -> U) +-- when types are interpreted as sets of syntactic values. -function-≮:-none : ∀ {T U} → ((T ⇒ U) ≮: none) -function-≮:-none = witness function function none +_⊆_ : ∀ {A : Set} → (A → Set) → (A → Set) → Set +(P ⊆ Q) = ∀ a → (P a) → (Q a) + +_⊗_ : ∀ {A B : Set} → (A → Set) → (B → Set) → ((A × B) → Set) +(P ⊗ Q) (a , b) = (P a) × (Q b) + +Comp : ∀ {A : Set} → (A → Set) → (A → Set) +Comp P a = ¬(P a) + +Lift : ∀ {A : Set} → (A → Set) → (Maybe A → Set) +Lift P nothing = ⊥ +Lift P (just a) = P a + +set-theoretic-if : ∀ {S₁ T₁ S₂ T₂} → + + -- This is the "if" part of being a set-theoretic model + -- though it uses the definition from Frisch's thesis + -- rather than from the Gentle Introduction. The difference + -- being the presence of Lift, (written D_Ω in Defn 4.2 of + -- https://www.cduce.org/papers/frisch_phd.pdf). + (Language (S₁ ⇒ T₁) ⊆ Language (S₂ ⇒ T₂)) → + (∀ Q → Q ⊆ Comp((Language S₁) ⊗ Comp(Lift(Language T₁))) → Q ⊆ Comp((Language S₂) ⊗ Comp(Lift(Language T₂)))) + +set-theoretic-if {S₁} {T₁} {S₂} {T₂} p Q q (t , just u) Qtu (S₂t , ¬T₂u) = q (t , just u) Qtu (S₁t , ¬T₁u) where + + S₁t : Language S₁ t + S₁t with dec-language S₁ t + S₁t | Left ¬S₁t with p (function-err t) (function-err ¬S₁t) + S₁t | Left ¬S₁t | function-err ¬S₂t = CONTRADICTION (language-comp t ¬S₂t S₂t) + S₁t | Right r = r + + ¬T₁u : ¬(Language T₁ u) + ¬T₁u T₁u with p (function-ok t u) (function-ok₂ T₁u) + ¬T₁u T₁u | function-ok₁ ¬S₂t = language-comp t ¬S₂t S₂t + ¬T₁u T₁u | function-ok₂ T₂u = ¬T₂u T₂u + +set-theoretic-if {S₁} {T₁} {S₂} {T₂} p Q q (t , nothing) Qt- (S₂t , _) = q (t , nothing) Qt- (S₁t , λ ()) where + + S₁t : Language S₁ t + S₁t with dec-language S₁ t + S₁t | Left ¬S₁t with p (function-err t) (function-err ¬S₁t) + S₁t | Left ¬S₁t | function-err ¬S₂t = CONTRADICTION (language-comp t ¬S₂t S₂t) + S₁t | Right r = r + +not-quite-set-theoretic-only-if : ∀ {S₁ T₁ S₂ T₂} → + + -- We don't quite have that this is a set-theoretic model + -- it's only true when Language S₂ is inhabited + -- in particular it's not true when S₂ is never, + ∀ s₂ → Language S₂ s₂ → + + -- This is the "only if" part of being a set-theoretic model + (∀ Q → Q ⊆ Comp((Language S₁) ⊗ Comp(Lift(Language T₁))) → Q ⊆ Comp((Language S₂) ⊗ Comp(Lift(Language T₂)))) → + (Language (S₁ ⇒ T₁) ⊆ Language (S₂ ⇒ T₂)) + +not-quite-set-theoretic-only-if {S₁} {T₁} {S₂} {T₂} s₂ S₂s₂ p = r where + + Q : (Tree × Maybe Tree) → Set + Q (t , just u) = Either (¬Language S₁ t) (Language T₁ u) + Q (t , nothing) = ¬Language S₁ t + + q : Q ⊆ Comp(Language S₁ ⊗ Comp(Lift(Language T₁))) + q (t , just u) (Left ¬S₁t) (S₁t , ¬T₁u) = language-comp t ¬S₁t S₁t + q (t , just u) (Right T₂u) (S₁t , ¬T₁u) = ¬T₁u T₂u + q (t , nothing) ¬S₁t (S₁t , _) = language-comp t ¬S₁t S₁t + + r : Language (S₁ ⇒ T₁) ⊆ Language (S₂ ⇒ T₂) + r function function = function + r (function-err s) (function-err ¬S₁s) with dec-language S₂ s + r (function-err s) (function-err ¬S₁s) | Left ¬S₂s = function-err ¬S₂s + r (function-err s) (function-err ¬S₁s) | Right S₂s = CONTRADICTION (p Q q (s , nothing) ¬S₁s (S₂s , λ ())) + r (function-ok s t) (function-ok₁ ¬S₁s) with dec-language S₂ s + r (function-ok s t) (function-ok₁ ¬S₁s) | Left ¬S₂s = function-ok₁ ¬S₂s + r (function-ok s t) (function-ok₁ ¬S₁s) | Right S₂s = CONTRADICTION (p Q q (s , nothing) ¬S₁s (S₂s , λ ())) + r (function-ok s t) (function-ok₂ T₁t) with dec-language T₂ t + r (function-ok s t) (function-ok₂ T₁t) | Left ¬T₂t with dec-language S₂ s + r (function-ok s t) (function-ok₂ T₁t) | Left ¬T₂t | Left ¬S₂s = function-ok₁ ¬S₂s + r (function-ok s t) (function-ok₂ T₁t) | Left ¬T₂t | Right S₂s = CONTRADICTION (p Q q (s , just t) (Right T₁t) (S₂s , language-comp t ¬T₂t)) + r (function-ok s t) (function-ok₂ T₁t) | Right T₂t = function-ok₂ T₂t + r (function-tgt t) (function-tgt T₁t) with dec-language T₂ t + r (function-tgt t) (function-tgt T₁t) | Left ¬T₂t = CONTRADICTION (p Q q (s₂ , just t) (Right T₁t) (S₂s₂ , language-comp t ¬T₂t)) + r (function-tgt t) (function-tgt T₁t) | Right T₂t = function-tgt T₂t + +-- A counterexample when the argument type is empty. + +set-theoretic-counterexample-one : (∀ Q → Q ⊆ Comp((Language never) ⊗ Comp(Lift(Language number))) → Q ⊆ Comp((Language never) ⊗ Comp(Lift(Language string)))) +set-theoretic-counterexample-one Q q ((scalar s) , u) Qtu (scalar () , p) + +set-theoretic-counterexample-two : (never ⇒ number) ≮: (never ⇒ string) +set-theoretic-counterexample-two = witness (function-tgt (scalar number)) (function-tgt (scalar number)) (function-tgt (scalar-scalar number string (λ ()))) diff --git a/prototyping/Properties/TypeCheck.agda b/prototyping/Properties/TypeCheck.agda index ead0c097..b53bbd04 100644 --- a/prototyping/Properties/TypeCheck.agda +++ b/prototyping/Properties/TypeCheck.agda @@ -1,16 +1,15 @@ {-# OPTIONS --rewriting #-} -open import Luau.Type using (Mode) - -module Properties.TypeCheck (m : Mode) where +module Properties.TypeCheck where open import Agda.Builtin.Equality using (_≡_; refl) open import Agda.Builtin.Bool using (Bool; true; false) open import FFI.Data.Maybe using (Maybe; just; nothing) open import FFI.Data.Either using (Either) -open import Luau.TypeCheck(m) using (_⊢ᴱ_∈_; _⊢ᴮ_∈_; ⊢ᴼ_; ⊢ᴴ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; nil; var; addr; number; bool; string; app; function; block; binexp; done; return; local; nothing; orAny; tgtBinOp) +open import Luau.ResolveOverloads using (resolve) +open import Luau.TypeCheck using (_⊢ᴱ_∈_; _⊢ᴮ_∈_; ⊢ᴼ_; ⊢ᴴ_; _⊢ᴴᴱ_▷_∈_; _⊢ᴴᴮ_▷_∈_; nil; var; addr; number; bool; string; app; function; block; binexp; done; return; local; nothing; orUnknown; tgtBinOp) open import Luau.Syntax using (Block; Expr; Value; BinaryOperator; yes; nil; addr; number; bool; string; val; var; binexp; _$_; function_is_end; block_is_end; _∙_; return; done; local_←_; _⟨_⟩; _⟨_⟩∈_; var_∈_; name; fun; arg; +; -; *; /; <; >; ==; ~=; <=; >=) -open import Luau.Type using (Type; nil; any; none; number; boolean; string; _⇒_; tgt) +open import Luau.Type using (Type; nil; unknown; never; number; boolean; string; _⇒_) open import Luau.RuntimeType using (RuntimeType; nil; number; function; string; valueType) open import Luau.VarCtxt using (VarCtxt; ∅; _↦_; _⊕_↦_; _⋒_; _⊝_) renaming (_[_] to _[_]ⱽ) open import Luau.Addr using (Addr) @@ -22,9 +21,6 @@ open import Properties.Equality using (_≢_; sym; trans; cong) open import Properties.Product using (_×_; _,_) open import Properties.Remember using (Remember; remember; _,_) -src : Type → Type -src = Luau.Type.src m - typeOfᴼ : Object yes → Type typeOfᴼ (function f ⟨ var x ∈ S ⟩∈ T is B end) = (S ⇒ T) @@ -42,9 +38,9 @@ typeOfⱽ H (string x) = just string typeOfᴱ : Heap yes → VarCtxt → (Expr yes) → Type typeOfᴮ : Heap yes → VarCtxt → (Block yes) → Type -typeOfᴱ H Γ (var x) = orAny(Γ [ x ]ⱽ) -typeOfᴱ H Γ (val v) = orAny(typeOfⱽ H v) -typeOfᴱ H Γ (M $ N) = tgt(typeOfᴱ H Γ M) +typeOfᴱ H Γ (var x) = orUnknown(Γ [ x ]ⱽ) +typeOfᴱ H Γ (val v) = orUnknown(typeOfⱽ H v) +typeOfᴱ H Γ (M $ N) = resolve (typeOfᴱ H Γ M) (typeOfᴱ H Γ N) typeOfᴱ H Γ (function f ⟨ var x ∈ S ⟩∈ T is B end) = S ⇒ T typeOfᴱ H Γ (block var b ∈ T is B end) = T typeOfᴱ H Γ (binexp M op N) = tgtBinOp op @@ -54,27 +50,19 @@ typeOfᴮ H Γ (local var x ∈ T ← M ∙ B) = typeOfᴮ H (Γ ⊕ x ↦ T) B typeOfᴮ H Γ (return M ∙ B) = typeOfᴱ H Γ M typeOfᴮ H Γ done = nil -mustBeFunction : ∀ H Γ v → (none ≢ src (typeOfᴱ H Γ (val v))) → (function ≡ valueType(v)) -mustBeFunction H Γ nil p = CONTRADICTION (p refl) -mustBeFunction H Γ (addr a) p = refl -mustBeFunction H Γ (number n) p = CONTRADICTION (p refl) -mustBeFunction H Γ (bool true) p = CONTRADICTION (p refl) -mustBeFunction H Γ (bool false) p = CONTRADICTION (p refl) -mustBeFunction H Γ (string x) p = CONTRADICTION (p refl) - mustBeNumber : ∀ H Γ v → (typeOfᴱ H Γ (val v) ≡ number) → (valueType(v) ≡ number) mustBeNumber H Γ (addr a) p with remember (H [ a ]ᴴ) -mustBeNumber H Γ (addr a) p | (just O , q) with trans (cong orAny (cong typeOfᴹᴼ (sym q))) p +mustBeNumber H Γ (addr a) p | (just O , q) with trans (cong orUnknown (cong typeOfᴹᴼ (sym q))) p mustBeNumber H Γ (addr a) p | (just function f ⟨ var x ∈ T ⟩∈ U is B end , q) | () -mustBeNumber H Γ (addr a) p | (nothing , q) with trans (cong orAny (cong typeOfᴹᴼ (sym q))) p +mustBeNumber H Γ (addr a) p | (nothing , q) with trans (cong orUnknown (cong typeOfᴹᴼ (sym q))) p mustBeNumber H Γ (addr a) p | nothing , q | () mustBeNumber H Γ (number n) p = refl mustBeString : ∀ H Γ v → (typeOfᴱ H Γ (val v) ≡ string) → (valueType(v) ≡ string) mustBeString H Γ (addr a) p with remember (H [ a ]ᴴ) -mustBeString H Γ (addr a) p | (just O , q) with trans (cong orAny (cong typeOfᴹᴼ (sym q))) p +mustBeString H Γ (addr a) p | (just O , q) with trans (cong orUnknown (cong typeOfᴹᴼ (sym q))) p mustBeString H Γ (addr a) p | (just function f ⟨ var x ∈ T ⟩∈ U is B end , q) | () -mustBeString H Γ (addr a) p | (nothing , q) with trans (cong orAny (cong typeOfᴹᴼ (sym q))) p +mustBeString H Γ (addr a) p | (nothing , q) with trans (cong orUnknown (cong typeOfᴹᴼ (sym q))) p mustBeString H Γ (addr a) p | (nothing , q) | () mustBeString H Γ (string x) p = refl @@ -83,7 +71,7 @@ typeCheckᴮ : ∀ H Γ B → (Γ ⊢ᴮ B ∈ (typeOfᴮ H Γ B)) typeCheckᴱ H Γ (var x) = var refl typeCheckᴱ H Γ (val nil) = nil -typeCheckᴱ H Γ (val (addr a)) = addr (orAny (typeOfᴹᴼ (H [ a ]ᴴ))) +typeCheckᴱ H Γ (val (addr a)) = addr (orUnknown (typeOfᴹᴼ (H [ a ]ᴴ))) typeCheckᴱ H Γ (val (number n)) = number typeCheckᴱ H Γ (val (bool b)) = bool typeCheckᴱ H Γ (val (string x)) = string diff --git a/prototyping/Properties/TypeNormalization.agda b/prototyping/Properties/TypeNormalization.agda new file mode 100644 index 00000000..cbd8139f --- /dev/null +++ b/prototyping/Properties/TypeNormalization.agda @@ -0,0 +1,408 @@ +{-# OPTIONS --rewriting #-} + +module Properties.TypeNormalization where + +open import Luau.Type using (Type; Scalar; nil; number; string; boolean; never; unknown; _⇒_; _∪_; _∩_) +open import Luau.Subtyping using (Tree; Language; ¬Language; function; scalar; unknown; left; right; function-ok₁; function-ok₂; function-err; function-tgt; scalar-function; scalar-function-ok; scalar-function-err; scalar-function-tgt; function-scalar; _,_) +open import Luau.TypeNormalization using (_∪ⁿ_; _∩ⁿ_; _∪ᶠ_; _∪ⁿˢ_; _∩ⁿˢ_; normalize) +open import Luau.Subtyping using (_<:_; _≮:_; witness; never) +open import Properties.Subtyping using (<:-trans; <:-refl; <:-unknown; <:-never; <:-∪-left; <:-∪-right; <:-∪-lub; <:-∩-left; <:-∩-right; <:-∩-glb; <:-∩-symm; <:-function; <:-function-∪-∩; <:-function-∩-∪; <:-function-∪; <:-everything; <:-union; <:-∪-assocl; <:-∪-assocr; <:-∪-symm; <:-intersect; ∪-distl-∩-<:; ∪-distr-∩-<:; <:-∪-distr-∩; <:-∪-distl-∩; ∩-distl-∪-<:; <:-∩-distl-∪; <:-∩-distr-∪; scalar-∩-function-<:-never; scalar-≢-∩-<:-never) + +-- Normal forms for types +data FunType : Type → Set +data Normal : Type → Set + +data FunType where + _⇒_ : ∀ {S T} → Normal S → Normal T → FunType (S ⇒ T) + _∩_ : ∀ {F G} → FunType F → FunType G → FunType (F ∩ G) + +data Normal where + _⇒_ : ∀ {S T} → Normal S → Normal T → Normal (S ⇒ T) + _∩_ : ∀ {F G} → FunType F → FunType G → Normal (F ∩ G) + _∪_ : ∀ {S T} → Normal S → Scalar T → Normal (S ∪ T) + never : Normal never + unknown : Normal unknown + +data OptScalar : Type → Set where + never : OptScalar never + number : OptScalar number + boolean : OptScalar boolean + string : OptScalar string + nil : OptScalar nil + +-- Top function type +fun-top : ∀ {F} → (FunType F) → (F <: (never ⇒ unknown)) +fun-top (S ⇒ T) = <:-function <:-never <:-unknown +fun-top (F ∩ G) = <:-trans <:-∩-left (fun-top F) + +-- function types are inhabited +fun-function : ∀ {F} → FunType F → Language F function +fun-function (S ⇒ T) = function +fun-function (F ∩ G) = (fun-function F , fun-function G) + +fun-≮:-never : ∀ {F} → FunType F → (F ≮: never) +fun-≮:-never F = witness function (fun-function F) never + +-- function types aren't scalars +fun-¬scalar : ∀ {F S t} → (s : Scalar S) → FunType F → Language F t → ¬Language S t +fun-¬scalar s (S ⇒ T) function = scalar-function s +fun-¬scalar s (S ⇒ T) (function-ok₁ p) = scalar-function-ok s +fun-¬scalar s (S ⇒ T) (function-ok₂ p) = scalar-function-ok s +fun-¬scalar s (S ⇒ T) (function-err p) = scalar-function-err s +fun-¬scalar s (S ⇒ T) (function-tgt p) = scalar-function-tgt s +fun-¬scalar s (F ∩ G) (p₁ , p₂) = fun-¬scalar s G p₂ + +¬scalar-fun : ∀ {F S} → FunType F → (s : Scalar S) → ¬Language F (scalar s) +¬scalar-fun (S ⇒ T) s = function-scalar s +¬scalar-fun (F ∩ G) s = left (¬scalar-fun F s) + +scalar-≮:-fun : ∀ {F S} → FunType F → Scalar S → S ≮: F +scalar-≮:-fun F s = witness (scalar s) (scalar s) (¬scalar-fun F s) + +unknown-≮:-fun : ∀ {F} → FunType F → unknown ≮: F +unknown-≮:-fun F = witness (scalar nil) unknown (¬scalar-fun F nil) + +-- Normalization produces normal types +normal : ∀ T → Normal (normalize T) +normalᶠ : ∀ {F} → FunType F → Normal F +normal-∪ⁿ : ∀ {S T} → Normal S → Normal T → Normal (S ∪ⁿ T) +normal-∩ⁿ : ∀ {S T} → Normal S → Normal T → Normal (S ∩ⁿ T) +normal-∪ⁿˢ : ∀ {S T} → Normal S → OptScalar T → Normal (S ∪ⁿˢ T) +normal-∩ⁿˢ : ∀ {S T} → Normal S → Scalar T → OptScalar (S ∩ⁿˢ T) +normal-∪ᶠ : ∀ {F G} → FunType F → FunType G → FunType (F ∪ᶠ G) + +normal nil = never ∪ nil +normal (S ⇒ T) = (normal S) ⇒ (normal T) +normal never = never +normal unknown = unknown +normal boolean = never ∪ boolean +normal number = never ∪ number +normal string = never ∪ string +normal (S ∪ T) = normal-∪ⁿ (normal S) (normal T) +normal (S ∩ T) = normal-∩ⁿ (normal S) (normal T) + +normalᶠ (S ⇒ T) = S ⇒ T +normalᶠ (F ∩ G) = F ∩ G + +normal-∪ⁿ S (T₁ ∪ T₂) = (normal-∪ⁿ S T₁) ∪ T₂ +normal-∪ⁿ S never = S +normal-∪ⁿ S unknown = unknown +normal-∪ⁿ never (T ⇒ U) = T ⇒ U +normal-∪ⁿ never (G₁ ∩ G₂) = G₁ ∩ G₂ +normal-∪ⁿ unknown (T ⇒ U) = unknown +normal-∪ⁿ unknown (G₁ ∩ G₂) = unknown +normal-∪ⁿ (R ⇒ S) (T ⇒ U) = normalᶠ (normal-∪ᶠ (R ⇒ S) (T ⇒ U)) +normal-∪ⁿ (R ⇒ S) (G₁ ∩ G₂) = normalᶠ (normal-∪ᶠ (R ⇒ S) (G₁ ∩ G₂)) +normal-∪ⁿ (F₁ ∩ F₂) (T ⇒ U) = normalᶠ (normal-∪ᶠ (F₁ ∩ F₂) (T ⇒ U)) +normal-∪ⁿ (F₁ ∩ F₂) (G₁ ∩ G₂) = normalᶠ (normal-∪ᶠ (F₁ ∩ F₂) (G₁ ∩ G₂)) +normal-∪ⁿ (S₁ ∪ S₂) (T₁ ⇒ T₂) = normal-∪ⁿ S₁ (T₁ ⇒ T₂) ∪ S₂ +normal-∪ⁿ (S₁ ∪ S₂) (G₁ ∩ G₂) = normal-∪ⁿ S₁ (G₁ ∩ G₂) ∪ S₂ + +normal-∩ⁿ S never = never +normal-∩ⁿ S unknown = S +normal-∩ⁿ S (T ∪ U) = normal-∪ⁿˢ (normal-∩ⁿ S T) (normal-∩ⁿˢ S U ) +normal-∩ⁿ never (T ⇒ U) = never +normal-∩ⁿ unknown (T ⇒ U) = T ⇒ U +normal-∩ⁿ (R ⇒ S) (T ⇒ U) = (R ⇒ S) ∩ (T ⇒ U) +normal-∩ⁿ (R ∩ S) (T ⇒ U) = (R ∩ S) ∩ (T ⇒ U) +normal-∩ⁿ (R ∪ S) (T ⇒ U) = normal-∩ⁿ R (T ⇒ U) +normal-∩ⁿ never (T ∩ U) = never +normal-∩ⁿ unknown (T ∩ U) = T ∩ U +normal-∩ⁿ (R ⇒ S) (T ∩ U) = (R ⇒ S) ∩ (T ∩ U) +normal-∩ⁿ (R ∩ S) (T ∩ U) = (R ∩ S) ∩ (T ∩ U) +normal-∩ⁿ (R ∪ S) (T ∩ U) = normal-∩ⁿ R (T ∩ U) + +normal-∪ⁿˢ S never = S +normal-∪ⁿˢ never number = never ∪ number +normal-∪ⁿˢ unknown number = unknown +normal-∪ⁿˢ (R ⇒ S) number = (R ⇒ S) ∪ number +normal-∪ⁿˢ (R ∩ S) number = (R ∩ S) ∪ number +normal-∪ⁿˢ (R ∪ number) number = R ∪ number +normal-∪ⁿˢ (R ∪ boolean) number = normal-∪ⁿˢ R number ∪ boolean +normal-∪ⁿˢ (R ∪ string) number = normal-∪ⁿˢ R number ∪ string +normal-∪ⁿˢ (R ∪ nil) number = normal-∪ⁿˢ R number ∪ nil +normal-∪ⁿˢ never boolean = never ∪ boolean +normal-∪ⁿˢ unknown boolean = unknown +normal-∪ⁿˢ (R ⇒ S) boolean = (R ⇒ S) ∪ boolean +normal-∪ⁿˢ (R ∩ S) boolean = (R ∩ S) ∪ boolean +normal-∪ⁿˢ (R ∪ number) boolean = normal-∪ⁿˢ R boolean ∪ number +normal-∪ⁿˢ (R ∪ boolean) boolean = R ∪ boolean +normal-∪ⁿˢ (R ∪ string) boolean = normal-∪ⁿˢ R boolean ∪ string +normal-∪ⁿˢ (R ∪ nil) boolean = normal-∪ⁿˢ R boolean ∪ nil +normal-∪ⁿˢ never string = never ∪ string +normal-∪ⁿˢ unknown string = unknown +normal-∪ⁿˢ (R ⇒ S) string = (R ⇒ S) ∪ string +normal-∪ⁿˢ (R ∩ S) string = (R ∩ S) ∪ string +normal-∪ⁿˢ (R ∪ number) string = normal-∪ⁿˢ R string ∪ number +normal-∪ⁿˢ (R ∪ boolean) string = normal-∪ⁿˢ R string ∪ boolean +normal-∪ⁿˢ (R ∪ string) string = R ∪ string +normal-∪ⁿˢ (R ∪ nil) string = normal-∪ⁿˢ R string ∪ nil +normal-∪ⁿˢ never nil = never ∪ nil +normal-∪ⁿˢ unknown nil = unknown +normal-∪ⁿˢ (R ⇒ S) nil = (R ⇒ S) ∪ nil +normal-∪ⁿˢ (R ∩ S) nil = (R ∩ S) ∪ nil +normal-∪ⁿˢ (R ∪ number) nil = normal-∪ⁿˢ R nil ∪ number +normal-∪ⁿˢ (R ∪ boolean) nil = normal-∪ⁿˢ R nil ∪ boolean +normal-∪ⁿˢ (R ∪ string) nil = normal-∪ⁿˢ R nil ∪ string +normal-∪ⁿˢ (R ∪ nil) nil = R ∪ nil + +normal-∩ⁿˢ never number = never +normal-∩ⁿˢ never boolean = never +normal-∩ⁿˢ never string = never +normal-∩ⁿˢ never nil = never +normal-∩ⁿˢ unknown number = number +normal-∩ⁿˢ unknown boolean = boolean +normal-∩ⁿˢ unknown string = string +normal-∩ⁿˢ unknown nil = nil +normal-∩ⁿˢ (R ⇒ S) number = never +normal-∩ⁿˢ (R ⇒ S) boolean = never +normal-∩ⁿˢ (R ⇒ S) string = never +normal-∩ⁿˢ (R ⇒ S) nil = never +normal-∩ⁿˢ (R ∩ S) number = never +normal-∩ⁿˢ (R ∩ S) boolean = never +normal-∩ⁿˢ (R ∩ S) string = never +normal-∩ⁿˢ (R ∩ S) nil = never +normal-∩ⁿˢ (R ∪ number) number = number +normal-∩ⁿˢ (R ∪ boolean) number = normal-∩ⁿˢ R number +normal-∩ⁿˢ (R ∪ string) number = normal-∩ⁿˢ R number +normal-∩ⁿˢ (R ∪ nil) number = normal-∩ⁿˢ R number +normal-∩ⁿˢ (R ∪ number) boolean = normal-∩ⁿˢ R boolean +normal-∩ⁿˢ (R ∪ boolean) boolean = boolean +normal-∩ⁿˢ (R ∪ string) boolean = normal-∩ⁿˢ R boolean +normal-∩ⁿˢ (R ∪ nil) boolean = normal-∩ⁿˢ R boolean +normal-∩ⁿˢ (R ∪ number) string = normal-∩ⁿˢ R string +normal-∩ⁿˢ (R ∪ boolean) string = normal-∩ⁿˢ R string +normal-∩ⁿˢ (R ∪ string) string = string +normal-∩ⁿˢ (R ∪ nil) string = normal-∩ⁿˢ R string +normal-∩ⁿˢ (R ∪ number) nil = normal-∩ⁿˢ R nil +normal-∩ⁿˢ (R ∪ boolean) nil = normal-∩ⁿˢ R nil +normal-∩ⁿˢ (R ∪ string) nil = normal-∩ⁿˢ R nil +normal-∩ⁿˢ (R ∪ nil) nil = nil + +normal-∪ᶠ (R ⇒ S) (T ⇒ U) = (normal-∩ⁿ R T) ⇒ (normal-∪ⁿ S U) +normal-∪ᶠ (R ⇒ S) (G ∩ H) = normal-∪ᶠ (R ⇒ S) G ∩ normal-∪ᶠ (R ⇒ S) H +normal-∪ᶠ (E ∩ F) G = normal-∪ᶠ E G ∩ normal-∪ᶠ F G + +scalar-∩-fun-<:-never : ∀ {F S} → FunType F → Scalar S → (F ∩ S) <: never +scalar-∩-fun-<:-never (T ⇒ U) S = scalar-∩-function-<:-never S +scalar-∩-fun-<:-never (F ∩ G) S = <:-trans (<:-intersect <:-∩-left <:-refl) (scalar-∩-fun-<:-never F S) + +flipper : ∀ {S T U} → ((S ∪ T) ∪ U) <: ((S ∪ U) ∪ T) +flipper = <:-trans <:-∪-assocr (<:-trans (<:-union <:-refl <:-∪-symm) <:-∪-assocl) + +∩-<:-∩ⁿ : ∀ {S T} → Normal S → Normal T → (S ∩ T) <: (S ∩ⁿ T) +∩ⁿ-<:-∩ : ∀ {S T} → Normal S → Normal T → (S ∩ⁿ T) <: (S ∩ T) +∩-<:-∩ⁿˢ : ∀ {S T} → Normal S → Scalar T → (S ∩ T) <: (S ∩ⁿˢ T) +∩ⁿˢ-<:-∩ : ∀ {S T} → Normal S → Scalar T → (S ∩ⁿˢ T) <: (S ∩ T) +∪ᶠ-<:-∪ : ∀ {F G} → FunType F → FunType G → (F ∪ᶠ G) <: (F ∪ G) +∪ⁿ-<:-∪ : ∀ {S T} → Normal S → Normal T → (S ∪ⁿ T) <: (S ∪ T) +∪-<:-∪ⁿ : ∀ {S T} → Normal S → Normal T → (S ∪ T) <: (S ∪ⁿ T) +∪ⁿˢ-<:-∪ : ∀ {S T} → Normal S → OptScalar T → (S ∪ⁿˢ T) <: (S ∪ T) +∪-<:-∪ⁿˢ : ∀ {S T} → Normal S → OptScalar T → (S ∪ T) <: (S ∪ⁿˢ T) + +∩-<:-∩ⁿ S never = <:-∩-right +∩-<:-∩ⁿ S unknown = <:-∩-left +∩-<:-∩ⁿ S (T ∪ U) = <:-trans <:-∩-distl-∪ (<:-trans (<:-union (∩-<:-∩ⁿ S T) (∩-<:-∩ⁿˢ S U)) (∪-<:-∪ⁿˢ (normal-∩ⁿ S T) (normal-∩ⁿˢ S U)) ) +∩-<:-∩ⁿ never (T ⇒ U) = <:-∩-left +∩-<:-∩ⁿ unknown (T ⇒ U) = <:-∩-right +∩-<:-∩ⁿ (R ⇒ S) (T ⇒ U) = <:-refl +∩-<:-∩ⁿ (R ∩ S) (T ⇒ U) = <:-refl +∩-<:-∩ⁿ (R ∪ S) (T ⇒ U) = <:-trans <:-∩-distr-∪ (<:-trans (<:-union (∩-<:-∩ⁿ R (T ⇒ U)) (<:-trans <:-∩-symm (∩-<:-∩ⁿˢ (T ⇒ U) S))) (<:-∪-lub <:-refl <:-never)) +∩-<:-∩ⁿ never (T ∩ U) = <:-∩-left +∩-<:-∩ⁿ unknown (T ∩ U) = <:-∩-right +∩-<:-∩ⁿ (R ⇒ S) (T ∩ U) = <:-refl +∩-<:-∩ⁿ (R ∩ S) (T ∩ U) = <:-refl +∩-<:-∩ⁿ (R ∪ S) (T ∩ U) = <:-trans <:-∩-distr-∪ (<:-trans (<:-union (∩-<:-∩ⁿ R (T ∩ U)) (<:-trans <:-∩-symm (∩-<:-∩ⁿˢ (T ∩ U) S))) (<:-∪-lub <:-refl <:-never)) + +∩ⁿ-<:-∩ S never = <:-never +∩ⁿ-<:-∩ S unknown = <:-∩-glb <:-refl <:-unknown +∩ⁿ-<:-∩ S (T ∪ U) = <:-trans (∪ⁿˢ-<:-∪ (normal-∩ⁿ S T) (normal-∩ⁿˢ S U)) (<:-trans (<:-union (∩ⁿ-<:-∩ S T) (∩ⁿˢ-<:-∩ S U)) ∩-distl-∪-<:) +∩ⁿ-<:-∩ never (T ⇒ U) = <:-never +∩ⁿ-<:-∩ unknown (T ⇒ U) = <:-∩-glb <:-unknown <:-refl +∩ⁿ-<:-∩ (R ⇒ S) (T ⇒ U) = <:-refl +∩ⁿ-<:-∩ (R ∩ S) (T ⇒ U) = <:-refl +∩ⁿ-<:-∩ (R ∪ S) (T ⇒ U) = <:-trans (∩ⁿ-<:-∩ R (T ⇒ U)) (<:-∩-glb (<:-trans <:-∩-left <:-∪-left) <:-∩-right) +∩ⁿ-<:-∩ never (T ∩ U) = <:-never +∩ⁿ-<:-∩ unknown (T ∩ U) = <:-∩-glb <:-unknown <:-refl +∩ⁿ-<:-∩ (R ⇒ S) (T ∩ U) = <:-refl +∩ⁿ-<:-∩ (R ∩ S) (T ∩ U) = <:-refl +∩ⁿ-<:-∩ (R ∪ S) (T ∩ U) = <:-trans (∩ⁿ-<:-∩ R (T ∩ U)) (<:-∩-glb (<:-trans <:-∩-left <:-∪-left) <:-∩-right) + +∩-<:-∩ⁿˢ never number = <:-∩-left +∩-<:-∩ⁿˢ never boolean = <:-∩-left +∩-<:-∩ⁿˢ never string = <:-∩-left +∩-<:-∩ⁿˢ never nil = <:-∩-left +∩-<:-∩ⁿˢ unknown T = <:-∩-right +∩-<:-∩ⁿˢ (R ⇒ S) T = scalar-∩-fun-<:-never (R ⇒ S) T +∩-<:-∩ⁿˢ (F ∩ G) T = scalar-∩-fun-<:-never (F ∩ G) T +∩-<:-∩ⁿˢ (R ∪ number) number = <:-∩-right +∩-<:-∩ⁿˢ (R ∪ boolean) number = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R number) (scalar-≢-∩-<:-never boolean number (λ ()))) +∩-<:-∩ⁿˢ (R ∪ string) number = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R number) (scalar-≢-∩-<:-never string number (λ ()))) +∩-<:-∩ⁿˢ (R ∪ nil) number = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R number) (scalar-≢-∩-<:-never nil number (λ ()))) +∩-<:-∩ⁿˢ (R ∪ number) boolean = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R boolean) (scalar-≢-∩-<:-never number boolean (λ ()))) +∩-<:-∩ⁿˢ (R ∪ boolean) boolean = <:-∩-right +∩-<:-∩ⁿˢ (R ∪ string) boolean = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R boolean) (scalar-≢-∩-<:-never string boolean (λ ()))) +∩-<:-∩ⁿˢ (R ∪ nil) boolean = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R boolean) (scalar-≢-∩-<:-never nil boolean (λ ()))) +∩-<:-∩ⁿˢ (R ∪ number) string = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R string) (scalar-≢-∩-<:-never number string (λ ()))) +∩-<:-∩ⁿˢ (R ∪ boolean) string = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R string) (scalar-≢-∩-<:-never boolean string (λ ()))) +∩-<:-∩ⁿˢ (R ∪ string) string = <:-∩-right +∩-<:-∩ⁿˢ (R ∪ nil) string = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R string) (scalar-≢-∩-<:-never nil string (λ ()))) +∩-<:-∩ⁿˢ (R ∪ number) nil = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R nil) (scalar-≢-∩-<:-never number nil (λ ()))) +∩-<:-∩ⁿˢ (R ∪ boolean) nil = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R nil) (scalar-≢-∩-<:-never boolean nil (λ ()))) +∩-<:-∩ⁿˢ (R ∪ string) nil = <:-trans <:-∩-distr-∪ (<:-∪-lub (∩-<:-∩ⁿˢ R nil) (scalar-≢-∩-<:-never string nil (λ ()))) +∩-<:-∩ⁿˢ (R ∪ nil) nil = <:-∩-right + +∩ⁿˢ-<:-∩ never T = <:-never +∩ⁿˢ-<:-∩ unknown T = <:-∩-glb <:-unknown <:-refl +∩ⁿˢ-<:-∩ (R ⇒ S) T = <:-never +∩ⁿˢ-<:-∩ (F ∩ G) T = <:-never +∩ⁿˢ-<:-∩ (R ∪ number) number = <:-∩-glb <:-∪-right <:-refl +∩ⁿˢ-<:-∩ (R ∪ boolean) number = <:-trans (∩ⁿˢ-<:-∩ R number) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ string) number = <:-trans (∩ⁿˢ-<:-∩ R number) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ nil) number = <:-trans (∩ⁿˢ-<:-∩ R number) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ number) boolean = <:-trans (∩ⁿˢ-<:-∩ R boolean) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ boolean) boolean = <:-∩-glb <:-∪-right <:-refl +∩ⁿˢ-<:-∩ (R ∪ string) boolean = <:-trans (∩ⁿˢ-<:-∩ R boolean) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ nil) boolean = <:-trans (∩ⁿˢ-<:-∩ R boolean) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ number) string = <:-trans (∩ⁿˢ-<:-∩ R string) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ boolean) string = <:-trans (∩ⁿˢ-<:-∩ R string) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ string) string = <:-∩-glb <:-∪-right <:-refl +∩ⁿˢ-<:-∩ (R ∪ nil) string = <:-trans (∩ⁿˢ-<:-∩ R string) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ number) nil = <:-trans (∩ⁿˢ-<:-∩ R nil) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ boolean) nil = <:-trans (∩ⁿˢ-<:-∩ R nil) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ string) nil = <:-trans (∩ⁿˢ-<:-∩ R nil) (<:-intersect <:-∪-left <:-refl) +∩ⁿˢ-<:-∩ (R ∪ nil) nil = <:-∩-glb <:-∪-right <:-refl + +∪ᶠ-<:-∪ (R ⇒ S) (T ⇒ U) = <:-trans (<:-function (∩-<:-∩ⁿ R T) (∪ⁿ-<:-∪ S U)) <:-function-∪-∩ +∪ᶠ-<:-∪ (R ⇒ S) (G ∩ H) = <:-trans (<:-intersect (∪ᶠ-<:-∪ (R ⇒ S) G) (∪ᶠ-<:-∪ (R ⇒ S) H)) ∪-distl-∩-<: +∪ᶠ-<:-∪ (E ∩ F) G = <:-trans (<:-intersect (∪ᶠ-<:-∪ E G) (∪ᶠ-<:-∪ F G)) ∪-distr-∩-<: + +∪-<:-∪ᶠ : ∀ {F G} → FunType F → FunType G → (F ∪ G) <: (F ∪ᶠ G) +∪-<:-∪ᶠ (R ⇒ S) (T ⇒ U) = <:-trans <:-function-∪ (<:-function (∩ⁿ-<:-∩ R T) (∪-<:-∪ⁿ S U)) +∪-<:-∪ᶠ (R ⇒ S) (G ∩ H) = <:-trans <:-∪-distl-∩ (<:-intersect (∪-<:-∪ᶠ (R ⇒ S) G) (∪-<:-∪ᶠ (R ⇒ S) H)) +∪-<:-∪ᶠ (E ∩ F) G = <:-trans <:-∪-distr-∩ (<:-intersect (∪-<:-∪ᶠ E G) (∪-<:-∪ᶠ F G)) + +∪ⁿˢ-<:-∪ S never = <:-∪-left +∪ⁿˢ-<:-∪ never number = <:-refl +∪ⁿˢ-<:-∪ never boolean = <:-refl +∪ⁿˢ-<:-∪ never string = <:-refl +∪ⁿˢ-<:-∪ never nil = <:-refl +∪ⁿˢ-<:-∪ unknown number = <:-∪-left +∪ⁿˢ-<:-∪ unknown boolean = <:-∪-left +∪ⁿˢ-<:-∪ unknown string = <:-∪-left +∪ⁿˢ-<:-∪ unknown nil = <:-∪-left +∪ⁿˢ-<:-∪ (R ⇒ S) number = <:-refl +∪ⁿˢ-<:-∪ (R ⇒ S) boolean = <:-refl +∪ⁿˢ-<:-∪ (R ⇒ S) string = <:-refl +∪ⁿˢ-<:-∪ (R ⇒ S) nil = <:-refl +∪ⁿˢ-<:-∪ (R ∩ S) number = <:-refl +∪ⁿˢ-<:-∪ (R ∩ S) boolean = <:-refl +∪ⁿˢ-<:-∪ (R ∩ S) string = <:-refl +∪ⁿˢ-<:-∪ (R ∩ S) nil = <:-refl +∪ⁿˢ-<:-∪ (R ∪ number) number = <:-union <:-∪-left <:-refl +∪ⁿˢ-<:-∪ (R ∪ boolean) number = <:-trans (<:-union (∪ⁿˢ-<:-∪ R number) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ string) number = <:-trans (<:-union (∪ⁿˢ-<:-∪ R number) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ nil) number = <:-trans (<:-union (∪ⁿˢ-<:-∪ R number) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ number) boolean = <:-trans (<:-union (∪ⁿˢ-<:-∪ R boolean) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ boolean) boolean = <:-union <:-∪-left <:-refl +∪ⁿˢ-<:-∪ (R ∪ string) boolean = <:-trans (<:-union (∪ⁿˢ-<:-∪ R boolean) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ nil) boolean = <:-trans (<:-union (∪ⁿˢ-<:-∪ R boolean) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ number) string = <:-trans (<:-union (∪ⁿˢ-<:-∪ R string) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ boolean) string = <:-trans (<:-union (∪ⁿˢ-<:-∪ R string) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ string) string = <:-union <:-∪-left <:-refl +∪ⁿˢ-<:-∪ (R ∪ nil) string = <:-trans (<:-union (∪ⁿˢ-<:-∪ R string) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ number) nil = <:-trans (<:-union (∪ⁿˢ-<:-∪ R nil) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ boolean) nil = <:-trans (<:-union (∪ⁿˢ-<:-∪ R nil) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ string) nil = <:-trans (<:-union (∪ⁿˢ-<:-∪ R nil) <:-refl) flipper +∪ⁿˢ-<:-∪ (R ∪ nil) nil = <:-union <:-∪-left <:-refl + +∪-<:-∪ⁿˢ T never = <:-∪-lub <:-refl <:-never +∪-<:-∪ⁿˢ never number = <:-refl +∪-<:-∪ⁿˢ never boolean = <:-refl +∪-<:-∪ⁿˢ never string = <:-refl +∪-<:-∪ⁿˢ never nil = <:-refl +∪-<:-∪ⁿˢ unknown number = <:-unknown +∪-<:-∪ⁿˢ unknown boolean = <:-unknown +∪-<:-∪ⁿˢ unknown string = <:-unknown +∪-<:-∪ⁿˢ unknown nil = <:-unknown +∪-<:-∪ⁿˢ (R ⇒ S) number = <:-refl +∪-<:-∪ⁿˢ (R ⇒ S) boolean = <:-refl +∪-<:-∪ⁿˢ (R ⇒ S) string = <:-refl +∪-<:-∪ⁿˢ (R ⇒ S) nil = <:-refl +∪-<:-∪ⁿˢ (R ∩ S) number = <:-refl +∪-<:-∪ⁿˢ (R ∩ S) boolean = <:-refl +∪-<:-∪ⁿˢ (R ∩ S) string = <:-refl +∪-<:-∪ⁿˢ (R ∩ S) nil = <:-refl +∪-<:-∪ⁿˢ (R ∪ number) number = <:-∪-lub <:-refl <:-∪-right +∪-<:-∪ⁿˢ (R ∪ boolean) number = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R number) <:-refl) +∪-<:-∪ⁿˢ (R ∪ string) number = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R number) <:-refl) +∪-<:-∪ⁿˢ (R ∪ nil) number = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R number) <:-refl) +∪-<:-∪ⁿˢ (R ∪ number) boolean = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R boolean) <:-refl) +∪-<:-∪ⁿˢ (R ∪ boolean) boolean = <:-∪-lub <:-refl <:-∪-right +∪-<:-∪ⁿˢ (R ∪ string) boolean = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R boolean) <:-refl) +∪-<:-∪ⁿˢ (R ∪ nil) boolean = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R boolean) <:-refl) +∪-<:-∪ⁿˢ (R ∪ number) string = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R string) <:-refl) +∪-<:-∪ⁿˢ (R ∪ boolean) string = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R string) <:-refl) +∪-<:-∪ⁿˢ (R ∪ string) string = <:-∪-lub <:-refl <:-∪-right +∪-<:-∪ⁿˢ (R ∪ nil) string = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R string) <:-refl) +∪-<:-∪ⁿˢ (R ∪ number) nil = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R nil) <:-refl) +∪-<:-∪ⁿˢ (R ∪ boolean) nil = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R nil) <:-refl) +∪-<:-∪ⁿˢ (R ∪ string) nil = <:-trans flipper (<:-union (∪-<:-∪ⁿˢ R nil) <:-refl) +∪-<:-∪ⁿˢ (R ∪ nil) nil = <:-∪-lub <:-refl <:-∪-right + +∪ⁿ-<:-∪ S never = <:-∪-left +∪ⁿ-<:-∪ S unknown = <:-∪-right +∪ⁿ-<:-∪ never (T ⇒ U) = <:-∪-right +∪ⁿ-<:-∪ unknown (T ⇒ U) = <:-∪-left +∪ⁿ-<:-∪ (R ⇒ S) (T ⇒ U) = ∪ᶠ-<:-∪ (R ⇒ S) (T ⇒ U) +∪ⁿ-<:-∪ (R ∩ S) (T ⇒ U) = ∪ᶠ-<:-∪ (R ∩ S) (T ⇒ U) +∪ⁿ-<:-∪ (R ∪ S) (T ⇒ U) = <:-trans (<:-union (∪ⁿ-<:-∪ R (T ⇒ U)) <:-refl) (<:-∪-lub (<:-∪-lub (<:-trans <:-∪-left <:-∪-left) <:-∪-right) (<:-trans <:-∪-right <:-∪-left)) +∪ⁿ-<:-∪ never (T ∩ U) = <:-∪-right +∪ⁿ-<:-∪ unknown (T ∩ U) = <:-∪-left +∪ⁿ-<:-∪ (R ⇒ S) (T ∩ U) = ∪ᶠ-<:-∪ (R ⇒ S) (T ∩ U) +∪ⁿ-<:-∪ (R ∩ S) (T ∩ U) = ∪ᶠ-<:-∪ (R ∩ S) (T ∩ U) +∪ⁿ-<:-∪ (R ∪ S) (T ∩ U) = <:-trans (<:-union (∪ⁿ-<:-∪ R (T ∩ U)) <:-refl) (<:-∪-lub (<:-∪-lub (<:-trans <:-∪-left <:-∪-left) <:-∪-right) (<:-trans <:-∪-right <:-∪-left)) +∪ⁿ-<:-∪ S (T ∪ U) = <:-∪-lub (<:-trans (∪ⁿ-<:-∪ S T) (<:-union <:-refl <:-∪-left)) (<:-trans <:-∪-right <:-∪-right) + +∪-<:-∪ⁿ S never = <:-∪-lub <:-refl <:-never +∪-<:-∪ⁿ S unknown = <:-unknown +∪-<:-∪ⁿ never (T ⇒ U) = <:-∪-lub <:-never <:-refl +∪-<:-∪ⁿ unknown (T ⇒ U) = <:-unknown +∪-<:-∪ⁿ (R ⇒ S) (T ⇒ U) = ∪-<:-∪ᶠ (R ⇒ S) (T ⇒ U) +∪-<:-∪ⁿ (R ∩ S) (T ⇒ U) = ∪-<:-∪ᶠ (R ∩ S) (T ⇒ U) +∪-<:-∪ⁿ (R ∪ S) (T ⇒ U) = <:-trans <:-∪-assocr (<:-trans (<:-union <:-refl <:-∪-symm) (<:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ R (T ⇒ U)) <:-refl))) +∪-<:-∪ⁿ never (T ∩ U) = <:-∪-lub <:-never <:-refl +∪-<:-∪ⁿ unknown (T ∩ U) = <:-unknown +∪-<:-∪ⁿ (R ⇒ S) (T ∩ U) = ∪-<:-∪ᶠ (R ⇒ S) (T ∩ U) +∪-<:-∪ⁿ (R ∩ S) (T ∩ U) = ∪-<:-∪ᶠ (R ∩ S) (T ∩ U) +∪-<:-∪ⁿ (R ∪ S) (T ∩ U) = <:-trans <:-∪-assocr (<:-trans (<:-union <:-refl <:-∪-symm) (<:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ R (T ∩ U)) <:-refl))) +∪-<:-∪ⁿ never (T ∪ U) = <:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ never T) <:-refl) +∪-<:-∪ⁿ unknown (T ∪ U) = <:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ unknown T) <:-refl) +∪-<:-∪ⁿ (R ⇒ S) (T ∪ U) = <:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ (R ⇒ S) T) <:-refl) +∪-<:-∪ⁿ (R ∩ S) (T ∪ U) = <:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ (R ∩ S) T) <:-refl) +∪-<:-∪ⁿ (R ∪ S) (T ∪ U) = <:-trans <:-∪-assocl (<:-union (∪-<:-∪ⁿ (R ∪ S) T) <:-refl) + +normalize-<: : ∀ T → normalize T <: T +<:-normalize : ∀ T → T <: normalize T + +<:-normalize nil = <:-∪-right +<:-normalize (S ⇒ T) = <:-function (normalize-<: S) (<:-normalize T) +<:-normalize never = <:-refl +<:-normalize unknown = <:-refl +<:-normalize boolean = <:-∪-right +<:-normalize number = <:-∪-right +<:-normalize string = <:-∪-right +<:-normalize (S ∪ T) = <:-trans (<:-union (<:-normalize S) (<:-normalize T)) (∪-<:-∪ⁿ (normal S) (normal T)) +<:-normalize (S ∩ T) = <:-trans (<:-intersect (<:-normalize S) (<:-normalize T)) (∩-<:-∩ⁿ (normal S) (normal T)) + +normalize-<: nil = <:-∪-lub <:-never <:-refl +normalize-<: (S ⇒ T) = <:-function (<:-normalize S) (normalize-<: T) +normalize-<: never = <:-refl +normalize-<: unknown = <:-refl +normalize-<: boolean = <:-∪-lub <:-never <:-refl +normalize-<: number = <:-∪-lub <:-never <:-refl +normalize-<: string = <:-∪-lub <:-never <:-refl +normalize-<: (S ∪ T) = <:-trans (∪ⁿ-<:-∪ (normal S) (normal T)) (<:-union (normalize-<: S) (normalize-<: T)) +normalize-<: (S ∩ T) = <:-trans (∩ⁿ-<:-∩ (normal S) (normal T)) (<:-intersect (normalize-<: S) (normalize-<: T)) + + diff --git a/prototyping/Properties/TypeSaturation.agda b/prototyping/Properties/TypeSaturation.agda new file mode 100644 index 00000000..13f7d171 --- /dev/null +++ b/prototyping/Properties/TypeSaturation.agda @@ -0,0 +1,433 @@ +{-# OPTIONS --rewriting #-} + +module Properties.TypeSaturation where + +open import Agda.Builtin.Equality using (_≡_; refl) +open import FFI.Data.Either using (Either; Left; Right) +open import Luau.Subtyping using (Tree; Language; ¬Language; _<:_; _≮:_; witness; scalar; function; function-err; function-ok; function-ok₁; function-ok₂; scalar-function; _,_; never) +open import Luau.Type using (Type; _⇒_; _∩_; _∪_; never; unknown) +open import Luau.TypeNormalization using (_∩ⁿ_; _∪ⁿ_) +open import Luau.TypeSaturation using (_⋓_; _⋒_; _∩ᵘ_; _∩ⁱ_; ∪-saturate; ∩-saturate; saturate) +open import Properties.Subtyping using (dec-language; language-comp; <:-impl-⊇; <:-refl; <:-trans; <:-trans-≮:; <:-impl-¬≮: ; <:-never; <:-unknown; <:-function; <:-union; <:-∪-symm; <:-∪-left; <:-∪-right; <:-∪-lub; <:-∪-assocl; <:-∪-assocr; <:-intersect; <:-∩-symm; <:-∩-left; <:-∩-right; <:-∩-glb; ≮:-function-left; ≮:-function-right; <:-function-∩-∪; <:-function-∩-∩; <:-∩-assocl; <:-∩-assocr; ∩-<:-∪; <:-∩-distl-∪; ∩-distl-∪-<:; <:-∩-distr-∪; ∩-distr-∪-<:) +open import Properties.TypeNormalization using (Normal; FunType; _⇒_; _∩_; _∪_; never; unknown; normal-∪ⁿ; normal-∩ⁿ; ∪ⁿ-<:-∪; ∪-<:-∪ⁿ; ∩ⁿ-<:-∩; ∩-<:-∩ⁿ) +open import Properties.Contradiction using (CONTRADICTION) +open import Properties.Functions using (_∘_) + +-- Saturation preserves normalization +normal-⋒ : ∀ {F G} → FunType F → FunType G → FunType (F ⋒ G) +normal-⋒ (R ⇒ S) (T ⇒ U) = (normal-∩ⁿ R T) ⇒ (normal-∩ⁿ S U) +normal-⋒ (R ⇒ S) (G ∩ H) = normal-⋒ (R ⇒ S) G ∩ normal-⋒ (R ⇒ S) H +normal-⋒ (E ∩ F) G = normal-⋒ E G ∩ normal-⋒ F G + +normal-⋓ : ∀ {F G} → FunType F → FunType G → FunType (F ⋓ G) +normal-⋓ (R ⇒ S) (T ⇒ U) = (normal-∪ⁿ R T) ⇒ (normal-∪ⁿ S U) +normal-⋓ (R ⇒ S) (G ∩ H) = normal-⋓ (R ⇒ S) G ∩ normal-⋓ (R ⇒ S) H +normal-⋓ (E ∩ F) G = normal-⋓ E G ∩ normal-⋓ F G + +normal-∩-saturate : ∀ {F} → FunType F → FunType (∩-saturate F) +normal-∩-saturate (S ⇒ T) = S ⇒ T +normal-∩-saturate (F ∩ G) = (normal-∩-saturate F ∩ normal-∩-saturate G) ∩ normal-⋒ (normal-∩-saturate F) (normal-∩-saturate G) + +normal-∪-saturate : ∀ {F} → FunType F → FunType (∪-saturate F) +normal-∪-saturate (S ⇒ T) = S ⇒ T +normal-∪-saturate (F ∩ G) = (normal-∪-saturate F ∩ normal-∪-saturate G) ∩ normal-⋓ (normal-∪-saturate F) (normal-∪-saturate G) + +normal-saturate : ∀ {F} → FunType F → FunType (saturate F) +normal-saturate F = normal-∪-saturate (normal-∩-saturate F) + +-- Saturation resects subtyping +∪-saturate-<: : ∀ {F} → FunType F → ∪-saturate F <: F +∪-saturate-<: (S ⇒ T) = <:-refl +∪-saturate-<: (F ∩ G) = <:-trans <:-∩-left (<:-intersect (∪-saturate-<: F) (∪-saturate-<: G)) + +∩-saturate-<: : ∀ {F} → FunType F → ∩-saturate F <: F +∩-saturate-<: (S ⇒ T) = <:-refl +∩-saturate-<: (F ∩ G) = <:-trans <:-∩-left (<:-intersect (∩-saturate-<: F) (∩-saturate-<: G)) + +saturate-<: : ∀ {F} → FunType F → saturate F <: F +saturate-<: F = <:-trans (∪-saturate-<: (normal-∩-saturate F)) (∩-saturate-<: F) + +∩-<:-⋓ : ∀ {F G} → FunType F → FunType G → (F ∩ G) <: (F ⋓ G) +∩-<:-⋓ (R ⇒ S) (T ⇒ U) = <:-trans <:-function-∩-∪ (<:-function (∪ⁿ-<:-∪ R T) (∪-<:-∪ⁿ S U)) +∩-<:-⋓ (R ⇒ S) (G ∩ H) = <:-trans (<:-∩-glb (<:-intersect <:-refl <:-∩-left) (<:-intersect <:-refl <:-∩-right)) (<:-intersect (∩-<:-⋓ (R ⇒ S) G) (∩-<:-⋓ (R ⇒ S) H)) +∩-<:-⋓ (E ∩ F) G = <:-trans (<:-∩-glb (<:-intersect <:-∩-left <:-refl) (<:-intersect <:-∩-right <:-refl)) (<:-intersect (∩-<:-⋓ E G) (∩-<:-⋓ F G)) + +∩-<:-⋒ : ∀ {F G} → FunType F → FunType G → (F ∩ G) <: (F ⋒ G) +∩-<:-⋒ (R ⇒ S) (T ⇒ U) = <:-trans <:-function-∩-∩ (<:-function (∩ⁿ-<:-∩ R T) (∩-<:-∩ⁿ S U)) +∩-<:-⋒ (R ⇒ S) (G ∩ H) = <:-trans (<:-∩-glb (<:-intersect <:-refl <:-∩-left) (<:-intersect <:-refl <:-∩-right)) (<:-intersect (∩-<:-⋒ (R ⇒ S) G) (∩-<:-⋒ (R ⇒ S) H)) +∩-<:-⋒ (E ∩ F) G = <:-trans (<:-∩-glb (<:-intersect <:-∩-left <:-refl) (<:-intersect <:-∩-right <:-refl)) (<:-intersect (∩-<:-⋒ E G) (∩-<:-⋒ F G)) + +<:-∪-saturate : ∀ {F} → FunType F → F <: ∪-saturate F +<:-∪-saturate (S ⇒ T) = <:-refl +<:-∪-saturate (F ∩ G) = <:-∩-glb (<:-intersect (<:-∪-saturate F) (<:-∪-saturate G)) (<:-trans (<:-intersect (<:-∪-saturate F) (<:-∪-saturate G)) (∩-<:-⋓ (normal-∪-saturate F) (normal-∪-saturate G))) + +<:-∩-saturate : ∀ {F} → FunType F → F <: ∩-saturate F +<:-∩-saturate (S ⇒ T) = <:-refl +<:-∩-saturate (F ∩ G) = <:-∩-glb (<:-intersect (<:-∩-saturate F) (<:-∩-saturate G)) (<:-trans (<:-intersect (<:-∩-saturate F) (<:-∩-saturate G)) (∩-<:-⋒ (normal-∩-saturate F) (normal-∩-saturate G))) + +<:-saturate : ∀ {F} → FunType F → F <: saturate F +<:-saturate F = <:-trans (<:-∩-saturate F) (<:-∪-saturate (normal-∩-saturate F)) + +-- Overloads F is the set of overloads of F +data Overloads : Type → Type → Set where + + here : ∀ {S T} → Overloads (S ⇒ T) (S ⇒ T) + left : ∀ {S T F G} → Overloads F (S ⇒ T) → Overloads (F ∩ G) (S ⇒ T) + right : ∀ {S T F G} → Overloads G (S ⇒ T) → Overloads (F ∩ G) (S ⇒ T) + +normal-overload-src : ∀ {F S T} → FunType F → Overloads F (S ⇒ T) → Normal S +normal-overload-src (S ⇒ T) here = S +normal-overload-src (F ∩ G) (left o) = normal-overload-src F o +normal-overload-src (F ∩ G) (right o) = normal-overload-src G o + +normal-overload-tgt : ∀ {F S T} → FunType F → Overloads F (S ⇒ T) → Normal T +normal-overload-tgt (S ⇒ T) here = T +normal-overload-tgt (F ∩ G) (left o) = normal-overload-tgt F o +normal-overload-tgt (F ∩ G) (right o) = normal-overload-tgt G o + +-- An inductive presentation of the overloads of F ⋓ G +data ∪-Lift (P Q : Type → Set) : Type → Set where + + union : ∀ {R S T U} → + + P (R ⇒ S) → + Q (T ⇒ U) → + -------------------- + ∪-Lift P Q ((R ∪ T) ⇒ (S ∪ U)) + +-- An inductive presentation of the overloads of F ⋒ G +data ∩-Lift (P Q : Type → Set) : Type → Set where + + intersect : ∀ {R S T U} → + + P (R ⇒ S) → + Q (T ⇒ U) → + -------------------- + ∩-Lift P Q ((R ∩ T) ⇒ (S ∩ U)) + +-- An inductive presentation of the overloads of ∪-saturate F +data ∪-Saturate (P : Type → Set) : Type → Set where + + base : ∀ {S T} → + + P (S ⇒ T) → + -------------------- + ∪-Saturate P (S ⇒ T) + + union : ∀ {R S T U} → + + ∪-Saturate P (R ⇒ S) → + ∪-Saturate P (T ⇒ U) → + -------------------- + ∪-Saturate P ((R ∪ T) ⇒ (S ∪ U)) + +-- An inductive presentation of the overloads of ∩-saturate F +data ∩-Saturate (P : Type → Set) : Type → Set where + + base : ∀ {S T} → + + P (S ⇒ T) → + -------------------- + ∩-Saturate P (S ⇒ T) + + intersect : ∀ {R S T U} → + + ∩-Saturate P (R ⇒ S) → + ∩-Saturate P (T ⇒ U) → + -------------------- + ∩-Saturate P ((R ∩ T) ⇒ (S ∩ U)) + +-- The <:-up-closure of a set of function types +data <:-Close (P : Type → Set) : Type → Set where + + defn : ∀ {R S T U} → + + P (S ⇒ T) → + R <: S → + T <: U → + ------------------ + <:-Close P (R ⇒ U) + +-- F ⊆ᵒ G whenever every overload of F is an overload of G +_⊆ᵒ_ : Type → Type → Set +F ⊆ᵒ G = ∀ {S T} → Overloads F (S ⇒ T) → Overloads G (S ⇒ T) + +-- F <:ᵒ G when every overload of G is a supertype of an overload of F +_<:ᵒ_ : Type → Type → Set +_<:ᵒ_ F G = ∀ {S T} → Overloads G (S ⇒ T) → <:-Close (Overloads F) (S ⇒ T) + +-- P ⊂: Q when any type in P is a subtype of some type in Q +_⊂:_ : (Type → Set) → (Type → Set) → Set +P ⊂: Q = ∀ {S T} → P (S ⇒ T) → <:-Close Q (S ⇒ T) + +-- <:-Close is a monad +just : ∀ {P S T} → P (S ⇒ T) → <:-Close P (S ⇒ T) +just p = defn p <:-refl <:-refl + +infixl 5 _>>=_ _>>=ˡ_ _>>=ʳ_ +_>>=_ : ∀ {P Q S T} → <:-Close P (S ⇒ T) → (P ⊂: Q) → <:-Close Q (S ⇒ T) +(defn p p₁ p₂) >>= P⊂Q with P⊂Q p +(defn p p₁ p₂) >>= P⊂Q | defn q q₁ q₂ = defn q (<:-trans p₁ q₁) (<:-trans q₂ p₂) + +_>>=ˡ_ : ∀ {P R S T} → <:-Close P (S ⇒ T) → (R <: S) → <:-Close P (R ⇒ T) +(defn p p₁ p₂) >>=ˡ q = defn p (<:-trans q p₁) p₂ + +_>>=ʳ_ : ∀ {P S T U} → <:-Close P (S ⇒ T) → (T <: U) → <:-Close P (S ⇒ U) +(defn p p₁ p₂) >>=ʳ q = defn p p₁ (<:-trans p₂ q) + +-- Properties of ⊂: +⊂:-refl : ∀ {P} → P ⊂: P +⊂:-refl p = just p + +_[∪]_ : ∀ {P Q R S T U} → <:-Close P (R ⇒ S) → <:-Close Q (T ⇒ U) → <:-Close (∪-Lift P Q) ((R ∪ T) ⇒ (S ∪ U)) +(defn p p₁ p₂) [∪] (defn q q₁ q₂) = defn (union p q) (<:-union p₁ q₁) (<:-union p₂ q₂) + +_[∩]_ : ∀ {P Q R S T U} → <:-Close P (R ⇒ S) → <:-Close Q (T ⇒ U) → <:-Close (∩-Lift P Q) ((R ∩ T) ⇒ (S ∩ U)) +(defn p p₁ p₂) [∩] (defn q q₁ q₂) = defn (intersect p q) (<:-intersect p₁ q₁) (<:-intersect p₂ q₂) + +⊂:-∩-saturate-inj : ∀ {P} → P ⊂: ∩-Saturate P +⊂:-∩-saturate-inj p = defn (base p) <:-refl <:-refl + +⊂:-∪-saturate-inj : ∀ {P} → P ⊂: ∪-Saturate P +⊂:-∪-saturate-inj p = just (base p) + +⊂:-∩-lift-saturate : ∀ {P} → ∩-Lift (∩-Saturate P) (∩-Saturate P) ⊂: ∩-Saturate P +⊂:-∩-lift-saturate (intersect p q) = just (intersect p q) + +⊂:-∪-lift-saturate : ∀ {P} → ∪-Lift (∪-Saturate P) (∪-Saturate P) ⊂: ∪-Saturate P +⊂:-∪-lift-saturate (union p q) = just (union p q) + +⊂:-∩-lift : ∀ {P Q R S} → (P ⊂: Q) → (R ⊂: S) → (∩-Lift P R ⊂: ∩-Lift Q S) +⊂:-∩-lift P⊂Q R⊂S (intersect n o) = P⊂Q n [∩] R⊂S o + +⊂:-∪-lift : ∀ {P Q R S} → (P ⊂: Q) → (R ⊂: S) → (∪-Lift P R ⊂: ∪-Lift Q S) +⊂:-∪-lift P⊂Q R⊂S (union n o) = P⊂Q n [∪] R⊂S o + +⊂:-∩-saturate : ∀ {P Q} → (P ⊂: Q) → (∩-Saturate P ⊂: ∩-Saturate Q) +⊂:-∩-saturate P⊂Q (base p) = P⊂Q p >>= ⊂:-∩-saturate-inj +⊂:-∩-saturate P⊂Q (intersect p q) = (⊂:-∩-saturate P⊂Q p [∩] ⊂:-∩-saturate P⊂Q q) >>= ⊂:-∩-lift-saturate + +⊂:-∪-saturate : ∀ {P Q} → (P ⊂: Q) → (∪-Saturate P ⊂: ∪-Saturate Q) +⊂:-∪-saturate P⊂Q (base p) = P⊂Q p >>= ⊂:-∪-saturate-inj +⊂:-∪-saturate P⊂Q (union p q) = (⊂:-∪-saturate P⊂Q p [∪] ⊂:-∪-saturate P⊂Q q) >>= ⊂:-∪-lift-saturate + +⊂:-∩-saturate-indn : ∀ {P Q} → (P ⊂: Q) → (∩-Lift Q Q ⊂: Q) → (∩-Saturate P ⊂: Q) +⊂:-∩-saturate-indn P⊂Q QQ⊂Q (base p) = P⊂Q p +⊂:-∩-saturate-indn P⊂Q QQ⊂Q (intersect p q) = (⊂:-∩-saturate-indn P⊂Q QQ⊂Q p [∩] ⊂:-∩-saturate-indn P⊂Q QQ⊂Q q) >>= QQ⊂Q + +⊂:-∪-saturate-indn : ∀ {P Q} → (P ⊂: Q) → (∪-Lift Q Q ⊂: Q) → (∪-Saturate P ⊂: Q) +⊂:-∪-saturate-indn P⊂Q QQ⊂Q (base p) = P⊂Q p +⊂:-∪-saturate-indn P⊂Q QQ⊂Q (union p q) = (⊂:-∪-saturate-indn P⊂Q QQ⊂Q p [∪] ⊂:-∪-saturate-indn P⊂Q QQ⊂Q q) >>= QQ⊂Q + +∪-saturate-resp-∩-saturation : ∀ {P} → (∩-Lift P P ⊂: P) → (∩-Lift (∪-Saturate P) (∪-Saturate P) ⊂: ∪-Saturate P) +∪-saturate-resp-∩-saturation ∩P⊂P (intersect (base p) (base q)) = ∩P⊂P (intersect p q) >>= ⊂:-∪-saturate-inj +∪-saturate-resp-∩-saturation ∩P⊂P (intersect p (union q q₁)) = (∪-saturate-resp-∩-saturation ∩P⊂P (intersect p q) [∪] ∪-saturate-resp-∩-saturation ∩P⊂P (intersect p q₁)) >>= ⊂:-∪-lift-saturate >>=ˡ <:-∩-distl-∪ >>=ʳ ∩-distl-∪-<: +∪-saturate-resp-∩-saturation ∩P⊂P (intersect (union p p₁) q) = (∪-saturate-resp-∩-saturation ∩P⊂P (intersect p q) [∪] ∪-saturate-resp-∩-saturation ∩P⊂P (intersect p₁ q)) >>= ⊂:-∪-lift-saturate >>=ˡ <:-∩-distr-∪ >>=ʳ ∩-distr-∪-<: + +ov-language : ∀ {F t} → FunType F → (∀ {S T} → Overloads F (S ⇒ T) → Language (S ⇒ T) t) → Language F t +ov-language (S ⇒ T) p = p here +ov-language (F ∩ G) p = (ov-language F (p ∘ left) , ov-language G (p ∘ right)) + +ov-<: : ∀ {F R S T U} → FunType F → Overloads F (R ⇒ S) → ((R ⇒ S) <: (T ⇒ U)) → F <: (T ⇒ U) +ov-<: F here p = p +ov-<: (F ∩ G) (left o) p = <:-trans <:-∩-left (ov-<: F o p) +ov-<: (F ∩ G) (right o) p = <:-trans <:-∩-right (ov-<: G o p) + +<:ᵒ-impl-<: : ∀ {F G} → FunType F → FunType G → (F <:ᵒ G) → (F <: G) +<:ᵒ-impl-<: F (T ⇒ U) F>= ⊂:-overloads-left +⊂:-overloads-⋒ (R ⇒ S) (G ∩ H) (intersect here (right o)) = ⊂:-overloads-⋒ (R ⇒ S) H (intersect here o) >>= ⊂:-overloads-right +⊂:-overloads-⋒ (E ∩ F) G (intersect (left n) o) = ⊂:-overloads-⋒ E G (intersect n o) >>= ⊂:-overloads-left +⊂:-overloads-⋒ (E ∩ F) G (intersect (right n) o) = ⊂:-overloads-⋒ F G (intersect n o) >>= ⊂:-overloads-right + +⊂:-⋒-overloads : ∀ {F G} → FunType F → FunType G → Overloads (F ⋒ G) ⊂: ∩-Lift (Overloads F) (Overloads G) +⊂:-⋒-overloads (R ⇒ S) (T ⇒ U) here = defn (intersect here here) (∩ⁿ-<:-∩ R T) (∩-<:-∩ⁿ S U) +⊂:-⋒-overloads (R ⇒ S) (G ∩ H) (left o) = ⊂:-⋒-overloads (R ⇒ S) G o >>= ⊂:-∩-lift ⊂:-refl ⊂:-overloads-left +⊂:-⋒-overloads (R ⇒ S) (G ∩ H) (right o) = ⊂:-⋒-overloads (R ⇒ S) H o >>= ⊂:-∩-lift ⊂:-refl ⊂:-overloads-right +⊂:-⋒-overloads (E ∩ F) G (left o) = ⊂:-⋒-overloads E G o >>= ⊂:-∩-lift ⊂:-overloads-left ⊂:-refl +⊂:-⋒-overloads (E ∩ F) G (right o) = ⊂:-⋒-overloads F G o >>= ⊂:-∩-lift ⊂:-overloads-right ⊂:-refl + +⊂:-overloads-⋓ : ∀ {F G} → FunType F → FunType G → ∪-Lift (Overloads F) (Overloads G) ⊂: Overloads (F ⋓ G) +⊂:-overloads-⋓ (R ⇒ S) (T ⇒ U) (union here here) = defn here (∪-<:-∪ⁿ R T) (∪ⁿ-<:-∪ S U) +⊂:-overloads-⋓ (R ⇒ S) (G ∩ H) (union here (left o)) = ⊂:-overloads-⋓ (R ⇒ S) G (union here o) >>= ⊂:-overloads-left +⊂:-overloads-⋓ (R ⇒ S) (G ∩ H) (union here (right o)) = ⊂:-overloads-⋓ (R ⇒ S) H (union here o) >>= ⊂:-overloads-right +⊂:-overloads-⋓ (E ∩ F) G (union (left n) o) = ⊂:-overloads-⋓ E G (union n o) >>= ⊂:-overloads-left +⊂:-overloads-⋓ (E ∩ F) G (union (right n) o) = ⊂:-overloads-⋓ F G (union n o) >>= ⊂:-overloads-right + +⊂:-⋓-overloads : ∀ {F G} → FunType F → FunType G → Overloads (F ⋓ G) ⊂: ∪-Lift (Overloads F) (Overloads G) +⊂:-⋓-overloads (R ⇒ S) (T ⇒ U) here = defn (union here here) (∪ⁿ-<:-∪ R T) (∪-<:-∪ⁿ S U) +⊂:-⋓-overloads (R ⇒ S) (G ∩ H) (left o) = ⊂:-⋓-overloads (R ⇒ S) G o >>= ⊂:-∪-lift ⊂:-refl ⊂:-overloads-left +⊂:-⋓-overloads (R ⇒ S) (G ∩ H) (right o) = ⊂:-⋓-overloads (R ⇒ S) H o >>= ⊂:-∪-lift ⊂:-refl ⊂:-overloads-right +⊂:-⋓-overloads (E ∩ F) G (left o) = ⊂:-⋓-overloads E G o >>= ⊂:-∪-lift ⊂:-overloads-left ⊂:-refl +⊂:-⋓-overloads (E ∩ F) G (right o) = ⊂:-⋓-overloads F G o >>= ⊂:-∪-lift ⊂:-overloads-right ⊂:-refl + +∪-saturate-overloads : ∀ {F} → FunType F → Overloads (∪-saturate F) ⊂: ∪-Saturate (Overloads F) +∪-saturate-overloads (S ⇒ T) here = just (base here) +∪-saturate-overloads (F ∩ G) (left (left o)) = ∪-saturate-overloads F o >>= ⊂:-∪-saturate ⊂:-overloads-left +∪-saturate-overloads (F ∩ G) (left (right o)) = ∪-saturate-overloads G o >>= ⊂:-∪-saturate ⊂:-overloads-right +∪-saturate-overloads (F ∩ G) (right o) = + ⊂:-⋓-overloads (normal-∪-saturate F) (normal-∪-saturate G) o >>= + ⊂:-∪-lift (∪-saturate-overloads F) (∪-saturate-overloads G) >>= + ⊂:-∪-lift (⊂:-∪-saturate ⊂:-overloads-left) (⊂:-∪-saturate ⊂:-overloads-right) >>= + ⊂:-∪-lift-saturate + +overloads-∪-saturate : ∀ {F} → FunType F → ∪-Saturate (Overloads F) ⊂: Overloads (∪-saturate F) +overloads-∪-saturate F = ⊂:-∪-saturate-indn (inj F) (step F) where + + inj : ∀ {F} → FunType F → Overloads F ⊂: Overloads (∪-saturate F) + inj (S ⇒ T) here = just here + inj (F ∩ G) (left p) = inj F p >>= ⊂:-overloads-left >>= ⊂:-overloads-left + inj (F ∩ G) (right p) = inj G p >>= ⊂:-overloads-right >>= ⊂:-overloads-left + + step : ∀ {F} → FunType F → ∪-Lift (Overloads (∪-saturate F)) (Overloads (∪-saturate F)) ⊂: Overloads (∪-saturate F) + step (S ⇒ T) (union here here) = defn here (<:-∪-lub <:-refl <:-refl) <:-∪-left + step (F ∩ G) (union (left (left p)) (left (left q))) = step F (union p q) >>= ⊂:-overloads-left >>= ⊂:-overloads-left + step (F ∩ G) (union (left (left p)) (left (right q))) = ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) (union p q) >>= ⊂:-overloads-right + step (F ∩ G) (union (left (right p)) (left (left q))) = ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) (union q p) >>= ⊂:-overloads-right >>=ˡ <:-∪-symm >>=ʳ <:-∪-symm + step (F ∩ G) (union (left (right p)) (left (right q))) = step G (union p q) >>= ⊂:-overloads-right >>= ⊂:-overloads-left + step (F ∩ G) (union p (right q)) with ⊂:-⋓-overloads (normal-∪-saturate F) (normal-∪-saturate G) q + step (F ∩ G) (union (left (left p)) (right q)) | defn (union q₁ q₂) q₃ q₄ = + (step F (union p q₁) [∪] just q₂) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union <:-refl q₃) <:-∪-assocl >>=ʳ + <:-trans <:-∪-assocr (<:-union <:-refl q₄) + step (F ∩ G) (union (left (right p)) (right q)) | defn (union q₁ q₂) q₃ q₄ = + (just q₁ [∪] step G (union p q₂)) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union <:-refl q₃) (<:-∪-lub (<:-trans <:-∪-left <:-∪-right) (<:-∪-lub <:-∪-left (<:-trans <:-∪-right <:-∪-right))) >>=ʳ + <:-trans (<:-∪-lub (<:-trans <:-∪-left <:-∪-right) (<:-∪-lub <:-∪-left (<:-trans <:-∪-right <:-∪-right))) (<:-union <:-refl q₄) + step (F ∩ G) (union (right p) (right q)) | defn (union q₁ q₂) q₃ q₄ with ⊂:-⋓-overloads (normal-∪-saturate F) (normal-∪-saturate G) p + step (F ∩ G) (union (right p) (right q)) | defn (union q₁ q₂) q₃ q₄ | defn (union p₁ p₂) p₃ p₄ = + (step F (union p₁ q₁) [∪] step G (union p₂ q₂)) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union p₃ q₃) (<:-∪-lub (<:-union <:-∪-left <:-∪-left) (<:-union <:-∪-right <:-∪-right)) >>=ʳ + <:-trans (<:-∪-lub (<:-union <:-∪-left <:-∪-left) (<:-union <:-∪-right <:-∪-right)) (<:-union p₄ q₄) + step (F ∩ G) (union (right p) q) with ⊂:-⋓-overloads (normal-∪-saturate F) (normal-∪-saturate G) p + step (F ∩ G) (union (right p) (left (left q))) | defn (union p₁ p₂) p₃ p₄ = + (step F (union p₁ q) [∪] just p₂) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union p₃ <:-refl) (<:-∪-lub (<:-union <:-∪-left <:-refl) (<:-trans <:-∪-right <:-∪-left)) >>=ʳ + <:-trans (<:-∪-lub (<:-union <:-∪-left <:-refl) (<:-trans <:-∪-right <:-∪-left)) (<:-union p₄ <:-refl) + step (F ∩ G) (union (right p) (left (right q))) | defn (union p₁ p₂) p₃ p₄ = + (just p₁ [∪] step G (union p₂ q)) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union p₃ <:-refl) <:-∪-assocr >>=ʳ + <:-trans <:-∪-assocl (<:-union p₄ <:-refl) + step (F ∩ G) (union (right p) (right q)) | defn (union p₁ p₂) p₃ p₄ with ⊂:-⋓-overloads (normal-∪-saturate F) (normal-∪-saturate G) q + step (F ∩ G) (union (right p) (right q)) | defn (union p₁ p₂) p₃ p₄ | defn (union q₁ q₂) q₃ q₄ = + (step F (union p₁ q₁) [∪] step G (union p₂ q₂)) >>= + ⊂:-overloads-⋓ (normal-∪-saturate F) (normal-∪-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-union p₃ q₃) (<:-∪-lub (<:-union <:-∪-left <:-∪-left) (<:-union <:-∪-right <:-∪-right)) >>=ʳ + <:-trans (<:-∪-lub (<:-union <:-∪-left <:-∪-left) (<:-union <:-∪-right <:-∪-right)) (<:-union p₄ q₄) + +∪-saturated : ∀ {F} → FunType F → ∪-Lift (Overloads (∪-saturate F)) (Overloads (∪-saturate F)) ⊂: Overloads (∪-saturate F) +∪-saturated F o = + ⊂:-∪-lift (∪-saturate-overloads F) (∪-saturate-overloads F) o >>= + ⊂:-∪-lift-saturate >>= + overloads-∪-saturate F + +∩-saturate-overloads : ∀ {F} → FunType F → Overloads (∩-saturate F) ⊂: ∩-Saturate (Overloads F) +∩-saturate-overloads (S ⇒ T) here = just (base here) +∩-saturate-overloads (F ∩ G) (left (left o)) = ∩-saturate-overloads F o >>= ⊂:-∩-saturate ⊂:-overloads-left +∩-saturate-overloads (F ∩ G) (left (right o)) = ∩-saturate-overloads G o >>= ⊂:-∩-saturate ⊂:-overloads-right +∩-saturate-overloads (F ∩ G) (right o) = + ⊂:-⋒-overloads (normal-∩-saturate F) (normal-∩-saturate G) o >>= + ⊂:-∩-lift (∩-saturate-overloads F) (∩-saturate-overloads G) >>= + ⊂:-∩-lift (⊂:-∩-saturate ⊂:-overloads-left) (⊂:-∩-saturate ⊂:-overloads-right) >>= + ⊂:-∩-lift-saturate + +overloads-∩-saturate : ∀ {F} → FunType F → ∩-Saturate (Overloads F) ⊂: Overloads (∩-saturate F) +overloads-∩-saturate F = ⊂:-∩-saturate-indn (inj F) (step F) where + + inj : ∀ {F} → FunType F → Overloads F ⊂: Overloads (∩-saturate F) + inj (S ⇒ T) here = just here + inj (F ∩ G) (left p) = inj F p >>= ⊂:-overloads-left >>= ⊂:-overloads-left + inj (F ∩ G) (right p) = inj G p >>= ⊂:-overloads-right >>= ⊂:-overloads-left + + step : ∀ {F} → FunType F → ∩-Lift (Overloads (∩-saturate F)) (Overloads (∩-saturate F)) ⊂: Overloads (∩-saturate F) + step (S ⇒ T) (intersect here here) = defn here <:-∩-left (<:-∩-glb <:-refl <:-refl) + step (F ∩ G) (intersect (left (left p)) (left (left q))) = step F (intersect p q) >>= ⊂:-overloads-left >>= ⊂:-overloads-left + step (F ∩ G) (intersect (left (left p)) (left (right q))) = ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) (intersect p q) >>= ⊂:-overloads-right + step (F ∩ G) (intersect (left (right p)) (left (left q))) = ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) (intersect q p) >>= ⊂:-overloads-right >>=ˡ <:-∩-symm >>=ʳ <:-∩-symm + step (F ∩ G) (intersect (left (right p)) (left (right q))) = step G (intersect p q) >>= ⊂:-overloads-right >>= ⊂:-overloads-left + step (F ∩ G) (intersect (right p) q) with ⊂:-⋒-overloads (normal-∩-saturate F) (normal-∩-saturate G) p + step (F ∩ G) (intersect (right p) (left (left q))) | defn (intersect p₁ p₂) p₃ p₄ = + (step F (intersect p₁ q) [∩] just p₂) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect p₃ <:-refl) (<:-∩-glb (<:-intersect <:-∩-left <:-refl) (<:-trans <:-∩-left <:-∩-right)) >>=ʳ + <:-trans (<:-∩-glb (<:-intersect <:-∩-left <:-refl) (<:-trans <:-∩-left <:-∩-right)) (<:-intersect p₄ <:-refl) + step (F ∩ G) (intersect (right p) (left (right q))) | defn (intersect p₁ p₂) p₃ p₄ = + (just p₁ [∩] step G (intersect p₂ q)) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect p₃ <:-refl) <:-∩-assocr >>=ʳ + <:-trans <:-∩-assocl (<:-intersect p₄ <:-refl) + step (F ∩ G) (intersect (right p) (right q)) | defn (intersect p₁ p₂) p₃ p₄ with ⊂:-⋒-overloads (normal-∩-saturate F) (normal-∩-saturate G) q + step (F ∩ G) (intersect (right p) (right q)) | defn (intersect p₁ p₂) p₃ p₄ | defn (intersect q₁ q₂) q₃ q₄ = + (step F (intersect p₁ q₁) [∩] step G (intersect p₂ q₂)) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect p₃ q₃) (<:-∩-glb (<:-intersect <:-∩-left <:-∩-left) (<:-intersect <:-∩-right <:-∩-right)) >>=ʳ + <:-trans (<:-∩-glb (<:-intersect <:-∩-left <:-∩-left) (<:-intersect <:-∩-right <:-∩-right)) (<:-intersect p₄ q₄) + step (F ∩ G) (intersect p (right q)) with ⊂:-⋒-overloads (normal-∩-saturate F) (normal-∩-saturate G) q + step (F ∩ G) (intersect (left (left p)) (right q)) | defn (intersect q₁ q₂) q₃ q₄ = + (step F (intersect p q₁) [∩] just q₂) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect <:-refl q₃) <:-∩-assocl >>=ʳ + <:-trans <:-∩-assocr (<:-intersect <:-refl q₄) + step (F ∩ G) (intersect (left (right p)) (right q)) | defn (intersect q₁ q₂) q₃ q₄ = + (just q₁ [∩] step G (intersect p q₂) ) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect <:-refl q₃) (<:-∩-glb (<:-trans <:-∩-right <:-∩-left) (<:-∩-glb <:-∩-left (<:-trans <:-∩-right <:-∩-right))) >>=ʳ + <:-∩-glb (<:-trans <:-∩-right <:-∩-left) (<:-trans (<:-∩-glb <:-∩-left (<:-trans <:-∩-right <:-∩-right)) q₄) + step (F ∩ G) (intersect (right p) (right q)) | defn (intersect q₁ q₂) q₃ q₄ with ⊂:-⋒-overloads (normal-∩-saturate F) (normal-∩-saturate G) p + step (F ∩ G) (intersect (right p) (right q)) | defn (intersect q₁ q₂) q₃ q₄ | defn (intersect p₁ p₂) p₃ p₄ = + (step F (intersect p₁ q₁) [∩] step G (intersect p₂ q₂)) >>= + ⊂:-overloads-⋒ (normal-∩-saturate F) (normal-∩-saturate G) >>= + ⊂:-overloads-right >>=ˡ + <:-trans (<:-intersect p₃ q₃) (<:-∩-glb (<:-intersect <:-∩-left <:-∩-left) (<:-intersect <:-∩-right <:-∩-right)) >>=ʳ + <:-trans (<:-∩-glb (<:-intersect <:-∩-left <:-∩-left) (<:-intersect <:-∩-right <:-∩-right)) (<:-intersect p₄ q₄) + +saturate-overloads : ∀ {F} → FunType F → Overloads (saturate F) ⊂: ∪-Saturate (∩-Saturate (Overloads F)) +saturate-overloads F o = ∪-saturate-overloads (normal-∩-saturate F) o >>= (⊂:-∪-saturate (∩-saturate-overloads F)) + +overloads-saturate : ∀ {F} → FunType F → ∪-Saturate (∩-Saturate (Overloads F)) ⊂: Overloads (saturate F) +overloads-saturate F o = ⊂:-∪-saturate (overloads-∩-saturate F) o >>= overloads-∪-saturate (normal-∩-saturate F) + +-- Saturated F whenever +-- * if F has overloads (R ⇒ S) and (T ⇒ U) then F has an overload which is a subtype of ((R ∩ T) ⇒ (S ∩ U)) +-- * ditto union +data Saturated (F : Type) : Set where + + defn : + + (∀ {R S T U} → Overloads F (R ⇒ S) → Overloads F (T ⇒ U) → <:-Close (Overloads F) ((R ∩ T) ⇒ (S ∩ U))) → + (∀ {R S T U} → Overloads F (R ⇒ S) → Overloads F (T ⇒ U) → <:-Close (Overloads F) ((R ∪ T) ⇒ (S ∪ U))) → + ----------- + Saturated F + +-- saturated F is saturated! +saturated : ∀ {F} → FunType F → Saturated (saturate F) +saturated F = defn + (λ n o → (saturate-overloads F n [∩] saturate-overloads F o) >>= ∪-saturate-resp-∩-saturation ⊂:-∩-lift-saturate >>= overloads-saturate F) + (λ n o → ∪-saturated (normal-∩-saturate F) (union n o)) diff --git a/prototyping/Tests/PrettyPrinter/smoke_test/out.txt b/prototyping/Tests/PrettyPrinter/smoke_test/out.txt index 34e0c4fe..ca95cae9 100644 --- a/prototyping/Tests/PrettyPrinter/smoke_test/out.txt +++ b/prototyping/Tests/PrettyPrinter/smoke_test/out.txt @@ -10,10 +10,10 @@ local function comp(f) end local id2 = comp(id)(id) local nil2 = id2(nil) -local a : any = nil +local a : unknown = nil local b : nil = nil local c : (nil) -> nil = nil -local d : (any & nil) = nil -local e : any? = nil +local d : (unknown & nil) = nil +local e : unknown? = nil local f : number = 123.0 return id2(nil2) diff --git a/rfcs/STATUS.md b/rfcs/STATUS.md index 93a09ece..23a1be83 100644 --- a/rfcs/STATUS.md +++ b/rfcs/STATUS.md @@ -15,35 +15,26 @@ This document tracks unimplemented RFCs. **Status**: Needs implementation -## Sealed/unsealed typing changes - -[RFC: Unsealed table literals](https://github.com/Roblox/luau/blob/master/rfcs/unsealed-table-literals.md) | -[RFC: Only strip optional properties from unsealed tables during subtyping](https://github.com/Roblox/luau/blob/master/rfcs/unsealed-table-subtyping-strips-optional-properties.md) - -**Status**: Implemented but not fully rolled out yet. - -## Singleton types - -[RFC: Singleton types](https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md) - -**Status**: Implemented but not fully rolled out yet. - -## Safe navigation operator - -[RFC: Safe navigation postfix operator (?)](https://github.com/Roblox/luau/blob/master/rfcs/syntax-safe-navigation-operator.md) - -**Status**: Needs implementation. - -**Notes**: We have unresolved issues with interaction between this feature and Roblox instance hierarchy. This may affect the viability of this proposal. - ## String interpolation [RFC: String interpolation](https://github.com/Roblox/luau/blob/master/rfcs/syntax-string-interpolation.md) **Status**: Needs implementation -## Generalized iteration +## Lower Bounds Calculation -[RFC: Generalized iteration](https://github.com/Roblox/luau/blob/master/rfcs/generalized-iteration.md) +[RFC: Lower bounds calculation](https://github.com/Roblox/luau/blob/master/rfcs/lower-bounds-calculation.md) + +**Status**: Implemented but not fully rolled out yet. + +## never and unknown types + +[RFC: never and unknown types](https://github.com/Roblox/luau/blob/master/rfcs/never-and-unknown-types.md) + +**Status**: Needs implementation + +## __len metamethod for tables and rawlen function + +[RFC: Support __len metamethod for tables and rawlen function](https://github.com/Roblox/luau/blob/master/rfcs/len-metamethod-rawlen.md) **Status**: Needs implementation diff --git a/rfcs/function-bit32-countlz-countrz.md b/rfcs/function-bit32-countlz-countrz.md index d2439f72..b4ccb197 100644 --- a/rfcs/function-bit32-countlz-countrz.md +++ b/rfcs/function-bit32-countlz-countrz.md @@ -1,11 +1,11 @@ # bit32.countlz/countrz +**Status**: Implemented + ## Summary Add bit32.countlz (count left zeroes) and bit32.countrz (count right zeroes) to accelerate bit scanning -**Status**: Implemented - ## Motivation All CPUs have instructions to determine the position of first/last set bit in an integer. These instructions have a variety of uses, the popular ones being: diff --git a/rfcs/function-coroutine-close.md b/rfcs/function-coroutine-close.md index 6def1533..b9ffbf6f 100644 --- a/rfcs/function-coroutine-close.md +++ b/rfcs/function-coroutine-close.md @@ -1,11 +1,11 @@ # coroutine.close +**Status**: Implemented + ## Summary Add `coroutine.close` function from Lua 5.4 that takes a suspended coroutine and makes it "dead" (non-runnable). -**Status**: Implemented - ## Motivation When implementing various higher level objects on top of coroutines, such as promises, it can be useful to cancel the coroutine execution externally - when the caller is not diff --git a/rfcs/generalized-iteration.md b/rfcs/generalized-iteration.md index 72bdd69e..99671090 100644 --- a/rfcs/generalized-iteration.md +++ b/rfcs/generalized-iteration.md @@ -1,5 +1,7 @@ # Generalized iteration +**Status**: Implemented + ## Summary Introduce support for iterating over tables without using `pairs`/`ipairs` as well as a generic customization point for iteration via `__iter` metamethod. diff --git a/rfcs/len-metamethod-rawlen.md b/rfcs/len-metamethod-rawlen.md new file mode 100644 index 00000000..45284b71 --- /dev/null +++ b/rfcs/len-metamethod-rawlen.md @@ -0,0 +1,43 @@ +# Support `__len` metamethod for tables and `rawlen` function + +## Summary + +`__len` metamethod will be called by `#` operator on tables, matching Lua 5.2 + +## Motivation + +Lua 5.1 invokes `__len` only on userdata objects, whereas Lua 5.2 extends this to tables. In addition to making `__len` metamethod more uniform and making Luau +more compatible with later versions of Lua, this has the important advantage which is that it makes it possible to implement an index based container. + +Before `__iter` and `__len` it was possible to implement a custom container using `__index`/`__newindex`, but to iterate through the container a custom function was +necessary, because Luau didn't support generalized iteration, `__pairs`/`__ipairs` from Lua 5.2, or `#` override. + +With generalized iteration, a custom container can implement its own iteration behavior so as long as code uses `for k,v in obj` iteration style, the container can +be interfaced with the same way as a table. However, when the container uses integer indices, manual iteration via `#` would still not work - which is required for some +more complicated algorithms, or even to simply iterate through the container backwards. + +Supporting `__len` would make it possible to implement a custom integer based container that exposes the same interface as a table does. + +## Design + +`#v` will call `__len` metamethod if the object is a table and the metamethod exists; the result of the metamethod will be returned if it's a number (an error will be raised otherwise). + +`table.` functions that implicitly compute table length, such as `table.getn`, `table.insert`, will continue using the actual table length. This is consistent with the +general policy that Luau doesn't support metamethods in `table.` functions. + +A new function, `rawlen(v)`, will be added to the standard library; given a string or a table, it will return the length of the object without calling any metamethods. +The new function has the previous behavior of `#` operator with the exception of not supporting userdata inputs, as userdata doesn't have an inherent definition of length. + +## Drawbacks + +`#` is an operator that is used frequently and as such an extra metatable check here may impact performance. However, `#` is usually called on tables without metatables, +and even when it is, using the existing metamethod-absence-caching approach we use for many other metamethods a test version of the change to support `__len` shows no +statistically significant difference on existing benchmark suite. This does complicate the `#` computation a little more which may affect JIT as well, but even if the +table doesn't have a metatable the process of computing `#` involves a series of condition checks and as such will likely require slow paths anyway. + +This is technically changing semantics of `#` when called on tables with an existing `__len` metamethod, and as such has a potential to change behavior of an existing valid program. +That said, it's unlikely that any table would have a metatable with `__len` metamethod as outside of userdata it would not anything, and this drawback is not feasible to resolve with any alternate version of the proposal. + +## Alternatives + +Do not implement `__len`. diff --git a/rfcs/lower-bounds-calculation.md b/rfcs/lower-bounds-calculation.md new file mode 100644 index 00000000..a1793884 --- /dev/null +++ b/rfcs/lower-bounds-calculation.md @@ -0,0 +1,217 @@ +# Lower Bounds Calculation + +## Summary + +We propose adapting lower bounds calculation from Pierce's Local Type Inference paper into the Luau type inference algorithm. + +https://www.cis.upenn.edu/~bcpierce/papers/lti-toplas.pdf + +## Motivation + +There are a number of important scenarios that occur where Luau cannot infer a sensible type without annotations. + +Many of these revolve around type variables that occur in contravariant positions. + +### Function Return Types + +A very common thing to write in Luau is a function to try to find something in some data structure. These functions habitually return the relevant datum when it is successfully found, or `nil` in the case that it cannot. For instance: + +```lua +-- A.lua +function find_first_if(vec, f) + for i, e in ipairs(vec) do + if f(e) then + return i + end + end + + return nil +end +``` + +This function has two `return` statements: One returns `number` and the other `nil`. Today, Luau flags this as an error. We ask authors to add a return annotation to make this error go away. + +We would like to automatically infer `find_first_if : ({T}, (T) -> boolean) -> number?`. + +Higher order functions also present a similar problem. + +```lua +-- B.lua +function foo(f) + f(5) + f("string") +end +``` + +There is nothing wrong with the implementation of `foo` here, but Luau fails to typecheck it all the same because `f` is used in an inconsistent way. This too can be worked around by introducing a type annotation for `f`. + +The fact that the return type of `f` is never used confounds things a little, but for now it would be a big improvement if we inferred `f : ((number | string) -> T...) -> ()`. + +## Design + +We introduce a new kind of TypeVar, `ConstrainedTypeVar` to represent a TypeVar whose lower bounds are known. We will never expose syntax for a user to write these types: They only temporarily exist as type inference is being performed. + +When unifying some type with a `ConstrainedTypeVar` we _broaden_ the set of constraints that can be placed upon it. + +It may help to realize that what we have been doing up until now has been _upper bounds calculation_. + +When we `quantify` a function, we will _normalize_ each type and convert each `ConstrainedTypeVar` into a `UnionTypeVar`. + +### Normalization + +When computing lower bounds, we need to have some process by which we reduce types down to a minimal shape and canonicalize them, if only to have a clean way to flush out degenerate unions like `A | A`. Normalization is about reducing union and intersection types to a minimal, canonicalizable shape. + +A normalized union is one where there do not exist two branches on the union where one is a subtype of the other. It is quite straightforward to implement. + +A normalized intersection is a little bit more complicated: + +1. The tables of an intersection are always combined into a single table. Coincident properties are merged into intersections of their own. + * eg `normalize({x: number, y: string} & {y: number, z: number}) == {x: number, y: string & number, z: number}` + * This is recursive. eg `normalize({x: {y: number}} & {x: {y: string}}) == {x: {y: number & string}}` +1. If two functions in the intersection have a subtyping relationship, the normalization results only in the super-type-most function. (more on function subtyping later) + +### Function subtyping relationships + +If we are going to infer intersections of functions, then we need to be very careful about keeping combinatorics under control. We therefore need to be very deliberate about what subtyping rules we have for functions of differing arity. We have some important requirements: + +* We'd like some way to canonicalize intersections of functions, and yet +* optional function arguments are a great feature that we don't want to break + +A very important use case for us is the case where the user is providing a callback to some higher-order function, and that function will be invoked with extra arguments that the original customer doesn't actually care about. For example: + +```lua +-- C.lua +function map_array(arr, f) + local result = {} + for i, e in ipairs(arr) do + table.insert(result, f(e, i, arr)) + end + return result +end + +local example = {1, 2, 3, 4} +local example_result = map_array(example, function(i) return i * 2 end) +``` + +This function mirrors the actual `Array.map` function in JavaScript. It is very frequent for users of this function to provide a lambda that only accepts one argument. It would be annoying for callers to be forced to provide a lambda that accepts two unused arguments. This obviously becomes even worse if the function later changes to provide yet more optional information to the callback. + +This use case is very important for Roblox, as we have many APIs that accept callbacks. Implementors of those callbacks frequently omit arguments that they don't care about. + +Here is an example straight out of the Roblox developer documentation. ([full example here](https://developer.roblox.com/en-us/api-reference/event/BasePart/Touched)) + +```lua +-- D.lua +local part = script.Parent + +local function blink() + -- ... +end + +part.Touched:Connect(blink) +``` + +The `Touched` event actually passes a single argument: the part that touched the `Instance` in question. In this example, it is omitted from the callback handler. + +We therefore want _oversaturation_ of a function to be allowed, but this combines with optional function arguments to create a problem with soundness. Consider the following: + +```lua +-- E.lua +type Callback = (Instance) -> () + +local cb: Callback +function register_callback(c: Callback) + cb = c +end + +function invoke_callback(i: Instance) + cb(i) +end + +--- + +function bad_callback(x: number?) +end + +local obscured: () -> () = bad_callback + +register_callback(obscured) + +function good_callback() +end + +register_callback(good_callback) +``` + +The problem we run into is, if we allow the subtyping rule `(T?) -> () <: () -> ()` and also allow oversaturation of a function, it becomes easy to obscure an argument type and pass the wrong type of value to it. + +Next, consider the following type alias + +```lua +-- F.lua +type OldFunctionType = (any, any) -> any +type NewFunctionType = (any) -> any +type FunctionType = OldFunctionType & NewFunctionType +``` + +If we have a subtyping rule `(T0..TN) <: (T0..TN-1)` to permit the function subtyping relationship `(T0..TN-1) -> R <: (T0..TN) -> R`, then the above type alias normalizes to `(any) -> any`. In order to call the two-argument variation, we would need to permit oversaturation, which runs afoul of the soundness hole from the previous example. + +We need a solution here. + +To resolve this, let's reframe things in simpler terms: + +If there is never a subtyping relationship between packs of different length, then we don't have any soundness issues, but we find ourselves unable to register `good_callback`. + +To resolve _that_, consider that we are in truth being a bit hasty when we say `good_callback : () -> ()`. We can pass any number of arguments to this function safely. We could choose to type `good_callback : () -> () & (any) -> () & (any, any) -> () & ...`. Luau already has syntax for this particular sort of infinite intersection: `good_callback : (any...) -> ()`. + +So, we propose some different inference rules for functions: + +1. The AST fragment `function(arg0..argN) ... end` is typed `(T0..TN, any...) -> R` where `arg0..argN : T0..TN` and `R` is the inferred return type of the function body. Function statements are inferred the same way. +1. Type annotations are unchanged. `() -> ()` is still a nullary function. + +For reference, the subtyping rules for unions and functions are unchanged. We include them here for clarity. + +1. `A <: A | B` +1. `B <: A | B` +1. `A | B <: T` if `A <: T` or `B <: T` +1. `T -> R <: U -> S` if `U <: T` and `R <: S` + +We propose new subtyping rules for type packs: + +1. `(T0..TN) <: (U0..UN)` if, for each `T` and `U`, `T <: U` +1. `(U...)` is the same as `() | (U) | (U, U) | (U, U, U) | ...`, therefore +1. `(T0..TN) <: (U...)` if for each `T`, `T <: U`, therefore +1. `(U...) -> R <: (T0..TN) -> R` if for each `T`, `T <: U` + +The important difference is that we remove all subtyping rules that mention options. Functions of different arities are no longer considered subtypes of one another. Optional function arguments are still allowed, but function as a feature of function calls. + +Under these rules, functions of different arities can never be converted to one another, but actual functions are known to be safe to oversaturate with anything, and so gain a type that says so. + +Under these subtyping rules, snippets `C.lua` and `D.lua`, check the way we want: literal functions are implicitly safe to oversaturate, so it is fine to cast them as the necessary callback function type. + +`E.lua` also typechecks the way we need it to: `(Instance) -> () ()` and so `obscured` cannot receive the value `bad_callback`, which prevents it from being passed to `register_callback`. However, `good_callback : (any...) -> ()` and `(any...) -> () <: (Instance) -> ()` and so it is safe to register `good_callback`. + +Snippet `F.lua` is also fixed with this ruleset: There is no subtyping relationship between `(any) -> ()` and `(any, any) -> ()`, so the intersection is not combined under normalization. + +This works, but itself creates some small problems that we need to resolve: + +First, the `...` symbol still needs to be unavailable for functions that have been given this implicit `...any` type. This is actually taken care of in the Luau parser, so no code change is required. + +Secondly, we do not want to silently allow oversaturation of direct calls to a function if we know that the arguments will be ignored. We need to treat these variadic packs differently when unifying for function calls. + +Thirdly, we don't want to display this variadic in the signature if the author doesn't expect to see it. + +We solve these issues by adding a property `bool VariadicTypePack::hidden` to the implementation and switching on it in the above scenarios. The implementation is relatively straightforward for all 3 cases. + +## Drawbacks + +There is a potential cause for concern that we will be inferring unions of functions in cases where we previously did not. Unions are known to be potential sources of performance issues. One possibility is to allow Luau to be less intelligent and have it "give up" and produce less precise types. This would come at the cost of accuracy and soundness. + +If we allow functions to be oversaturated, we are going to miss out on opportunities to warn the user about legitimate problems with their program. I think we will have to work out some kind of special logic to detect when we are oversaturating a function whose exact definition is known and warn on that. + +Allowing indirect function calls to be oversaturated with `nil` values only should be safe, but a little bit unfortunate. As long as we statically know for certain that `nil` is actually a permissible value for that argument position, it should be safe. + +## Alternatives + +If we are willing to sacrifice soundness, we could adopt success typing and come up with an inference algorithm that produces less precise type information. + +We could also technically choose to do nothing, but this has some unpalatable consequences: Something I would like to do in the near future is to have the inference algorithm assume the same `self` type for all methods of a table. This will make inference of common OO patterns dramatically more intuitive and ergonomic, but inference of polymorphic methods requires some kind of lower bounds calculation to work correctly. diff --git a/rfcs/never-and-unknown-types.md b/rfcs/never-and-unknown-types.md new file mode 100644 index 00000000..d996afc6 --- /dev/null +++ b/rfcs/never-and-unknown-types.md @@ -0,0 +1,144 @@ +# never and unknown types + +## Summary + +Add `unknown` and `never` types that are inhabited by everything and nothing respectively. + +## Motivation + +There are lots of cases in local type inference, semantic subtyping, +and type normalization, where it would be useful to have top and +bottom types. Currently, `any` is filling that role, but it has +special "switch off the type system" superpowers. + +Any use of `unknown` must be narrowed by type refinements unless another `unknown` or `any` is expected. For +example a function which can return any value is: + +```lua + function anything() : unknown ... end +``` + +and can be used as: + +```lua + local x = anything() + if type(x) == "number" then + print(x + 1) + end +``` + +The type of this function cannot be given concisely in current +Luau. The nearest equivalent is `any`, but this switches off the type system, for example +if the type of `anything` is `() -> any` then the following code typechecks: + +```lua + local x = anything() + print(x + 1) +``` + +This is fine in nonstrict mode, but strict mode should flag this as an error. + +The `never` type comes up whenever type inference infers incompatible types for a variable, for example + +```lua + function oops(x) + print("hi " .. x) -- constrains x must be a string + print(math.abs(x)) -- constrains x must be a number + end +``` + +The most general type of `x` is `string & number`, so this code gives +a type error, but we still need to provide a type for `oops`. With a +`never` type, we can infer the type `oops : (never) -> ()`. + +or when exhaustive type casing is achieved: + +```lua + function f(x: string | number) + if type(x) == "string" then + -- x : string + elseif type(x) == "number" then + -- x : number + else + -- x : never + end + end +``` + +or even when the type casing is simply nonsensical: + +```lua + function f(x: string | number) + if type(x) == "string" and type(x) == "number" then + -- x : string & number which is never + end + end +``` + +The `never` type is also useful in cases such as tagged unions where +some of the cases are impossible. For example: + +```lua + type Result = { err: false, val: T } | { err: true, err: E } +``` + +For code which we know is successful, we would like to be able to +indicate that the error case is impossible. With a `never` type, we +can do this with `Result`. Similarly, code which cannot succeed +has type `Result`. + +These types can _almost_ be defined in current Luau, but only quite verbosely: + +```lua + type never = number & string + type unknown = nil | number | boolean | string | {} | (...never) -> (...unknown) +``` + +But even for `unknown` it is impossible to include every single data types, e.g. every root class. + +Providing `never` and `unknown` as built-in types makes the code for +type inference simpler, for example we have a way to present a union +type with no options (as `never`). Otherwise we have to contend with ad hoc +corner cases. + +## Design + +Add: + +* a type `never`, inhabited by nothing, and +* a type `unknown`, inhabited by everything. + +And under success types (nonstrict mode), `unknown` is exactly equivalent to `any` because `unknown` +encompasses everything as does `any`. + +The interesting thing is that `() -> (never, string)` is equivalent to `() -> never` because all +values in a pack must be inhabitable in order for the pack itself to also be inhabitable. In fact, +the type `() -> never` is not completely accurate, it should be `() -> (never, ...never)` to avoid +cascading type errors. Ditto for when an expression list `f(), g()` where the resulting type pack is +`(never, string, number)` is still the same as `(never, ...never)`. + +```lua + function f(): never error() end + function g(): string return "" end + + -- no cascading type error where count mismatches, because the expression list f(), g() + -- was made to return (never, ...never) due to the presence of a never type in the pack + local x, y, z = f(), g() + -- x : never + -- y : never + -- z : never +``` + +## Drawbacks + +Another bit of complexity budget spent. + +These types will be visible to creators, so yay bikeshedding! + +Replacing `any` with `unknown` is a breaking change: code in strict mode may now produce errors. + +## Alternatives + +Stick with the current use of `any` for these cases. + +Make `never` and `unknown` type aliases rather than built-ins. diff --git a/rfcs/syntax-safe-navigation-operator.md b/rfcs/syntax-safe-navigation-operator.md deleted file mode 100644 index c98f3957..00000000 --- a/rfcs/syntax-safe-navigation-operator.md +++ /dev/null @@ -1,102 +0,0 @@ -# Safe navigation postfix operator (?) - -## Summary - -Introduce syntax to navigate through `nil` values, or short-circuit with `nil` if it was encountered. - - -## Motivation - -nil values are very common in Lua, and take care to prevent runtime errors. - -Currently, attempting to index `dog.name` while caring for `dog` being nil requires some form of the following: - -```lua -local dogName = nil -if dog ~= nil then - dogName = dog.name -end -``` - -...or the unusual to read... - -```lua -local dogName = dog and dog.name -``` - -...which will return `false` if `dog` is `false`, instead of throwing an error because of the index of `false.name`. - -Luau provides the if...else expression making this turn into: - -```lua -local dogName = if dog == nil then nil else dog.name -``` - -...but this is fairly clunky for such a common expression. - -## Design - -The safe navigation operator will make all of these smooth, by supporting `x?.y` to safely index nil values. `dog?.name` would resolve to `nil` if `dog` was nil, or the name otherwise. - -The previous example turns into `local dogName = dog?.name` (or just using `dog?.name` elsewhere). - -Failing the nil-safety check early would make the entire expression nil, for instance `dog?.body.legs` would resolve to `nil` if `dog` is nil, rather than resolve `dog?.body` into nil, then turning into `nil.legs`. - -```lua -dog?.name --[[ is the same as ]] if dog == nil then nil else dog.name -``` - -The short-circuiting is limited within the expression. - -```lua -dog?.owner.name -- This will return nil if `dog` is nil -(dog?.owner).name -- `(dog?.owner)` resolves to nil, of which `name` is then indexed. This will error at runtime if `dog` is nil. - -dog?.legs + 3 -- `dog?.legs` is resolved on its own, meaning this will error at runtime if it is nil (`nil + 3`) -``` - -The operator must be used in the context of either a call or an index, and so: - -```lua -local value = x? -``` - -...would be invalid syntax. - -This syntax would be based on expressions, and not identifiers, meaning that `(x or y)?.call()` would be valid syntax. - -### Type -If the expression is typed as an optional, then the resulting type would be the final expression, also optional. Otherwise, it'll just be the resulting type if `?` wasn't used. - -```lua -local optionalObject: { name: string }? -local optionalObjectName = optionalObject?.name -- resolves to `string?` - -local nonOptionalObject: { name: string } -local nonOptionalObjectName = nonOptionalObject?.name -- resolves to `string` -``` - -### Calling - -This RFC only specifies `x?.y` as an index method. `x?:y()` is currently unspecified, and `x?.y(args)` as a syntax will be reserved (will error if you try to use it). - -While being able to support `dog?.getName()` is useful, it provides [some logistical issues for the language](https://github.com/Roblox/luau/pull/142#issuecomment-990563536). - -`x?.y(args)` will be reserved both so that this can potentially be resolved later down the line if something comes up, but also because it would be a guaranteed runtime error under this RFC: `dog?.getName()` will first index `dog?.getName`, which will return nil, then will attempt to call it. - -### Assignment -`x?.y = z` is not supported, and will be reported as a syntax error. - -## Drawbacks - -As with all syntax additions, this adds complexity to the parsing of expressions, and the execution of cancelling the rest of the expression could prove challenging. - -Furthermore, with the proposed syntax, it might lock off other uses of `?` within code (and not types) for the future as being ambiguous. - -## Alternatives - -Doing nothing is an option, as current standard if-checks already work, as well as the `and` trick in other use cases, but as shown before this can create some hard to read code, and nil values are common enough that the safe navigation operator is welcome. - -Supporting optional calls/indexes, such as `x?[1]` and `x?()`, while not out of scope, are likely too fringe to support, while adding on a significant amount of parsing difficulty, especially in the case of shorthand function calls, such as `x?{}` and `x?""`. - -It is possible to make `x?.y = z` resolve to only setting `x.y` if `x` is nil, but assignments silently failing can be seen as surprising. diff --git a/rfcs/syntax-singleton-types.md b/rfcs/syntax-singleton-types.md index 26ea3028..2c1f5442 100644 --- a/rfcs/syntax-singleton-types.md +++ b/rfcs/syntax-singleton-types.md @@ -2,6 +2,8 @@ > Note: this RFC was adapted from an internal proposal that predates RFC process +**Status**: Implemented + ## Summary Introduce a new kind of type variable, called singleton types. They are just like normal types but has the capability to represent a constant runtime value as a type. diff --git a/rfcs/syntax-type-ascription-bidi.md b/rfcs/syntax-type-ascription-bidi.md index bf37eca2..0831aba5 100644 --- a/rfcs/syntax-type-ascription-bidi.md +++ b/rfcs/syntax-type-ascription-bidi.md @@ -1,11 +1,11 @@ # Relaxing type assertions +**Status**: Implemented + ## Summary The way `::` works today is really strange. The best solution we can come up with is to allow `::` to convert between any two related types. -**Status**: Implemented - ## Motivation Due to an accident of the implementation, the Luau `::` operator can only be used for downcasts and casts to `any`. diff --git a/rfcs/unsealed-table-assign-optional-property.md b/rfcs/unsealed-table-assign-optional-property.md index ed037b14..477399c2 100644 --- a/rfcs/unsealed-table-assign-optional-property.md +++ b/rfcs/unsealed-table-assign-optional-property.md @@ -1,5 +1,7 @@ # Unsealed table assignment creates an optional property +**Status**: Implemented + ## Summary In Luau, tables have a state, which can, among others, be "unsealed". diff --git a/rfcs/unsealed-table-literals.md b/rfcs/unsealed-table-literals.md index 320bf7ca..669b67d4 100644 --- a/rfcs/unsealed-table-literals.md +++ b/rfcs/unsealed-table-literals.md @@ -1,5 +1,7 @@ # Unsealed table literals +**Status**: Implemented + ## Summary Currently the only way to create an unsealed table is as an empty table literal `{}`. @@ -73,4 +75,4 @@ We could introduce a new table state for unsealed-but-precise tables. The trade-off is that that would be more precise, at the cost of adding user-visible complexity to the type system. -We could continue to treat array-like tables as sealed. \ No newline at end of file +We could continue to treat array-like tables as sealed. diff --git a/rfcs/unsealed-table-subtyping-strips-optional-properties.md b/rfcs/unsealed-table-subtyping-strips-optional-properties.md index deecfdb3..d99c1f81 100644 --- a/rfcs/unsealed-table-subtyping-strips-optional-properties.md +++ b/rfcs/unsealed-table-subtyping-strips-optional-properties.md @@ -1,5 +1,7 @@ # Only strip optional properties from unsealed tables during subtyping +**Status**: Implemented + ## Summary Currently subtyping allows optional properties to be stripped from table types during subtyping. diff --git a/scripts/run-with-cachegrind.sh b/scripts/run-with-cachegrind.sh new file mode 100644 index 00000000..787043ff --- /dev/null +++ b/scripts/run-with-cachegrind.sh @@ -0,0 +1,109 @@ +#!/bin/bash +set -euo pipefail +IFS=$'\n\t' + +declare -A event_map +event_map[Ir]="TotalInstructionsExecuted,executions\n" +event_map[I1mr]="L1_InstrReadCacheMisses,misses/op\n" +event_map[ILmr]="LL_InstrReadCacheMisses,misses/op\n" +event_map[Dr]="TotalMemoryReads,reads\n" +event_map[D1mr]="L1_DataReadCacheMisses,misses/op\n" +event_map[DLmr]="LL_DataReadCacheMisses,misses/op\n" +event_map[Dw]="TotalMemoryWrites,writes\n" +event_map[D1mw]="L1_DataWriteCacheMisses,misses/op\n" +event_map[DLmw]="LL_DataWriteCacheMisses,misses/op\n" +event_map[Bc]="ConditionalBranchesExecuted,executions\n" +event_map[Bcm]="ConditionalBranchMispredictions,mispredictions/op\n" +event_map[Bi]="IndirectBranchesExecuted,executions\n" +event_map[Bim]="IndirectBranchMispredictions,mispredictions/op\n" + +now_ms() { + echo -n $(date +%s%N | cut -b1-13) +} + +# Run cachegrind on a given benchmark and echo the results. +ITERATION_COUNT=$4 +START_TIME=$(now_ms) + +ARGS=( "$@" ) +REST_ARGS="${ARGS[@]:4}" + +valgrind \ + --quiet \ + --tool=cachegrind \ + "$1" "$2" $REST_ARGS>/dev/null + +ARGS=( "$@" ) +REST_ARGS="${ARGS[@]:4}" + + +TIME_ELAPSED=$(bc <<< "$(now_ms) - ${START_TIME}") + +# Generate report using cg_annotate and extract the header and totals of the +# recorded events valgrind was configured to record. +CG_RESULTS=$(cg_annotate $(ls -t cachegrind.out.* | head -1)) +CG_HEADERS=$(grep -B2 'PROGRAM TOTALS$' <<< "$CG_RESULTS" | head -1 | sed -E 's/\s+/\n/g' | sed '/^$/d') +CG_TOTALS=$(grep 'PROGRAM TOTALS$' <<< "$CG_RESULTS" | head -1 | grep -Po '[0-9,]+\s' | tr -d ', ') + +TOTALS_ARRAY=($CG_TOTALS) +HEADERS_ARRAY=($CG_HEADERS) + +declare -A header_map +for i in "${!TOTALS_ARRAY[@]}"; do + header_map[${HEADERS_ARRAY[$i]}]=$i +done + +# Map the results to the format that the benchmark script expects. +for i in "${!TOTALS_ARRAY[@]}"; do + TOTAL=${TOTALS_ARRAY[$i]} + + # Labels and unit descriptions are packed together in the map. + EVENT_TUPLE=${event_map[${HEADERS_ARRAY[$i]}]} + IFS=$',' read -d '\n' -ra EVENT_VALUES < <(printf "%s" "$EVENT_TUPLE") + EVENT_NAME="${EVENT_VALUES[0]}" + UNIT="${EVENT_VALUES[1]}" + + case ${HEADERS_ARRAY[$i]} in + I1mr | ILmr) + REF=${TOTALS_ARRAY[header_map["Ir"]]} + OPS_PER_SEC=$(bc -l <<< "$TOTAL / $REF") + ;; + + D1mr | DLmr) + REF=${TOTALS_ARRAY[header_map["Dr"]]} + OPS_PER_SEC=$(bc -l <<< "$TOTAL / $REF") + ;; + + D1mw | DLmw) + REF=${TOTALS_ARRAY[header_map["Dw"]]} + OPS_PER_SEC=$(bc -l <<< "$TOTAL / $REF") + ;; + + Bcm) + REF=${TOTALS_ARRAY[header_map["Bc"]]} + OPS_PER_SEC=$(bc -l <<< "$TOTAL / $REF") + ;; + + Bim) + REF=${TOTALS_ARRAY[header_map["Bi"]]} + OPS_PER_SEC=$(bc -l <<< "$TOTAL / $REF") + ;; + + *) + OPS_PER_SEC=$(bc -l <<< "$TOTAL") + ;; + esac + + STD_DEV="0%" + RUNS="1" + + if [[ $OPS_PER_SEC =~ ^[+-]?[0-9]*$ ]] + then # $OPS_PER_SEC is integer + printf "%s#%s x %.0f %s ±%s (%d runs sampled)\n" \ + "$3" "$EVENT_NAME" "$OPS_PER_SEC" "$UNIT" "$STD_DEV" "$RUNS" + else # $OPS_PER_SEC is float + printf "%s#%s x %.10f %s ±%s (%d runs sampled)\n" \ + "$3" "$EVENT_NAME" "$OPS_PER_SEC" "$UNIT" "$STD_DEV" "$RUNS" + fi + +done diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp new file mode 100644 index 00000000..7f863c6f --- /dev/null +++ b/tests/AssemblyBuilderX64.test.cpp @@ -0,0 +1,410 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/StringUtils.h" + +#include "doctest.h" + +#include +#include + +using namespace Luau::CodeGen; + +std::string bytecodeAsArray(const std::vector& bytecode) +{ + std::string result = "{"; + + for (size_t i = 0; i < bytecode.size(); i++) + Luau::formatAppend(result, "%s0x%02x", i == 0 ? "" : ", ", bytecode[i]); + + return result.append("}"); +} + +class AssemblyBuilderX64Fixture +{ +public: + void check(std::function f, std::vector result) + { + AssemblyBuilderX64 build(/* logText= */ false); + + f(build); + + build.finalize(); + + if (build.code != result) + { + printf("Expected: %s\nReceived: %s\n", bytecodeAsArray(result).c_str(), bytecodeAsArray(build.code).c_str()); + CHECK(false); + } + } +}; + +TEST_SUITE_BEGIN("x64Assembly"); + +#define SINGLE_COMPARE(inst, ...) \ + check( \ + [](AssemblyBuilderX64& build) { \ + build.inst; \ + }, \ + {__VA_ARGS__}) + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseBinaryInstructionForms") +{ + // reg, reg + SINGLE_COMPARE(add(rax, rcx), 0x48, 0x03, 0xc1); + SINGLE_COMPARE(add(rsp, r12), 0x49, 0x03, 0xe4); + SINGLE_COMPARE(add(r14, r10), 0x4d, 0x03, 0xf2); + + // reg, imm + SINGLE_COMPARE(add(rax, 0), 0x48, 0x83, 0xc0, 0x00); + SINGLE_COMPARE(add(rax, 0x7f), 0x48, 0x83, 0xc0, 0x7f); + SINGLE_COMPARE(add(rax, 0x80), 0x48, 0x81, 0xc0, 0x80, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r10, 0x7fffffff), 0x49, 0x81, 0xc2, 0xff, 0xff, 0xff, 0x7f); + + // reg, [reg] + SINGLE_COMPARE(add(rax, qword[rax]), 0x48, 0x03, 0x00); + SINGLE_COMPARE(add(rax, qword[rbx]), 0x48, 0x03, 0x03); + SINGLE_COMPARE(add(rax, qword[rsp]), 0x48, 0x03, 0x04, 0x24); + SINGLE_COMPARE(add(rax, qword[rbp]), 0x48, 0x03, 0x45, 0x00); + SINGLE_COMPARE(add(rax, qword[r10]), 0x49, 0x03, 0x02); + SINGLE_COMPARE(add(rax, qword[r12]), 0x49, 0x03, 0x04, 0x24); + SINGLE_COMPARE(add(rax, qword[r13]), 0x49, 0x03, 0x45, 0x00); + + SINGLE_COMPARE(add(r12, qword[rax]), 0x4c, 0x03, 0x20); + SINGLE_COMPARE(add(r12, qword[rbx]), 0x4c, 0x03, 0x23); + SINGLE_COMPARE(add(r12, qword[rsp]), 0x4c, 0x03, 0x24, 0x24); + SINGLE_COMPARE(add(r12, qword[rbp]), 0x4c, 0x03, 0x65, 0x00); + SINGLE_COMPARE(add(r12, qword[r10]), 0x4d, 0x03, 0x22); + SINGLE_COMPARE(add(r12, qword[r12]), 0x4d, 0x03, 0x24, 0x24); + SINGLE_COMPARE(add(r12, qword[r13]), 0x4d, 0x03, 0x65, 0x00); + + // reg, [base+imm8] + SINGLE_COMPARE(add(rax, qword[rax + 0x1b]), 0x48, 0x03, 0x40, 0x1b); + SINGLE_COMPARE(add(rax, qword[rbx + 0x1b]), 0x48, 0x03, 0x43, 0x1b); + SINGLE_COMPARE(add(rax, qword[rsp + 0x1b]), 0x48, 0x03, 0x44, 0x24, 0x1b); + SINGLE_COMPARE(add(rax, qword[rbp + 0x1b]), 0x48, 0x03, 0x45, 0x1b); + SINGLE_COMPARE(add(rax, qword[r10 + 0x1b]), 0x49, 0x03, 0x42, 0x1b); + SINGLE_COMPARE(add(rax, qword[r12 + 0x1b]), 0x49, 0x03, 0x44, 0x24, 0x1b); + SINGLE_COMPARE(add(rax, qword[r13 + 0x1b]), 0x49, 0x03, 0x45, 0x1b); + + SINGLE_COMPARE(add(r12, qword[rax + 0x1b]), 0x4c, 0x03, 0x60, 0x1b); + SINGLE_COMPARE(add(r12, qword[rbx + 0x1b]), 0x4c, 0x03, 0x63, 0x1b); + SINGLE_COMPARE(add(r12, qword[rsp + 0x1b]), 0x4c, 0x03, 0x64, 0x24, 0x1b); + SINGLE_COMPARE(add(r12, qword[rbp + 0x1b]), 0x4c, 0x03, 0x65, 0x1b); + SINGLE_COMPARE(add(r12, qword[r10 + 0x1b]), 0x4d, 0x03, 0x62, 0x1b); + SINGLE_COMPARE(add(r12, qword[r12 + 0x1b]), 0x4d, 0x03, 0x64, 0x24, 0x1b); + SINGLE_COMPARE(add(r12, qword[r13 + 0x1b]), 0x4d, 0x03, 0x65, 0x1b); + + // reg, [base+imm32] + SINGLE_COMPARE(add(rax, qword[rax + 0xabab]), 0x48, 0x03, 0x80, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rbx + 0xabab]), 0x48, 0x03, 0x83, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rsp + 0xabab]), 0x48, 0x03, 0x84, 0x24, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rbp + 0xabab]), 0x48, 0x03, 0x85, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r10 + 0xabab]), 0x49, 0x03, 0x82, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r12 + 0xabab]), 0x49, 0x03, 0x84, 0x24, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r13 + 0xabab]), 0x49, 0x03, 0x85, 0xab, 0xab, 0x00, 0x00); + + SINGLE_COMPARE(add(r12, qword[rax + 0xabab]), 0x4c, 0x03, 0xa0, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rbx + 0xabab]), 0x4c, 0x03, 0xa3, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rsp + 0xabab]), 0x4c, 0x03, 0xa4, 0x24, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rbp + 0xabab]), 0x4c, 0x03, 0xa5, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r10 + 0xabab]), 0x4d, 0x03, 0xa2, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r12 + 0xabab]), 0x4d, 0x03, 0xa4, 0x24, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r13 + 0xabab]), 0x4d, 0x03, 0xa5, 0xab, 0xab, 0x00, 0x00); + + // reg, [index*scale] + SINGLE_COMPARE(add(rax, qword[rax * 2]), 0x48, 0x03, 0x04, 0x45, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rbx * 2]), 0x48, 0x03, 0x04, 0x5d, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rbp * 2]), 0x48, 0x03, 0x04, 0x6d, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r10 * 2]), 0x4a, 0x03, 0x04, 0x55, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r12 * 2]), 0x4a, 0x03, 0x04, 0x65, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r13 * 2]), 0x4a, 0x03, 0x04, 0x6d, 0x00, 0x00, 0x00, 0x00); + + SINGLE_COMPARE(add(r12, qword[rax * 2]), 0x4c, 0x03, 0x24, 0x45, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rbx * 2]), 0x4c, 0x03, 0x24, 0x5d, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rbp * 2]), 0x4c, 0x03, 0x24, 0x6d, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r10 * 2]), 0x4e, 0x03, 0x24, 0x55, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r12 * 2]), 0x4e, 0x03, 0x24, 0x65, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r13 * 2]), 0x4e, 0x03, 0x24, 0x6d, 0x00, 0x00, 0x00, 0x00); + + // reg, [base+index*scale+imm] + SINGLE_COMPARE(add(rax, qword[rax + rax * 2]), 0x48, 0x03, 0x04, 0x40); + SINGLE_COMPARE(add(rax, qword[rax + rbx * 2 + 0x1b]), 0x48, 0x03, 0x44, 0x58, 0x1b); + SINGLE_COMPARE(add(rax, qword[rax + rbp * 2]), 0x48, 0x03, 0x04, 0x68); + SINGLE_COMPARE(add(rax, qword[rax + rbp + 0xabab]), 0x48, 0x03, 0x84, 0x28, 0xAB, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rax + r12 + 0x1b]), 0x4a, 0x03, 0x44, 0x20, 0x1b); + SINGLE_COMPARE(add(rax, qword[rax + r12 * 4 + 0xabab]), 0x4a, 0x03, 0x84, 0xa0, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rax + r13 * 2 + 0x1b]), 0x4a, 0x03, 0x44, 0x68, 0x1b); + SINGLE_COMPARE(add(rax, qword[rax + r13 + 0xabab]), 0x4a, 0x03, 0x84, 0x28, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rax + r12 * 2]), 0x4e, 0x03, 0x24, 0x60); + SINGLE_COMPARE(add(r12, qword[rax + r13 + 0xabab]), 0x4e, 0x03, 0xA4, 0x28, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rax + rbp * 2 + 0x1b]), 0x4c, 0x03, 0x64, 0x68, 0x1b); + + // reg, [imm32] + SINGLE_COMPARE(add(rax, qword[0]), 0x48, 0x03, 0x04, 0x25, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[0xabab]), 0x48, 0x03, 0x04, 0x25, 0xab, 0xab, 0x00, 0x00); + + // [addr], reg + SINGLE_COMPARE(add(qword[rax], rax), 0x48, 0x01, 0x00); + SINGLE_COMPARE(add(qword[rax + rax * 4 + 0xabab], rax), 0x48, 0x01, 0x84, 0x80, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(qword[rbx + rax * 2 + 0x1b], rax), 0x48, 0x01, 0x44, 0x43, 0x1b); + SINGLE_COMPARE(add(qword[rbx + rbp * 2 + 0x1b], rax), 0x48, 0x01, 0x44, 0x6b, 0x1b); + SINGLE_COMPARE(add(qword[rbp + rbp * 4 + 0xabab], rax), 0x48, 0x01, 0x84, 0xad, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(qword[rbp + r12 + 0x1b], rax), 0x4a, 0x01, 0x44, 0x25, 0x1b); + SINGLE_COMPARE(add(qword[r12], rax), 0x49, 0x01, 0x04, 0x24); + SINGLE_COMPARE(add(qword[r13 + rbx + 0xabab], rax), 0x49, 0x01, 0x84, 0x1d, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(qword[rax + r13 * 2 + 0x1b], rsi), 0x4a, 0x01, 0x74, 0x68, 0x1b); + SINGLE_COMPARE(add(qword[rbp + rbx * 2], rsi), 0x48, 0x01, 0x74, 0x5d, 0x00); + SINGLE_COMPARE(add(qword[rsp + r10 * 2 + 0x1b], r10), 0x4e, 0x01, 0x54, 0x54, 0x1b); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseUnaryInstructionForms") +{ + SINGLE_COMPARE(div(rcx), 0x48, 0xf7, 0xf1); + SINGLE_COMPARE(idiv(qword[rax]), 0x48, 0xf7, 0x38); + SINGLE_COMPARE(mul(qword[rax + rbx]), 0x48, 0xf7, 0x24, 0x18); + SINGLE_COMPARE(neg(r9), 0x49, 0xf7, 0xd9); + SINGLE_COMPARE(not_(r12), 0x49, 0xf7, 0xd4); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfMov") +{ + SINGLE_COMPARE(mov(rcx, 1), 0x48, 0xb9, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(mov64(rcx, 0x1234567812345678ll), 0x48, 0xb9, 0x78, 0x56, 0x34, 0x12, 0x78, 0x56, 0x34, 0x12); + SINGLE_COMPARE(mov(ecx, 2), 0xb9, 0x02, 0x00, 0x00, 0x00); + SINGLE_COMPARE(mov(cl, 2), 0xb1, 0x02); + SINGLE_COMPARE(mov(rcx, qword[rdi]), 0x48, 0x8b, 0x0f); + SINGLE_COMPARE(mov(dword[rax], 0xabcd), 0xc7, 0x00, 0xcd, 0xab, 0x00, 0x00); + SINGLE_COMPARE(mov(r13, 1), 0x49, 0xbd, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(mov64(r13, 0x1234567812345678ll), 0x49, 0xbd, 0x78, 0x56, 0x34, 0x12, 0x78, 0x56, 0x34, 0x12); + SINGLE_COMPARE(mov(r13d, 2), 0x41, 0xbd, 0x02, 0x00, 0x00, 0x00); + SINGLE_COMPARE(mov(r13, qword[r12]), 0x4d, 0x8b, 0x2c, 0x24); + SINGLE_COMPARE(mov(dword[r13], 0xabcd), 0x41, 0xc7, 0x45, 0x00, 0xcd, 0xab, 0x00, 0x00); + SINGLE_COMPARE(mov(qword[rdx], r9), 0x4c, 0x89, 0x0a); + SINGLE_COMPARE(mov(byte[rsi], 0x3), 0xc6, 0x06, 0x03); + SINGLE_COMPARE(mov(byte[rsi], al), 0x88, 0x06); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfTest") +{ + SINGLE_COMPARE(test(al, 8), 0xf6, 0xc0, 0x08); + SINGLE_COMPARE(test(eax, 8), 0xf7, 0xc0, 0x08, 0x00, 0x00, 0x00); + SINGLE_COMPARE(test(rax, 8), 0x48, 0xf7, 0xc0, 0x08, 0x00, 0x00, 0x00); + SINGLE_COMPARE(test(rcx, 0xabab), 0x48, 0xf7, 0xc1, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(test(rcx, rax), 0x48, 0x85, 0xc8); + SINGLE_COMPARE(test(rax, qword[rcx]), 0x48, 0x85, 0x01); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfShift") +{ + SINGLE_COMPARE(shl(al, 1), 0xd0, 0xe0); + SINGLE_COMPARE(shl(al, cl), 0xd2, 0xe0); + SINGLE_COMPARE(shr(al, 4), 0xc0, 0xe8, 0x04); + SINGLE_COMPARE(shr(eax, 1), 0xd1, 0xe8); + SINGLE_COMPARE(sal(eax, cl), 0xd3, 0xe0); + SINGLE_COMPARE(sal(eax, 4), 0xc1, 0xe0, 0x04); + SINGLE_COMPARE(sar(rax, 4), 0x48, 0xc1, 0xf8, 0x04); + SINGLE_COMPARE(sar(r11, 1), 0x49, 0xd1, 0xfb); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") +{ + SINGLE_COMPARE(lea(rax, qword[rdx + rcx]), 0x48, 0x8d, 0x04, 0x0a); + SINGLE_COMPARE(lea(rax, qword[rdx + rax * 4]), 0x48, 0x8d, 0x04, 0x82); + SINGLE_COMPARE(lea(rax, qword[r13 + r12 * 4 + 4]), 0x4b, 0x8d, 0x44, 0xa5, 0x04); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") +{ + // Jump back + check( + [](AssemblyBuilderX64& build) { + Label start = build.setLabel(); + build.add(rsi, 1); + build.cmp(rsi, rdi); + build.jcc(Condition::Equal, start); + }, + {0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf3, 0xff, 0xff, 0xff}); + + // Jump back, but the label is set before use + check( + [](AssemblyBuilderX64& build) { + Label start; + build.add(rsi, 1); + build.setLabel(start); + build.cmp(rsi, rdi); + build.jcc(Condition::Equal, start); + }, + {0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf7, 0xff, 0xff, 0xff}); + + // Jump forward + check( + [](AssemblyBuilderX64& build) { + Label skip; + + build.cmp(rsi, rdi); + build.jcc(Condition::Greater, skip); + build.or_(rdi, 0x3e); + build.setLabel(skip); + }, + {0x48, 0x3b, 0xf7, 0x0f, 0x8f, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xcf, 0x3e}); + + // Regular jump + check( + [](AssemblyBuilderX64& build) { + Label skip; + + build.jmp(skip); + build.and_(rdi, 0x3e); + build.setLabel(skip); + }, + {0xe9, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xe7, 0x3e}); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") +{ + SINGLE_COMPARE(vaddpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xa9, 0x58, 0xc6); + SINGLE_COMPARE(vaddpd(xmm8, xmm10, xmmword[r9]), 0xc4, 0x41, 0xa9, 0x58, 0x01); + SINGLE_COMPARE(vaddpd(ymm8, ymm10, ymm14), 0xc4, 0x41, 0xad, 0x58, 0xc6); + SINGLE_COMPARE(vaddpd(ymm8, ymm10, ymmword[r9]), 0xc4, 0x41, 0xad, 0x58, 0x01); + SINGLE_COMPARE(vaddps(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xa8, 0x58, 0xc6); + SINGLE_COMPARE(vaddps(xmm8, xmm10, xmmword[r9]), 0xc4, 0x41, 0xa8, 0x58, 0x01); + SINGLE_COMPARE(vaddsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x58, 0xc6); + SINGLE_COMPARE(vaddsd(xmm8, xmm10, qword[r9]), 0xc4, 0x41, 0xab, 0x58, 0x01); + SINGLE_COMPARE(vaddss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xaa, 0x58, 0xc6); + SINGLE_COMPARE(vaddss(xmm8, xmm10, dword[r9]), 0xc4, 0x41, 0xaa, 0x58, 0x01); + + SINGLE_COMPARE(vaddps(xmm1, xmm2, xmm3), 0xc4, 0xe1, 0xe8, 0x58, 0xcb); + SINGLE_COMPARE(vaddps(xmm9, xmm12, xmmword[r9 + r14 * 2 + 0x1c]), 0xc4, 0x01, 0x98, 0x58, 0x4c, 0x71, 0x1c); + SINGLE_COMPARE(vaddps(ymm1, ymm2, ymm3), 0xc4, 0xe1, 0xec, 0x58, 0xcb); + SINGLE_COMPARE(vaddps(ymm9, ymm12, ymmword[r9 + r14 * 2 + 0x1c]), 0xc4, 0x01, 0x9c, 0x58, 0x4c, 0x71, 0x1c); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXUnaryMergeInstructionForms") +{ + SINGLE_COMPARE(vsqrtpd(xmm8, xmm10), 0xc4, 0x41, 0xf9, 0x51, 0xc2); + SINGLE_COMPARE(vsqrtpd(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf9, 0x51, 0x01); + SINGLE_COMPARE(vsqrtpd(ymm8, ymm10), 0xc4, 0x41, 0xfd, 0x51, 0xc2); + SINGLE_COMPARE(vsqrtpd(ymm8, ymmword[r9]), 0xc4, 0x41, 0xfd, 0x51, 0x01); + SINGLE_COMPARE(vsqrtps(xmm8, xmm10), 0xc4, 0x41, 0xf8, 0x51, 0xc2); + SINGLE_COMPARE(vsqrtps(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf8, 0x51, 0x01); + SINGLE_COMPARE(vsqrtsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x51, 0xc6); + SINGLE_COMPARE(vsqrtsd(xmm8, xmm10, qword[r9]), 0xc4, 0x41, 0xab, 0x51, 0x01); + SINGLE_COMPARE(vsqrtss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xaa, 0x51, 0xc6); + SINGLE_COMPARE(vsqrtss(xmm8, xmm10, dword[r9]), 0xc4, 0x41, 0xaa, 0x51, 0x01); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXMoveInstructionForms") +{ + SINGLE_COMPARE(vmovsd(qword[r9], xmm10), 0xc4, 0x41, 0xfb, 0x11, 0x11); + SINGLE_COMPARE(vmovsd(xmm8, qword[r9]), 0xc4, 0x41, 0xfb, 0x10, 0x01); + SINGLE_COMPARE(vmovsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xab, 0x10, 0xc6); + SINGLE_COMPARE(vmovss(dword[r9], xmm10), 0xc4, 0x41, 0xfa, 0x11, 0x11); + SINGLE_COMPARE(vmovss(xmm8, dword[r9]), 0xc4, 0x41, 0xfa, 0x10, 0x01); + SINGLE_COMPARE(vmovss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xaa, 0x10, 0xc6); + SINGLE_COMPARE(vmovapd(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf9, 0x28, 0x01); + SINGLE_COMPARE(vmovapd(xmmword[r9], xmm10), 0xc4, 0x41, 0xf9, 0x29, 0x11); + SINGLE_COMPARE(vmovapd(ymm8, ymmword[r9]), 0xc4, 0x41, 0xfd, 0x28, 0x01); + SINGLE_COMPARE(vmovaps(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf8, 0x28, 0x01); + SINGLE_COMPARE(vmovaps(xmmword[r9], xmm10), 0xc4, 0x41, 0xf8, 0x29, 0x11); + SINGLE_COMPARE(vmovaps(ymm8, ymmword[r9]), 0xc4, 0x41, 0xfc, 0x28, 0x01); + SINGLE_COMPARE(vmovupd(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf9, 0x10, 0x01); + SINGLE_COMPARE(vmovupd(xmmword[r9], xmm10), 0xc4, 0x41, 0xf9, 0x11, 0x11); + SINGLE_COMPARE(vmovupd(ymm8, ymmword[r9]), 0xc4, 0x41, 0xfd, 0x10, 0x01); + SINGLE_COMPARE(vmovups(xmm8, xmmword[r9]), 0xc4, 0x41, 0xf8, 0x10, 0x01); + SINGLE_COMPARE(vmovups(xmmword[r9], xmm10), 0xc4, 0x41, 0xf8, 0x11, 0x11); + SINGLE_COMPARE(vmovups(ymm8, ymmword[r9]), 0xc4, 0x41, 0xfc, 0x10, 0x01); +} + +TEST_CASE("LogTest") +{ + AssemblyBuilderX64 build(/* logText= */ true); + + build.push(r12); + build.add(rax, rdi); + build.add(rcx, 8); + build.sub(dword[rax], 0x1fdc); + build.and_(dword[rcx], 0x37); + build.mov(rdi, qword[rax + rsi * 2]); + build.vaddss(xmm0, xmm0, dword[rax + r14 * 2 + 0x1c]); + + Label start = build.setLabel(); + build.cmp(rsi, rdi); + build.jcc(Condition::Equal, start); + + build.jmp(qword[rdx]); + build.vaddps(ymm9, ymm12, ymmword[rbp + 0xc]); + build.vaddpd(ymm2, ymm7, build.f64(2.5)); + build.neg(qword[rbp + r12 * 2]); + build.mov64(r10, 0x1234567812345678ll); + build.vmovapd(xmmword[rax], xmm11); + build.pop(r12); + build.ret(); + + build.finalize(); + + bool same = "\n" + build.text == R"( + push r12 + add rax,rdi + add rcx,8 + sub dword ptr [rax],1FDCh + and dword ptr [rcx],37h + mov rdi,qword ptr [rax+rsi*2] + vaddss xmm0,xmm0,dword ptr [rax+r14*2+01Ch] +.L1: + cmp rsi,rdi + je .L1 + jmp qword ptr [rdx] + vaddps ymm9,ymm12,ymmword ptr [rbp+0Ch] + vaddpd ymm2,ymm7,qword ptr [.start-8] + neg qword ptr [rbp+r12*2] + mov r10,1234567812345678h + vmovapd xmmword ptr [rax],xmm11 + pop r12 + ret +)"; + CHECK(same); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "Constants") +{ + // clang-format off + check( + [](AssemblyBuilderX64& build) { + build.xor_(rax, rax); + build.add(rax, build.i64(0x1234567887654321)); + build.vmovss(xmm2, build.f32(1.0f)); + build.vmovsd(xmm3, build.f64(1.0)); + build.vmovaps(xmm4, build.f32x4(1.0f, 2.0f, 4.0f, 8.0f)); + build.ret(); + }, + { + 0x48, 0x33, 0xc0, + 0x48, 0x03, 0x05, 0xee, 0xff, 0xff, 0xff, + 0xc4, 0xe1, 0xfa, 0x10, 0x15, 0xe1, 0xff, 0xff, 0xff, + 0xc4, 0xe1, 0xfb, 0x10, 0x1d, 0xcc, 0xff, 0xff, 0xff, + 0xc4, 0xe1, 0xf8, 0x28, 0x25, 0xab, 0xff, 0xff, 0xff, + 0xc3 + }); + // clang-format on +} + +TEST_CASE("ConstantStorage") +{ + AssemblyBuilderX64 build(/* logText= */ false); + + for (int i = 0; i <= 3000; i++) + build.vaddss(xmm0, xmm0, build.f32(float(i))); + + build.finalize(); + + LUAU_ASSERT(build.data.size() == 12004); + + for (int i = 0; i <= 3000; i++) + { + float v; + memcpy(&v, &build.data[build.data.size() - (i + 1) * sizeof(float)], sizeof(v)); + LUAU_ASSERT(v == float(i)); + } +} + +TEST_SUITE_END(); diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 292625b0..f0017509 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -7,7 +7,7 @@ using namespace Luau; -struct DocumentationSymbolFixture : Fixture +struct DocumentationSymbolFixture : BuiltinsFixture { std::optional getDocSymbol(const std::string& source, Position position) { @@ -92,4 +92,17 @@ bar(foo()) CHECK_EQ("number", toString(*expectedOty)); } +TEST_CASE_FIXTURE(Fixture, "ast_ancestry_at_eof") +{ + check(R"( +if true then + )"); + + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), Position(2, 4)); + REQUIRE_GE(ancestry.size(), 2); + AstStat* parentStat = ancestry[ancestry.size() - 2]->asStat(); + REQUIRE(bool(parentStat)); + REQUIRE(parentStat->is()); +} + TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 1db782cc..f3b0bcad 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) -LUAU_FASTFLAG(LuauTableCloneType) using namespace Luau; @@ -26,6 +25,11 @@ static std::optional nullCallback(std::string tag, std::op template struct ACFixtureImpl : BaseType { + ACFixtureImpl() + : BaseType(true, true) + { + } + AutocompleteResult autocomplete(unsigned row, unsigned column) { return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); @@ -73,7 +77,18 @@ struct ACFixtureImpl : BaseType } LUAU_ASSERT("Digit expected after @ symbol" && prevChar != '@'); - return Fixture::check(filteredSource); + return BaseType::check(filteredSource); + } + + LoadDefinitionFileResult loadDefinition(const std::string& source) + { + TypeChecker& typeChecker = this->frontend.typeCheckerForAutocomplete; + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, source, "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); + return result; } const Position& getPosition(char marker) const @@ -88,6 +103,18 @@ struct ACFixtureImpl : BaseType }; struct ACFixture : ACFixtureImpl +{ + ACFixture() + : ACFixtureImpl() + { + addGlobalBinding(frontend.typeChecker, "table", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.typeChecker, "math", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.typeCheckerForAutocomplete, "table", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.typeCheckerForAutocomplete, "math", Binding{typeChecker.anyType}); + } +}; + +struct ACBuiltinsFixture : ACFixtureImpl { }; @@ -254,7 +281,7 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") CHECK(ac.entryMap.count("test")); } -TEST_CASE_FIXTURE(ACFixture, "get_member_completions") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "get_member_completions") { check(R"( local a = table.@1 @@ -262,7 +289,7 @@ TEST_CASE_FIXTURE(ACFixture, "get_member_completions") auto ac = autocomplete('1'); - CHECK_EQ(FFlag::LuauTableCloneType ? 17 : 16, ac.entryMap.size()); + CHECK_EQ(17, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); CHECK(!ac.entryMap.count("math")); @@ -353,7 +380,7 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") CHECK(ac.entryMap.count("c3")); } -TEST_CASE_FIXTURE(ACFixture, "get_string_completions") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "get_string_completions") { check(R"( local a = ("foo"):@1 @@ -404,7 +431,7 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") CHECK(!ac.entryMap.count("math")); } -TEST_CASE_FIXTURE(ACFixture, "method_call_inside_if_conditional") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "method_call_inside_if_conditional") { check(R"( if table: @1 @@ -1861,7 +1888,7 @@ ex.b(function(x: CHECK(!ac.entryMap.count("(done) -> number")); } -TEST_CASE_FIXTURE(ACFixture, "suggest_external_module_type") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "suggest_external_module_type") { fileResolver.source["Module/A"] = R"( export type done = { x: number, y: number } @@ -1965,6 +1992,7 @@ local fp: @1= f auto ac = autocomplete('1'); + REQUIRE_EQ("({| x: number, y: number |}) -> number", toString(requireType("f"))); CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } @@ -2212,7 +2240,7 @@ local a: aaa.do CHECK(ac.entryMap.count("other")); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteSource") { std::string_view source = R"( local a = table. -- Line 1 @@ -2221,7 +2249,7 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") auto ac = autocompleteSource(frontend, source, Position{1, 24}, nullCallback).result; - CHECK_EQ(FFlag::LuauTableCloneType ? 17 : 16, ac.entryMap.size()); + CHECK_EQ(17, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); CHECK(!ac.entryMap.count("math")); @@ -2246,7 +2274,7 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_comments") CHECK_EQ(0, ac.entryMap.size()); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteProp_index_function_metamethod_is_variadic") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteProp_index_function_metamethod_is_variadic") { std::string_view source = R"( type Foo = {x: number} @@ -2497,7 +2525,7 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") +TEST_CASE_FIXTURE(ACFixture, "autocomplete_documentation_symbols") { loadDefinition(R"( declare y: { @@ -2505,13 +2533,11 @@ TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") } )"); - fileResolver.source["Module/A"] = R"( - local a = y. - )"; + check(R"( + local a = y.@1 + )"); - frontend.check("Module/A"); - - auto ac = autocomplete(frontend, "Module/A", Position{1, 21}, nullCallback); + auto ac = autocomplete('1'); REQUIRE(ac.entryMap.count("x")); CHECK_EQ(ac.entryMap["x"].documentationSymbol, "@test/global/y.x"); @@ -2595,7 +2621,6 @@ a = if temp then even elseif true then temp else e@9 TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_else_regression") { - ScopedFastFlag FFlagLuauIfElseExprFixCompletionIssue("LuauIfElseExprFixCompletionIssue", true); check(R"( local abcdef = 0; local temp = false @@ -2699,7 +2724,7 @@ type A = () -> T CHECK(ac.entryMap.count("string")); } -TEST_CASE_FIXTURE(ACFixture, "autocomplete_oop_implicit_self") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_oop_implicit_self") { check(R"( --!strict @@ -2707,15 +2732,15 @@ local Class = {} Class.__index = Class type Class = typeof(setmetatable({} :: { x: number }, Class)) function Class.new(x: number): Class - return setmetatable({x = x}, Class) + return setmetatable({x = x}, Class) end function Class.getx(self: Class) - return self.x + return self.x end function test() - local c = Class.new(42) - local n = c:@1 - print(n) + local c = Class.new(42) + local n = c:@1 + print(n) end )"); @@ -2724,7 +2749,7 @@ end CHECK(ac.entryMap.count("getx")); } -TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") { check(R"( --!strict @@ -2737,6 +2762,98 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") CHECK(ac.entryMap.count("format")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") +{ + check(R"( + type tag = "cat" | "dog" + local function f(a: tag) end + f("@1") + f(@2) + local x: tag = "@3" + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("\"cat\"")); + CHECK(ac.entryMap.count("\"dog\"")); + + ac = autocomplete('3'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + check(R"( + type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number} + local x: tagged = {tag="@4"} + )"); + + ac = autocomplete('4'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") +{ + check(R"( + type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number} + local x: tagged = {tag="cat", fieldx=2} + if x.tag == "@1" or "@2" ~= x.tag then end + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + // CLI-48823: assignment to x.tag should also autocomplete, but union l-values are not supported yet +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_boolean_singleton") +{ + check(R"( +local function f(x: true) end +f(@1) + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("true")); + CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); + REQUIRE(ac.entryMap.count("false")); + CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") +{ + check(R"( + type tag = "strange\t\"cat\"" | 'nice\t"dog"' + local function f(x: tag) end + f(@1) + f("@2") + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("\"strange\\t\\\"cat\\\"\"")); + CHECK(ac.entryMap.count("\"nice\\t\\\"dog\\\"\"")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("strange\\t\\\"cat\\\"")); + CHECK(ac.entryMap.count("nice\\t\\\"dog\\\"")); +} + TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2") { check(R"( @@ -2752,7 +2869,7 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_on_class") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; loadDefinition(R"( declare class Foo @@ -2792,7 +2909,7 @@ t.@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local t = {} @@ -2808,7 +2925,7 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local f: (() -> number) & ((number) -> number) = function(x: number?) return 2 end @@ -2840,7 +2957,7 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local s = "hello" @@ -2859,7 +2976,7 @@ s:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local s = "hello" @@ -2868,17 +2985,15 @@ s.@1 auto ac = autocomplete('1'); - REQUIRE(ac.entryMap.count("byte")); - CHECK(ac.entryMap["byte"].wrongIndexType == true); REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == false); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == true); } -TEST_CASE_FIXTURE(ACFixture, "string_library_non_self_calls_are_fine") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_non_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( string.@1 @@ -2892,11 +3007,24 @@ string.@1 CHECK(ac.entryMap["char"].wrongIndexType == false); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == false); + + check(R"( +table.@1 + )"); + + ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("remove")); + CHECK(ac.entryMap["remove"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("getn")); + CHECK(ac.entryMap["getn"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("insert")); + CHECK(ac.entryMap["insert"].wrongIndexType == false); } -TEST_CASE_FIXTURE(ACFixture, "string_library_self_calls_are_invalid") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_self_calls_are_invalid") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( string:@1 @@ -2912,4 +3040,40 @@ string:@1 CHECK(ac.entryMap["sub"].wrongIndexType == true); } +TEST_CASE_FIXTURE(ACFixture, "source_module_preservation_and_invalidation") +{ + check(R"( +local a = { x = 2, y = 4 } +a.@1 + )"); + + frontend.clear(); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.check("MainModule", {}); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.markDirty("MainModule", nullptr); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.check("MainModule", {}); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); +} + TEST_SUITE_END(); diff --git a/tests/BuiltinDefinitions.test.cpp b/tests/BuiltinDefinitions.test.cpp index dbe80f2c..496df4b4 100644 --- a/tests/BuiltinDefinitions.test.cpp +++ b/tests/BuiltinDefinitions.test.cpp @@ -10,8 +10,10 @@ using namespace Luau; TEST_SUITE_BEGIN("BuiltinDefinitionsTest"); -TEST_CASE_FIXTURE(Fixture, "lib_documentation_symbols") +TEST_CASE_FIXTURE(BuiltinsFixture, "lib_documentation_symbols") { + CHECK(!typeChecker.globalScope->bindings.empty()); + for (const auto& [name, binding] : typeChecker.globalScope->bindings) { std::string nameString(name.c_str()); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 3dc57da0..655e48cb 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -17,11 +17,13 @@ std::string rep(const std::string& s, size_t n); using namespace Luau; -static std::string compileFunction(const char* source, uint32_t id) +static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1) { Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); - Luau::compileOrThrow(bcb, source); + Luau::CompileOptions options; + options.optimizationLevel = optimizationLevel; + Luau::compileOrThrow(bcb, source, options); return bcb.dumpFunction(id); } @@ -160,10 +162,10 @@ TEST_CASE("ImportCall") { CHECK_EQ("\n" + compileFunction0("return math.max(1, 2)"), R"( LOADN R1 1 -FASTCALL2K 18 R1 K0 +4 +FASTCALL2K 18 R1 K0 L0 LOADK R2 K0 GETIMPORT R0 3 -CALL R0 2 -1 +L0: CALL R0 2 -1 RETURN R0 -1 )"); } @@ -244,16 +246,16 @@ RETURN R0 1 TEST_CASE("RepeatLocals") { CHECK_EQ("\n" + compileFunction0("repeat local a a = 5 until a - 4 < 0 or a - 4 >= 0"), R"( -LOADNIL R0 +L0: LOADNIL R0 LOADN R0 5 SUBK R1 R0 K0 LOADN R2 0 -JUMPIFLT R1 R2 +6 +JUMPIFLT R1 R2 L1 SUBK R1 R0 K0 LOADN R2 0 -JUMPIFLE R2 R1 +2 -JUMPBACK -11 -RETURN R0 0 +JUMPIFLE R2 R1 L1 +JUMPBACK L0 +L1: RETURN R0 0 )"); } @@ -264,12 +266,12 @@ TEST_CASE("ForBytecode") LOADN R2 1 LOADN R0 5 LOADN R1 1 -FORNPREP R0 +5 -GETIMPORT R3 1 +FORNPREP R0 L1 +L0: GETIMPORT R3 1 MOVE R4 R2 CALL R3 1 0 -FORNLOOP R0 -5 -RETURN R0 0 +FORNLOOP R0 L0 +L1: RETURN R0 0 )"); // when you assign the variable internally, we freak out and copy the variable so that you aren't changing the loop behavior @@ -277,14 +279,14 @@ RETURN R0 0 LOADN R2 1 LOADN R0 5 LOADN R1 1 -FORNPREP R0 +7 -MOVE R3 R2 +FORNPREP R0 L1 +L0: MOVE R3 R2 LOADN R3 7 GETIMPORT R4 1 MOVE R5 R3 CALL R4 1 0 -FORNLOOP R0 -7 -RETURN R0 0 +FORNLOOP R0 L0 +L1: RETURN R0 0 )"); // basic for-in loop, generic version @@ -293,11 +295,11 @@ GETIMPORT R0 2 LOADK R1 K3 LOADK R2 K4 CALL R0 2 3 -JUMP +4 -GETIMPORT R5 6 +FORGPREP R0 L1 +L0: GETIMPORT R5 6 MOVE R6 R3 CALL R5 1 0 -FORGLOOP R0 -5 1 +L1: FORGLOOP R0 L0 1 RETURN R0 0 )"); @@ -306,12 +308,12 @@ RETURN R0 0 GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 -FORGPREP_INEXT R0 +5 -GETIMPORT R5 3 +FORGPREP_INEXT R0 L1 +L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -FORGLOOP_INEXT R0 -6 +L1: FORGLOOP_INEXT R0 L0 RETURN R0 0 )"); @@ -320,12 +322,12 @@ RETURN R0 0 GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 -FORGPREP_NEXT R0 +5 -GETIMPORT R5 3 +FORGPREP_NEXT R0 L1 +L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -FORGLOOP_NEXT R0 -6 +L1: FORGLOOP R0 L0 2 RETURN R0 0 )"); @@ -333,12 +335,12 @@ RETURN R0 0 GETIMPORT R0 1 NEWTABLE R1 0 0 LOADNIL R2 -FORGPREP_NEXT R0 +5 -GETIMPORT R5 3 +FORGPREP_NEXT R0 L1 +L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -FORGLOOP_NEXT R0 -6 +L1: FORGLOOP R0 L0 2 RETURN R0 0 )"); } @@ -350,8 +352,8 @@ TEST_CASE("ForBytecodeBuiltin") GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 -FORGPREP_INEXT R0 +0 -FORGLOOP_INEXT R0 -1 +FORGPREP_INEXT R0 L0 +L0: FORGLOOP_INEXT R0 L0 RETURN R0 0 )"); @@ -361,8 +363,8 @@ GETIMPORT R0 1 MOVE R1 R0 NEWTABLE R2 0 0 CALL R1 1 3 -FORGPREP_INEXT R1 +0 -FORGLOOP_INEXT R1 -1 +FORGPREP_INEXT R1 L0 +L0: FORGLOOP_INEXT R1 L0 RETURN R0 0 )"); @@ -371,8 +373,8 @@ RETURN R0 0 GETUPVAL R0 0 NEWTABLE R1 0 0 CALL R0 1 3 -FORGPREP_INEXT R0 +0 -FORGLOOP_INEXT R0 -1 +FORGPREP_INEXT R0 L0 +L0: FORGLOOP_INEXT R0 L0 RETURN R0 0 )"); @@ -383,8 +385,8 @@ GETIMPORT R0 3 MOVE R1 R0 NEWTABLE R2 0 0 CALL R1 1 3 -JUMP +0 -FORGLOOP R1 -1 2 +FORGPREP R1 L0 +L0: FORGLOOP R1 L0 2 RETURN R0 0 )"); @@ -395,8 +397,8 @@ SETGLOBAL R0 K2 GETGLOBAL R0 K2 NEWTABLE R1 0 0 CALL R0 1 3 -JUMP +0 -FORGLOOP R0 -1 2 +FORGPREP R0 L0 +L0: FORGLOOP R0 L0 2 RETURN R0 0 )"); @@ -405,8 +407,8 @@ RETURN R0 0 GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 -JUMP +0 -FORGLOOP R0 -1 2 +FORGPREP R0 L0 +L0: FORGLOOP R0 L0 2 RETURN R0 0 )"); } @@ -806,11 +808,11 @@ NEWTABLE R0 0 4 LOADN R3 1 LOADN R1 4 LOADN R2 1 -FORNPREP R1 +3 -LOADN R4 0 +FORNPREP R1 L1 +L0: LOADN R4 0 SETTABLE R4 R0 R3 -FORNLOOP R1 -3 -RETURN R0 1 +FORNLOOP R1 L0 +L1: RETURN R0 1 )"); } @@ -860,18 +862,18 @@ TEST_CASE("ConditionalBasic") { CHECK_EQ("\n" + compileFunction0("local a = ... if a then return 5 end"), R"( GETVARARGS R0 1 -JUMPIFNOT R0 +2 +JUMPIFNOT R0 L0 LOADN R1 5 RETURN R1 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = ... if not a then return 5 end"), R"( GETVARARGS R0 1 -JUMPIF R0 +2 +JUMPIF R0 L0 LOADN R1 5 RETURN R1 1 -RETURN R0 0 +L0: RETURN R0 0 )"); } @@ -879,50 +881,50 @@ TEST_CASE("ConditionalCompare") { CHECK_EQ("\n" + compileFunction0("local a, b = ... if a < b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTLT R0 R1 +3 +JUMPIFNOTLT R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if a <= b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTLE R0 R1 +3 +JUMPIFNOTLE R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if a > b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTLT R1 R0 +3 +JUMPIFNOTLT R1 R0 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if a >= b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTLE R1 R0 +3 +JUMPIFNOTLE R1 R0 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if a == b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTEQ R0 R1 +3 +JUMPIFNOTEQ R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if a ~= b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFEQ R0 R1 +3 +JUMPIFEQ R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); } @@ -930,18 +932,18 @@ TEST_CASE("ConditionalNot") { CHECK_EQ("\n" + compileFunction0("local a, b = ... if not (not (a < b)) then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTLT R0 R1 +3 +JUMPIFNOTLT R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if not (not (not (a < b))) then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFLT R0 R1 +3 +JUMPIFLT R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); } @@ -949,51 +951,51 @@ TEST_CASE("ConditionalAndOr") { CHECK_EQ("\n" + compileFunction0("local a, b, c = ... if a < b and b < c then return 5 end"), R"( GETVARARGS R0 3 -JUMPIFNOTLT R0 R1 +5 -JUMPIFNOTLT R1 R2 +3 +JUMPIFNOTLT R0 R1 L0 +JUMPIFNOTLT R1 R2 L0 LOADN R3 5 RETURN R3 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b, c = ... if a < b or b < c then return 5 end"), R"( GETVARARGS R0 3 -JUMPIFLT R0 R1 +3 -JUMPIFNOTLT R1 R2 +3 -LOADN R3 5 +JUMPIFLT R0 R1 L0 +JUMPIFNOTLT R1 R2 L1 +L0: LOADN R3 5 RETURN R3 1 -RETURN R0 0 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a,b,c,d = ... if (a or b) and not (c and d) then return 5 end"), R"( GETVARARGS R0 4 -JUMPIF R0 +1 -JUMPIFNOT R1 +4 -JUMPIFNOT R2 +1 -JUMPIF R3 +2 -LOADN R4 5 +JUMPIF R0 L0 +JUMPIFNOT R1 L2 +L0: JUMPIFNOT R2 L1 +JUMPIF R3 L2 +L1: LOADN R4 5 RETURN R4 1 -RETURN R0 0 +L2: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a,b,c = ... if a or not b or c then return 5 end"), R"( GETVARARGS R0 3 -JUMPIF R0 +2 -JUMPIFNOT R1 +1 -JUMPIFNOT R2 +2 -LOADN R3 5 +JUMPIF R0 L0 +JUMPIFNOT R1 L0 +JUMPIFNOT R2 L1 +L0: LOADN R3 5 RETURN R3 1 -RETURN R0 0 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a,b,c = ... if a and not b and c then return 5 end"), R"( GETVARARGS R0 3 -JUMPIFNOT R0 +4 -JUMPIF R1 +3 -JUMPIFNOT R2 +2 +JUMPIFNOT R0 L0 +JUMPIF R1 L0 +JUMPIFNOT R2 L0 LOADN R3 5 RETURN R3 1 -RETURN R0 0 +L0: RETURN R0 0 )"); } @@ -1018,9 +1020,9 @@ LOADN R0 1 LOADN R1 2 SETGLOBAL R1 K0 MOVE R1 R0 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 GETGLOBAL R1 K0 -MOVE R0 R1 +L0: MOVE R0 R1 RETURN R0 1 )"); @@ -1043,9 +1045,9 @@ LOADN R0 1 LOADN R1 2 SETGLOBAL R1 K0 MOVE R1 R0 -JUMPIF R1 +2 +JUMPIF R1 L0 GETGLOBAL R1 K0 -MOVE R0 R1 +L0: MOVE R0 R1 RETURN R0 1 )"); @@ -1057,9 +1059,9 @@ MOVE R0 R0 LOADN R1 2 SETGLOBAL R1 K0 MOVE R1 R0 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 GETGLOBAL R1 K0 -RETURN R1 1 +L0: RETURN R1 1 )"); CHECK_EQ("\n" + compileFunction0("local a = 1 a = a b = 2 local c = a or b return c"), R"( @@ -1068,9 +1070,42 @@ MOVE R0 R0 LOADN R1 2 SETGLOBAL R1 K0 MOVE R1 R0 -JUMPIF R1 +2 +JUMPIF R1 L0 GETGLOBAL R1 K0 -RETURN R1 1 +L0: RETURN R1 1 +)"); +} + +TEST_CASE("AndOrFoldLeft") +{ + // constant folding and/or expression is possible even if just the left hand is constant + CHECK_EQ("\n" + compileFunction0("local a = false if a and b then b() end"), R"( +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = true if a or b then b() end"), R"( +GETIMPORT R0 1 +CALL R0 0 0 +RETURN R0 0 +)"); + + // however, if right hand side is constant we can't constant fold the entire expression + // (note that we don't need to evaluate the right hand side, but we do need a branch) + CHECK_EQ("\n" + compileFunction0("local a = false if b and a then b() end"), R"( +GETIMPORT R0 1 +JUMPIFNOT R0 L0 +RETURN R0 0 +GETIMPORT R0 1 +CALL R0 0 0 +L0: RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = true if b or a then b() end"), R"( +GETIMPORT R0 1 +JUMPIF R0 L0 +L0: GETIMPORT R0 1 +CALL R0 0 0 +RETURN R0 0 )"); } @@ -1089,19 +1124,19 @@ GETIMPORT R3 1 SUB R1 R2 R3 GETIMPORT R3 4 ADDK R2 R3 K2 -JUMPIFNOTLT R1 R2 +4 +JUMPIFNOTLT R1 R2 L0 GETIMPORT R0 8 -JUMPIF R0 +15 -GETIMPORT R1 10 +JUMPIF R0 L2 +L0: GETIMPORT R1 10 LOADN R2 0 -JUMPIFNOTLT R2 R1 +9 +JUMPIFNOTLT R2 R1 L1 GETIMPORT R1 10 LOADN R2 1 -JUMPIFNOTLT R1 R2 +4 +JUMPIFNOTLT R1 R2 L1 GETIMPORT R0 8 -JUMPIF R0 +2 -GETIMPORT R0 12 -RETURN R0 1 +JUMPIF R0 L2 +L1: GETIMPORT R0 12 +L2: RETURN R0 1 )"); } @@ -1147,50 +1182,50 @@ RETURN R0 1 // codegen for a non-constant condition CHECK_EQ("\n" + compileFunction0("return if condition then 10 else 20"), R"( GETIMPORT R1 1 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 LOADN R0 10 RETURN R0 1 -LOADN R0 20 +L0: LOADN R0 20 RETURN R0 1 )"); // codegen for a non-constant condition using an assignment CHECK_EQ("\n" + compileFunction0("result = if condition then 10 else 20"), R"( GETIMPORT R1 1 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 LOADN R0 10 -JUMP +1 -LOADN R0 20 -SETGLOBAL R0 K2 +JUMP L1 +L0: LOADN R0 20 +L1: SETGLOBAL R0 K2 RETURN R0 0 )"); // codegen for a non-constant condition using an assignment to a local variable CHECK_EQ("\n" + compileFunction0("local result = if condition then 10 else 20"), R"( GETIMPORT R1 1 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 LOADN R0 10 RETURN R0 0 -LOADN R0 20 +L0: LOADN R0 20 RETURN R0 0 )"); // codegen for an if-else expression with multiple elseif's CHECK_EQ("\n" + compileFunction0("result = if condition1 then 10 elseif condition2 then 20 elseif condition3 then 30 else 40"), R"( GETIMPORT R1 1 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 LOADN R0 10 -JUMP +11 -GETIMPORT R1 3 -JUMPIFNOT R1 +2 +JUMP L3 +L0: GETIMPORT R1 3 +JUMPIFNOT R1 L1 LOADN R0 20 -JUMP +6 -GETIMPORT R1 5 -JUMPIFNOT R1 +2 +JUMP L3 +L1: GETIMPORT R1 5 +JUMPIFNOT R1 L2 LOADN R0 30 -JUMP +1 -LOADN R0 40 -SETGLOBAL R0 K6 +JUMP L3 +L2: LOADN R0 40 +L3: SETGLOBAL R0 K6 RETURN R0 0 )"); } @@ -1428,16 +1463,16 @@ RETURN R0 1 // constant fold parts in chains of and/or statements CHECK_EQ("\n" + compileFunction0("return a and true and b"), R"( GETIMPORT R0 1 -JUMPIFNOT R0 +2 +JUMPIFNOT R0 L0 GETIMPORT R0 3 -RETURN R0 1 +L0: RETURN R0 1 )"); CHECK_EQ("\n" + compileFunction0("return a or false or b"), R"( GETIMPORT R0 1 -JUMPIF R0 +2 +JUMPIF R0 L0 GETIMPORT R0 3 -RETURN R0 1 +L0: RETURN R0 1 )"); } @@ -1445,38 +1480,38 @@ TEST_CASE("ConstantFoldConditionalAndOr") { CHECK_EQ("\n" + compileFunction0("local a = ... if false or a then print(1) end"), R"( GETVARARGS R0 1 -JUMPIFNOT R0 +4 +JUMPIFNOT R0 L0 GETIMPORT R1 1 LOADN R2 1 CALL R1 1 0 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = ... if not (false or a) then print(1) end"), R"( GETVARARGS R0 1 -JUMPIF R0 +4 +JUMPIF R0 L0 GETIMPORT R1 1 LOADN R2 1 CALL R1 1 0 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = ... if true and a then print(1) end"), R"( GETVARARGS R0 1 -JUMPIFNOT R0 +4 +JUMPIFNOT R0 L0 GETIMPORT R1 1 LOADN R2 1 CALL R1 1 0 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = ... if not (true and a) then print(1) end"), R"( GETVARARGS R0 1 -JUMPIF R0 +4 +JUMPIF R0 L0 GETIMPORT R1 1 LOADN R2 1 CALL R1 1 0 -RETURN R0 0 +L0: RETURN R0 0 )"); } @@ -1510,10 +1545,10 @@ RETURN R0 0 // while CHECK_EQ("\n" + compileFunction0("while true do print(1) end"), R"( -GETIMPORT R0 1 +L0: GETIMPORT R0 1 LOADN R1 1 CALL R0 1 0 -JUMPBACK -5 +JUMPBACK L0 RETURN R0 0 )"); @@ -1530,21 +1565,21 @@ RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("repeat print(1) until false"), R"( -GETIMPORT R0 1 +L0: GETIMPORT R0 1 LOADN R1 1 CALL R0 1 0 -JUMPBACK -5 +JUMPBACK L0 RETURN R0 0 )"); // there's an odd case in repeat..until compilation where we evaluate the expression that is always false for side-effects of the left hand side CHECK_EQ("\n" + compileFunction0("repeat print(1) until five and false"), R"( -GETIMPORT R0 1 +L0: GETIMPORT R0 1 LOADN R1 1 CALL R0 1 0 GETIMPORT R0 3 -JUMPIFNOT R0 +0 -JUMPBACK -8 +JUMPIFNOT R0 L1 +L1: JUMPBACK L0 RETURN R0 0 )"); } @@ -1553,24 +1588,24 @@ TEST_CASE("LoopBreak") { // default codegen: compile breaks as unconditional jumps CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break else end end"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFNOTLT R0 R1 +3 +JUMPIFNOTLT R0 R1 L1 RETURN R0 0 -JUMP +0 -JUMPBACK -9 +JUMP L1 +L1: JUMPBACK L0 RETURN R0 0 )"); // optimization: if then body is a break statement, flip the branches CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break end end"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFLT R0 R1 +2 -JUMPBACK -7 -RETURN R0 0 +JUMPIFLT R0 R1 L1 +JUMPBACK L0 +L1: RETURN R0 0 )"); } @@ -1578,28 +1613,28 @@ TEST_CASE("LoopContinue") { // default codegen: compile continue as unconditional jumps CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue else end break until false error()"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFNOTLT R0 R1 +5 -JUMP +2 -JUMP +2 -JUMP +1 -JUMPBACK -10 -GETIMPORT R0 5 +JUMPIFNOTLT R0 R1 L2 +JUMP L1 +JUMP L2 +JUMP L2 +L1: JUMPBACK L0 +L2: GETIMPORT R0 5 CALL R0 0 0 RETURN R0 0 )"); // optimization: if then body is a continue statement, flip the branches CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue end break until false error()"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFLT R0 R1 +2 -JUMP +1 -JUMPBACK -8 -GETIMPORT R0 5 +JUMPIFLT R0 R1 L1 +JUMP L2 +L1: JUMPBACK L0 +L2: GETIMPORT R0 5 CALL R0 0 0 RETURN R0 0 )"); @@ -1609,15 +1644,15 @@ TEST_CASE("LoopContinueUntil") { // it's valid to use locals defined inside the loop in until expression if they're defined before continue CHECK_EQ("\n" + compileFunction0("repeat local r = math.random() if r > 0.5 then continue end r = r + 0.3 until r < 0.5"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFLT R1 R0 +2 +JUMPIFLT R1 R0 L1 ADDK R0 R0 K4 -LOADK R1 K3 -JUMPIFLT R0 R1 +2 -JUMPBACK -11 -RETURN R0 0 +L1: LOADK R1 K3 +JUMPIFLT R0 R1 L2 +JUMPBACK L0 +L2: RETURN R0 0 )"); // it's however invalid to use locals if they are defined after continue @@ -1648,16 +1683,16 @@ until rr < 0.5 CHECK_EQ("\n" + compileFunction0( "repeat local r = math.random() repeat if r > 0.5 then continue end r = r - 0.1 until true r = r + 0.3 until r < 0.5"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFLT R1 R0 +2 +JUMPIFLT R1 R0 L1 SUBK R0 R0 K4 -ADDK R0 R0 K5 +L1: ADDK R0 R0 K5 LOADK R1 K3 -JUMPIFLT R0 R1 +2 -JUMPBACK -12 -RETURN R0 0 +JUMPIFLT R0 R1 L2 +JUMPBACK L0 +L2: RETURN R0 0 )"); // and it's also okay to use a local defined in the until expression as long as it's inside a function! @@ -1665,20 +1700,20 @@ RETURN R0 0 "\n" + compileFunction( "repeat local r = math.random() if r > 0.5 then continue end r = r + 0.3 until (function() local a = r return a < 0.5 end)()", 1), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFNOTLT R1 R0 +3 +JUMPIFNOTLT R1 R0 L1 CLOSEUPVALS R0 -JUMP +1 -ADDK R0 R0 K4 -NEWCLOSURE R1 P0 +JUMP L2 +L1: ADDK R0 R0 K4 +L2: NEWCLOSURE R1 P0 CAPTURE REF R0 CALL R1 0 1 -JUMPIF R1 +2 -CLOSEUPVALS R0 -JUMPBACK -15 +JUMPIF R1 L3 CLOSEUPVALS R0 +JUMPBACK L0 +L3: CLOSEUPVALS R0 RETURN R0 0 )"); @@ -1709,17 +1744,17 @@ until (function() return rr end)() < 0.5 CHECK_EQ("\n" + compileFunction0("local stop = false stop = true function test() repeat local r = math.random() if r > 0.5 then " "continue end r = r + 0.3 until stop or r < 0.5 end"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFLT R1 R0 +2 +JUMPIFLT R1 R0 L1 ADDK R0 R0 K4 -GETUPVAL R1 0 -JUMPIF R1 +4 +L1: GETUPVAL R1 0 +JUMPIF R1 L2 LOADK R1 K3 -JUMPIFLT R0 R1 +2 -JUMPBACK -13 -RETURN R0 0 +JUMPIFLT R0 R1 L2 +JUMPBACK L0 +L2: RETURN R0 0 )"); // including upvalue references from a function expression @@ -1727,21 +1762,21 @@ RETURN R0 0 "end r = r + 0.3 until (function() return stop or r < 0.5 end)() end", 1), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFNOTLT R1 R0 +3 +JUMPIFNOTLT R1 R0 L1 CLOSEUPVALS R0 -JUMP +1 -ADDK R0 R0 K4 -NEWCLOSURE R1 P0 +JUMP L2 +L1: ADDK R0 R0 K4 +L2: NEWCLOSURE R1 P0 CAPTURE UPVAL U0 CAPTURE REF R0 CALL R1 0 1 -JUMPIF R1 +2 -CLOSEUPVALS R0 -JUMPBACK -16 +JUMPIF R1 L3 CLOSEUPVALS R0 +JUMPBACK L0 +L3: CLOSEUPVALS R0 RETURN R0 0 )"); } @@ -1782,11 +1817,11 @@ ORK R2 R1 K0 SUB R0 R0 R2 LOADN R4 1 LOADN R8 0 -JUMPIFNOTLT R0 R8 +3 +JUMPIFNOTLT R0 R8 L0 MINUS R7 R0 -JUMPIF R7 +1 -MOVE R7 R0 -MULK R6 R7 K1 +JUMPIF R7 L1 +L0: MOVE R7 R0 +L1: MULK R6 R7 K1 LOADN R8 1 SUB R7 R8 R2 DIV R5 R6 R7 @@ -1806,14 +1841,14 @@ LOADB R2 0 LOADK R4 K0 MULK R5 R1 K1 SUB R3 R4 R5 -JUMPIFNOTLT R3 R0 +8 +JUMPIFNOTLT R3 R0 L1 LOADK R4 K0 MULK R5 R1 K1 ADD R3 R4 R5 -JUMPIFLT R0 R3 +2 +JUMPIFLT R0 R3 L0 LOADB R2 0 +1 -LOADB R2 1 -RETURN R2 1 +L0: LOADB R2 1 +L1: RETURN R2 1 )"); // sometimes we need to compute a boolean; this uses LOADB with an offset for the last op, note that first op is compiled better @@ -1828,14 +1863,14 @@ LOADB R2 1 LOADK R4 K0 MULK R5 R1 K1 SUB R3 R4 R5 -JUMPIFLT R0 R3 +8 +JUMPIFLT R0 R3 L1 LOADK R4 K0 MULK R5 R1 K1 ADD R3 R4 R5 -JUMPIFLT R3 R0 +2 +JUMPIFLT R3 R0 L0 LOADB R2 0 +1 -LOADB R2 1 -RETURN R2 1 +L0: LOADB R2 1 +L1: RETURN R2 1 )"); // trivial ternary if with constants @@ -1846,10 +1881,10 @@ end )", 0), R"( -JUMPIFNOT R0 +2 +JUMPIFNOT R0 L0 LOADN R1 1 RETURN R1 1 -LOADN R1 0 +L0: LOADN R1 0 RETURN R1 1 )"); @@ -1862,14 +1897,14 @@ end 0), R"( LOADN R2 0 -JUMPIFNOTLT R0 R2 +3 +JUMPIFNOTLT R0 R2 L0 LOADN R1 0 RETURN R1 1 -LOADN R2 1 -JUMPIFNOTLT R2 R0 +3 +L0: LOADN R2 1 +JUMPIFNOTLT R2 R0 L1 LOADN R1 1 RETURN R1 1 -MOVE R1 R0 +L1: MOVE R1 R0 RETURN R1 1 )"); } @@ -1879,24 +1914,24 @@ TEST_CASE("JumpFold") // jump-to-return folding to return CHECK_EQ("\n" + compileFunction0("return a and 1 or 0"), R"( GETIMPORT R1 1 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 LOADN R0 1 RETURN R0 1 -LOADN R0 0 +L0: LOADN R0 0 RETURN R0 1 )"); // conditional jump in the inner if() folding to jump out of the expression (JUMPIFNOT+5 skips over all jumps, JUMP+1 skips over JUMP+0) CHECK_EQ("\n" + compileFunction0("if a then if b then b() else end else end d()"), R"( GETIMPORT R0 1 -JUMPIFNOT R0 +8 +JUMPIFNOT R0 L0 GETIMPORT R0 3 -JUMPIFNOT R0 +5 +JUMPIFNOT R0 L0 GETIMPORT R0 3 CALL R0 0 0 -JUMP +1 -JUMP +0 -GETIMPORT R0 5 +JUMP L0 +JUMP L0 +L0: GETIMPORT R0 5 CALL R0 0 0 RETURN R0 0 )"); @@ -1904,14 +1939,14 @@ RETURN R0 0 // same as example before but the unconditional jumps are folded with RETURN CHECK_EQ("\n" + compileFunction0("if a then if b then b() else end else end"), R"( GETIMPORT R0 1 -JUMPIFNOT R0 +8 +JUMPIFNOT R0 L0 GETIMPORT R0 3 -JUMPIFNOT R0 +5 +JUMPIFNOT R0 L0 GETIMPORT R0 3 CALL R0 0 0 RETURN R0 0 RETURN R0 0 -RETURN R0 0 +L0: RETURN R0 0 )"); // in this example, we do *not* have a JUMP after RETURN in the if branch @@ -1931,7 +1966,7 @@ end R"( ORK R6 R3 K0 ORK R7 R4 K1 -JUMPIF R5 +19 +JUMPIF R5 L0 GETIMPORT R10 5 DIV R13 R0 R7 MULK R14 R6 K6 @@ -1948,7 +1983,7 @@ CALL R10 3 1 MULK R9 R10 K2 ADDK R8 R9 K2 RETURN R8 1 -GETIMPORT R8 5 +L0: GETIMPORT R8 5 DIV R11 R0 R7 MULK R12 R6 K6 ADD R10 R11 R12 @@ -2057,15 +2092,15 @@ RETURN R0 0 TEST_CASE("NestedFunctionCalls") { CHECK_EQ("\n" + compileFunction0("function clamp(t,a,b) return math.min(math.max(t,a),b) end"), R"( -FASTCALL2 18 R0 R1 +5 +FASTCALL2 18 R0 R1 L0 MOVE R5 R0 MOVE R6 R1 GETIMPORT R4 2 -CALL R4 2 1 -FASTCALL2 19 R4 R2 +4 +L0: CALL R4 2 1 +FASTCALL2 19 R4 R2 L1 MOVE R5 R2 GETIMPORT R3 4 -CALL R3 2 -1 +L1: CALL R3 2 -1 RETURN R3 -1 )"); } @@ -2089,20 +2124,20 @@ end LOADN R2 1 LOADN R0 10 LOADN R1 1 -FORNPREP R0 +14 -MOVE R3 R2 +FORNPREP R0 L2 +L0: MOVE R3 R2 MOVE R3 R3 GETIMPORT R4 1 NEWCLOSURE R5 P0 CAPTURE REF R3 CALL R4 1 0 GETIMPORT R4 3 -JUMPIFNOT R4 +2 +JUMPIFNOT R4 L1 CLOSEUPVALS R3 -JUMP +2 -CLOSEUPVALS R3 -FORNLOOP R0 -14 -LOADN R0 0 +JUMP L2 +L1: CLOSEUPVALS R3 +FORNLOOP R0 L0 +L2: LOADN R0 0 RETURN R0 1 )"); @@ -2123,19 +2158,19 @@ end GETIMPORT R0 1 GETIMPORT R1 3 CALL R0 1 3 -FORGPREP_INEXT R0 +12 -MOVE R3 R3 +FORGPREP_INEXT R0 L2 +L0: MOVE R3 R3 GETIMPORT R5 5 NEWCLOSURE R6 P0 CAPTURE REF R3 CALL R5 1 0 GETIMPORT R5 7 -JUMPIFNOT R5 +2 +JUMPIFNOT R5 L1 CLOSEUPVALS R3 -JUMP +2 -CLOSEUPVALS R3 -FORGLOOP_INEXT R0 -13 -LOADN R0 0 +JUMP L3 +L1: CLOSEUPVALS R3 +L2: FORGLOOP_INEXT R0 L0 +L3: LOADN R0 0 RETURN R0 1 )"); @@ -2157,8 +2192,8 @@ end 1), R"( LOADN R0 0 -LOADN R1 5 -JUMPIFNOTLT R0 R1 +16 +L0: LOADN R1 5 +JUMPIFNOTLT R0 R1 L2 LOADNIL R1 MOVE R1 R0 GETIMPORT R2 1 @@ -2167,12 +2202,12 @@ CAPTURE REF R1 CALL R2 1 0 ADDK R0 R0 K2 GETIMPORT R2 4 -JUMPIFNOT R2 +2 +JUMPIFNOT R2 L1 CLOSEUPVALS R1 -JUMP +2 -CLOSEUPVALS R1 -JUMPBACK -18 -LOADN R1 0 +JUMP L2 +L1: CLOSEUPVALS R1 +JUMPBACK L0 +L2: LOADN R1 0 RETURN R1 1 )"); @@ -2194,7 +2229,7 @@ end 1), R"( LOADN R0 0 -LOADNIL R1 +L0: LOADNIL R1 MOVE R1 R0 GETIMPORT R2 1 NEWCLOSURE R3 P0 @@ -2202,15 +2237,15 @@ CAPTURE REF R1 CALL R2 1 0 ADDK R0 R0 K2 GETIMPORT R2 4 -JUMPIFNOT R2 +2 +JUMPIFNOT R2 L1 CLOSEUPVALS R1 -JUMP +6 -LOADN R2 5 -JUMPIFLT R0 R2 +3 +JUMP L3 +L1: LOADN R2 5 +JUMPIFLT R0 R2 L2 CLOSEUPVALS R1 -JUMPBACK -18 -CLOSEUPVALS R1 -LOADN R1 0 +JUMPBACK L0 +L2: CLOSEUPVALS R1 +L3: LOADN R1 0 RETURN R1 1 )"); } @@ -2270,11 +2305,11 @@ return result 14: GETIMPORT R2 11 14: MOVE R3 R0 14: CALL R2 1 3 -14: FORGPREP_NEXT R2 +3 -15: MOVE R7 R1 +14: FORGPREP_NEXT R2 L1 +15: L0: MOVE R7 R1 15: MOVE R8 R5 15: CONCAT R1 R7 R8 -14: FORGLOOP_NEXT R2 -4 +14: L1: FORGLOOP R2 L0 1 17: RETURN R1 1 )"); } @@ -2301,11 +2336,11 @@ end 5: LOADN R0 1 7: LOADN R1 2 9: LOADN R2 3 -9: JUMP +4 -11: GETIMPORT R5 1 +9: FORGPREP R0 L1 +11: L0: GETIMPORT R5 1 11: MOVE R6 R3 11: CALL R5 1 0 -2: FORGLOOP R0 -5 1 +2: L1: FORGLOOP R0 L0 1 13: RETURN R0 0 )"); } @@ -2327,14 +2362,14 @@ end CHECK_EQ("\n" + bcb.dumpFunction(0), R"( 2: LOADN R0 0 -4: ADDK R0 R0 K0 +4: L0: ADDK R0 R0 K0 5: LOADN R1 1 -5: JUMPIFNOTLT R1 R0 +6 +5: JUMPIFNOTLT R1 R0 L1 6: GETIMPORT R1 2 6: LOADK R2 K3 6: CALL R1 1 0 10: RETURN R0 0 -3: JUMPBACK -10 +3: L1: JUMPBACK L0 10: RETURN R0 0 )"); } @@ -2355,16 +2390,16 @@ until f == 0 0), R"( 2: LOADN R0 0 -4: ADDK R0 R0 K0 -5: JUMPIFNOTEQK R0 K0 +6 +4: L0: ADDK R0 R0 K0 +5: JUMPIFNOTEQK R0 K0 L1 6: GETIMPORT R1 2 6: MOVE R2 R0 6: CALL R1 1 0 -6: JUMP +1 -8: LOADN R0 0 -10: JUMPIFEQK R0 K3 +2 -10: JUMPBACK -12 -11: RETURN R0 0 +6: JUMP L2 +8: L1: LOADN R0 0 +10: L2: JUMPIFEQK R0 K3 L3 +10: JUMPBACK L0 +11: L3: RETURN R0 0 )"); } @@ -2466,11 +2501,11 @@ return CHECK_EQ("\n" + bcb.dumpFunction(0), R"( 2: GETVARARGS R0 2 -5: FASTCALL2 18 R0 R1 +5 +5: FASTCALL2 18 R0 R1 L0 5: MOVE R3 R0 5: MOVE R4 R1 5: GETIMPORT R2 2 -5: CALL R2 2 -1 +5: L0: CALL R2 2 -1 5: RETURN R2 -1 )"); } @@ -2567,13 +2602,13 @@ LOADK R1 K9 GETIMPORT R2 11 MOVE R3 R0 CALL R2 1 3 -FORGPREP_NEXT R2 +3 +FORGPREP_NEXT R2 L1 15: result = result .. k -MOVE R7 R1 +L0: MOVE R7 R1 MOVE R8 R5 CONCAT R1 R7 R8 14: for k in pairs(kSelectedBiomes) do -FORGLOOP_NEXT R2 -4 +L1: FORGLOOP R2 L0 1 17: return result RETURN R1 1 )"); @@ -2618,29 +2653,29 @@ end local 0: reg 5, start pc 5 line 5, end pc 8 line 5 local 1: reg 6, start pc 14 line 8, end pc 18 line 8 local 2: reg 7, start pc 14 line 8, end pc 18 line 8 -local 3: reg 3, start pc 21 line 12, end pc 24 line 12 -local 4: reg 3, start pc 26 line 16, end pc 30 line 16 -local 5: reg 0, start pc 0 line 3, end pc 34 line 21 -local 6: reg 1, start pc 0 line 3, end pc 34 line 21 -local 7: reg 2, start pc 1 line 4, end pc 34 line 21 -local 8: reg 3, start pc 34 line 21, end pc 34 line 21 +local 3: reg 3, start pc 22 line 12, end pc 25 line 12 +local 4: reg 3, start pc 27 line 16, end pc 31 line 16 +local 5: reg 0, start pc 0 line 3, end pc 35 line 21 +local 6: reg 1, start pc 0 line 3, end pc 35 line 21 +local 7: reg 2, start pc 1 line 4, end pc 35 line 21 +local 8: reg 3, start pc 35 line 21, end pc 35 line 21 3: LOADN R2 1 4: LOADN R5 1 4: LOADN R3 3 4: LOADN R4 1 -4: FORNPREP R3 +5 -5: GETIMPORT R6 1 +4: FORNPREP R3 L1 +5: L0: GETIMPORT R6 1 5: MOVE R7 R5 5: CALL R6 1 0 -4: FORNLOOP R3 -5 -7: GETIMPORT R3 3 +4: FORNLOOP R3 L0 +7: L1: GETIMPORT R3 3 7: CALL R3 0 3 -7: FORGPREP_NEXT R3 +5 -8: GETIMPORT R8 1 +7: FORGPREP_NEXT R3 L3 +8: L2: GETIMPORT R8 1 8: MOVE R9 R6 8: MOVE R10 R7 8: CALL R8 2 0 -7: FORGLOOP_NEXT R3 -6 +7: L3: FORGLOOP R3 L2 2 11: LOADN R3 2 12: GETIMPORT R4 1 12: LOADN R5 2 @@ -2656,6 +2691,33 @@ local 8: reg 3, start pc 34 line 21, end pc 34 line 21 )"); } +TEST_CASE("DebugRemarks") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Remarks); + + uint32_t fid = bcb.beginFunction(0); + + bcb.addDebugRemark("test remark #%d", 1); + bcb.emitABC(LOP_LOADNIL, 0, 0, 0); + bcb.addDebugRemark("test remark #%d", 2); + bcb.addDebugRemark("test remark #%d", 3); + bcb.emitABC(LOP_RETURN, 0, 1, 0); + + bcb.endFunction(1, 0); + + bcb.setMainFunction(fid); + bcb.finalize(); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +REMARK test remark #1 +LOADNIL R0 +REMARK test remark #2 +REMARK test remark #3 +RETURN R0 0 +)"); +} + TEST_CASE("AssignmentConflict") { // assignments are left to right @@ -2758,9 +2820,9 @@ TEST_CASE("FastcallBytecode") // direct global call CHECK_EQ("\n" + compileFunction0("return math.abs(-5)"), R"( LOADN R1 -5 -FASTCALL1 2 R1 +2 +FASTCALL1 2 R1 L0 GETIMPORT R0 2 -CALL R0 1 -1 +L0: CALL R0 1 -1 RETURN R0 -1 )"); @@ -2768,18 +2830,18 @@ RETURN R0 -1 CHECK_EQ("\n" + compileFunction0("local abs = math.abs return abs(-5)"), R"( GETIMPORT R0 2 LOADN R2 -5 -FASTCALL1 2 R2 +1 +FASTCALL1 2 R2 L0 MOVE R1 R0 -CALL R1 1 -1 +L0: CALL R1 1 -1 RETURN R1 -1 )"); // call through an upvalue CHECK_EQ("\n" + compileFunction0("local abs = math.abs function foo() return abs(-5) end return foo()"), R"( LOADN R1 -5 -FASTCALL1 2 R1 +1 +FASTCALL1 2 R1 L0 GETUPVAL R0 0 -CALL R0 1 -1 +L0: CALL R0 1 -1 RETURN R0 -1 )"); @@ -2822,10 +2884,10 @@ TEST_CASE("FastcallSelect") // select(_, ...) compiles to a builtin call CHECK_EQ("\n" + compileFunction0("return (select('#', ...))"), R"( LOADK R1 K0 -FASTCALL1 57 R1 +3 +FASTCALL1 57 R1 L0 GETIMPORT R0 2 GETVARARGS R2 -1 -CALL R0 -1 1 +L0: CALL R0 -1 1 RETURN R0 1 )"); @@ -2841,21 +2903,21 @@ return sum LOADN R0 0 LOADN R3 1 LOADK R5 K0 -FASTCALL1 57 R5 +3 +FASTCALL1 57 R5 L0 GETIMPORT R4 2 GETVARARGS R6 -1 -CALL R4 -1 1 +L0: CALL R4 -1 1 MOVE R1 R4 LOADN R2 1 -FORNPREP R1 +8 -FASTCALL1 57 R3 +4 +FORNPREP R1 L3 +L1: FASTCALL1 57 R3 L2 GETIMPORT R4 2 MOVE R5 R3 GETVARARGS R6 -1 -CALL R4 -1 1 +L2: CALL R4 -1 1 ADD R0 R0 R4 -FORNLOOP R1 -8 -RETURN R0 1 +FORNLOOP R1 L1 +L3: RETURN R0 1 )"); // currently we assume a single value return to avoid dealing with stack resizing @@ -3158,15 +3220,24 @@ TEST_CASE("FastCallImportFallback") std::vector insns = Luau::split(code, '\n'); - CHECK_EQ(insns[insns.size() - 9], "LOADN R1 1024"); - CHECK_EQ(insns[insns.size() - 8], "LOADK R2 K1023"); - CHECK_EQ(insns[insns.size() - 7], "SETTABLE R2 R0 R1"); - CHECK_EQ(insns[insns.size() - 6], "LOADN R2 -1"); - CHECK_EQ(insns[insns.size() - 5], "FASTCALL1 2 R2 +4"); - CHECK_EQ(insns[insns.size() - 4], "GETGLOBAL R3 K1024"); // note: it's important that this doesn't overwrite R2 - CHECK_EQ(insns[insns.size() - 3], "GETTABLEKS R1 R3 K1025"); - CHECK_EQ(insns[insns.size() - 2], "CALL R1 1 -1"); - CHECK_EQ(insns[insns.size() - 1], "RETURN R1 -1"); + std::string fragment; + for (size_t i = 9; i > 1; --i) + { + fragment += std::string(insns[insns.size() - i]); + fragment += "\n"; + } + + // note: it's important that GETGLOBAL below doesn't overwrite R2 + CHECK_EQ("\n" + fragment, R"( +LOADN R1 1024 +LOADK R2 K1023 +SETTABLE R2 R0 R1 +LOADN R2 -1 +FASTCALL1 2 R2 L0 +GETGLOBAL R3 K1024 +GETTABLEKS R1 R3 K1025 +L0: CALL R1 1 -1 +)"); } TEST_CASE("CompoundAssignment") @@ -3323,36 +3394,47 @@ TEST_CASE("JumpTrampoline") insns.push_back(insn); // FORNPREP and early JUMPs (break) need to go through a trampoline - CHECK_EQ(insns[0], "LOADN R0 0"); - CHECK_EQ(insns[1], "LOADN R3 1"); - CHECK_EQ(insns[2], "LOADN R1 3"); - CHECK_EQ(insns[3], "LOADN R2 1"); - CHECK_EQ(insns[4], "JUMP +1"); - CHECK_EQ(insns[5], "JUMPX +54542"); - CHECK_EQ(insns[6], "FORNPREP R1 -2"); - CHECK_EQ(insns[7], "ADD R0 R0 R3"); - CHECK_EQ(insns[8], "LOADK R4 K0"); - CHECK_EQ(insns[9], "JUMP +1"); - CHECK_EQ(insns[10], "JUMPX +54537"); - CHECK_EQ(insns[11], "JUMPIFLT R4 R0 -2"); - CHECK_EQ(insns[12], "ADD R0 R0 R3"); - CHECK_EQ(insns[13], "LOADK R4 K0"); - CHECK_EQ(insns[14], "JUMP +1"); - CHECK_EQ(insns[15], "JUMPX +54531"); - CHECK_EQ(insns[16], "JUMPIFLT R4 R0 -2"); + std::string head; + for (size_t i = 0; i < 16; ++i) + head += insns[i] + "\n"; + + CHECK_EQ("\n" + head, R"( +LOADN R0 0 +LOADN R3 1 +LOADN R1 3 +LOADN R2 1 +JUMP L1 +L0: JUMPX L14543 +L1: FORNPREP R1 L0 +L2: ADD R0 R0 R3 +LOADK R4 K0 +JUMP L4 +L3: JUMPX L14543 +L4: JUMPIFLT R4 R0 L3 +ADD R0 R0 R3 +LOADK R4 K0 +JUMP L6 +L5: JUMPX L14543 +)"); // FORNLOOP has to go through a trampoline since the jump is back to the beginning of the function // however, late JUMPs (break) don't need a trampoline since the loop end is really close by - CHECK_EQ(insns[44539], "ADD R0 R0 R3"); - CHECK_EQ(insns[44540], "LOADK R4 K0"); - CHECK_EQ(insns[44541], "JUMPIFLT R4 R0 +8"); - CHECK_EQ(insns[44542], "ADD R0 R0 R3"); - CHECK_EQ(insns[44543], "LOADK R4 K0"); - CHECK_EQ(insns[44544], "JUMPIFLT R4 R0 +4"); - CHECK_EQ(insns[44545], "JUMP +1"); - CHECK_EQ(insns[44546], "JUMPX -54540"); - CHECK_EQ(insns[44547], "FORNLOOP R1 -2"); - CHECK_EQ(insns[44548], "RETURN R0 1"); + std::string tail; + for (size_t i = 44539; i < insns.size(); ++i) + tail += insns[i] + "\n"; + + CHECK_EQ("\n" + tail, R"( +ADD R0 R0 R3 +LOADK R4 K0 +JUMPIFLT R4 R0 L14543 +ADD R0 R0 R3 +LOADK R4 K0 +JUMPIFLT R4 R0 L14543 +JUMP L14542 +L14541: JUMPX L2 +L14542: FORNLOOP R1 L14541 +L14543: RETURN R0 1 +)"); } TEST_CASE("CompileBytecode") @@ -3427,10 +3509,10 @@ local b = obj == 1 )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 +2 +JUMPIFEQK R0 K0 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0(R"( @@ -3439,10 +3521,10 @@ local b = 1 == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 +2 +JUMPIFEQK R0 K0 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0(R"( @@ -3451,10 +3533,10 @@ local b = "Hello, Sailor!" == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 +2 +JUMPIFEQK R0 K0 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0(R"( @@ -3463,10 +3545,10 @@ local b = nil == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 +2 +JUMPIFEQK R0 K0 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0(R"( @@ -3475,10 +3557,10 @@ local b = true == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 +2 +JUMPIFEQK R0 K0 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0(R"( @@ -3487,10 +3569,10 @@ local b = nil ~= obj )"), R"( GETVARARGS R0 1 -JUMPIFNOTEQK R0 K0 +2 +JUMPIFNOTEQK R0 K0 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); // table literals should not generate IFEQK variants @@ -3501,10 +3583,10 @@ local b = obj == {} R"( GETVARARGS R0 1 NEWTABLE R2 0 0 -JUMPIFEQ R0 R2 +2 +JUMPIFEQ R0 R2 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); } @@ -3566,13 +3648,13 @@ end R"( 2: COVERAGE 2: GETIMPORT R0 1 -2: JUMPIFNOT R0 +6 +2: JUMPIFNOT R0 L0 3: COVERAGE 3: GETIMPORT R0 3 3: LOADN R1 1 3: CALL R0 1 0 7: RETURN R0 0 -5: COVERAGE +5: L0: COVERAGE 5: GETIMPORT R0 3 5: LOADN R1 2 5: CALL R0 1 0 @@ -3594,13 +3676,13 @@ end R"( 2: COVERAGE 2: GETIMPORT R0 1 -2: JUMPIFNOT R0 +6 +2: JUMPIFNOT R0 L0 4: COVERAGE 4: GETIMPORT R0 3 4: LOADN R1 1 4: CALL R0 1 0 9: RETURN R0 0 -7: COVERAGE +7: L0: COVERAGE 7: GETIMPORT R0 3 7: LOADN R1 2 7: CALL R0 1 0 @@ -3833,32 +3915,32 @@ end LOADN R2 1 LOADN R0 10 LOADN R1 1 -FORNPREP R0 +6 -GETIMPORT R3 1 +FORNPREP R0 L1 +L0: GETIMPORT R3 1 NEWCLOSURE R4 P0 CAPTURE VAL R2 CALL R3 1 0 -FORNLOOP R0 -6 -GETIMPORT R0 3 +FORNLOOP R0 L0 +L1: GETIMPORT R0 3 GETVARARGS R1 -1 CALL R0 -1 3 -FORGPREP_NEXT R0 +5 -GETIMPORT R5 1 +FORGPREP_NEXT R0 L3 +L2: GETIMPORT R5 1 NEWCLOSURE R6 P1 CAPTURE VAL R3 CALL R5 1 0 -FORGLOOP_NEXT R0 -6 +L3: FORGLOOP R0 L2 2 LOADN R2 1 LOADN R0 10 LOADN R1 1 -FORNPREP R0 +7 -MOVE R3 R2 +FORNPREP R0 L5 +L4: MOVE R3 R2 GETIMPORT R4 1 NEWCLOSURE R5 P2 CAPTURE VAL R3 CALL R4 1 0 -FORNLOOP R0 -7 -RETURN R0 0 +FORNLOOP R0 L4 +L5: RETURN R0 0 )"); } @@ -3973,9 +4055,9 @@ TEST_CASE("VectorFastCall") LOADN R1 1 LOADN R2 2 LOADN R3 3 -FASTCALL 54 +2 +FASTCALL 54 L0 GETIMPORT R0 2 -CALL R0 3 -1 +L0: CALL R0 3 -1 RETURN R0 -1 )"); } @@ -4043,4 +4125,1635 @@ RETURN R1 6 )"); } +TEST_CASE("LoopUnrollBasic") +{ + // forward loops + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,2 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 2 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 2 +SETTABLEN R1 R0 2 +RETURN R0 1 +)"); + + // backward loops + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=2,1,-1 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 2 +SETTABLEN R1 R0 2 +LOADN R1 1 +SETTABLEN R1 R0 1 +RETURN R0 1 +)"); + + // loops with step that doesn't divide to-from + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,4,2 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 3 +SETTABLEN R1 R0 3 +RETURN R0 1 +)"); + + // empty loops + CHECK_EQ("\n" + compileFunction(R"( +for i=2,1 do +end +)", + 0, 2), + R"( +RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollNested") +{ + // we can unroll nested loops just fine + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=0,1 do + for j=0,1 do + t[i*2+(j+1)] = 0 + end +end +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 0 +SETTABLEN R1 R0 1 +LOADN R1 0 +SETTABLEN R1 R0 2 +LOADN R1 0 +SETTABLEN R1 R0 3 +LOADN R1 0 +SETTABLEN R1 R0 4 +RETURN R0 0 +)"); + + // if the inner loop is too expensive, we won't unroll the outer loop though, but we'll still unroll the inner loop! + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=0,3 do + for j=0,3 do + t[i*4+(j+1)] = 0 + end +end +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R3 0 +LOADN R1 3 +LOADN R2 1 +FORNPREP R1 L1 +L0: MULK R5 R3 K1 +ADDK R4 R5 K0 +LOADN R5 0 +SETTABLE R5 R0 R4 +MULK R5 R3 K1 +ADDK R4 R5 K2 +LOADN R5 0 +SETTABLE R5 R0 R4 +MULK R5 R3 K1 +ADDK R4 R5 K3 +LOADN R5 0 +SETTABLE R5 R0 R4 +MULK R5 R3 K1 +ADDK R4 R5 K1 +LOADN R5 0 +SETTABLE R5 R0 R4 +FORNLOOP R1 L0 +L1: RETURN R0 0 +)"); + + // note, we sometimes can even unroll a loop with varying internal iterations + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=0,1 do + for j=0,i do + t[i*2+(j+1)] = 0 + end +end +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 0 +SETTABLEN R1 R0 1 +LOADN R1 0 +SETTABLEN R1 R0 3 +LOADN R1 0 +SETTABLEN R1 R0 4 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollUnsupported") +{ + // can't unroll loops with non-constant bounds + CHECK_EQ("\n" + compileFunction(R"( +for i=x,y,z do +end +)", + 0, 2), + R"( +GETIMPORT R2 1 +GETIMPORT R0 3 +GETIMPORT R1 5 +FORNPREP R0 L1 +L0: FORNLOOP R0 L0 +L1: RETURN R0 0 +)"); + + // can't unroll loops with bounds where we can't compute trip count + CHECK_EQ("\n" + compileFunction(R"( +for i=1,1,0 do +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 1 +LOADN R1 0 +FORNPREP R0 L1 +L0: FORNLOOP R0 L0 +L1: RETURN R0 0 +)"); + + // can't unroll loops with bounds that might be imprecise (non-integer) + CHECK_EQ("\n" + compileFunction(R"( +for i=1,2,0.1 do +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 2 +LOADK R1 K0 +FORNPREP R0 L1 +L0: FORNLOOP R0 L0 +L1: RETURN R0 0 +)"); + + // can't unroll loops if the bounds are too large, as it might overflow trip count math + CHECK_EQ("\n" + compileFunction(R"( +for i=4294967295,4294967296 do +end +)", + 0, 2), + R"( +LOADK R2 K0 +LOADK R0 K1 +LOADN R1 1 +FORNPREP R0 L1 +L0: FORNLOOP R0 L0 +L1: RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollControlFlow") +{ + ScopedFastInt sfis[] = { + {"LuauCompileLoopUnrollThreshold", 50}, + {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, + }; + + // break jumps to the end + CHECK_EQ("\n" + compileFunction(R"( +for i=1,3 do + if math.random() < 0.5 then + break + end +end +)", + 0, 2), + R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L0 +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L0 +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L0 +L0: RETURN R0 0 +)"); + + // continue jumps to the next iteration + CHECK_EQ("\n" + compileFunction(R"( +for i=1,3 do + if math.random() < 0.5 then + continue + end + print(i) +end +)", + 0, 2), + R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L0 +GETIMPORT R0 5 +LOADN R1 1 +CALL R0 1 0 +L0: GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L1 +GETIMPORT R0 5 +LOADN R1 2 +CALL R0 1 0 +L1: GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L2 +GETIMPORT R0 5 +LOADN R1 3 +CALL R0 1 0 +L2: RETURN R0 0 +)"); + + // continue needs to properly close upvalues + CHECK_EQ("\n" + compileFunction(R"( +for i=1,1 do + local j = math.abs(i) + print(function() return j end) + if math.random() < 0.5 then + continue + end + j += 1 +end +)", + 1, 2), + R"( +LOADN R1 1 +FASTCALL1 2 R1 L0 +GETIMPORT R0 2 +L0: CALL R0 1 1 +GETIMPORT R1 4 +NEWCLOSURE R2 P0 +CAPTURE REF R0 +CALL R1 1 0 +GETIMPORT R1 6 +CALL R1 0 1 +LOADK R2 K7 +JUMPIFNOTLT R1 R2 L1 +CLOSEUPVALS R0 +RETURN R0 0 +L1: ADDK R0 R0 K8 +CLOSEUPVALS R0 +RETURN R0 0 +)"); + + // this weird contraption just disappears + CHECK_EQ("\n" + compileFunction(R"( +for i=1,1 do + for j=1,1 do + if i == 1 then + continue + else + break + end + end +end +)", + 0, 2), + R"( +RETURN R0 0 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollNestedClosure") +{ + // if the body has functions that refer to loop variables, we unroll the loop and use MOVE+CAPTURE for upvalues + CHECK_EQ("\n" + compileFunction(R"( +for i=1,2 do + local x = function() return i end +end +)", + 1, 2), + R"( +LOADN R1 1 +NEWCLOSURE R0 P0 +CAPTURE VAL R1 +LOADN R1 2 +NEWCLOSURE R0 P0 +CAPTURE VAL R1 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollCost") +{ + ScopedFastInt sfis[] = { + {"LuauCompileLoopUnrollThreshold", 25}, + {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, + }; + + // loops with short body + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,10 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 10 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 2 +SETTABLEN R1 R0 2 +LOADN R1 3 +SETTABLEN R1 R0 3 +LOADN R1 4 +SETTABLEN R1 R0 4 +LOADN R1 5 +SETTABLEN R1 R0 5 +LOADN R1 6 +SETTABLEN R1 R0 6 +LOADN R1 7 +SETTABLEN R1 R0 7 +LOADN R1 8 +SETTABLEN R1 R0 8 +LOADN R1 9 +SETTABLEN R1 R0 9 +LOADN R1 10 +SETTABLEN R1 R0 10 +RETURN R0 1 +)"); + + // loops with body that's too long + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,100 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R3 1 +LOADN R1 100 +LOADN R2 1 +FORNPREP R1 L1 +L0: SETTABLE R3 R0 R3 +FORNLOOP R1 L0 +L1: RETURN R0 1 +)"); + + // loops with body that's long but has a high boost factor due to constant folding + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,25 do + t[i] = i * i * i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 8 +SETTABLEN R1 R0 2 +LOADN R1 27 +SETTABLEN R1 R0 3 +LOADN R1 64 +SETTABLEN R1 R0 4 +LOADN R1 125 +SETTABLEN R1 R0 5 +LOADN R1 216 +SETTABLEN R1 R0 6 +LOADN R1 343 +SETTABLEN R1 R0 7 +LOADN R1 512 +SETTABLEN R1 R0 8 +LOADN R1 729 +SETTABLEN R1 R0 9 +LOADN R1 1000 +SETTABLEN R1 R0 10 +LOADN R1 1331 +SETTABLEN R1 R0 11 +LOADN R1 1728 +SETTABLEN R1 R0 12 +LOADN R1 2197 +SETTABLEN R1 R0 13 +LOADN R1 2744 +SETTABLEN R1 R0 14 +LOADN R1 3375 +SETTABLEN R1 R0 15 +LOADN R1 4096 +SETTABLEN R1 R0 16 +LOADN R1 4913 +SETTABLEN R1 R0 17 +LOADN R1 5832 +SETTABLEN R1 R0 18 +LOADN R1 6859 +SETTABLEN R1 R0 19 +LOADN R1 8000 +SETTABLEN R1 R0 20 +LOADN R1 9261 +SETTABLEN R1 R0 21 +LOADN R1 10648 +SETTABLEN R1 R0 22 +LOADN R1 12167 +SETTABLEN R1 R0 23 +LOADN R1 13824 +SETTABLEN R1 R0 24 +LOADN R1 15625 +SETTABLEN R1 R0 25 +RETURN R0 1 +)"); + + // loops with body that's long and doesn't have a high boost factor + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,10 do + t[i] = math.abs(math.sin(i)) +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 10 +LOADN R3 1 +LOADN R1 10 +LOADN R2 1 +FORNPREP R1 L3 +L0: FASTCALL1 24 R3 L1 +MOVE R6 R3 +GETIMPORT R5 2 +L1: CALL R5 1 -1 +FASTCALL 2 L2 +GETIMPORT R4 4 +L2: CALL R4 -1 1 +SETTABLE R4 R0 R3 +FORNLOOP R1 L0 +L3: RETURN R0 1 +)"); +} + +TEST_CASE("LoopUnrollMutable") +{ + // can't unroll loops that mutate iteration variable + CHECK_EQ("\n" + compileFunction(R"( +for i=1,3 do + i = 3 + print(i) -- should print 3 three times in a row +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 3 +LOADN R1 1 +FORNPREP R0 L1 +L0: MOVE R3 R2 +LOADN R3 3 +GETIMPORT R4 1 +MOVE R5 R3 +CALL R4 1 0 +FORNLOOP R0 L0 +L1: RETURN R0 0 +)"); +} + +TEST_CASE("InlineBasic") +{ + // inline function that returns a constant + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return 42 +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // inline function that returns the argument + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // inline function that returns one of the two arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b, c) + if a then + return b + else + return c + end +end + +local x = foo(true, math.random(), 5) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETIMPORT R2 3 +CALL R2 0 1 +MOVE R1 R2 +RETURN R1 1 +RETURN R1 1 +)"); + + // inline function that returns one of the two arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b, c) + if a then + return b + else + return c + end +end + +local x = foo(true, 5, math.random()) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETIMPORT R2 3 +CALL R2 0 1 +LOADN R1 5 +RETURN R1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineBasicProhibited") +{ + // we can't inline variadic functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(...) + return 42 +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineNestedLoops") +{ + // functions with basic loops get inlined + CHECK_EQ("\n" + compileFunction(R"( +local function foo(t) + for i=1,3 do + t[i] = i + end + return t +end + +local x = foo({}) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +NEWTABLE R2 0 0 +LOADN R3 1 +SETTABLEN R3 R2 1 +LOADN R3 2 +SETTABLEN R3 R2 2 +LOADN R3 3 +SETTABLEN R3 R2 3 +MOVE R1 R2 +RETURN R1 1 +)"); + + // we can even unroll the loops based on inline argument + CHECK_EQ("\n" + compileFunction(R"( +local function foo(t, n) + for i=1, n do + t[i] = i + end + return t +end + +local x = foo({}, 3) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +NEWTABLE R2 0 0 +LOADN R3 1 +SETTABLEN R3 R2 1 +LOADN R3 2 +SETTABLEN R3 R2 2 +LOADN R3 3 +SETTABLEN R3 R2 3 +MOVE R1 R2 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineNestedClosures") +{ + // we can inline functions that contain/return functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(x) + return function(y) return x + y end +end + +local x = foo(1)(2) +return x +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 1 +NEWCLOSURE R1 P1 +CAPTURE VAL R2 +LOADN R2 2 +CALL R1 1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineMutate") +{ + // if the argument is mutated, it gets a register even if the value is constant + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 5 + return a +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 42 +ORK R2 R2 K1 +MOVE R1 R2 +RETURN R1 1 +)"); + + // if the argument is a local, it can be used directly + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = ... +local y = foo(x) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R2 R1 +RETURN R2 1 +)"); + + // ... but if it's mutated, we move it in case it is mutated through a capture during the inlined function + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = ... +x = nil +local y = foo(x) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +LOADNIL R1 +MOVE R3 R1 +MOVE R2 R3 +RETURN R2 1 +)"); + + // we also don't inline functions if they have been assigned to + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +foo = foo + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R0 R0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineUpval") +{ + // if the argument is an upvalue, we naturally need to copy it to a local + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local b = ... + +function bar() + local x = foo(b) + return x +end +)", + 1, 2), + R"( +GETUPVAL R1 0 +MOVE R0 R1 +RETURN R0 1 +)"); + + // if the function uses an upvalue it's more complicated, because the lexical upvalue may become a local + CHECK_EQ("\n" + compileFunction(R"( +local b = ... + +local function foo(a) + return a + b +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +GETVARARGS R0 1 +DUPCLOSURE R1 K0 +CAPTURE VAL R0 +LOADN R3 42 +ADD R2 R3 R0 +RETURN R2 1 +)"); + + // sometimes the lexical upvalue is deep enough that it's still an upvalue though + CHECK_EQ("\n" + compileFunction(R"( +local b = ... + +function bar() + local function foo(a) + return a + b + end + + local x = foo(42) + return x +end +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +CAPTURE UPVAL U0 +LOADN R2 42 +GETUPVAL R3 0 +ADD R1 R2 R3 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineCapture") +{ + // if the argument is captured by a nested closure, normally we can rely on capture by value + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local x = ... +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +NEWCLOSURE R2 P1 +CAPTURE VAL R1 +RETURN R2 1 +)"); + + // if the argument is a constant, we move it to a register so that capture by value can happen + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local y = foo(42) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 42 +NEWCLOSURE R1 P1 +CAPTURE VAL R2 +RETURN R1 1 +)"); + + // if the argument is an externally mutated variable, we copy it to an argument and capture it by value + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local x x = 42 +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADN R1 42 +MOVE R3 R1 +NEWCLOSURE R2 P1 +CAPTURE VAL R3 +RETURN R2 1 +)"); + + // finally, if the argument is mutated internally, we must capture it by reference and close the upvalue + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 42 + return function() return a end +end + +local y = foo() +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R2 +ORK R2 R2 K1 +NEWCLOSURE R1 P1 +CAPTURE REF R2 +CLOSEUPVALS R2 +RETURN R1 1 +)"); + + // note that capture might need to be performed during the fallthrough block + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 42 + print(function() return a end) +end + +local x = ... +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R3 R1 +ORK R3 R3 K1 +GETIMPORT R4 3 +NEWCLOSURE R5 P1 +CAPTURE REF R3 +CALL R4 1 0 +LOADNIL R2 +CLOSEUPVALS R3 +RETURN R2 1 +)"); + + // note that mutation and capture might be inside internal control flow + // TODO: this has an oddly redundant CLOSEUPVALS after JUMP; it's not due to inlining, and is an artifact of how StatBlock/StatReturn interact + // fixing this would reduce the number of redundant CLOSEUPVALS a bit but it only affects bytecode size as these instructions aren't executed + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + if not a then + local b b = 42 + return function() return b end + end +end + +local x = ... +local y = foo(x) +return y, x +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +JUMPIF R1 L0 +LOADNIL R3 +LOADN R3 42 +NEWCLOSURE R2 P1 +CAPTURE REF R3 +CLOSEUPVALS R3 +JUMP L1 +CLOSEUPVALS R3 +L0: LOADNIL R2 +L1: MOVE R3 R2 +MOVE R4 R1 +RETURN R3 2 +)"); +} + +TEST_CASE("InlineFallthrough") +{ + // if the function doesn't return, we still fill the results with nil + CHECK_EQ("\n" + compileFunction(R"( +local function foo() +end + +local a, b = foo() +return a, b +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADNIL R2 +RETURN R1 2 +)"); + + // this happens even if the function returns conditionally + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + if a then return 42 end +end + +local a, b = foo(false) +return a, b +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADNIL R2 +RETURN R1 2 +)"); + + // note though that we can't inline a function like this in multret context + // this is because we don't have a SETTOP instruction + CHECK_EQ("\n" + compileFunction(R"( +local function foo() +end + +return foo() +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 -1 +RETURN R1 -1 +)"); +} + +TEST_CASE("InlineArgMismatch") +{ + // when inlining a function, we must respect all the usual rules + + // caller might not have enough arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +RETURN R1 1 +)"); + + // caller might be using multret for arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x = foo(math.modf(1.5)) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADK R3 K1 +FASTCALL1 20 R3 L0 +GETIMPORT R2 4 +L0: CALL R2 1 2 +ADD R1 R2 R3 +RETURN R1 1 +)"); + + // caller might be using varargs for arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x = foo(...) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R2 2 +ADD R1 R2 R3 +RETURN R1 1 +)"); + + // caller might have too many arguments, but we still need to compute them for side effects + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo(42, print()) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETIMPORT R2 2 +CALL R2 0 1 +LOADN R1 42 +RETURN R1 1 +)"); + + // caller might not have enough arguments, and the arg might be mutated so it needs a register + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = 42 + return a +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R2 +LOADN R2 42 +MOVE R1 R2 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineMultiple") +{ + // we call this with a different set of variable/constant args + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x, y = ... +local a = foo(x, 1) +local b = foo(1, x) +local c = foo(1, 2) +local d = foo(x, y) +return a, b, c, d +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 2 +ADDK R3 R1 K1 +LOADN R5 1 +ADD R4 R5 R1 +LOADN R5 3 +ADD R6 R1 R2 +RETURN R3 4 +)"); +} + +TEST_CASE("InlineChain") +{ + // inline a chain of functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local function bar(x) + return foo(x, 1) * foo(x, -1) +end + +local function baz() + return (bar(42)) +end + +return (baz()) +)", + 3, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +DUPCLOSURE R2 K2 +LOADN R4 43 +LOADN R5 41 +MUL R3 R4 R5 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineThresholds") +{ + ScopedFastInt sfis[] = { + {"LuauCompileInlineThreshold", 25}, + {"LuauCompileInlineThresholdMaxBoost", 300}, + {"LuauCompileInlineDepth", 2}, + }; + + // this function has enormous register pressure (50 regs) so we choose not to inline it + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return {{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}} +end + +return (foo()) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // this function has less register pressure but a large cost + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return {},{},{},{},{} +end + +return (foo()) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // this chain of function is of length 3 but our limit in this test is 2, so we call foo twice + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local function bar(x) + return foo(x, 1) * foo(x, -1) +end + +local function baz() + return (bar(42)) +end + +return (baz()) +)", + 3, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +DUPCLOSURE R2 K2 +MOVE R4 R0 +LOADN R5 42 +LOADN R6 1 +CALL R4 2 1 +MOVE R5 R0 +LOADN R6 42 +LOADN R7 -1 +CALL R5 2 1 +MUL R3 R4 R5 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineIIFE") +{ + // IIFE with arguments + CHECK_EQ("\n" + compileFunction(R"( +function choose(a, b, c) + return ((function(a, b, c) if a then return b else return c end end)(a, b, c)) +end +)", + 1, 2), + R"( +JUMPIFNOT R0 L0 +MOVE R3 R1 +RETURN R3 1 +L0: MOVE R3 R2 +RETURN R3 1 +RETURN R3 1 +)"); + + // IIFE with upvalues + CHECK_EQ("\n" + compileFunction(R"( +function choose(a, b, c) + return ((function() if a then return b else return c end end)()) +end +)", + 1, 2), + R"( +JUMPIFNOT R0 L0 +MOVE R3 R1 +RETURN R3 1 +L0: MOVE R3 R2 +RETURN R3 1 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineRecurseArguments") +{ + // we can't inline a function if it's used to compute its own arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) +end +foo(foo(foo,foo(foo,foo))[foo]) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R2 R0 +MOVE R3 R0 +MOVE R4 R0 +MOVE R5 R0 +MOVE R6 R0 +CALL R4 2 -1 +CALL R2 -1 1 +GETTABLE R1 R2 R0 +RETURN R0 0 +)"); +} + +TEST_CASE("InlineFastCallK") +{ + CHECK_EQ("\n" + compileFunction(R"( +local function set(l0) + rawset({}, l0) +end + +set(false) +set({}) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +NEWTABLE R2 0 0 +FASTCALL2K 49 R2 K1 L0 +LOADK R3 K1 +GETIMPORT R1 3 +L0: CALL R1 2 0 +NEWTABLE R1 0 0 +NEWTABLE R3 0 0 +FASTCALL2 49 R3 R1 L1 +MOVE R4 R1 +GETIMPORT R2 3 +L1: CALL R2 2 0 +RETURN R0 0 +)"); +} + +TEST_CASE("InlineExprIndexK") +{ + CHECK_EQ("\n" + compileFunction(R"( +local _ = function(l0) +local _ = nil +while _(_)[_] do +end +end +local _ = _(0)[""] +if _ then +do +for l0=0,8 do +end +end +elseif _ then +_ = nil +do +for l0=0,8 do +return true +end +end +end +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +L0: LOADNIL R4 +LOADNIL R5 +CALL R4 1 1 +LOADNIL R5 +GETTABLE R3 R4 R5 +JUMPIFNOT R3 L1 +JUMPBACK L0 +L1: LOADNIL R2 +GETTABLEKS R1 R2 K1 +JUMPIFNOT R1 L2 +RETURN R0 0 +L2: JUMPIFNOT R1 L3 +LOADNIL R1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +L3: RETURN R0 0 +)"); +} + +TEST_CASE("InlineHiddenMutation") +{ + // when the argument is assigned inside the function, we can't reuse the local + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = 42 + return a +end + +local x = ... +local y = foo(x :: number) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R3 R1 +LOADN R3 42 +MOVE R2 R3 +RETURN R2 1 +)"); + + // and neither can we do that when it's assigned outside the function + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + mutator() + return a +end + +local x = ... +mutator = function() x = 42 end + +local y = foo(x :: number) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +NEWCLOSURE R2 P1 +CAPTURE REF R1 +SETGLOBAL R2 K1 +MOVE R3 R1 +GETGLOBAL R4 K1 +CALL R4 0 0 +MOVE R2 R3 +CLOSEUPVALS R1 +RETURN R2 1 +)"); +} + +TEST_CASE("InlineMultret") +{ + // inlining a function in multret context is prohibited since we can't adjust L->top outside of CALL/GETVARARGS + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a() +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // however, if we can deduce statically that a function always returns a single value, the inlining will work + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // this analysis will also propagate through other functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local function bar(a) + return foo(a) +end + +return bar(42) +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +LOADN R2 42 +RETURN R2 1 +)"); + + // we currently don't do this analysis fully for recursive functions since they can't be inlined anyway + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return foo(a) +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +CAPTURE VAL R0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // and unfortunately we can't do this analysis for builtins or method calls due to getfenv + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return math.abs(a) +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); +} + +TEST_CASE("ReturnConsecutive") +{ + // we can return a single local directly + CHECK_EQ("\n" + compileFunction0(R"( +local x = ... +return x +)"), + R"( +GETVARARGS R0 1 +RETURN R0 1 +)"); + + // or multiple, when they are allocated in consecutive registers + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return x, y +)"), + R"( +GETVARARGS R0 2 +RETURN R0 2 +)"); + + // but not if it's an expression + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return x, y + 1 +)"), + R"( +GETVARARGS R0 2 +MOVE R2 R0 +ADDK R3 R1 K0 +RETURN R2 2 +)"); + + // or a local with wrong register number + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return y, x +)"), + R"( +GETVARARGS R0 2 +MOVE R2 R1 +MOVE R3 R0 +RETURN R2 2 +)"); + + // also double check the optimization doesn't trip on no-argument return (these are rare) + CHECK_EQ("\n" + compileFunction0(R"( +return +)"), + R"( +RETURN R0 0 +)"); + + // this optimization also works in presence of group / type casts + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return (x), y :: number +)"), + R"( +GETVARARGS R0 2 +RETURN R0 2 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 83d4518d..96a2775f 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -241,9 +241,17 @@ TEST_CASE("Math") TEST_CASE("Table") { - ScopedFastFlag sff("LuauTableClone", true); - - runConformance("nextvar.lua"); + runConformance("nextvar.lua", [](lua_State* L) { + lua_pushcfunction( + L, + [](lua_State* L) { + unsigned v = luaL_checkunsigned(L, 1); + lua_pushlightuserdata(L, reinterpret_cast(uintptr_t(v))); + return 1; + }, + "makelud"); + lua_setglobal(L, "makelud"); + }); } TEST_CASE("PatternMatch") @@ -467,8 +475,6 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { - ScopedFastFlag sff("LuauTableCloneType", true); - runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; @@ -735,7 +741,7 @@ TEST_CASE("ApiTables") lua_pop(L, 1); } -TEST_CASE("ApiFunctionCalls") +TEST_CASE("ApiCalls") { StateRef globalState = runConformance("apicalls.lua"); lua_State* L = globalState.get(); @@ -784,6 +790,58 @@ TEST_CASE("ApiFunctionCalls") CHECK(lua_equal(L2, -1, -2) == 1); lua_pop(L2, 2); } + + // lua_clonefunction + fenv + { + lua_getfield(L, LUA_GLOBALSINDEX, "getpi"); + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3.1415926); + lua_pop(L, 1); + + lua_getfield(L, LUA_GLOBALSINDEX, "getpi"); + + // clone & override env + lua_clonefunction(L, -1); + lua_newtable(L); + lua_pushnumber(L, 42); + lua_setfield(L, -2, "pi"); + lua_setfenv(L, -2); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + + // this one calls original function again + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3.1415926); + lua_pop(L, 1); + } + + // lua_clonefunction + upvalues + { + lua_getfield(L, LUA_GLOBALSINDEX, "incuv"); + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 1); + lua_pop(L, 1); + + lua_getfield(L, LUA_GLOBALSINDEX, "incuv"); + // two clones + lua_clonefunction(L, -1); + lua_clonefunction(L, -2); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 2); + lua_pop(L, 1); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3); + lua_pop(L, 1); + + // this one calls original function again + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 4); + lua_pop(L, 1); + } } static bool endsWith(const std::string& str, const std::string& suffix) @@ -1060,7 +1118,7 @@ TEST_CASE("UserdataApi") lua_State* L = globalState.get(); // setup dtor for tag 42 (created later) - lua_setuserdatadtor(L, 42, [](void* data) { + lua_setuserdatadtor(L, 42, [](lua_State* l, void* data) { dtorhits += *(int*)data; }); @@ -1068,6 +1126,7 @@ TEST_CASE("UserdataApi") int lud; lua_pushlightuserdata(L, &lud); + CHECK(lua_tolightuserdata(L, -1) == &lud); CHECK(lua_touserdata(L, -1) == &lud); CHECK(lua_topointer(L, -1) == &lud); @@ -1075,6 +1134,7 @@ TEST_CASE("UserdataApi") int* ud1 = (int*)lua_newuserdata(L, 4); *ud1 = 42; + CHECK(lua_tolightuserdata(L, -1) == nullptr); CHECK(lua_touserdata(L, -1) == ud1); CHECK(lua_topointer(L, -1) == ud1); @@ -1103,4 +1163,210 @@ TEST_CASE("UserdataApi") CHECK(dtorhits == 42); } +TEST_CASE("Iter") +{ + runConformance("iter.lua"); +} + +const int kInt64Tag = 1; +static int gInt64MT = -1; + +static int64_t getInt64(lua_State* L, int idx) +{ + if (void* p = lua_touserdatatagged(L, idx, kInt64Tag)) + return *static_cast(p); + + if (lua_isnumber(L, idx)) + return lua_tointeger(L, idx); + + luaL_typeerror(L, 1, "int64"); +} + +static void pushInt64(lua_State* L, int64_t value) +{ + void* p = lua_newuserdatatagged(L, sizeof(int64_t), kInt64Tag); + + lua_getref(L, gInt64MT); + lua_setmetatable(L, -2); + + *static_cast(p) = value; +} + +TEST_CASE("Userdata") +{ + runConformance("userdata.lua", [](lua_State* L) { + // create metatable with all the metamethods + lua_newtable(L); + gInt64MT = lua_ref(L, -1); + + // __index + lua_pushcfunction( + L, + [](lua_State* L) { + void* p = lua_touserdatatagged(L, 1, kInt64Tag); + if (!p) + luaL_typeerror(L, 1, "int64"); + + const char* name = luaL_checkstring(L, 2); + + if (strcmp(name, "value") == 0) + { + lua_pushnumber(L, double(*static_cast(p))); + return 1; + } + + luaL_error(L, "unknown field %s", name); + }, + nullptr); + lua_setfield(L, -2, "__index"); + + // __newindex + lua_pushcfunction( + L, + [](lua_State* L) { + void* p = lua_touserdatatagged(L, 1, kInt64Tag); + if (!p) + luaL_typeerror(L, 1, "int64"); + + const char* name = luaL_checkstring(L, 2); + + if (strcmp(name, "value") == 0) + { + double value = luaL_checknumber(L, 3); + *static_cast(p) = int64_t(value); + return 0; + } + + luaL_error(L, "unknown field %s", name); + }, + nullptr); + lua_setfield(L, -2, "__newindex"); + + // __eq + lua_pushcfunction( + L, + [](lua_State* L) { + lua_pushboolean(L, getInt64(L, 1) == getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__eq"); + + // __lt + lua_pushcfunction( + L, + [](lua_State* L) { + lua_pushboolean(L, getInt64(L, 1) < getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__lt"); + + // __le + lua_pushcfunction( + L, + [](lua_State* L) { + lua_pushboolean(L, getInt64(L, 1) <= getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__le"); + + // __add + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, getInt64(L, 1) + getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__add"); + + // __sub + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, getInt64(L, 1) - getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__sub"); + + // __mul + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, getInt64(L, 1) * getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__mul"); + + // __div + lua_pushcfunction( + L, + [](lua_State* L) { + // ideally we'd guard against 0 but it's a test so eh + pushInt64(L, getInt64(L, 1) / getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__div"); + + // __mod + lua_pushcfunction( + L, + [](lua_State* L) { + // ideally we'd guard against 0 and INT64_MIN but it's a test so eh + pushInt64(L, getInt64(L, 1) % getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__mod"); + + // __pow + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, int64_t(pow(double(getInt64(L, 1)), double(getInt64(L, 2))))); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__pow"); + + // __unm + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, -getInt64(L, 1)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__unm"); + + // __tostring + lua_pushcfunction( + L, + [](lua_State* L) { + int64_t value = getInt64(L, 1); + std::string str = std::to_string(value); + lua_pushlstring(L, str.c_str(), str.length()); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__tostring"); + + // ctor + lua_pushcfunction( + L, + [](lua_State* L) { + double v = luaL_checknumber(L, 1); + pushInt64(L, int64_t(v)); + return 1; + }, + "int64"); + lua_setglobal(L, "int64"); + }); +} + TEST_SUITE_END(); diff --git a/tests/ConstraintGraphBuilder.test.cpp b/tests/ConstraintGraphBuilder.test.cpp new file mode 100644 index 00000000..96b21613 --- /dev/null +++ b/tests/ConstraintGraphBuilder.test.cpp @@ -0,0 +1,126 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" +#include "Luau/ConstraintGraphBuilder.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("ConstraintGraphBuilder"); + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello_world") +{ + AstStatBlock* block = parse(R"( + local a = "hello" + local b = a + )"); + + cgb.visit(block); + auto constraints = collectConstraints(cgb.rootScope); + + REQUIRE(2 == constraints.size()); + + ToStringOptions opts; + CHECK("string <: a" == toString(*constraints[0], opts)); + CHECK("a <: b" == toString(*constraints[1], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives") +{ + AstStatBlock* block = parse(R"( + local s = "hello" + local n = 555 + local b = true + local n2 = nil + )"); + + cgb.visit(block); + auto constraints = collectConstraints(cgb.rootScope); + + REQUIRE(3 == constraints.size()); + + ToStringOptions opts; + CHECK("string <: a" == toString(*constraints[0], opts)); + CHECK("number <: b" == toString(*constraints[1], opts)); + CHECK("boolean <: c" == toString(*constraints[2], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "nil_primitive") +{ + AstStatBlock* block = parse(R"( + local function a() return nil end + local b = a() + )"); + + cgb.visit(block); + auto constraints = collectConstraints(cgb.rootScope); + + ToStringOptions opts; + REQUIRE(5 <= constraints.size()); + + CHECK("*blocked-1* ~ gen () -> (a...)" == toString(*constraints[0], opts)); + CHECK("b ~ inst *blocked-1*" == toString(*constraints[1], opts)); + CHECK("() -> (c...) <: b" == toString(*constraints[2], opts)); + CHECK("c... <: d" == toString(*constraints[3], opts)); + CHECK("nil <: a..." == toString(*constraints[4], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") +{ + AstStatBlock* block = parse(R"( + local a = "hello" + local b = a("world") + )"); + + cgb.visit(block); + auto constraints = collectConstraints(cgb.rootScope); + + REQUIRE(4 == constraints.size()); + + ToStringOptions opts; + CHECK("string <: a" == toString(*constraints[0], opts)); + CHECK("b ~ inst a" == toString(*constraints[1], opts)); + CHECK("(string) -> (c...) <: b" == toString(*constraints[2], opts)); + CHECK("c... <: d" == toString(*constraints[3], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") +{ + AstStatBlock* block = parse(R"( + local function f(a) + return a + end + )"); + + cgb.visit(block); + auto constraints = collectConstraints(cgb.rootScope); + + REQUIRE(2 == constraints.size()); + + ToStringOptions opts; + CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); + CHECK("a <: b..." == toString(*constraints[1], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") +{ + AstStatBlock* block = parse(R"( + local function f(a) + return f(a) + end + )"); + + cgb.visit(block); + auto constraints = collectConstraints(cgb.rootScope); + + REQUIRE(4 == constraints.size()); + + ToStringOptions opts; + CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); + CHECK("c ~ inst (a) -> (b...)" == toString(*constraints[1], opts)); + CHECK("(a) -> (d...) <: c" == toString(*constraints[2], opts)); + CHECK("d... <: b..." == toString(*constraints[3], opts)); +} + +TEST_SUITE_END(); diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp new file mode 100644 index 00000000..5959f55c --- /dev/null +++ b/tests/ConstraintSolver.test.cpp @@ -0,0 +1,87 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintSolver.h" + +using namespace Luau; + +static TypeId requireBinding(Scope2* scope, const char* name) +{ + auto b = linearSearchForBinding(scope, name); + LUAU_ASSERT(b.has_value()); + return *b; +} + +TEST_SUITE_BEGIN("ConstraintSolver"); + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") +{ + AstStatBlock* block = parse(R"( + local a = 55 + local b = a + )"); + + cgb.visit(block); + + ConstraintSolver cs{&arena, cgb.rootScope}; + + cs.run(); + + TypeId bType = requireBinding(cgb.rootScope, "b"); + + CHECK("number" == toString(bType)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") +{ + AstStatBlock* block = parse(R"( + local function id(a) + return a + end + )"); + + cgb.visit(block); + + ConstraintSolver cs{&arena, cgb.rootScope}; + + cs.run(); + + TypeId idType = requireBinding(cgb.rootScope, "id"); + + CHECK("(a) -> a" == toString(idType)); +} + +#if 1 +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") +{ + AstStatBlock* block = parse(R"( + local function a(c) + local function d(e) + return c + end + + return d + end + + local b = a(5) + )"); + + cgb.visit(block); + + ToStringOptions opts; + + ConstraintSolver cs{&arena, cgb.rootScope}; + + cs.run(); + + TypeId idType = requireBinding(cgb.rootScope, "b"); + + CHECK("(a) -> number" == toString(idType, opts)); +} +#endif + +TEST_SUITE_END(); diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp new file mode 100644 index 00000000..c709ba8e --- /dev/null +++ b/tests/CostModel.test.cpp @@ -0,0 +1,226 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" + +#include "doctest.h" + +using namespace Luau; + +namespace Luau +{ +namespace Compile +{ + +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); +int computeCost(uint64_t model, const bool* varsConst, size_t varCount); + +} // namespace Compile +} // namespace Luau + +TEST_SUITE_BEGIN("CostModel"); + +static uint64_t modelFunction(const char* source) +{ + Allocator allocator; + AstNameTable names(allocator); + + ParseResult result = Parser::parse(source, strlen(source), names, allocator); + REQUIRE(result.root != nullptr); + + AstStatFunction* func = result.root->body.data[0]->as(); + REQUIRE(func); + + return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size); +} + +TEST_CASE("Expression") +{ + uint64_t model = modelFunction(R"( +function test(a, b, c) + return a + (b + 1) * (b + 1) - c +end +)"); + + const bool args1[] = {false, false, false}; + const bool args2[] = {false, true, false}; + + CHECK_EQ(5, Luau::Compile::computeCost(model, args1, 3)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 3)); +} + +TEST_CASE("PropagateVariable") +{ + uint64_t model = modelFunction(R"( +function test(a) + local b = a * a * a + return b * b +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(0, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("LoopAssign") +{ + uint64_t model = modelFunction(R"( +function test(a) + for i=1,3 do + a[i] = i + end +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + // loop baseline cost is 5 + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("MutableVariable") +{ + uint64_t model = modelFunction(R"( +function test(a, b) + local x = a * a + x += b + return x * x +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("ImportCall") +{ + uint64_t model = modelFunction(R"( +function test(a) + return Instance.new(a) +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("FastCall") +{ + uint64_t model = modelFunction(R"( +function test(a) + return math.abs(a + 1) +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + // note: we currently don't treat fast calls differently from cost model perspective + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(5, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("ControlFlow") +{ + uint64_t model = modelFunction(R"( +function test(a) + while a < 0 do + a += 1 + end + for i=10,1,-1 do + a += 1 + end + for i in pairs({}) do + a += 1 + if a % 2 == 0 then continue end + end + repeat + a += 1 + if a % 2 == 0 then break end + until a > 10 + return a +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(82, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(79, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("Conditional") +{ + uint64_t model = modelFunction(R"( +function test(a) + return if a < 0 then -a else a +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(4, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("VarArgs") +{ + uint64_t model = modelFunction(R"( +function test(...) + return select('#', ...) :: number +end +)"); + + CHECK_EQ(8, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("TablesFunctions") +{ + uint64_t model = modelFunction(R"( +function test() + return { 42, op = function() end } +end +)"); + + CHECK_EQ(22, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("CostOverflow") +{ + uint64_t model = modelFunction(R"( +function test() + return {{{{{{{{{{{{{{{}}}}}}}}}}}}}}} +end +)"); + + CHECK_EQ(127, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("TableAssign") +{ + uint64_t model = modelFunction(R"( +function test(a) + for i=1,#a do + a[i] = i + end +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(7, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index a7e7ea39..ac22f65b 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -17,6 +17,8 @@ static const char* mainModuleName = "MainModule"; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { @@ -83,7 +85,7 @@ std::optional TestFileResolver::getEnvironmentForModule(const Modul return std::nullopt; } -Fixture::Fixture(bool freeze) +Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true}) , typeChecker(frontend.typeChecker) @@ -92,9 +94,8 @@ Fixture::Fixture(bool freeze) configResolver.defaultConfig.enabledLint.warningMask = ~0ull; configResolver.defaultConfig.parseOptions.captureComments = true; - registerBuiltinTypes(frontend.typeChecker); - registerTestTypes(); Luau::freeze(frontend.typeChecker.globalTypes); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); Luau::setPrintLine([](auto s) {}); } @@ -228,7 +229,7 @@ ModulePtr Fixture::getMainModule() SourceModule* Fixture::getMainSourceModule() { - return frontend.getSourceModule(fromString("MainModule")); + return frontend.getSourceModule(fromString(mainModuleName)); } std::optional Fixture::getPrimitiveType(TypeId ty) @@ -250,13 +251,16 @@ std::optional Fixture::getType(const std::string& name) ModulePtr module = getMainModule(); REQUIRE(module); - return lookupName(module->getModuleScope(), name); + if (FFlag::DebugLuauDeferredConstraintResolution) + return linearSearchForBinding(module->getModuleScope2(), name.c_str()); + else + return lookupName(module->getModuleScope(), name); } TypeId Fixture::requireType(const std::string& name) { std::optional ty = getType(name); - REQUIRE(bool(ty)); + REQUIRE_MESSAGE(bool(ty), "Unable to requireType \"" << name << "\""); return follow(*ty); } @@ -341,7 +345,7 @@ void Fixture::dumpErrors(std::ostream& os, const std::vector& errors) if (error.location.begin.line >= lines.size()) { os << "\tSource not available?" << std::endl; - return; + continue; } std::string_view theLine = lines[error.location.begin.line]; @@ -407,6 +411,28 @@ LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) return result; } +BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) + : Fixture(freeze, prepareAutocomplete) +{ + Luau::unfreeze(frontend.typeChecker.globalTypes); + Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); + + registerBuiltinTypes(frontend.typeChecker); + if (prepareAutocomplete) + registerBuiltinTypes(frontend.typeCheckerForAutocomplete); + registerTestTypes(); + + Luau::freeze(frontend.typeChecker.globalTypes); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); +} + +ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() + : Fixture() + , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} +{ + BlockedTypeVar::nextIndex = 0; +} + ModuleName fromString(std::string_view name) { return ModuleName(name); @@ -446,4 +472,27 @@ std::optional lookupName(ScopePtr scope, const std::string& name) return std::nullopt; } +std::optional linearSearchForBinding(Scope2* scope, const char* name) +{ + while (scope) + { + for (const auto& [n, ty] : scope->bindings) + { + if (n.astName() == name) + return ty; + } + + scope = scope->parent; + } + + return std::nullopt; +} + +void dump(const std::vector& constraints) +{ + ToStringOptions opts; + for (const auto& c : constraints) + printf("%s\n", toString(c, opts).c_str()); +} + } // namespace Luau diff --git a/tests/Fixture.h b/tests/Fixture.h index 4e45a952..0e3735f6 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -8,6 +8,7 @@ #include "Luau/Linter.h" #include "Luau/Location.h" #include "Luau/ModuleResolver.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -91,7 +92,7 @@ struct TestConfigResolver : ConfigResolver struct Fixture { - explicit Fixture(bool freeze = true); + explicit Fixture(bool freeze = true, bool prepareAutocomplete = false); ~Fixture(); // Throws Luau::ParseErrors if the parse fails. @@ -151,6 +152,21 @@ struct Fixture LoadDefinitionFileResult loadDefinition(const std::string& source); }; +struct BuiltinsFixture : Fixture +{ + BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); +}; + +struct ConstraintGraphBuilderFixture : Fixture +{ + TypeArena arena; + ConstraintGraphBuilder cgb{&arena}; + + ScopedFastFlag forceTheFlag; + + ConstraintGraphBuilderFixture(); +}; + ModuleName fromString(std::string_view name); template @@ -170,9 +186,12 @@ bool isInArena(TypeId t, const TypeArena& arena); void dumpErrors(const ModulePtr& module); void dumpErrors(const Module& module); void dump(const std::string& name, TypeId ty); +void dump(const std::vector& constraints); std::optional lookupName(ScopePtr scope, const std::string& name); // Warning: This function runs in O(n**2) +std::optional linearSearchForBinding(Scope2* scope, const char* name); + } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 8a59acd1..b9c24704 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -77,7 +77,7 @@ struct NaiveFileResolver : NullFileResolver } // namespace -struct FrontendFixture : Fixture +struct FrontendFixture : BuiltinsFixture { FrontendFixture() { @@ -97,8 +97,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "find_a_require") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(1, res.requires.size()); - CHECK_EQ(res.requires[0].first, "Modules/Foo/Bar"); + CHECK_EQ(1, res.requireList.size()); + CHECK_EQ(res.requireList[0].first, "Modules/Foo/Bar"); } // It could be argued that this should not work. @@ -113,7 +113,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "find_a_require_inside_a_function") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(1, res.requires.size()); + CHECK_EQ(1, res.requireList.size()); } TEST_CASE_FIXTURE(FrontendFixture, "real_source") @@ -138,7 +138,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "real_source") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(8, res.requires.size()); + CHECK_EQ(8, res.requireList.size()); } TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") @@ -384,6 +384,66 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_error_paths") CHECK_EQ(ce2->cycle[1], "game/Gui/Modules/A"); } +TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface") +{ + fileResolver.source["game/A"] = R"( + return {hello = 2} + )"; + + CheckResult result = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/A"] = R"( + local me = require(game.A) + return {hello = 2} + )"; + frontend.markDirty("game/A"); + + result = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(result); + + auto ty = requireType("game/A", "me"); + CHECK_EQ(toString(ty), "any"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer") +{ + fileResolver.source["game/A"] = R"( + return {mod_a = 2} + )"; + + CheckResult result = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/B"] = R"( + local me = require(game.A) + return {mod_b = 4} + )"; + + result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/A"] = R"( + local me = require(game.B) + return {mod_a_prime = 3} + )"; + + frontend.markDirty("game/A"); + frontend.markDirty("game/B"); + + result = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(result); + + TypeId tyA = requireType("game/A", "me"); + CHECK_EQ(toString(tyA), "any"); + + result = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(result); + + TypeId tyB = requireType("game/B", "me"); + CHECK_EQ(toString(tyB), "any"); +} + TEST_CASE_FIXTURE(FrontendFixture, "dont_reparse_clean_file_when_linting") { fileResolver.source["Modules/A"] = R"( @@ -911,8 +971,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "typecheck_twice_for_ast_types") TEST_CASE_FIXTURE(FrontendFixture, "imported_table_modification_2") { - ScopedFastFlag sffs("LuauSealExports", true); - frontend.options.retainFullTypeGraphs = false; fileResolver.source["Module/A"] = R"( @@ -971,4 +1029,18 @@ return false; fix.frontend.check("Module/B"); } +TEST_CASE("check_without_builtin_next") +{ + TestFileResolver fileResolver; + TestConfigResolver configResolver; + Frontend frontend(&fileResolver, &configResolver); + + fileResolver.source["Module/A"] = "for k,v in 2 do end"; + fileResolver.source["Module/B"] = "return next"; + + // We don't care about the result. That we haven't crashed is enough. + frontend.check("Module/A"); + frontend.check("Module/B"); +} + TEST_SUITE_END(); diff --git a/tests/JsonEncoder.test.cpp b/tests/JsonEncoder.test.cpp index cb508072..8a263bd2 100644 --- a/tests/JsonEncoder.test.cpp +++ b/tests/JsonEncoder.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Ast.h" #include "Luau/JsonEncoder.h" +#include "Luau/Parser.h" #include "doctest.h" @@ -8,6 +9,46 @@ using namespace Luau; +struct JsonEncoderFixture +{ + Allocator allocator; + AstNameTable names{allocator}; + + ParseResult parse(std::string_view src) + { + ParseOptions opts; + opts.allowDeclarationSyntax = true; + return Parser::parse(src.data(), src.size(), names, allocator, opts); + } + + AstStatBlock* expectParse(std::string_view src) + { + ParseResult res = parse(src); + REQUIRE(res.errors.size() == 0); + return res.root; + } + + AstStat* expectParseStatement(std::string_view src) + { + AstStatBlock* root = expectParse(src); + REQUIRE(1 == root->body.size); + return root->body.data[0]; + } + + AstExpr* expectParseExpr(std::string_view src) + { + std::string s = "a = "; + s.append(src); + AstStatBlock* root = expectParse(s); + + AstStatAssign* statAssign = root->body.data[0]->as(); + REQUIRE(statAssign != nullptr); + REQUIRE(statAssign->values.size == 1); + + return statAssign->values.data[0]; + } +}; + TEST_SUITE_BEGIN("JsonEncoderTests"); TEST_CASE("encode_constants") @@ -50,4 +91,329 @@ TEST_CASE("encode_AstStatBlock") toJson(&block)); } +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_tables") +{ + std::string src = R"( + local x: { + foo: number + } = { + foo = 123, + } + )"; + + AstStatBlock* root = expectParse(src); + std::string json = toJson(root); + + CHECK( + json == + R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); +} + +TEST_CASE("encode_AstExprGroup") +{ + AstExprConstantNumber number{Location{}, 5.0}; + AstExprGroup group{Location{}, &number}; + + std::string json = toJson(&group); + + const std::string expected = + R"({"type":"AstExprGroup","location":"0,0 - 0,0","expr":{"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":5}})"; + + CHECK(json == expected); +} + +TEST_CASE("encode_AstExprGlobal") +{ + AstExprGlobal global{Location{}, AstName{"print"}}; + + std::string json = toJson(&global); + std::string expected = R"({"type":"AstExprGlobal","location":"0,0 - 0,0","global":"print"})"; + + CHECK(json == expected); +} + +TEST_CASE("encode_AstExprLocal") +{ + AstLocal local{AstName{"foo"}, Location{}, nullptr, 0, 0, nullptr}; + AstExprLocal exprLocal{Location{}, &local, false}; + + CHECK(toJson(&exprLocal) == R"({"type":"AstExprLocal","location":"0,0 - 0,0","local":{"type":null,"name":"foo","location":"0,0 - 0,0"}})"); +} + +TEST_CASE("encode_AstExprVarargs") +{ + AstExprVarargs varargs{Location{}}; + + CHECK(toJson(&varargs) == R"({"type":"AstExprVarargs","location":"0,0 - 0,0"})"); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprCall") +{ + AstExpr* expr = expectParseExpr("foo(1, 2, 3)"); + std::string_view expected = + R"({"type":"AstExprCall","location":"0,4 - 0,16","func":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"args":[{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},{"type":"AstExprConstantNumber","location":"0,11 - 0,12","value":2},{"type":"AstExprConstantNumber","location":"0,14 - 0,15","value":3}],"self":false,"argLocation":"0,8 - 0,16"})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprIndexName") +{ + AstExpr* expr = expectParseExpr("foo.bar"); + + std::string_view expected = + R"({"type":"AstExprIndexName","location":"0,4 - 0,11","expr":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"index":"bar","indexLocation":"0,8 - 0,11","op":"."})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprIndexExpr") +{ + AstExpr* expr = expectParseExpr("foo['bar']"); + + std::string_view expected = + R"({"type":"AstExprIndexExpr","location":"0,4 - 0,14","expr":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"index":{"type":"AstExprConstantString","location":"0,8 - 0,13","value":"bar"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprFunction") +{ + AstExpr* expr = expectParseExpr("function (a) return a end"); + + std::string_view expected = + R"({"type":"AstExprFunction","location":"0,4 - 0,29","generics":[],"genericPacks":[],"args":[{"type":null,"name":"a","location":"0,14 - 0,15"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,16 - 0,26","body":[{"type":"AstStatReturn","location":"0,17 - 0,25","list":[{"type":"AstExprLocal","location":"0,24 - 0,25","local":{"type":null,"name":"a","location":"0,14 - 0,15"}}]}]},"functionDepth":1,"debugname":"","hasEnd":true})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprTable") +{ + AstExpr* expr = expectParseExpr("{true, key=true, [key2]=true}"); + + std::string_view expected = + R"({"type":"AstExprTable","location":"0,4 - 0,33","items":[{"kind":"item","value":{"type":"AstExprConstantBool","location":"0,5 - 0,9","value":true}},{"kind":"record","key":{"type":"AstExprConstantString","location":"0,11 - 0,14","value":"key"},"value":{"type":"AstExprConstantBool","location":"0,15 - 0,19","value":true}},{"kind":"general","key":{"type":"AstExprGlobal","location":"0,22 - 0,26","global":"key2"},"value":{"type":"AstExprConstantBool","location":"0,28 - 0,32","value":true}}]})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprUnary") +{ + AstExpr* expr = expectParseExpr("-b"); + + std::string_view expected = + R"({"type":"AstExprUnary","location":"0,4 - 0,6","op":"minus","expr":{"type":"AstExprGlobal","location":"0,5 - 0,6","global":"b"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprBinary") +{ + AstExpr* expr = expectParseExpr("b + c"); + + std::string_view expected = + R"({"type":"AstExprBinary","location":"0,4 - 0,9","op":"Add","left":{"type":"AstExprGlobal","location":"0,4 - 0,5","global":"b"},"right":{"type":"AstExprGlobal","location":"0,8 - 0,9","global":"c"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprTypeAssertion") +{ + AstExpr* expr = expectParseExpr("b :: any"); + + std::string_view expected = + R"({"type":"AstExprTypeAssertion","location":"0,4 - 0,12","expr":{"type":"AstExprGlobal","location":"0,4 - 0,5","global":"b"},"annotation":{"type":"AstTypeReference","location":"0,9 - 0,12","name":"any","parameters":[]}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprError") +{ + std::string_view src = "a = "; + ParseResult parseResult = Parser::parse(src.data(), src.size(), names, allocator); + + REQUIRE(1 == parseResult.root->body.size); + + AstStatAssign* statAssign = parseResult.root->body.data[0]->as(); + REQUIRE(statAssign != nullptr); + REQUIRE(1 == statAssign->values.size); + + AstExpr* expr = statAssign->values.data[0]; + + std::string_view expected = R"({"type":"AstExprError","location":"0,4 - 0,4","expressions":[],"messageIndex":0})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatIf") +{ + AstStat* statement = expectParseStatement("if true then else end"); + + std::string_view expected = + R"({"type":"AstStatIf","location":"0,0 - 0,21","condition":{"type":"AstExprConstantBool","location":"0,3 - 0,7","value":true},"thenbody":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"elsebody":{"type":"AstStatBlock","location":"0,17 - 0,18","body":[]},"hasThen":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatWhile") +{ + AstStat* statement = expectParseStatement("while true do end"); + + std::string_view expected = + R"({"type":"AtStatWhile","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatRepeat") +{ + AstStat* statement = expectParseStatement("repeat until true"); + + std::string_view expected = + R"({"type":"AstStatRepeat","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,13 - 0,17","value":true},"body":{"type":"AstStatBlock","location":"0,6 - 0,7","body":[]},"hasUntil":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatBreak") +{ + AstStat* statement = expectParseStatement("while true do break end"); + + std::string_view expected = + R"({"type":"AtStatWhile","location":"0,0 - 0,23","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,20","body":[{"type":"AstStatBreak","location":"0,14 - 0,19"}]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatContinue") +{ + AstStat* statement = expectParseStatement("while true do continue end"); + + std::string_view expected = + R"({"type":"AtStatWhile","location":"0,0 - 0,26","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,23","body":[{"type":"AstStatContinue","location":"0,14 - 0,22"}]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatFor") +{ + AstStat* statement = expectParseStatement("for a=0,1 do end"); + + std::string_view expected = + R"({"type":"AstStatFor","location":"0,0 - 0,16","var":{"type":null,"name":"a","location":"0,4 - 0,5"},"from":{"type":"AstExprConstantNumber","location":"0,6 - 0,7","value":0},"to":{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},"body":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatForIn") +{ + AstStat* statement = expectParseStatement("for a in b do end"); + + std::string_view expected = + R"({"type":"AstStatForIn","location":"0,0 - 0,17","vars":[{"type":null,"name":"a","location":"0,4 - 0,5"}],"values":[{"type":"AstExprGlobal","location":"0,9 - 0,10","global":"b"}],"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasIn":true,"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatCompoundAssign") +{ + AstStat* statement = expectParseStatement("a += b"); + + std::string_view expected = + R"({"type":"AstStatCompoundAssign","location":"0,0 - 0,6","op":"Add","var":{"type":"AstExprGlobal","location":"0,0 - 0,1","global":"a"},"value":{"type":"AstExprGlobal","location":"0,5 - 0,6","global":"b"}})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatLocalFunction") +{ + AstStat* statement = expectParseStatement("local function a(b) return end"); + + std::string_view expected = + R"({"type":"AstStatLocalFunction","location":"0,0 - 0,30","name":{"type":null,"name":"a","location":"0,15 - 0,16"},"func":{"type":"AstExprFunction","location":"0,0 - 0,30","generics":[],"genericPacks":[],"args":[{"type":null,"name":"b","location":"0,17 - 0,18"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,19 - 0,27","body":[{"type":"AstStatReturn","location":"0,20 - 0,26","list":[]}]},"functionDepth":1,"debugname":"a","hasEnd":true}})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatTypeAlias") +{ + AstStat* statement = expectParseStatement("type A = B"); + + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,10","name":"A","generics":[],"genericPacks":[],"type":{"type":"AstTypeReference","location":"0,9 - 0,10","name":"B","parameters":[]},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction") +{ + AstStat* statement = expectParseStatement("declare function foo(x: number): string"); + + std::string_view expected = + R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","params":{"types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","parameters":[]}]},"retTypes":{"types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","parameters":[]}]},"generics":[],"genericPacks":[]})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") +{ + AstStatBlock* root = expectParse(R"( + declare class Foo + prop: number + function method(self, foo: number): string + end + + declare class Bar extends Foo + prop2: string + end + )"); + + REQUIRE(2 == root->body.size); + + std::string_view expected1 = + R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","parameters":[]}},{"name":"method","type":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","parameters":[]}]},"returnTypes":{"types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","parameters":[]}]}}}]})"; + CHECK(toJson(root->body.data[0]) == expected1); + + std::string_view expected2 = + R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","parameters":[]}}]})"; + CHECK(toJson(root->body.data[1]) == expected2); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_annotation") +{ + AstStat* statement = expectParseStatement("type T = ((number) -> (string | nil)) & ((string) -> ())"); + + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,35","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","parameters":[]}]},"returnTypes":{"types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","parameters":[]}]},"returnTypes":{"types":[]}}]},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypeError") +{ + ParseResult parseResult = parse("type T = "); + REQUIRE(1 == parseResult.root->body.size); + + AstStat* statement = parseResult.root->body.data[0]; + + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,9","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeError","location":"0,8 - 0,9","types":[],"messageIndex":0},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypePackExplicit") +{ + AstStatBlock* root = expectParse(R"( + type A = () -> T... + local a: A<(number, string)> + )"); + + CHECK(2 == root->body.size); + + std::string_view expected = + R"({"type":"AstStatLocal","location":"2,8 - 2,36","vars":[{"type":{"type":"AstTypeReference","location":"2,17 - 2,36","name":"A","parameters":[{"type":"AstTypePackExplicit","location":"2,19 - 2,20","typeList":{"types":[{"type":"AstTypeReference","location":"2,20 - 2,26","name":"number","parameters":[]},{"type":"AstTypeReference","location":"2,28 - 2,34","name":"string","parameters":[]}]}}]},"name":"a","location":"2,14 - 2,15"}],"values":[]})"; + + CHECK(toJson(root->body.data[1]) == expected); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 9ce9a4c2..202aeceb 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -75,7 +75,7 @@ _ = 6 CHECK_EQ(result.warnings.size(), 0); } -TEST_CASE_FIXTURE(Fixture, "BuiltinGlobalWrite") +TEST_CASE_FIXTURE(BuiltinsFixture, "BuiltinGlobalWrite") { LintResult result = lint(R"( math = {} @@ -309,7 +309,7 @@ print(arg) CHECK_EQ(result.warnings[0].text, "Variable 'arg' shadows previous declaration at line 2"); } -TEST_CASE_FIXTURE(Fixture, "LocalShadowGlobal") +TEST_CASE_FIXTURE(BuiltinsFixture, "LocalShadowGlobal") { LintResult result = lint(R"( local math = math @@ -597,8 +597,6 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { - ScopedFastFlag sff("LuauLintNoRobloxBits", true); - unfreeze(typeChecker.globalTypes); TableTypeVar::Props instanceProps{ {"ClassName", {typeChecker.anyType}}, @@ -1438,7 +1436,8 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") { unfreeze(typeChecker.globalTypes); - TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}}); + TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); + persist(instanceType); typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; getMutable(instanceType)->props = { @@ -1471,7 +1470,7 @@ end CHECK_EQ(result.warnings[2].text, "Member 'Instance.DataCost' is deprecated"); } -TEST_CASE_FIXTURE(Fixture, "TableOperations") +TEST_CASE_FIXTURE(BuiltinsFixture, "TableOperations") { LintResult result = lintTyped(R"( local t = {} diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 82b7a350..7c2f4d1c 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Clone.h" #include "Luau/Module.h" #include "Luau/Scope.h" +#include "Luau/RecursionCounter.h" #include "Fixture.h" @@ -8,6 +10,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("ModuleTests"); TEST_CASE_FIXTURE(Fixture, "is_within_comment") @@ -41,29 +45,23 @@ TEST_CASE_FIXTURE(Fixture, "is_within_comment") TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; // numberType is persistent. We leave it as-is. - TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks, cloneState); + TypeId newNumber = clone(typeChecker.numberType, dest, cloneState); CHECK_EQ(newNumber, typeChecker.numberType); } TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; // Create a new number type that isn't persistent unfreeze(typeChecker.globalTypes); TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveTypeVar{PrimitiveTypeVar::Number}); freeze(typeChecker.globalTypes); - TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks, cloneState); + TypeId newNumber = clone(oldNumber, dest, cloneState); CHECK_NE(newNumber, oldNumber); CHECK_EQ(*oldNumber, *newNumber); @@ -89,12 +87,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") TypeId counterType = requireType("Cyclic"); - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - TypeArena dest; - TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks, cloneState); + CloneState cloneState; + TypeId counterCopy = clone(counterType, dest, cloneState); TableTypeVar* ttv = getMutable(counterCopy); REQUIRE(ttv != nullptr); @@ -107,15 +102,18 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") const FunctionTypeVar* ftv = get(methodType); REQUIRE(ftv != nullptr); - std::optional methodReturnType = first(ftv->retType); + std::optional methodReturnType = first(ftv->retTypes); REQUIRE(methodReturnType); CHECK_EQ(methodReturnType, counterCopy); - CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type - CHECK_EQ(2, dest.typeVars.size()); // One table and one function + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(3, dest.typePacks.size()); // function args, its return type, and the hidden any... pack + else + CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type + CHECK_EQ(2, dest.typeVars.size()); // One table and one function } -TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") { CheckResult result = check(R"( return {sign=math.sign} @@ -142,15 +140,12 @@ TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") TEST_CASE_FIXTURE(Fixture, "deepClone_union") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldUnion = typeChecker.globalTypes.addType(UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks, cloneState); + TypeId newUnion = clone(oldUnion, dest, cloneState); CHECK_NE(newUnion, oldUnion); CHECK_EQ("number | string", toString(newUnion)); @@ -160,15 +155,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks, cloneState); + TypeId newIntersection = clone(oldIntersection, dest, cloneState); CHECK_NE(newIntersection, oldIntersection); CHECK_EQ("number & string", toString(newIntersection)); @@ -181,21 +173,18 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") { {"__add", {typeChecker.anyType}}, }, - std::nullopt, std::nullopt, {}, {}}}; + std::nullopt, std::nullopt, {}, {}, "Test"}}; TypeVar exampleClass{ClassTypeVar{"ExampleClass", { {"PropOne", {typeChecker.numberType}}, {"PropTwo", {typeChecker.stringType}}, }, - std::nullopt, &exampleMetaClass, {}, {}}}; + std::nullopt, &exampleMetaClass, {}, {}, "Test"}}; TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks, cloneState); + TypeId cloned = clone(&exampleClass, dest, cloneState); const ClassTypeVar* ctv = get(cloned); REQUIRE(ctv != nullptr); @@ -207,49 +196,56 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") CHECK_EQ("ExampleClassMeta", metatable->name); } -TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") +TEST_CASE_FIXTURE(Fixture, "clone_free_types") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, cloneState); - CHECK_EQ("any", toString(clonedTy)); - CHECK(cloneState.encounteredFreeType); + TypeId clonedTy = clone(&freeTy, dest, cloneState); + CHECK(get(clonedTy)); cloneState = {}; - TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, cloneState); - CHECK_EQ("...any", toString(clonedTp)); - CHECK(cloneState.encounteredFreeType); + TypePackId clonedTp = clone(&freeTp, dest, cloneState); + CHECK(get(clonedTp)); } -TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") +TEST_CASE_FIXTURE(Fixture, "clone_free_tables") { TypeVar tableTy{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tableTy); ttv->state = TableState::Free; TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, cloneState); + TypeId cloned = clone(&tableTy, dest, cloneState); const TableTypeVar* clonedTtv = get(cloned); - CHECK_EQ(clonedTtv->state, TableState::Sealed); - CHECK(cloneState.encounteredFreeType); + CHECK_EQ(clonedTtv->state, TableState::Free); } -TEST_CASE_FIXTURE(Fixture, "clone_self_property") +TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; + TypeArena src; + TypeId constrained = src.addType(ConstrainedTypeVar{TypeLevel{}, {getSingletonTypes().numberType, getSingletonTypes().stringType}}); + + TypeArena dest; + CloneState cloneState; + + TypeId cloned = clone(constrained, dest, cloneState); + CHECK_NE(constrained, cloned); + + const ConstrainedTypeVar* ctv = get(cloned); + REQUIRE_EQ(2, ctv->parts.size()); + CHECK_EQ(getSingletonTypes().numberType, ctv->parts[0]); + CHECK_EQ(getSingletonTypes().stringType, ctv->parts[1]); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") +{ fileResolver.source["Module/A"] = R"( --!nonstrict local a = {} @@ -298,11 +294,29 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") } TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - CHECK_THROWS_AS(clone(table, dest, seenTypes, seenTypePacks, cloneState), std::runtime_error); + CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); +} + +TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + fileResolver.source["Module/A"] = R"( +export type A = B +type B = A + )"; + + FrontendOptions opts; + opts.retainFullTypeGraphs = false; + CheckResult result = frontend.check("Module/A", opts); + LUAU_REQUIRE_ERRORS(result); + + auto mod = frontend.moduleResolver.getModule("Module/A"); + auto it = mod->getModuleScope()->exportedTypeBindings.find("A"); + REQUIRE(it != mod->getModuleScope()->exportedTypeBindings.end()); + CHECK(toString(it->second.type) == "any"); } TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index d3faea2a..50dcbad0 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -13,6 +13,77 @@ using namespace Luau; TEST_SUITE_BEGIN("NonstrictModeTests"); +TEST_CASE_FIXTURE(Fixture, "globals") +{ + CheckResult result = check(R"( + --!nonstrict + foo = true + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals2") +{ + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + --!nonstrict + foo = function() return 1 end + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("() -> number", toString(tm->wantedType)); + CHECK_EQ("string", toString(tm->givenType)); + CHECK_EQ("() -> number", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals_everywhere") +{ + CheckResult result = check(R"( + --!nonstrict + foo = 1 + + if true then + bar = 2 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_number_or_string") +{ + ScopedFastFlag sff[]{{"LuauReturnTypeInferenceInNonstrict", true}, {"LuauLowerBoundsCalculation", true}}; + + CheckResult result = check(R"( + --!nonstrict + local function f() + if math.random() > 0.5 then + return 5 + else + return "hi" + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("() -> number | string" == toString(requireType("f"))); +} + TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") { CheckResult result = check(R"( @@ -31,12 +102,17 @@ TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") REQUIRE_EQ("any", toString(args[0])); REQUIRE_EQ("any", toString(args[1])); - auto rets = flatten(ftv->retType).first; + auto rets = flatten(ftv->retTypes).first; REQUIRE_EQ(0, rets.size()); } -TEST_CASE_FIXTURE(Fixture, "infer_the_maximum_number_of_values_the_function_could_return") +TEST_CASE_FIXTURE(Fixture, "first_return_type_dictates_number_of_return_types") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict function getMinCardCountForWidth(width) @@ -51,22 +127,18 @@ TEST_CASE_FIXTURE(Fixture, "infer_the_maximum_number_of_values_the_function_coul TypeId t = requireType("getMinCardCountForWidth"); REQUIRE(t); - REQUIRE_EQ("(any) -> (...any)", toString(t)); + REQUIRE_EQ("(any) -> number", toString(t)); } -#if 0 -// Maybe we want this? TEST_CASE_FIXTURE(Fixture, "return_annotation_is_still_checked") { CheckResult result = check(R"( + --!nonstrict function foo(x): number return 'hello' end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE_NE(*typeChecker.anyType, *requireType("foo")); } -#endif TEST_CASE_FIXTURE(Fixture, "function_parameters_are_any") { @@ -126,8 +198,6 @@ TEST_CASE_FIXTURE(Fixture, "parameters_having_type_any_are_optional") TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; - CheckResult result = check(R"( --!nonstrict local T = {} @@ -145,8 +215,6 @@ TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; - CheckResult result = check(R"( --!nonstrict local T = {} @@ -200,7 +268,7 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") CHECK_MESSAGE(get(ttv->props["three"].type), "Should be a function: " << *ttv->props["three"].type); } -TEST_CASE_FIXTURE(Fixture, "for_in_iterator_variables_are_any") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_iterator_variables_are_any") { CheckResult result = check(R"( --!nonstrict @@ -219,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_iterator_variables_are_any") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_dot_insert_and_recursive_calls") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_insert_and_recursive_calls") { CheckResult result = check(R"( --!nonstrict @@ -256,6 +324,11 @@ TEST_CASE_FIXTURE(Fixture, "delay_function_does_not_require_its_argument_to_retu TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict @@ -272,7 +345,41 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("any", toString(getMainModule()->getModuleScope()->returnType)); + REQUIRE_EQ("((any) -> string) | {| foo: any |}", toString(getMainModule()->getModuleScope()->returnType)); +} + +TEST_CASE_FIXTURE(Fixture, "returning_insufficient_return_values") +{ + CheckResult result = check(R"( + --!nonstrict + + function foo(): (boolean, string?) + if true then + return true, "hello" + else + return false + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "returning_too_many_values") +{ + CheckResult result = check(R"( + --!nonstrict + + function foo(): boolean + if true then + return true, "hello" + else + return false + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp new file mode 100644 index 00000000..a474b6e7 --- /dev/null +++ b/tests/Normalize.test.cpp @@ -0,0 +1,1092 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +#include "Luau/Normalize.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +struct NormalizeFixture : Fixture +{ + ScopedFastFlag sff1{"LuauLowerBoundsCalculation", true}; +}; + +void createSomeClasses(TypeChecker& typeChecker) +{ + auto& arena = typeChecker.globalTypes; + + unfreeze(arena); + + TypeId parentType = arena.addType(ClassTypeVar{"Parent", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); + + ClassTypeVar* parentClass = getMutable(parentType); + parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; + + parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; + + addGlobalBinding(typeChecker, "Parent", {parentType}); + typeChecker.globalScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; + + TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); + + ClassTypeVar* childClass = getMutable(childType); + childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; + + addGlobalBinding(typeChecker, "Child", {childType}); + typeChecker.globalScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; + + TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); + + addGlobalBinding(typeChecker, "Unrelated", {unrelatedType}); + typeChecker.globalScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; + + for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + persist(ty.type); + + freeze(arena); +} + +static bool isSubtype(TypeId a, TypeId b) +{ + InternalErrorReporter ice; + return isSubtype(a, b, ice); +} + +TEST_SUITE_BEGIN("isSubtype"); + +TEST_CASE_FIXTURE(NormalizeFixture, "primitives") +{ + check(R"( + local a = 41 + local b = 32 + + local c = "hello" + local d = "world" + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(isSubtype(d, c)); + CHECK(!isSubtype(d, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions") +{ + check(R"( + function a(x: number): number return x end + function b(x: number): number return x end + + function c(x: number?): number return x end + function d(x: number): number? return x end + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(isSubtype(c, a)); + CHECK(!isSubtype(d, a)); + CHECK(isSubtype(a, d)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_and_any") +{ + check(R"( + function a(n: number) return "string" end + function b(q: any) return 5 :: any end + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + // Intuition: + // We cannot use b where a is required because we cannot rely on b to return a string. + // We cannot use a where b is required because we cannot rely on a to accept non-number arguments. + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(a, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arities") +{ + check(R"( + type A = (any) -> () + type B = (any, any) -> () + type T = A & B + + local a: A + local b: B + local t: T + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(a, b)); // !! + CHECK(!isSubtype(b, a)); + + CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") +{ + check(R"( + local a: (number) -> () + local b: () -> () + + local c: () -> number + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(c, a)); + + CHECK(!isSubtype(a, b)); + CHECK(!isSubtype(c, b)); + + CHECK(!isSubtype(a, c)); + CHECK(!isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_optional_parameters") +{ + /* + * (T0..TN) <: (T0..TN, A?) + * (T0..TN) <: (T0..TN, any) + * (T0..TN, A?) R <: U -> S if U <: T and R <: S + * A | B <: T if A <: T and B <: T + * T <: A | B if T <: A or T <: B + */ + check(R"( + local a: (number?) -> () + local b: (number) -> () + local c: (number, number?) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, number?) -> () <: (number) -> (number) + * The packs have inequal lengths, but (number) <: (number, number?) + * and number <: number + */ + CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * because (number, number?) () () + * because (number, number?) () + local b: (number) -> () + local c: (number, any) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, any) -> () (number) + * The packs have inequal lengths + */ + CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * The packs have inequal lengths + */ + CHECK(!isSubtype(a, c)); + + /* + * (number) -> () () + * The packs have inequal lengths + */ + CHECK(!isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "variadic_functions_with_no_head") +{ + check(R"( + local a: (...number) -> () + local b: (...number?) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "variadic_function_with_head") +{ + check(R"( + local a: (...number) -> () + local b: (number, number) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "union") +{ + check(R"( + local a: number | string + local b: number + local c: string + local d: number? + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(!isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(b, d)); + CHECK(!isSubtype(d, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_union_prop") +{ + check(R"( + local a: {x: number} + local b: {x: number?} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_any_prop") +{ + check(R"( + local a: {x: number} + local b: {x: any} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection") +{ + check(R"( + local a: number & string + local b: number + local c: string + local d: number & nil + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); + + CHECK(!isSubtype(c, a)); + CHECK(isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_and_intersection") +{ + check(R"( + local a: number & string + local b: number | nil + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_table_prop") +{ + check(R"( + type T = {x: {y: number}} & {x: {y: string}} + local a: T + )"); + + CHECK_EQ("{| x: {| y: number & string |} |}", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "tables") +{ + check(R"( + local a: {x: number} + local b: {x: any} + local c: {y: number} + local d: {x: number, y: number} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); + + CHECK(!isSubtype(c, a)); + CHECK(!isSubtype(a, c)); + + CHECK(isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(d, b)); + CHECK(!isSubtype(b, d)); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "table_indexers_are_invariant") +{ + check(R"( + local a: {[string]: number} + local b: {[string]: any} + local c: {[string]: number} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(isSubtype(a, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "mismatched_indexers") +{ + check(R"( + local a: {x: number} + local b: {[string]: number} + local c: {} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(!isSubtype(c, b)); + CHECK(isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_table") +{ + check(R"( + type A = {method: (A) -> ()} + local a: A + + type B = {method: (any) -> ()} + local b: B + + type C = {method: (C) -> ()} + local c: C + + type D = {method: (D) -> (), another: (D) -> ()} + local d: D + + type E = {method: (A) -> (), another: (E) -> ()} + local e: E + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + TypeId e = requireType("e"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(e, a)); + CHECK(!isSubtype(a, e)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "classes") +{ + createSomeClasses(typeChecker); + + TypeId p = typeChecker.globalScope->lookupType("Parent")->type; + TypeId c = typeChecker.globalScope->lookupType("Child")->type; + TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type; + + CHECK(isSubtype(c, p)); + CHECK(!isSubtype(p, c)); + CHECK(!isSubtype(u, p)); + CHECK(!isSubtype(p, u)); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "metatable" * doctest::expected_failures{1}) +{ + check(R"( + local T = {} + T.__index = T + function T.new() + return setmetatable({}, T) + end + + function T:method() end + + local a: typeof(T.new) + local b: {method: (any) -> ()} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_tables") +{ + check(R"( + type T = {x: number} & ({x: number} & {y: string?}) + local t: T + )"); + + CHECK("{| x: number, y: string? |}" == toString(requireType("t"))); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("Normalize"); + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_disjoint_tables") +{ + check(R"( + type T = {a: number} & {b: number} + local t: T + )"); + + CHECK_EQ("{| a: number, b: number |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_overlapping_tables") +{ + check(R"( + type T = {a: number, b: string} & {b: number, c: string} + local t: T + )"); + + CHECK_EQ("{| a: number, b: number & string, c: string |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_confluent_overlapping_tables") +{ + check(R"( + type T = {a: number, b: string} & {b: string, c: string} + local t: T + )"); + + CHECK_EQ("{| a: number, b: string, c: string |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_subtype_relationship") +{ + check(R"( + local t: {x: number} | {x: number?} + )"); + + ModulePtr tempModule{new Module}; + + // HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze + // the arena that the type lives in. + ModulePtr mainModule = getMainModule(); + unfreeze(mainModule->internalTypes); + + TypeId tType = requireType("t"); + normalize(tType, tempModule, *typeChecker.iceHandler); + + CHECK_EQ("{| x: number? |}", toString(tType, {true})); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions") +{ + check(R"( + type T = ((any) -> string) & ((number) -> string) + local t: T + )"); + + CHECK_EQ("(any) -> string", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "normalize_module_return_type") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauReturnTypeInferenceInNonstrict", true}, + }; + + check(R"( + --!nonstrict + + if Math.random() then + return function(initialState, handlers) + return function(state, action) + return state + end + end + else + return function(initialState, handlers) + return function(state, action) + return state + end + end + end + )"); + + CHECK_EQ("(any, any) -> (any, any) -> any", toString(getMainModule()->getModuleScope()->returnType)); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_is_not_a_constrained_intersection") +{ + check(R"( + function foo(x:number, y:number) + return x + y + end + )"); + + CHECK_EQ("(number, number) -> number", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function") +{ + check(R"( + function apply(f, x) + return f(x) + end + + local a = apply(function(x: number) return x + x end, 5) + )"); + + TypeId aType = requireType("a"); + CHECK_MESSAGE(isNumber(follow(aType)), "Expected a number but got ", toString(aType)); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_with_annotation") +{ + check(R"( + function apply(f: (a) -> b, x) + return f(x) + end + )"); + + CHECK_EQ("((a) -> b, a) -> b", toString(requireType("apply"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_marked_normal") +{ + ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", true}, {"LuauNormalizeFlagIsConservative", false}}; + + check(R"( + type Fiber = { + return_: Fiber? + } + + local f: Fiber + )"); + + TypeId t = requireType("f"); + CHECK(t->normal); +} + +// Unfortunately, getting this right in the general case is difficult. +TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_not_marked_normal") +{ + ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", true}, {"LuauNormalizeFlagIsConservative", true}}; + + check(R"( + type Fiber = { + return_: Fiber? + } + + local f: Fiber + )"); + + TypeId t = requireType("f"); + CHECK(!t->normal); +} + +TEST_CASE_FIXTURE(Fixture, "variadic_tail_is_marked_normal") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + type Weirdo = (...{x: number}) -> () + + local w: Weirdo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("w"); + auto ftv = get(t); + REQUIRE(ftv); + + auto [argHead, argTail] = flatten(ftv->argTypes); + CHECK(argHead.empty()); + REQUIRE(argTail.has_value()); + + auto vtp = get(*argTail); + REQUIRE(vtp); + CHECK(vtp->ty->normal); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_table_normalizes_sensibly") +{ + CheckResult result = check(R"( + local Cyclic = {} + function Cyclic.get() + return Cyclic + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = requireType("Cyclic"); + CHECK_EQ("t1 where t1 = { get: () -> t1 }", toString(ty, {true})); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "union_of_distinct_free_types") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function fussy(a, b) + if math.random() > 0.5 then + return a + else + return b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("(a, b) -> a | b" == toString(requireType("fussy"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "constrained_intersection_of_intersections") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + local f : (() -> number) | ((number) -> number) + local g : (() -> number) | ((string) -> number) + + function h() + if math.random() then + return f + else + return g + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId h = requireType("h"); + + CHECK("() -> (() -> number) | ((number) -> number) | ((string) -> number)" == toString(h)); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + type X = {} + type Y = {y: number} + type Z = {z: string} + type W = {w: boolean} + type T = {x: Y & X} & {x:Z & W} + + local x: X + local y: Y + local z: Z + local w: W + local t: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("{| |}" == toString(requireType("x"), {true})); + CHECK("{| y: number |}" == toString(requireType("y"), {true})); + CHECK("{| z: string |}" == toString(requireType("z"), {true})); + CHECK("{| w: boolean |}" == toString(requireType("w"), {true})); + CHECK("{| x: {| w: boolean, y: number, z: string |} |}" == toString(requireType("t"), {true})); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_2") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(w, x, y, z) + y.y = 5 + z.z = "five" + w.w = true + + type Z = {x: typeof(x) & typeof(y)} & {x: typeof(w) & typeof(z)} + + return ((nil :: any) :: Z) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(4 == args.size()); + CHECK("{+ w: boolean +}" == toString(args[0])); + CHECK("a" == toString(args[1])); + CHECK("{+ y: number +}" == toString(args[2])); + CHECK("{+ z: string +}" == toString(args[3])); + + std::vector ret = flatten(ftv->retTypes).first; + + REQUIRE(1 == ret.size()); + CHECK("{| x: a & {+ w: boolean, y: number, z: string +} |}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_3") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y, z) + x.x = true + y.y = y + z.z = "five" + + type Z = {x: typeof(y)} & {x: typeof(x) & typeof(z)} + + return ((nil :: any) :: Z) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(3 == args.size()); + CHECK("{+ x: boolean +}" == toString(args[0])); + CHECK("t1 where t1 = {+ y: t1 +}" == toString(args[1])); + CHECK("{+ z: string +}" == toString(args[2])); + + std::vector ret = flatten(ftv->retTypes).first; + + REQUIRE(1 == ret.size()); + CHECK("{| x: {+ x: boolean, y: t1, z: string +} |} where t1 = {+ y: t1 +}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_4") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y, z) + x.x = true + z.z = "five" + + type R = {x: typeof(y)} & {x: typeof(x) & typeof(z)} + local r: R + + y.y = r + + return r + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(3 == args.size()); + CHECK("{+ x: boolean +}" == toString(args[0])); + CHECK("{+ y: t1 +} where t1 = {| x: {+ x: boolean, y: t1, z: string +} |}" == toString(args[1])); + CHECK("{+ z: string +}" == toString(args[2])); + + std::vector ret = flatten(ftv->retTypes).first; + + REQUIRE(1 == ret.size()); + CHECK("t1 where t1 = {| x: {+ x: boolean, y: t1, z: string +} |}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeCombineTableFix", true}, + }; + // CLI-52787 + // ends up combining {_:any} with any, recursively + // which used to ICE because this combines a table with a non-table. + CheckResult result = check(R"( + export type t0 = any & { _: {_:any} } & { _:any } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "visiting_a_type_twice_is_not_considered_normal") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + --!strict + function f(a, b) + local function g() + if math.random() > 0.5 then + return a() + else + return b + end + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(() -> a, a) -> ()", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + export type t0 = {_:{_:any} & {_:any|string}} & {_:{_:{}}} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_failure_bound_type_is_normal_but_not_its_bounded_to") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type t252 = ((t0)|(any))|(any) + type t0 = t252,t24...> + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +// We had an issue where a normal BoundTypeVar might point at a non-normal BoundTypeVar if it in turn pointed to a +// normal TypeVar because we were calling follow() in an improper place. +TEST_CASE_FIXTURE(Fixture, "bound_typevars_should_only_be_marked_normal_if_their_pointee_is_normal") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeFlagIsConservative", true}, + }; + + CheckResult result = check(R"( + local T = {} + + function T:M() + local function f(a) + print(self.prop) + self:g(a) + self.prop = a + end + end + + return T + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "skip_force_normal_on_external_types") +{ + createSomeClasses(typeChecker); + + CheckResult result = check(R"( +export type t0 = { a: Child } +export type t1 = { a: typeof(string.byte) } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_combine_on_bound_self") +{ + ScopedFastFlag luauNormalizeCombineEqFix{"LuauNormalizeCombineEqFix", true}; + + CheckResult result = check(R"( +export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,})) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_does_not_convert_ever") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + --!strict + local function f() + if math.random() > 0.5 then + return true + end + type Ret = typeof(f()) + if math.random() > 0.5 then + return "something" + end + return "something" :: Ret + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("() -> boolean | string", toString(requireType("f"))); +} + +TEST_SUITE_END(); diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp new file mode 100644 index 00000000..ed1c25ec --- /dev/null +++ b/tests/NotNull.test.cpp @@ -0,0 +1,157 @@ +#include "Luau/NotNull.h" + +#include "doctest.h" + +#include +#include +#include + +using Luau::NotNull; + +namespace +{ + +struct Test +{ + int x; + float y; + + static int count; + Test() + { + ++count; + } + + ~Test() + { + --count; + } +}; + +int Test::count = 0; + +} + +int foo(NotNull p) +{ + return *p; +} + +void bar(int* q) +{} + +TEST_SUITE_BEGIN("NotNull"); + +TEST_CASE("basic_stuff") +{ + NotNull a = NotNull{new int(55)}; // Does runtime test + NotNull b{new int(55)}; // As above + // NotNull c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull in the general case is not good. + + // a = nullptr; // nope + + NotNull d = a; // No runtime test. a is known not to be null. + + int e = *d; + *d = 1; + CHECK(e == 55); + + const NotNull f = d; + *f = 5; // valid: there is a difference between const NotNull and NotNull + // f = a; // nope + + CHECK_EQ(a, d); + CHECK(a != b); + + NotNull g(a); + CHECK(g == a); + + // *g = 123; // nope + + (void)f; + + NotNull t{new Test}; + t->x = 5; + t->y = 3.14f; + + const NotNull u = t; + u->x = 44; + int v = u->x; + CHECK(v == 44); + + bar(a); + + // a++; // nope + // a[41]; // nope + // a + 41; // nope + // a - 41; // nope + + delete a; + delete b; + delete t; + + CHECK_EQ(0, Test::count); +} + +TEST_CASE("hashable") +{ + std::unordered_map, const char*> map; + int a_ = 8; + int b_ = 10; + + NotNull a{&a_}; + NotNull b{&b_}; + + std::string hello = "hello"; + std::string world = "world"; + + map[a] = hello.c_str(); + map[b] = world.c_str(); + + CHECK_EQ(2, map.size()); + CHECK_EQ(hello.c_str(), map[a]); + CHECK_EQ(world.c_str(), map[b]); +} + +TEST_CASE("const") +{ + int p = 0; + int q = 0; + + NotNull n{&p}; + + *n = 123; + + NotNull m = n; // Conversion from NotNull to NotNull is allowed + + CHECK(123 == *m); // readonly access of m is ok + + // *m = 321; // nope. m points at const data. + + // NotNull o = m; // nope. Conversion from NotNull to NotNull is forbidden + + NotNull n2{&q}; + m = n2; // ok. m points to const data, but is not itself const + + const NotNull m2 = n; + // m2 = n2; // nope. m2 is const. + *m2 = 321; // ok. m2 is const, but points to mutable data + + CHECK(321 == *n); +} + +TEST_CASE("const_compatibility") +{ + int* raw = new int(8); + + NotNull a(raw); + NotNull b(raw); + NotNull c = a; + // NotNull d = c; // nope - no conversion from const to non-const + + CHECK_EQ(*c, 8); + + delete raw; +} + +TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 7f6a6c0d..878023e3 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1604,6 +1604,35 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_of_functions_unions_and_intersections") CHECK_EQ((Position{3, 42}), block->body.data[2]->location.end); } +TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") +{ + AstStatBlock* block = parse(R"( + type F = number + --comment + print('hello') + )"); + + REQUIRE_EQ(2, block->body.size); + CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); +} + +TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments_even_with_capture") +{ + // Same should hold when comments are captured + ParseOptions opts; + opts.captureComments = true; + + AstStatBlock* block = parse(R"( + type F = number + --comment + print('hello') + )", + opts); + + REQUIRE_EQ(2, block->body.size); + CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); +} + TEST_CASE_FIXTURE(Fixture, "parse_error_loop_control") { matchParseError("break", "break statement must be inside a loop"); @@ -2008,6 +2037,13 @@ TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); } +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_errors") +{ + matchParseError("type Y = {a: T..., b: number}", "Unexpected '...' after type name; type pack is not allowed in this context", + Location{{0, 20}, {0, 23}}); + matchParseError("type Y = {a: (number | string)...", "Unexpected '...' after type annotation", Location{{0, 36}, {0, 39}}); +} + TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") { { @@ -2576,4 +2612,42 @@ type Y = (T...) -> U... CHECK_EQ(1, result.errors.size()); } +TEST_CASE_FIXTURE(Fixture, "recover_unexpected_type_pack") +{ + ParseResult result = tryParse(R"( +type X = { a: T..., b: number } +type Y = { a: T..., b: number } +type Z = { a: string | T..., b: number } + )"); + REQUIRE_EQ(3, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "recover_function_return_type_annotations") +{ + ScopedFastFlag sff{"LuauReturnTypeTokenConfusion", true}; + ParseResult result = tryParse(R"( +type Custom = { x: A, y: B, z: C } +type Packed = { x: (A...) -> () } +type F = (number): Custom +type G = Packed<(number): (string, number, boolean)> +local function f(x: number) -> Custom +end + )"); + REQUIRE_EQ(3, result.errors.size()); + CHECK_EQ(result.errors[0].getMessage(), "Return types in function type annotations are written after '->' instead of ':'"); + CHECK_EQ(result.errors[1].getMessage(), "Return types in function type annotations are written after '->' instead of ':'"); + CHECK_EQ(result.errors[2].getMessage(), "Function return type annotations are written after ':' instead of '->'"); +} + +TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation") +{ + ScopedFastFlag sff{"LuauParserFunctionKeywordAsTypeHelp", true}; + ParseResult result = tryParse(R"( + type Foo = function + )"); + REQUIRE_EQ(1, result.errors.size()); + CHECK_EQ("Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> ...any'", + result.errors[0].getMessage()); +} + TEST_SUITE_END(); diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp new file mode 100644 index 00000000..6619147b --- /dev/null +++ b/tests/RuntimeLimits.test.cpp @@ -0,0 +1,276 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +/* Tests in this source file are meant to be a bellwether to verify that the numeric limits we've set are sufficient for + * most real-world scripts. + * + * If a change breaks a test in this source file, please don't adjust the flag values set in the fixture. Instead, + * consider it a latent performance problem by default. + * + * We should periodically revisit this to retest the limits. + */ + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + +struct LimitFixture : BuiltinsFixture +{ +#if defined(_NOOPT) || defined(_DEBUG) + ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; +#endif +}; + +template +bool hasError(const CheckResult& result, T* = nullptr) +{ + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](const TypeError& a) { + return nullptr != get(a); + }); + return it != result.errors.end(); +} + +TEST_SUITE_BEGIN("RuntimeLimits"); + +TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") +{ + constexpr const char* src = R"LUA( + --!strict + + -- Big thanks to Dionysusnu by letting us use this code as part of our test suite! + -- https://github.com/Dionysusnu/rbxts-rust-classes + -- Licensed under the MPL 2.0: https://raw.githubusercontent.com/Dionysusnu/rbxts-rust-classes/master/LICENSE + + local TS = _G[script] + local lazyGet = TS.import(script, script.Parent.Parent, "util", "lazyLoad").lazyGet + local unit = TS.import(script, script.Parent.Parent, "util", "Unit").unit + local Iterator + lazyGet("Iterator", function(c) + Iterator = c + end) + local Option + lazyGet("Option", function(c) + Option = c + end) + local Vec + lazyGet("Vec", function(c) + Vec = c + end) + local Result + do + Result = setmetatable({}, { + __tostring = function() + return "Result" + end, + }) + Result.__index = Result + function Result.new(...) + local self = setmetatable({}, Result) + self:constructor(...) + return self + end + function Result:constructor(okValue, errValue) + self.okValue = okValue + self.errValue = errValue + end + function Result:ok(val) + return Result.new(val, nil) + end + function Result:err(val) + return Result.new(nil, val) + end + function Result:fromCallback(c) + local _0 = c + local _1, _2 = pcall(_0) + local result = _1 and { + success = true, + value = _2, + } or { + success = false, + error = _2, + } + return result.success and Result:ok(result.value) or Result:err(Option:wrap(result.error)) + end + function Result:fromVoidCallback(c) + local _0 = c + local _1, _2 = pcall(_0) + local result = _1 and { + success = true, + value = _2, + } or { + success = false, + error = _2, + } + return result.success and Result:ok(unit()) or Result:err(Option:wrap(result.error)) + end + Result.fromPromise = TS.async(function(self, p) + local _0, _1 = TS.try(function() + return TS.TRY_RETURN, { Result:ok(TS.await(p)) } + end, function(e) + return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } + end) + if _0 then + return unpack(_1) + end + end) + Result.fromVoidPromise = TS.async(function(self, p) + local _0, _1 = TS.try(function() + TS.await(p) + return TS.TRY_RETURN, { Result:ok(unit()) } + end, function(e) + return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } + end) + if _0 then + return unpack(_1) + end + end) + function Result:isOk() + return self.okValue ~= nil + end + function Result:isErr() + return self.errValue ~= nil + end + function Result:contains(x) + return self.okValue == x + end + function Result:containsErr(x) + return self.errValue == x + end + function Result:okOption() + return Option:wrap(self.okValue) + end + function Result:errOption() + return Option:wrap(self.errValue) + end + function Result:map(func) + return self:isOk() and Result:ok(func(self.okValue)) or Result:err(self.errValue) + end + function Result:mapOr(def, func) + local _0 + if self:isOk() then + _0 = func(self.okValue) + else + _0 = def + end + return _0 + end + function Result:mapOrElse(def, func) + local _0 + if self:isOk() then + _0 = func(self.okValue) + else + _0 = def(self.errValue) + end + return _0 + end + function Result:mapErr(func) + return self:isErr() and Result:err(func(self.errValue)) or Result:ok(self.okValue) + end + Result["and"] = function(self, other) + return self:isErr() and Result:err(self.errValue) or other + end + function Result:andThen(func) + return self:isErr() and Result:err(self.errValue) or func(self.okValue) + end + Result["or"] = function(self, other) + return self:isOk() and Result:ok(self.okValue) or other + end + function Result:orElse(other) + return self:isOk() and Result:ok(self.okValue) or other(self.errValue) + end + function Result:expect(msg) + if self:isOk() then + return self.okValue + else + error(msg) + end + end + function Result:unwrap() + return self:expect("called `Result.unwrap()` on an `Err` value: " .. tostring(self.errValue)) + end + function Result:unwrapOr(def) + local _0 + if self:isOk() then + _0 = self.okValue + else + _0 = def + end + return _0 + end + function Result:unwrapOrElse(gen) + local _0 + if self:isOk() then + _0 = self.okValue + else + _0 = gen(self.errValue) + end + return _0 + end + function Result:expectErr(msg) + if self:isErr() then + return self.errValue + else + error(msg) + end + end + function Result:unwrapErr() + return self:expectErr("called `Result.unwrapErr()` on an `Ok` value: " .. tostring(self.okValue)) + end + function Result:transpose() + return self:isOk() and self.okValue:map(function(some) + return Result:ok(some) + end) or Option:some(Result:err(self.errValue)) + end + function Result:flatten() + return self:isOk() and Result.new(self.okValue.okValue, self.okValue.errValue) or Result:err(self.errValue) + end + function Result:match(ifOk, ifErr) + local _0 + if self:isOk() then + _0 = ifOk(self.okValue) + else + _0 = ifErr(self.errValue) + end + return _0 + end + function Result:asPtr() + local _0 = (self.okValue) + if _0 == nil then + _0 = (self.errValue) + end + return _0 + end + end + local resultMeta = Result + resultMeta.__eq = function(a, b) + return b:match(function(ok) + return a:contains(ok) + end, function(err) + return a:containsErr(err) + end) + end + resultMeta.__tostring = function(result) + return result:match(function(ok) + return "Result.ok(" .. tostring(ok) .. ")" + end, function(err) + return "Result.err(" .. tostring(err) .. ")" + end) + end + return { + Result = Result, + } + )LUA"; + + CheckResult result = check(src); + CodeTooComplex ctc; + + if (FFlag::LuauLowerBoundsCalculation) + LUAU_REQUIRE_ERRORS(result); + else + CHECK(hasError(result, &ctc)); +} + +TEST_SUITE_END(); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 29bdd866..95dcd70a 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauLowerBoundsCalculation) + using namespace Luau; struct ToDotClassFixture : Fixture @@ -19,18 +21,21 @@ struct ToDotClassFixture : Fixture TypeId baseClassMetaType = arena.addType(TableTypeVar{}); - TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}}); + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test"}); getMutable(baseClassInstanceType)->props = { {"BaseField", {typeChecker.numberType}}, }; typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}}); + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { {"ChildField", {typeChecker.stringType}}, }; typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + persist(ty.type); + freeze(arena); } }; @@ -101,9 +106,34 @@ local function f(a, ...: string) return a end )"); LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(a, ...string) -> a", toString(requireType("f"))); + ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + + if (FFlag::LuauLowerBoundsCalculation) + { + CHECK_EQ(R"(digraph graphname { +n1 [label="FunctionTypeVar 1"]; +n1 -> n2 [label="arg"]; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="GenericTypeVar 3"]; +n2 -> n4 [label="tail"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n1 -> n6 [label="ret"]; +n6 [label="TypePack 6"]; +n6 -> n7; +n7 [label="BoundTypeVar 7"]; +n7 -> n3; +})", + toDot(requireType("f"), opts)); + } + else + { + CHECK_EQ(R"(digraph graphname { n1 [label="FunctionTypeVar 1"]; n1 -> n2 [label="arg"]; n2 [label="TypePack 2"]; @@ -119,7 +149,8 @@ n6 -> n7; n7 [label="TypePack 7"]; n7 -> n3; })", - toDot(requireType("f"), opts)); + toDot(requireType("f"), opts)); + } } TEST_CASE_FIXTURE(Fixture, "union") @@ -196,7 +227,7 @@ n1 -> n4 [label="typePackParam"]; (void)toDot(requireType("a")); } -TEST_CASE_FIXTURE(Fixture, "metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable") { CheckResult result = check(R"( local a: typeof(setmetatable({}, {})) @@ -361,4 +392,50 @@ n3 [label="number"]; toDot(*ty, opts)); } +TEST_CASE_FIXTURE(Fixture, "constrained") +{ + // ConstrainedTypeVars never appear in the final type graph, so we have to create one directly + // to dotify it. + TypeVar t{ConstrainedTypeVar{TypeLevel{}, {typeChecker.numberType, typeChecker.stringType, typeChecker.nilType}}}; + + ToDotOptions opts; + opts.showPointers = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="ConstrainedTypeVar 1"]; +n1 -> n2; +n2 [label="number"]; +n1 -> n3; +n3 [label="string"]; +n1 -> n4; +n4 [label="nil"]; +})", + toDot(&t, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "singletontypes") +{ + CheckResult result = check(R"( + local x: "hi" | "\"hello\"" | true | false + )"); + + ToDotOptions opts; + opts.showPointers = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="UnionTypeVar 1"]; +n1 -> n2; +n2 [label="SingletonTypeVar string: hi"]; +n1 -> n3; +)" + "n3 [label=\"SingletonTypeVar string: \\\"hello\\\"\"];" + R"( +n1 -> n4; +n4 [label="SingletonTypeVar boolean: true"]; +n1 -> n5; +n5 [label="SingletonTypeVar boolean: false"]; +})", + toDot(requireType("x"), opts)); +} + TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 3051e209..e03069a9 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); + TEST_SUITE_BEGIN("ToString"); TEST_CASE_FIXTURE(Fixture, "primitive") @@ -58,7 +60,43 @@ TEST_CASE_FIXTURE(Fixture, "named_table") CHECK_EQ("TheTable", toString(&table)); } -TEST_CASE_FIXTURE(Fixture, "exhaustive_toString_of_cyclic_table") +TEST_CASE_FIXTURE(Fixture, "empty_table") +{ + ScopedFastFlag LuauToStringTableBracesNewlines("LuauToStringTableBracesNewlines", true); + CheckResult result = check(R"( + local a: {} + )"); + + CHECK_EQ("{| |}", toString(requireType("a"))); + + // Should stay the same with useLineBreaks enabled + ToStringOptions opts; + opts.useLineBreaks = true; + CHECK_EQ("{| |}", toString(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") +{ + ScopedFastFlag LuauToStringTableBracesNewlines("LuauToStringTableBracesNewlines", true); + CheckResult result = check(R"( + local a: { prop: string, anotherProp: number, thirdProp: boolean } + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + opts.indent = true; + + //clang-format off + CHECK_EQ("{|\n" + " anotherProp: number,\n" + " prop: string,\n" + " thirdProp: boolean\n" + "|}", + toString(requireType("a"), opts)); + //clang-format on +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") { CheckResult result = check(R"( --!strict @@ -124,6 +162,39 @@ TEST_CASE_FIXTURE(Fixture, "functions_are_always_parenthesized_in_unions_or_inte CHECK_EQ(toString(&itv), "((number, string) -> (string, number)) & ((string, number) -> (number, string))"); } +TEST_CASE_FIXTURE(Fixture, "intersections_respects_use_line_breaks") +{ + CheckResult result = check(R"( + local a: ((string) -> string) & ((number) -> number) + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + + //clang-format off + CHECK_EQ("((number) -> number)\n" + "& ((string) -> string)", + toString(requireType("a"), opts)); + //clang-format on +} + +TEST_CASE_FIXTURE(Fixture, "unions_respects_use_line_breaks") +{ + CheckResult result = check(R"( + local a: string | number | boolean + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + + //clang-format off + CHECK_EQ("boolean\n" + "| number\n" + "| string", + toString(requireType("a"), opts)); + //clang-format on +} + TEST_CASE_FIXTURE(Fixture, "quit_stringifying_table_type_when_length_is_exceeded") { TableTypeVar ttv{}; @@ -336,10 +407,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") REQUIRE_EQ("c", toString(params[2], opts)); } -TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") +TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( local base = {} function base:one() return 1 end @@ -470,7 +539,6 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function id(x) return x end )"); @@ -483,7 +551,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function map(arr, fn) local t = {} @@ -502,7 +569,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(a: number, b: string) end local function test(...: T...): U... @@ -519,7 +585,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") TEST_CASE("toStringNamedFunction_unit_f") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; TypePackVar empty{TypePack{}}; FunctionTypeVar ftv{&empty, &empty, {}, false}; CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); @@ -527,7 +592,6 @@ TEST_CASE("toStringNamedFunction_unit_f") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: a, ...): (a, a, b...) return x, x, ... @@ -542,7 +606,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): ...number return 1, 2, 3 @@ -557,7 +620,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): (string, ...number) return 'a', 1, 2, 3 @@ -572,7 +634,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local f: (number, y: number) -> number )"); @@ -585,7 +646,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_ar TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: T, g: (T) -> U)): () end @@ -601,8 +661,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; - CheckResult result = check(R"( local function test(a, b : string, ... : number) return a end )"); @@ -615,4 +673,50 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") CHECK_EQ("test(first: a, second: string, ...: number): a", toStringNamedFunction("test", *ftv, opts)); } +TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_generics") +{ + ScopedFastFlag sff[] = { + {"LuauAlwaysQuantify", true}, + }; + + CheckResult result = check(R"( + function foo(x: a, y) end + )"); + + CHECK("(a, b) -> ()" == toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") +{ + CheckResult result = check(R"( + local foo = {} + function foo:method(arg: string): () + end + )"); + + TypeId parentTy = requireType("foo"); + auto ttv = get(follow(parentTy)); + auto ftv = get(ttv->props.at("method").type); + + CHECK_EQ("foo:method(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); +} + + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") +{ + CheckResult result = check(R"( + local foo = {} + function foo:method(arg: string): () + end + )"); + + TypeId parentTy = requireType("foo"); + auto ttv = get(follow(parentTy)); + auto ftv = get(ttv->props.at("method").type); + + ToStringOptions opts; + opts.hideFunctionSelfArgument = true; + CHECK_EQ("foo:method(arg: string): ()", toStringNamedFunction("foo:method", *ftv, opts)); +} + TEST_SUITE_END(); diff --git a/tests/TopoSort.test.cpp b/tests/TopoSort.test.cpp index 9b990866..1f14ae88 100644 --- a/tests/TopoSort.test.cpp +++ b/tests/TopoSort.test.cpp @@ -340,26 +340,28 @@ TEST_CASE_FIXTURE(Fixture, "nested_type_annotations_depends_on_later_typealiases TEST_CASE_FIXTURE(Fixture, "return_comes_last") { - CheckResult result = check(R"( -export type Module = { bar: (number) -> boolean, foo: () -> string } + AstStatBlock* program = parse(R"( + local module = {} -return function() : Module - local module = {} + local function confuseCompiler() return module.foo() end - local function confuseCompiler() return module.foo() end - - module.foo = function() return "" end + module.foo = function() return "" end - function module.bar(x:number) - confuseCompiler() - return true - end - - return module -end + function module.bar(x:number) + confuseCompiler() + return true + end + + return module )"); - LUAU_REQUIRE_NO_ERRORS(result); + auto sorted = toposort(*program); + + CHECK_EQ(sorted[0], program->body.data[0]); + CHECK_EQ(sorted[2], program->body.data[1]); + CHECK_EQ(sorted[1], program->body.data[2]); + CHECK_EQ(sorted[3], program->body.data[3]); + CHECK_EQ(sorted[4], program->body.data[4]); } TEST_CASE_FIXTURE(Fixture, "break_comes_last") diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 5ac45ff2..b02a52b2 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -388,7 +388,7 @@ TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly") std::string actual = decorateWithTypes(code); - CHECK_EQ(expected, decorateWithTypes(code)); + CHECK_EQ(expected, actual); } TEST_CASE_FIXTURE(Fixture, "function_type_location") @@ -661,4 +661,21 @@ type t4 = false CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_array_types") +{ + std::string code = R"( +type t1 = {number} +type t2 = {[string]: number} + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple_types") +{ + std::string code = "for k:string,v:boolean in next,{}do end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index b2e76052..d6f0a0c8 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -7,10 +7,21 @@ using namespace Luau; -LUAU_FASTFLAG(LuauFixIncorrectLineNumberDuplicateType) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) TEST_SUITE_BEGIN("TypeAliases"); +TEST_CASE_FIXTURE(Fixture, "basic_alias") +{ + CheckResult result = check(R"( + type T = number + local x: T = 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("x"))); +} + TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") { CheckResult result = check(R"( @@ -26,6 +37,63 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); } +TEST_CASE_FIXTURE(Fixture, "names_are_ascribed") +{ + CheckResult result = check(R"( + type T = { x: number } + local x: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("T", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") +{ + // This is a tricky case. In order to support recursive type aliases, + // we first walk the block and generate free types as placeholders. + // We then walk the AST as normal. If we declare a type alias as below, + // we generate a free type. We then begin our normal walk, examining + // local x: T = "foo", which establishes two constraints: + // a <: b + // string <: a + // We then visit the type alias, and establish that + // b <: number + // Then, when solving these constraints, we dispatch them in the order + // they appear above. This means that a ~ b, and a ~ string, thus + // b ~ string. This means the b <: number constraint has no effect. + // Essentially we've "stolen" the alias's type out from under it. + // This test ensures that we don't actually do this. + CheckResult result = check(R"( + local x: T = "foo" + type T = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK(result.errors[0] == TypeError{ + Location{{1, 21}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); + } + else + { + CHECK(result.errors[0] == TypeError{ + Location{{1, 8}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); + } +} + TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") { CheckResult result = check(R"( @@ -43,7 +111,22 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_whe CHECK_EQ(typeChecker.numberType, tm->givenType); } -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_aliases") +{ + CheckResult result = check(R"( + --!strict + type T = { f: number, g: U } + type U = { h: number, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = 3, g = { h = 5, i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_generic_aliases") { CheckResult result = check(R"( --!strict @@ -257,11 +340,7 @@ TEST_CASE_FIXTURE(Fixture, "reported_location_is_correct_when_type_alias_are_dup auto dtd = get(result.errors[0]); REQUIRE(dtd); CHECK_EQ(dtd->name, "B"); - - if (FFlag::LuauFixIncorrectLineNumberDuplicateType) - CHECK_EQ(dtd->previousLocation.begin.line + 1, 3); - else - CHECK_EQ(dtd->previousLocation.begin.line + 1, 1); + CHECK_EQ(dtd->previousLocation.begin.line + 1, 3); } TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") @@ -285,7 +364,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") CHECK_EQ("Node", toString(e->wantedType)); } -TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_multi_assign") { fileResolver.source["workspace/A"] = R"( export type myvec2 = {x: number, y: number} @@ -323,7 +402,7 @@ TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") REQUIRE(bType->props.size() == 3); } -TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") { CheckResult result = check("type t10 = typeof(table)"); LUAU_REQUIRE_NO_ERRORS(result); @@ -391,7 +470,7 @@ type Cool = typeof(c) CHECK_EQ(ttv->name, "Cool"); } -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_of_an_imported_recursive_type") { fileResolver.source["game/A"] = R"( export type X = { a: number, b: X? } @@ -416,7 +495,7 @@ type X = Import.X CHECK_EQ(follow(*ty1), follow(*ty2)); } -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_of_an_imported_recursive_generic_type") { fileResolver.source["game/A"] = R"( export type X = { a: T, b: U, C: X? } @@ -495,8 +574,6 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( -- OK because forwarded types are used with their parameters. type Tree = { data: T, children: Forest } @@ -508,8 +585,6 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( -- Not OK because forwarded types are used with different types than their parameters. type Forest = {Tree<{T}>} @@ -531,8 +606,6 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( type Tree1 = { data: T, children: {Tree2} } type Tree2 = { data: U, children: {Tree1} } @@ -576,7 +649,7 @@ TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") * * We solved this by ascribing a unique subLevel to each prototyped alias. */ -TEST_CASE_FIXTURE(Fixture, "do_not_quantify_unresolved_aliases") +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_quantify_unresolved_aliases") { CheckResult result = check(R"( --!strict @@ -627,8 +700,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_ */ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any") { - ScopedFastFlag sff[] = {{"LuauTwoPassAliasDefinitionFix", true}}; - CheckResult result = check(R"( local function x() local y: FutureType = {}::any @@ -645,13 +716,6 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2") { - ScopedFastFlag sff[] = { - {"LuauTwoPassAliasDefinitionFix", true}, - - // We also force this flag because it surfaced an unfortunate interaction. - {"LuauErrorRecoveryType", true}, - }; - CheckResult result = check(R"( local B = {} B.bar = 4 @@ -687,8 +751,6 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( -- this would be an infinite type if we allowed it type Tree = { data: T, children: {Tree<{T}>} } diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 2ad11d01..3e2ad6dc 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -30,11 +30,21 @@ TEST_CASE_FIXTURE(Fixture, "successful_check") dumpErrors(result); } +TEST_CASE_FIXTURE(Fixture, "variable_type_is_supertype") +{ + CheckResult result = check(R"( + local x: number = 1 + local y: number? = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "function_parameters_can_have_annotations") { CheckResult result = check(R"( function double(x: number) - return x * 2 + return 2 end local four = double(2) @@ -47,7 +57,7 @@ TEST_CASE_FIXTURE(Fixture, "function_parameter_annotations_are_checked") { CheckResult result = check(R"( function double(x: number) - return x * 2 + return 2 end local four = double("two") @@ -70,13 +80,13 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") const FunctionTypeVar* ftv = get(fiftyType); REQUIRE(ftv != nullptr); - TypePackId retPack = ftv->retType; + TypePackId retPack = follow(ftv->retTypes); const TypePack* tp = get(retPack); REQUIRE(tp != nullptr); REQUIRE_EQ(1, tp->head.size()); - REQUIRE_EQ(typeChecker.anyType, tp->head[0]); + REQUIRE_EQ(typeChecker.anyType, follow(tp->head[0])); } TEST_CASE_FIXTURE(Fixture, "function_return_multret_annotations_are_checked") @@ -116,6 +126,23 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotation_should_continuously_parse LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "unknown_type_reference_generates_error") +{ + CheckResult result = check(R"( + local x: IDoNotExist + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{ + Location{{1, 17}, {1, 28}}, + getMainSourceModule()->name, + UnknownSymbol{ + "IDoNotExist", + UnknownSymbol::Context::Type, + }, + }); +} + TEST_CASE_FIXTURE(Fixture, "typeof_variable_type_annotation_should_return_its_type") { CheckResult result = check(R"( @@ -221,8 +248,6 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") TEST_CASE_FIXTURE(Fixture, "as_expr_warns_on_unrelated_cast") { - ScopedFastFlag sff2{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( local a = 55 :: string )"); @@ -407,8 +432,6 @@ TEST_CASE_FIXTURE(Fixture, "typeof_expr") TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( type A = B type B = A @@ -532,7 +555,7 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti CHECK_EQ(recordType, bType); } -TEST_CASE_FIXTURE(Fixture, "use_type_required_from_another_file") +TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") { addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); @@ -558,7 +581,7 @@ TEST_CASE_FIXTURE(Fixture, "use_type_required_from_another_file") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "cannot_use_nonexported_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") { addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); @@ -584,7 +607,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_nonexported_type") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "builtin_types_are_not_exported") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_are_not_exported") { addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); @@ -636,7 +659,10 @@ int AssertionCatcher::tripped; TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice") { - ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", false}, + }; AssertionCatcher ac; @@ -650,9 +676,10 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice") TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_handler") { - ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; - - AssertionCatcher ac; + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", false}, + }; bool caught = false; @@ -666,8 +693,44 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_handler") std::runtime_error); CHECK_EQ(true, caught); +} - frontend.iceHandler.onInternalError = {}; +TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag") +{ + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", true}, + }; + + AssertionCatcher ac; + + CHECK_THROWS_AS(check(R"( + local a: _luau_ice = 55 + )"), + InternalCompilerError); + + LUAU_ASSERT(1 == AssertionCatcher::tripped); +} + +TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag_handler") +{ + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", true}, + }; + + bool caught = false; + + frontend.iceHandler.onInternalError = [&](const char*) { + caught = true; + }; + + CHECK_THROWS_AS(check(R"( + local a: _luau_ice = 55 + )"), + InternalCompilerError); + + CHECK_EQ(true, caught); } TEST_CASE_FIXTURE(Fixture, "luau_ice_is_not_special_without_the_flag") @@ -680,7 +743,7 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_is_not_special_without_the_flag") )"); } -TEST_CASE_FIXTURE(Fixture, "luau_print_is_magic_if_the_flag_is_set") +TEST_CASE_FIXTURE(BuiltinsFixture, "luau_print_is_magic_if_the_flag_is_set") { // Luau::resetPrintLine(); ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; @@ -753,4 +816,14 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_intersection_typevar") REQUIRE(ocf); } +TEST_CASE_FIXTURE(Fixture, "instantiation_clone_has_to_follow") +{ + CheckResult result = check(R"( + export type t8 = (t0)&(((true)|(any))->"") + export type t0 = ({})&({_:{[any]:number},}) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 5224b5d8..bc55940e 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -237,7 +237,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") CHECK_EQ("*unknown*", toString(requireType("a"))); } -TEST_CASE_FIXTURE(Fixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") +TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") { CheckResult result = check(R"( local a: any @@ -285,7 +285,7 @@ end LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "metatable_of_any_can_be_a_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_of_any_can_be_a_table") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index c6fbebed..2f0266ec 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -8,9 +8,11 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("BuiltinTests"); -TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_things_are_defined") { CheckResult result = check(R"( local a00 = math.frexp @@ -48,7 +50,7 @@ TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "next_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "next_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( local a: string, b: number = next({ 1 }) @@ -61,7 +63,7 @@ TEST_CASE_FIXTURE(Fixture, "next_iterator_should_infer_types_and_type_check") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "pairs_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( type Map = { [K]: V } @@ -73,7 +75,7 @@ TEST_CASE_FIXTURE(Fixture, "pairs_iterator_should_infer_types_and_type_check") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "ipairs_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( type Map = { [K]: V } @@ -85,7 +87,7 @@ TEST_CASE_FIXTURE(Fixture, "ipairs_iterator_should_infer_types_and_type_check") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_dot_remove_optionally_returns_generic") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_remove_optionally_returns_generic") { CheckResult result = check(R"( local t = { 1 } @@ -96,7 +98,7 @@ TEST_CASE_FIXTURE(Fixture, "table_dot_remove_optionally_returns_generic") CHECK_EQ(toString(requireType("n")), "number?"); } -TEST_CASE_FIXTURE(Fixture, "table_concat_returns_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_concat_returns_string") { CheckResult result = check(R"( local r = table.concat({1,2,3,4}, ",", 2); @@ -106,7 +108,7 @@ TEST_CASE_FIXTURE(Fixture, "table_concat_returns_string") CHECK_EQ(*typeChecker.stringType, *requireType("r")); } -TEST_CASE_FIXTURE(Fixture, "sort") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort") { CheckResult result = check(R"( local t = {1, 2, 3}; @@ -116,7 +118,7 @@ TEST_CASE_FIXTURE(Fixture, "sort") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sort_with_predicate") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_predicate") { CheckResult result = check(R"( --!strict @@ -128,7 +130,7 @@ TEST_CASE_FIXTURE(Fixture, "sort_with_predicate") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sort_with_bad_predicate") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_bad_predicate") { CheckResult result = check(R"( --!strict @@ -138,6 +140,12 @@ TEST_CASE_FIXTURE(Fixture, "sort_with_bad_predicate") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '(number, number) -> boolean' could not be converted into '((a, a) -> boolean)?' +caused by: + None of the union options are compatible. For example: Type '(number, number) -> boolean' could not be converted into '(a, a) -> boolean' +caused by: + Argument #1 type is not compatible. Type 'string' could not be converted into 'number')", + toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "strings_have_methods") @@ -150,7 +158,7 @@ TEST_CASE_FIXTURE(Fixture, "strings_have_methods") CHECK_EQ(*typeChecker.stringType, *requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "math_max_variatic") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_variatic") { CheckResult result = check(R"( local n = math.max(1,2,3,4,5,6,7,8,9,0) @@ -160,16 +168,17 @@ TEST_CASE_FIXTURE(Fixture, "math_max_variatic") CHECK_EQ(*typeChecker.numberType, *requireType("n")); } -TEST_CASE_FIXTURE(Fixture, "math_max_checks_for_numbers") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_checks_for_numbers") { CheckResult result = check(R"( local n = math.max(1,2,"3") )"); CHECK(!result.errors.empty()); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "builtin_tables_sealed") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_tables_sealed") { CheckResult result = check(R"LUA( local b = bit32 @@ -181,7 +190,7 @@ TEST_CASE_FIXTURE(Fixture, "builtin_tables_sealed") CHECK_EQ(bit32t->state, TableState::Sealed); } -TEST_CASE_FIXTURE(Fixture, "lua_51_exported_globals_all_exist") +TEST_CASE_FIXTURE(BuiltinsFixture, "lua_51_exported_globals_all_exist") { // Extracted from lua5.1 CheckResult result = check(R"( @@ -338,7 +347,7 @@ TEST_CASE_FIXTURE(Fixture, "lua_51_exported_globals_all_exist") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "setmetatable_unpacks_arg_types_correctly") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_unpacks_arg_types_correctly") { CheckResult result = check(R"( setmetatable({}, setmetatable({}, {})) @@ -346,7 +355,7 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_unpacks_arg_types_correctly") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_overload") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_2_args_overload") { CheckResult result = check(R"( local t = {} @@ -358,7 +367,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_o CHECK_EQ(typeChecker.stringType, requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_3_args_overload") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_3_args_overload") { CheckResult result = check(R"( local t = {} @@ -370,7 +379,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_3_args_o CHECK_EQ("string", toString(requireType("s"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack") { CheckResult result = check(R"( local t = table.pack(1, "foo", true) @@ -380,7 +389,7 @@ TEST_CASE_FIXTURE(Fixture, "table_pack") CHECK_EQ("{| [number]: boolean | number | string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack_variadic") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack_variadic") { CheckResult result = check(R"( --!strict @@ -395,7 +404,7 @@ local t = table.pack(f()) CHECK_EQ("{| [number]: number | string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack_reduce") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack_reduce") { CheckResult result = check(R"( local t = table.pack(1, 2, true) @@ -412,7 +421,7 @@ TEST_CASE_FIXTURE(Fixture, "table_pack_reduce") CHECK_EQ("{| [number]: string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "gcinfo") +TEST_CASE_FIXTURE(BuiltinsFixture, "gcinfo") { CheckResult result = check(R"( local n = gcinfo() @@ -422,12 +431,12 @@ TEST_CASE_FIXTURE(Fixture, "gcinfo") CHECK_EQ(*typeChecker.numberType, *requireType("n")); } -TEST_CASE_FIXTURE(Fixture, "getfenv") +TEST_CASE_FIXTURE(BuiltinsFixture, "getfenv") { LUAU_REQUIRE_NO_ERRORS(check("getfenv(1)")); } -TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "os_time_takes_optional_date_table") { CheckResult result = check(R"( local n1 = os.time() @@ -441,7 +450,7 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") CHECK_EQ(*typeChecker.numberType, *requireType("n3")); } -TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "thread_is_a_type") { CheckResult result = check(R"( local co = coroutine.create(function() end) @@ -451,7 +460,7 @@ TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") CHECK_EQ(*typeChecker.threadType, *requireType("co")); } -TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") +TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_resume_anything_goes") { CheckResult result = check(R"( local function nifty(x, y) @@ -469,7 +478,7 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") +TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_wrap_anything_goes") { CheckResult result = check(R"( --!nonstrict @@ -488,7 +497,7 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_should_not_mutate_persisted_types") { CheckResult result = check(R"( local string = string @@ -503,7 +512,7 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") REQUIRE(ttv); } -TEST_CASE_FIXTURE(Fixture, "string_format_arg_types_inference") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_arg_types_inference") { CheckResult result = check(R"( --!strict @@ -516,7 +525,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_arg_types_inference") CHECK_EQ("(number, number, string) -> string", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "string_format_arg_count_mismatch") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_arg_count_mismatch") { CheckResult result = check(R"( --!strict @@ -532,7 +541,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_arg_count_mismatch") CHECK_EQ(result.errors[2].location.begin.line, 4); } -TEST_CASE_FIXTURE(Fixture, "string_format_correctly_ordered_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_correctly_ordered_types") { CheckResult result = check(R"( --!strict @@ -546,7 +555,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_correctly_ordered_types") CHECK_EQ(tm->givenType, typeChecker.numberType); } -TEST_CASE_FIXTURE(Fixture, "xpcall") +TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall") { CheckResult result = check(R"( --!strict @@ -557,12 +566,12 @@ TEST_CASE_FIXTURE(Fixture, "xpcall") )"); LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("boolean", toString(requireType("a"))); - REQUIRE_EQ("number", toString(requireType("b"))); - REQUIRE_EQ("boolean", toString(requireType("c"))); + CHECK_EQ("boolean", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("boolean", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "see_thru_select") +TEST_CASE_FIXTURE(BuiltinsFixture, "see_thru_select") { CheckResult result = check(R"( local a:number, b:boolean = select(2,"hi", 10, true) @@ -571,7 +580,7 @@ TEST_CASE_FIXTURE(Fixture, "see_thru_select") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "see_thru_select_count") +TEST_CASE_FIXTURE(BuiltinsFixture, "see_thru_select_count") { CheckResult result = check(R"( local a = select("#","hi", 10, true) @@ -581,7 +590,7 @@ TEST_CASE_FIXTURE(Fixture, "see_thru_select_count") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "select_with_decimal_argument_is_rounded_down") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_decimal_argument_is_rounded_down") { CheckResult result = check(R"( local a: number, b: boolean = select(2.9, "foo", 1, true) @@ -591,7 +600,7 @@ TEST_CASE_FIXTURE(Fixture, "select_with_decimal_argument_is_rounded_down") } // Could be flaky if the fix has regressed. -TEST_CASE_FIXTURE(Fixture, "bad_select_should_not_crash") +TEST_CASE_FIXTURE(BuiltinsFixture, "bad_select_should_not_crash") { CheckResult result = check(R"( do end @@ -603,10 +612,12 @@ TEST_CASE_FIXTURE(Fixture, "bad_select_should_not_crash") end )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Argument count mismatch. Function expects at least 1 argument, but none are specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function expects 1 argument, but none are specified", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "select_way_out_of_range") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_way_out_of_range") { CheckResult result = check(R"( select(5432598430953240958) @@ -617,7 +628,7 @@ TEST_CASE_FIXTURE(Fixture, "select_way_out_of_range") CHECK_EQ("bad argument #1 to select (index out of range)", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "select_slightly_out_of_range") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_slightly_out_of_range") { CheckResult result = check(R"( select(3, "a", 1) @@ -628,7 +639,7 @@ TEST_CASE_FIXTURE(Fixture, "select_slightly_out_of_range") CHECK_EQ("bad argument #1 to select (index out of range)", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail") { CheckResult result = check(R"( --!nonstrict @@ -647,7 +658,7 @@ TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail") CHECK_EQ("any", toString(requireType("quux"))); } -TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail_and_string_head") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_string_head") { CheckResult result = check(R"( --!nonstrict @@ -701,7 +712,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument2") CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "debug_traceback_is_crazy") +TEST_CASE_FIXTURE(BuiltinsFixture, "debug_traceback_is_crazy") { CheckResult result = check(R"( local co: thread = ... @@ -718,7 +729,7 @@ debug.traceback(co, "msg", 1) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "debug_info_is_crazy") +TEST_CASE_FIXTURE(BuiltinsFixture, "debug_info_is_crazy") { CheckResult result = check(R"( local co: thread, f: ()->() = ... @@ -732,7 +743,7 @@ debug.info(f, "n") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "aliased_string_format") +TEST_CASE_FIXTURE(BuiltinsFixture, "aliased_string_format") { CheckResult result = check(R"( local fmt = string.format @@ -743,7 +754,7 @@ TEST_CASE_FIXTURE(Fixture, "aliased_string_format") CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "string_lib_self_noself") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_lib_self_noself") { CheckResult result = check(R"( --!nonstrict @@ -762,7 +773,7 @@ TEST_CASE_FIXTURE(Fixture, "string_lib_self_noself") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "gmatch_definition") +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_definition") { CheckResult result = check(R"_( local a, b, c = ("hey"):gmatch("(.)(.)(.)")() @@ -775,7 +786,7 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "select_on_variadic") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_on_variadic") { CheckResult result = check(R"( local function f(): (number, ...(boolean | number)) @@ -791,7 +802,7 @@ TEST_CASE_FIXTURE(Fixture, "select_on_variadic") CHECK_EQ("any", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_positions") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_report_all_type_errors_at_correct_positions") { CheckResult result = check(R"( ("%s%d%s"):format(1, "hello", true) @@ -823,7 +834,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[5].data); } -TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "tonumber_returns_optional_number_type") { CheckResult result = check(R"( --!strict @@ -834,7 +845,7 @@ TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") +TEST_CASE_FIXTURE(BuiltinsFixture, "tonumber_returns_optional_number_type2") { CheckResult result = check(R"( --!strict @@ -844,7 +855,7 @@ TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_add_definitions_to_persistent_types") { CheckResult result = check(R"( local f = math.sin @@ -866,14 +877,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") REQUIRE(gtv->definition); } -TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types") { - ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, - {"LuauWidenIfSupertypeIsFree2", true}, - }; - CheckResult result = check(R"( local function f(x: (number | boolean)?) return assert(x) @@ -881,17 +886,14 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); + else + CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types2") { - ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, - {"LuauWidenIfSupertypeIsFree2", true}, - }; - CheckResult result = check(R"( local function f(x: (number | boolean)?): number | true return assert(x) @@ -902,13 +904,8 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") { - ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local function f(...: number?) return assert(...) @@ -919,13 +916,8 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types_even_from_type_pack_tail_ CHECK_EQ("(...number?) -> (number, ...number?)", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") { - ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local function f(x: nil) return assert(x, "hmm") @@ -936,7 +928,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_fir CHECK_EQ("(nil) -> nil", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "table_freeze_is_generic") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") { CheckResult result = check(R"( local t1: {a: number} = {a = 42} @@ -963,7 +955,7 @@ TEST_CASE_FIXTURE(Fixture, "table_freeze_is_generic") CHECK_EQ("*unknown*", toString(requireType("d"))); } -TEST_CASE_FIXTURE(Fixture, "set_metatable_needs_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") { ScopedFastFlag sff{"LuauSetMetaTableArgsCheck", true}; CheckResult result = check(R"( @@ -986,7 +978,7 @@ local function f(a: typeof(f)) end CHECK_EQ("Unknown global 'f'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") +TEST_CASE_FIXTURE(BuiltinsFixture, "no_persistent_typelevel_change") { TypeId mathTy = requireType(typeChecker.globalScope, "math"); REQUIRE(mathTy); @@ -1003,7 +995,7 @@ TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") CHECK(ftv->level.subLevel == original.subLevel); } -TEST_CASE_FIXTURE(Fixture, "global_singleton_types_are_sealed") +TEST_CASE_FIXTURE(BuiltinsFixture, "global_singleton_types_are_sealed") { CheckResult result = check(R"( local function f(x: string) diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 98fa66eb..6f4191e3 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -10,7 +10,7 @@ using namespace Luau; using std::nullopt; -struct ClassFixture : Fixture +struct ClassFixture : BuiltinsFixture { ClassFixture() { @@ -19,13 +19,13 @@ struct ClassFixture : Fixture unfreeze(arena); - TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}}); + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(baseClassInstanceType)->props = { {"BaseMethod", {makeFunction(arena, baseClassInstanceType, {numberType}, {})}}, {"BaseField", {numberType}}, }; - TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}}); + TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(baseClassType)->props = { {"StaticMethod", {makeFunction(arena, nullopt, {}, {numberType})}}, {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, @@ -34,39 +34,39 @@ struct ClassFixture : Fixture typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; addGlobalBinding(typeChecker, "BaseClass", baseClassType, "@test"); - TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}}); + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}}); + TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(childClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; addGlobalBinding(typeChecker, "ChildClass", childClassType, "@test"); - TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}}); + TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(grandChildInstanceType)->props = { {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}}); + TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(grandChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; addGlobalBinding(typeChecker, "GrandChild", childClassType, "@test"); - TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}}); + TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(anotherChildInstanceType)->props = { {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}}); + TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(anotherChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, }; @@ -75,13 +75,13 @@ struct ClassFixture : Fixture TypeId vector2MetaType = arena.addType(TableTypeVar{}); - TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}}); + TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"}); getMutable(vector2InstanceType)->props = { {"X", {numberType}}, {"Y", {numberType}}, }; - TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}}); + TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(vector2Type)->props = { {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, }; @@ -91,6 +91,9 @@ struct ClassFixture : Fixture typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; addGlobalBinding(typeChecker, "Vector2", vector2Type, "@test"); + for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) + persist(tf.type); + freeze(arena); } }; @@ -465,4 +468,16 @@ caused by: toString(result.errors[0])); } +TEST_CASE_FIXTURE(ClassFixture, "class_type_mismatch_with_name_conflict") +{ + CheckResult result = check(R"( +local i = ChildClass.New() +type ChildClass = { x: number } +local a: ChildClass = i + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'ChildClass' from 'Test' could not be converted into 'ChildClass' from 'MainModule'", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 898d8902..4545b8db 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -295,8 +295,6 @@ TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_type TEST_CASE_FIXTURE(Fixture, "single_class_type_identity_in_global_types") { - ScopedFastFlag luauCloneDeclaredGlobals{"LuauCloneDeclaredGlobals", true}; - loadDefinition(R"( declare class Cls end diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index da4ea074..036a667a 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TypeInferFunctions"); TEST_CASE_FIXTURE(Fixture, "tc_function") @@ -43,7 +45,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_return_type") const FunctionTypeVar* takeFiveType = get(requireType("take_five")); REQUIRE(takeFiveType != nullptr); - std::vector retVec = flatten(takeFiveType->retType).first; + std::vector retVec = flatten(takeFiveType->retTypes).first; REQUIRE(!retVec.empty()); REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); @@ -83,7 +85,7 @@ TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") +TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") { CheckResult result = check(R"( local T = {} @@ -98,7 +100,7 @@ TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") end return result - end + end return T )"); @@ -274,6 +276,10 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") { + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( function f(g) return f(f) @@ -281,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); + CHECK_EQ("t1 where t1 = (t1) -> (a...)", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") @@ -339,7 +345,7 @@ TEST_CASE_FIXTURE(Fixture, "local_function") const FunctionTypeVar* ftv = get(h); REQUIRE(ftv != nullptr); - std::optional rt = first(ftv->retType); + std::optional rt = first(ftv->retTypes); REQUIRE(bool(rt)); TypeId retType = follow(*rt); @@ -355,7 +361,7 @@ TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") LUAU_REQUIRE_NO_ERRORS(result); const Luau::FunctionTypeVar* fn = get(requireType("p")); REQUIRE(fn); - auto ret = first(fn->retType); + auto ret = first(fn->retTypes); REQUIRE(ret); REQUIRE(get(follow(*ret))); } @@ -454,7 +460,7 @@ TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotat const FunctionTypeVar* functionType = get(requireType("most_of_the_natural_numbers")); - std::optional retType = first(functionType->retType); + std::optional retType = first(functionType->retTypes); REQUIRE(retType); CHECK(get(*retType)); } @@ -481,10 +487,10 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") std::vector fArgs = flatten(fType->argTypes).first; - TypeId xType = argVec[1]; + TypeId xType = follow(argVec[1]); CHECK_EQ(1, fArgs.size()); - CHECK_EQ(xType, fArgs[0]); + CHECK_EQ(xType, follow(fArgs[0])); } TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") @@ -549,7 +555,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") CHECK(bool(argType->indexer)); } -TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") +TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") { CheckResult result = check(R"( function bottomupmerge(comp, a, b, left, mid, right) @@ -614,7 +620,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); } -TEST_CASE_FIXTURE(Fixture, "mutual_recursion") +TEST_CASE_FIXTURE(BuiltinsFixture, "mutual_recursion") { CheckResult result = check(R"( --!strict @@ -633,7 +639,7 @@ TEST_CASE_FIXTURE(Fixture, "mutual_recursion") dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") +TEST_CASE_FIXTURE(BuiltinsFixture, "toposort_doesnt_break_mutual_recursion") { CheckResult result = check(R"( --!strict @@ -650,6 +656,11 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict @@ -658,14 +669,14 @@ TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") end return function() - return f():andThen() + return f() end )"); LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") { CheckResult result = check(R"( function onerror() end @@ -783,16 +794,20 @@ TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields }})); } -TEST_CASE_FIXTURE(Fixture, "calling_function_with_anytypepack_doesnt_leak_free_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "calling_function_with_anytypepack_doesnt_leak_free_types") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict - function Test(a) + function Test(a): ...any return 1, "" end - local tab = {} table.insert(tab, Test(1)); )"); @@ -936,8 +951,6 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( type Overload = ((string) -> string) & ((number, number) -> number) local abc: Overload @@ -953,7 +966,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") CHECK_EQ("string", toString(requireType("z"))); } -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") { // Simple direct arg to arg propagation CheckResult result = check(R"( @@ -1043,16 +1056,19 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - // Return type doesn't inference 'nil' - result = check(R"( -function f(a: (number) -> nil) return a(4) end -f(function(x) print(x) end) - )"); + if (!FFlag::LuauLowerBoundsCalculation) + { + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + } } -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") { // Simple direct arg to arg propagation CheckResult result = check(R"( @@ -1142,13 +1158,16 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - // Return type doesn't inference 'nil' - result = check(R"( -function f(a: (number) -> nil) return a(4) end -f(function(x) print(x) end) - )"); + if (!FFlag::LuauLowerBoundsCalculation) + { + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") @@ -1268,10 +1287,8 @@ caused by: Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); } -TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_quantify_right_type") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; - fileResolver.source["game/isAMagicMock"] = R"( --!nonstrict return function(value) @@ -1292,23 +1309,193 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; - CheckResult result = check(R"( function string.len(): number return 1 end )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + + // if 'string' library property was replaced with an internal module type, it will be freed and the next check will crash + frontend.clear(); + + result = check(R"( +print(string.len('hello')) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite_2") +{ + CheckResult result = check(R"( +local t: { f: ((x: number) -> number)? } = {} + +function t.f(x) + print(x + 5) + return x .. "asd" -- 1st error: we know that return type is a number, not a string +end + +t.f = function(x) + print(x + 5) + return x .. "asd" -- 2nd error: we know that return type is a number, not a string +end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'string' could not be converted into 'number')"); + CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "inconsistent_return_types") +{ + const ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(a: boolean, b: number) + if a then + return nil + else + return b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(boolean, number) -> number?", toString(requireType("foo"))); + + // TODO: Test multiple returns + // Think of various cases where typepacks need to grow. maybe consult other tests + // Basic normalization of ConstrainedTypeVars during quantification +} + +TEST_CASE_FIXTURE(Fixture, "inconsistent_higher_order_function") +{ + const ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(f) + f(5) + f("six") + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("((number | string) -> (a...)) -> ()", toString(requireType("foo"))); +} + + +/* The bug here is that we are using the same level 2.0 for both the body of resolveDispatcher and the + * lambda useCallback. + * + * I think what we want to do is, at each scope level, never reuse the same sublevel. + * + * We also adjust checkBlock to consider the syntax `local x = function() ... end` to be sortable + * in the same way as `local function x() ... end`. This causes the function `resolveDispatcher` to be + * checked before the lambda. + */ +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + --!strict + + local function resolveDispatcher() + return (nil :: any) :: {useCallback: (any) -> any} + end + + local useCallback = function(deps: any) + return resolveDispatcher().useCallback(deps) + end + )"); + + // LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken. + // You get a TypeMismatch error where both types stringify the same. + + CHECK(result.errors.empty()); + if (!result.errors.empty()) + { + for (const auto& e : result.errors) + printf("%s: %s\n", toString(e.location).c_str(), toString(e).c_str()); + } +} + +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time2") +{ + CheckResult result = check(R"( + --!strict + + local function resolveDispatcher() + return (nil :: any) :: {useContext: (number?) -> any} + end + + local useContext + useContext = function(unstable_observedBits: number?) + resolveDispatcher().useContext(unstable_observedBits) + end + )"); + + // LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken. + // You get a TypeMismatch error where both types stringify the same. + + CHECK(result.errors.empty()); + if (!result.errors.empty()) + { + for (const auto& e : result.errors) + printf("%s %s: %s\n", e.moduleName.c_str(), toString(e.location).c_str(), toString(e).c_str()); + } +} + +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time3") +{ + CheckResult result = check(R"( + local foo + + foo():bar(function() + return foo() + end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_unsealed_overwrite") +{ + CheckResult result = check(R"( +local t = { f = nil :: ((x: number) -> number)? } + +function t.f(x: string): string -- 1st error: new function value type is incompatible + return x .. "asd" +end + +t.f = function(x) + print(x + 5) + return x .. "asd" -- 2nd error: we know that return type is a number, not a string +end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string) -> string' could not be converted into '((number) -> number)?' +caused by: + None of the union options are compatible. For example: Type '(string) -> string' could not be converted into '(number) -> number' +caused by: + Argument #1 type is not compatible. Type 'number' could not be converted into 'string')"); + CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); } TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; - CheckResult result = check(R"( local function f(x: any) end f() @@ -1319,8 +1506,6 @@ TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; - CheckResult result = check(R"( local t: {[string]: () -> number} = {} @@ -1337,7 +1522,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") { - ScopedFastFlag sff{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; CheckResult result = check(R"( function test(a: number, b: string, ...) end @@ -1359,8 +1543,6 @@ TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic") { - ScopedFastFlag sff1{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; - ScopedFastFlag sff2{"LuauFixArgumentCountMismatchAmountWithGenericTypes", true}; CheckResult result = check(R"( function test(a: number, b: string, ...) return 1 @@ -1384,10 +1566,8 @@ wrapper(test) CHECK(acm->isVariadic); } -TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic2") +TEST_CASE_FIXTURE(BuiltinsFixture, "too_few_arguments_variadic_generic2") { - ScopedFastFlag sff1{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; - ScopedFastFlag sff2{"LuauFixArgumentCountMismatchAmountWithGenericTypes", true}; CheckResult result = check(R"( function test(a: number, b: string, ...) return 1 @@ -1411,4 +1591,84 @@ pcall(wrapper, test) CHECK(acm->isVariadic); } +TEST_CASE_FIXTURE(Fixture, "occurs_check_failure_in_function_return_type") +{ + CheckResult result = check(R"( + function f() + return 5, f() + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(nullptr != get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") +{ + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + local function f() return end + local g = function() return f() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_constrained_types") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + --!strict + local function foo(f) + f(5) + f("hi") + local function g() + return f + end + local h = g() + h(true) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("((boolean | number | string) -> (a...)) -> ()", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "call_o_with_another_argument_after_foo_was_quantified") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + local function f(o) + local t = {} + t[o] = true + + local function foo(o) + o.m1(5) + t[o] = nil + end + + o.m1("hi") + + return t + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // TODO: check the normalized type of f +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index f360a77c..97ba0808 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -11,8 +11,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauFixArgumentCountMismatchAmountWithGenericTypes) - TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -69,7 +67,7 @@ TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") +TEST_CASE_FIXTURE(BuiltinsFixture, "inferred_local_vars_can_be_polytypes") { CheckResult result = check(R"( local function id(x) return x end @@ -81,7 +79,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") +TEST_CASE_FIXTURE(BuiltinsFixture, "local_vars_can_be_instantiated_polytypes") { CheckResult result = check(R"( local function id(x) return x end @@ -226,12 +224,12 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") const FunctionTypeVar* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); - auto [rets, varrets] = flatten(idFun->retType); + auto [rets, varrets] = flatten(idFun->retTypes); CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); - CHECK_EQ(args[0], idFun->generics[0]); - CHECK_EQ(rets[0], idFun->generics[0]); + CHECK_EQ(follow(args[0]), follow(idFun->generics[0])); + CHECK_EQ(follow(rets[0]), follow(idFun->generics[0])); } TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") @@ -249,12 +247,12 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") const FunctionTypeVar* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); - auto [rets, varrets] = flatten(idFun->retType); + auto [rets, varrets] = flatten(idFun->retTypes); CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); - CHECK_EQ(args[0], idFun->generics[0]); - CHECK_EQ(rets[0], idFun->generics[0]); + CHECK_EQ(follow(args[0]), follow(idFun->generics[0])); + CHECK_EQ(follow(rets[0]), follow(idFun->generics[0])); } TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") @@ -611,7 +609,7 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") CHECK(requireType("y1") == requireType("y2")); } -TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") +TEST_CASE_FIXTURE(BuiltinsFixture, "bound_tables_do_not_clone_original_fields") { CheckResult result = check(R"( local exports = {} @@ -677,10 +675,8 @@ local d: D = c R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); } -TEST_CASE_FIXTURE(Fixture, "generic_functions_dont_cache_type_parameters") +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_functions_dont_cache_type_parameters") { - ScopedFastFlag sff{"LuauGenericFunctionsDontCacheTypeParams", true}; - CheckResult result = check(R"( -- See https://github.com/Roblox/luau/issues/332 -- This function has a type parameter with the same name as clones, @@ -704,13 +700,6 @@ end TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") { - ScopedFastFlag sffs[] = { - { "LuauTableSubtypingVariance2", true }, - { "LuauUnsealedTableLiteral", true }, - { "LuauPropertiesGetExpectedType", true }, - { "LuauRecursiveTypeParameterRestriction", true }, - }; - CheckResult result = check(R"( --!strict -- At one point this produced a UAF @@ -733,8 +722,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification1") { - ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; - CheckResult result = check(R"( --!strict type Dispatcher = { @@ -753,8 +740,6 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification2") { - ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; - CheckResult result = check(R"( --!strict type Dispatcher = { @@ -773,8 +758,6 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification3") { - ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; - CheckResult result = check(R"( --!strict type Dispatcher = { @@ -805,11 +788,7 @@ wrapper(test) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 1 is specified)"); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); } TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_many") @@ -826,11 +805,7 @@ wrapper(test2, 1, "", 3) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 3 arguments, but 4 are specified)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 4 are specified)"); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 3 arguments, but 4 are specified)"); } TEST_CASE_FIXTURE(Fixture, "generic_function") @@ -843,6 +818,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_function") LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(a) -> a", toString(requireType("id"))); CHECK_EQ(*typeChecker.numberType, *requireType("a")); CHECK_EQ(*typeChecker.nilType, *requireType("b")); } @@ -901,7 +877,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") const FunctionTypeVar* foo = get(follow(fooProp->type)); REQUIRE(bool(foo)); - std::optional ret_ = first(foo->retType); + std::optional ret_ = first(foo->retTypes); REQUIRE(bool(ret_)); TypeId ret = follow(*ret_); @@ -998,8 +974,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - // Mutability in type function application right now can create strange recursive types CheckResult result = check(R"( type Table = { a: number } @@ -1020,7 +994,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") type t0 = t0 | {} )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERRORS(result); std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); @@ -1032,30 +1006,42 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") CHECK(it != result.errors.end()); } -TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( -local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end -return sum(2, 3, function(a, b) return a + b end) + local function sum(x: a, y: a, f: (a, a) -> a) + return f(x, y) + end + return sum(2, 3, function(a, b) return a + b end) )"); LUAU_REQUIRE_NO_ERRORS(result); result = check(R"( -local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end -local a = {1, 2, 3} -local r = map(a, function(a) return a + a > 100 end) + local function map(arr: {a}, f: (a) -> b) + local r = {} + for i,v in ipairs(arr) do + table.insert(r, f(v)) + end + return r + end + local a = {1, 2, 3} + local r = map(a, function(a) return a + a > 100 end) )"); LUAU_REQUIRE_NO_ERRORS(result); REQUIRE_EQ("{boolean}", toString(requireType("r"))); check(R"( -local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end -local a = {1, 2, 3} -local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) + local function foldl(arr: {a}, init: b, f: (b, a) -> b) + local r = init + for i,v in ipairs(arr) do + r = f(r, v) + end + return r + end + local a = {1, 2, 3} + local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1065,31 +1051,25 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") { CheckResult result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + local g12: ((T, (T) -> T) -> T) & ((T, T, (T, T) -> T) -> T) -local g12: typeof(g1) & typeof(g2) - -g12(1, function(x) return x + x end) -g12(1, 2, function(x, y) return x + y end) + g12(1, function(x) return x + x end) + g12(1, 2, function(x, y) return x + y end) )"); LUAU_REQUIRE_NO_ERRORS(result); result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + local g12: ((T, (T) -> T) -> T) & ((T, T, (T, T) -> T) -> T) -local g12: typeof(g1) & typeof(g2) - -g12({x=1}, function(x) return {x=-x.x} end) -g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) + g12({x=1}, function(x) return {x=-x.x} end) + g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) )"); LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_lib_function_function_argument") { CheckResult result = check(R"( local a = {{x=4}, {x=7}, {x=1}} @@ -1121,15 +1101,85 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") { CheckResult result = check(R"( -type A = { x: number } -local a: A = { x = 1 } -local b = a -type B = typeof(b) -type X = T -local c: X + type A = { x: number } + local a: A = { x = 1 } + local b = a + type B = typeof(b) + type X = T + local c: X )"); LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics1") +{ + // https://github.com/Roblox/luau/issues/484 + CheckResult result = check(R"( +--!strict +type MyObject = { + getReturnValue: (cb: () -> V) -> V +} +local object: MyObject = { + getReturnValue = function(cb: () -> U): U + return cb() + end, +} + +type ComplexObject = { + id: T, + nested: MyObject +} + +local complex: ComplexObject = { + id = "Foo", + nested = object, +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics2") +{ + // https://github.com/Roblox/luau/issues/484 + CheckResult result = check(R"( +--!strict +type MyObject = { + getReturnValue: (cb: () -> V) -> V +} +type ComplexObject = { + id: T, + nested: MyObject +} + +local complex2: ComplexObject = nil + +local x = complex2.nested.getReturnValue(function(): string + return "" +end) + +local y = complex2.nested.getReturnValue(function() + return 3 +end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_functions_even_if_they_have_an_explicit_generic") +{ + ScopedFastFlag sff[] = { + {"LuauAlwaysQuantify", true}, + }; + + CheckResult result = check(R"( + function foo(f, x: X) + return f(x) + end + )"); + + CHECK("((X) -> (a...), X) -> (a...)" == toString(requireType("foo"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index d146f4e8..818d0124 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("IntersectionTypes"); TEST_CASE_FIXTURE(Fixture, "select_correct_union_fn") @@ -175,8 +177,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guarante TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_depth") { - ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; - CheckResult result = check(R"( type A = {x: {y: {z: {thing: string}}}} type B = {x: {y: {z: {thing: string}}}} @@ -306,16 +306,50 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table '{| x: number, y: number |}'"); + else + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); } TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") { CheckResult result = check(R"( - type X = { x: (number) -> number } - type Y = { y: (string) -> string } + type X = { x: (number) -> number } + type Y = { y: (string) -> string } - type XY = X & Y + type XY = X & Y + + local xy : XY = { + x = function(a: number) return -a end, + y = function(a: string) return a .. "b" end + } + function xy.z(a:number) return a * 10 end + function xy:y(a:number) return a * 10 end + function xy:w(a:number) return a * 10 end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table '{| x: (number) -> number, y: (string) -> string |}'"); + else + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); + CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); + + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table '{| x: (number) -> number, y: (string) -> string |}'"); + else + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); +} + +TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") +{ + // After normalization, previous 'table_intersection_write_sealed_indirect' is identical to this one + CheckResult result = check(R"( + type XY = { x: (number) -> number, y: (string) -> string } local xy : XY = { x = function(a: number) return -a end, @@ -326,13 +360,16 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") function xy:w(a:number) return a * 10 end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); - CHECK_EQ(toString(result.errors[1]), "Cannot add property 'y' to table 'X & Y'"); - CHECK_EQ(toString(result.errors[2]), "Cannot add property 'w' to table 'X & Y'"); + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'XY'"); + CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'XY'"); } -TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_intersection_setmetatable") { CheckResult result = check(R"( local t: {} & {} @@ -344,6 +381,8 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") { + ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}}; + CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -362,6 +401,8 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") { + ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}}; + CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -396,13 +437,13 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") repeat type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0) function _(l0):(t0)&(t0) - while nil do - end + while nil do + end end until _(_)(_)._ )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERRORS(result); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 30df717b..1c6fe1d8 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -29,7 +29,7 @@ TEST_CASE_FIXTURE(Fixture, "for_loop") CHECK_EQ(*typeChecker.numberType, *requireType("q")); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") { CheckResult result = check(R"( local n @@ -46,7 +46,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop") CHECK_EQ(*typeChecker.stringType, *requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_next") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_next") { CheckResult result = check(R"( local n @@ -85,9 +85,10 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Cannot call non-function string", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_just_one_iterator_is_ok") { CheckResult result = check(R"( local function keys(dictionary) @@ -106,7 +107,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "for_in_with_a_custom_iterator_should_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_a_custom_iterator_should_type_check") { CheckResult result = check(R"( local function range(l, h): () -> number @@ -158,7 +159,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") REQUIRE(get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") { CheckResult result = check(R"( local function hasDivisors(value: number, table) @@ -207,7 +208,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right CHECK_EQ(typeChecker.stringType, tm->givenType); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") { CheckResult result = check(R"( function prime_iter(state, index) @@ -285,7 +286,7 @@ TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") +TEST_CASE_FIXTURE(BuiltinsFixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") { CheckResult result = check(R"( repeat @@ -298,7 +299,7 @@ TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") +TEST_CASE_FIXTURE(BuiltinsFixture, "varlist_declared_by_for_in_loop_should_be_free") { CheckResult result = check(R"( local T = {} @@ -313,7 +314,7 @@ TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "properly_infer_iteratee_is_a_free_table") { // In this case, we cannot know the element type of the table {}. It could be anything. // We therefore must initially ascribe a free typevar to iter. @@ -326,7 +327,7 @@ TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_while") { CheckResult result = check(R"( while true do @@ -343,7 +344,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") CHECK_EQ(us->name, "a"); } -TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indices") +TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_produces_integral_indices") { CheckResult result = check(R"( local key @@ -375,7 +376,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") )"); } -TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") +TEST_CASE_FIXTURE(BuiltinsFixture, "unreachable_code_after_infinite_loop") { { CheckResult result = check(R"( @@ -457,7 +458,7 @@ TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") } } -TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional") { CheckResult result = check(R"( local t = {} @@ -470,4 +471,76 @@ TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") LUAU_REQUIRE_ERROR_COUNT(2, result); } +TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") +{ + // Just check that this doesn't assert + check(R"( + --!nonstrict + function _(l0:number) + return _ + end + for _ in _(8) do + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") +{ + CheckResult result = check(R"( + local t: {string} = {} + local key + for k: number in t do + end + for k: number, v: string in t do + end + for k, v in t do + key = k + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + CHECK_EQ(*typeChecker.numberType, *requireType("key")); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") +{ + CheckResult result = check(R"( + local t: {string} = {} + local extra + for k, v, e in t do + extra = e + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + CHECK_EQ(*typeChecker.nilType, *requireType("extra")); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") +{ + CheckResult result = check(R"( + local t = {} + for k, v in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("Cannot iterate over a table without indexer", ge->message); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_iter_metamethod") +{ + CheckResult result = check(R"( + local t = {} + setmetatable(t, { __iter = function(o) return next, o.children end }) + for k: number, v: string in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index e5eeae31..a0f670f1 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -5,7 +5,6 @@ #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" #include "Fixture.h" @@ -13,11 +12,9 @@ using namespace Luau; -LUAU_FASTFLAG(LuauTableSubtypingVariance2) - TEST_SUITE_BEGIN("TypeInferModules"); -TEST_CASE_FIXTURE(Fixture, "require") +TEST_CASE_FIXTURE(BuiltinsFixture, "require") { fileResolver.source["game/A"] = R"( local function hooty(x: number): string @@ -55,7 +52,7 @@ TEST_CASE_FIXTURE(Fixture, "require") REQUIRE_EQ("number", toString(*hType)); } -TEST_CASE_FIXTURE(Fixture, "require_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_types") { fileResolver.source["workspace/A"] = R"( export type Point = {x: number, y: number} @@ -70,7 +67,7 @@ TEST_CASE_FIXTURE(Fixture, "require_types") )"; CheckResult bResult = frontend.check("workspace/B"); - dumpErrors(bResult); + LUAU_REQUIRE_NO_ERRORS(bResult); ModulePtr b = frontend.moduleResolver.modules["workspace/B"]; REQUIRE(b != nullptr); @@ -79,7 +76,7 @@ TEST_CASE_FIXTURE(Fixture, "require_types") REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); } -TEST_CASE_FIXTURE(Fixture, "require_a_variadic_function") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_a_variadic_function") { fileResolver.source["game/A"] = R"( local T = {} @@ -122,7 +119,7 @@ TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") REQUIRE_EQ(result.errors[0], (TypeError{Location{{1, 17}, {1, 40}}, UnknownSymbol{"SomeModule.DoesNotExist"}})); } -TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_module_that_does_not_export") { const std::string sourceA = R"( )"; @@ -149,7 +146,7 @@ TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") CHECK_EQ("*unknown*", toString(hootyType)); } -TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") +TEST_CASE_FIXTURE(BuiltinsFixture, "warn_if_you_try_to_require_a_non_modulescript") { fileResolver.source["Modules/A"] = ""; fileResolver.sourceTypes["Modules/A"] = SourceCode::Local; @@ -165,7 +162,7 @@ TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") CHECK(get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "general_require_call_expression") +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_call_expression") { fileResolver.source["game/A"] = R"( --!strict @@ -184,7 +181,7 @@ a = tbl.abc.def CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "general_require_type_mismatch") +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_type_mismatch") { fileResolver.source["game/A"] = R"( return { def = 4 } @@ -220,7 +217,7 @@ return m LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "custom_require_global") +TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global") { CheckResult result = check(R"( --!nonstrict @@ -232,7 +229,7 @@ local crash = require(game.A) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "require_failed_module") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_failed_module") { fileResolver.source["game/A"] = R"( return unfortunately() @@ -250,7 +247,7 @@ local ModuleA = require(game.A) CHECK_EQ("*unknown*", toString(*oty)); } -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types") { fileResolver.source["game/A"] = R"( export type Type = { unrelated: boolean } @@ -265,10 +262,10 @@ function x:Destroy(): () end )"; CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_ERROR_COUNT(2, result); } -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_2") +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_2") { fileResolver.source["game/A"] = R"( export type Type = { x: { a: number } } @@ -286,7 +283,7 @@ type Rename = typeof(x.x) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_3") +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_3") { fileResolver.source["game/A"] = R"( local y = setmetatable({}, {}) @@ -305,10 +302,8 @@ type Rename = typeof(x.x) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "module_type_conflict") +TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict") { - ScopedFastFlag luauTypeMismatchModuleName{"LuauTypeMismatchModuleName", true}; - fileResolver.source["game/A"] = R"( export type T = { x: number } return {} @@ -329,22 +324,13 @@ local b: B.T = a CheckResult result = frontend.check("game/C"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauTableSubtypingVariance2) - { - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); - } - else - { - CHECK_EQ(toString(result.errors[0]), "Type 'T' from 'game/A' could not be converted into 'T' from 'game/B'"); - } } -TEST_CASE_FIXTURE(Fixture, "module_type_conflict_instantiated") +TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict_instantiated") { - ScopedFastFlag luauTypeMismatchModuleName{"LuauTypeMismatchModuleName", true}; - fileResolver.source["game/A"] = R"( export type Wrap = { x: T } return {} @@ -372,16 +358,9 @@ local b: B.T = a CheckResult result = frontend.check("game/D"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauTableSubtypingVariance2) - { - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); - } - else - { - CHECK_EQ(toString(result.errors[0]), "Type 'T' from 'game/B' could not be converted into 'T' from 'game/C'"); - } } TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 40831bf6..41690704 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -142,7 +142,7 @@ TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocat CHECK_GE(50, module->internalTypes.typeVars.size()); } -TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") +TEST_CASE_FIXTURE(BuiltinsFixture, "object_constructor_can_refer_to_method_of_self") { // CLI-30902 CheckResult result = check(R"( @@ -199,16 +199,16 @@ end TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") { CheckResult result = check(R"( ---!nonstrict -local f = {} -function f:foo(a: number, b: number) end + --!nonstrict + local f = {} + function f:foo(a: number, b: number) end -function bar(...) - f.foo(f, 1, ...) -end + function bar(...) + f.foo(f, 1, ...) + end -bar(2) -)"); + bar(2) + )"); LUAU_REQUIRE_NO_ERRORS(result); } @@ -243,7 +243,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_ )"); } -TEST_CASE_FIXTURE(Fixture, "table_oop") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_oop") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 6a8a9d93..fd9b1dd4 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -77,7 +77,7 @@ TEST_CASE_FIXTURE(Fixture, "and_or_ternary") CHECK_EQ(toString(*requireType("s")), "number | string"); } -TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") { CheckResult result = check(R"( function add(a: number, b: string) @@ -90,8 +90,9 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") const FunctionTypeVar* functionType = get(requireType("add")); - std::optional retType = first(functionType->retType); - CHECK_EQ(std::optional(typeChecker.numberType), retType); + std::optional retType = first(functionType->retTypes); + REQUIRE(retType.has_value()); + CHECK_EQ(typeChecker.numberType, follow(*retType)); CHECK_EQ(requireType("n"), typeChecker.numberType); CHECK_EQ(requireType("s"), typeChecker.stringType); } @@ -139,10 +140,8 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") CHECK_EQ("number", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_intersection") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( --!strict local Vec3 = {} @@ -175,10 +174,8 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("e"))); } -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( --!strict local Vec3 = {} @@ -248,7 +245,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_m REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); } -TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") { CheckResult result = check(R"( local M = {} @@ -269,7 +266,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_ov REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); } -TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") { CheckResult result = check(R"( --!strict @@ -292,7 +289,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_meta REQUIRE_EQ((Location{{11, 18}, {11, 23}}), result.errors[1].location); } -TEST_CASE_FIXTURE(Fixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") +TEST_CASE_FIXTURE(BuiltinsFixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") { CheckResult result = check(R"( --!strict @@ -364,7 +361,7 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); } -TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") { CheckResult result = check(R"( --!strict @@ -384,7 +381,7 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") CHECK_EQ(0, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_mismatch_metatable") { CheckResult result = check(R"( --!strict @@ -431,7 +428,7 @@ local x = false LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") { CheckResult result = check(R"( --!strict @@ -464,7 +461,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); } -TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") +TEST_CASE_FIXTURE(BuiltinsFixture, "unary_not_is_boolean") { CheckResult result = check(R"( local b = not "string" @@ -476,7 +473,7 @@ TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") REQUIRE_EQ("boolean", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") +TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") { CheckResult result = check(R"( --!strict @@ -576,7 +573,7 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "and_binexps_dont_unify") +TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") { CheckResult result = check(R"( --!strict @@ -631,7 +628,7 @@ TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") +TEST_CASE_FIXTURE(BuiltinsFixture, "UnknownGlobalCompoundAssign") { // In non-strict mode, global definition is still allowed { @@ -731,8 +728,6 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: string | number, b: boolean | number) return a == b @@ -758,8 +753,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_and_or") TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") { - ScopedFastFlag sff{"LuauDecoupleOperatorInferenceFromUnifiedTypeInference", true}; - CheckResult result = check(Mode::Strict, R"( local function f(x, y) return x + y @@ -782,4 +775,45 @@ TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") // the case right now, though. } +TEST_CASE_FIXTURE(BuiltinsFixture, "equality_operations_succeed_if_any_union_branch_succeeds") +{ + CheckResult result = check(R"( + local mm = {} + type Foo = typeof(setmetatable({}, mm)) + local x: Foo + local y: Foo? + + local v1 = x == y + local v2 = y == x + local v3 = x ~= y + local v4 = y ~= x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CheckResult result2 = check(R"( + local mm1 = { + x = "foo", + } + + local mm2 = { + y = "bar", + } + + type Foo = typeof(setmetatable({}, mm1)) + type Bar = typeof(setmetatable({}, mm2)) + + local x1: Foo + local x2: Foo? + local y1: Bar + local y2: Bar? + + local v1 = x1 == y1 + local v2 = x2 == y2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result2); + CHECK(toString(result2.errors[0]) == "Types Foo and Bar cannot be compared with == because they do not have the same metatable"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 44b7b0d0..e1684df7 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -85,8 +85,6 @@ TEST_CASE_FIXTURE(Fixture, "string_function_other") TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( local x: number = 9999 function x:y(z: number) @@ -95,6 +93,8 @@ end )"); LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Cannot add method to non-table type 'number'"); + CHECK_EQ(toString(result.errors[1]), "Type 'number' could not be converted into 'string'"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 2e16b21e..487e5979 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -7,7 +7,7 @@ #include -LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) using namespace Luau; @@ -52,7 +52,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") CHECK_EQ(expected, decorateWithTypes(code)); } -TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") +TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") { const std::string code = R"( local a, b, c = xpcall(function() return 1, "foo" end, function() return "foo", 1 end) @@ -104,7 +104,7 @@ TEST_CASE_FIXTURE(Fixture, "it_should_be_agnostic_of_actual_size") // Ideally setmetatable's second argument would be an optional free table. // For now, infer it as just a free table. -TEST_CASE_FIXTURE(Fixture, "setmetatable_constrains_free_type_into_free_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_constrains_free_type_into_free_table") { CheckResult result = check(R"( local a = {} @@ -145,7 +145,7 @@ TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") // Originally from TypeInfer.test.cpp. // I dont think type checking the metamethod at every site of == is the correct thing to do. // We should be type checking the metamethod at the call site of setmetatable. -TEST_CASE_FIXTURE(Fixture, "error_on_eq_metamethod_returning_a_type_other_than_boolean") +TEST_CASE_FIXTURE(BuiltinsFixture, "error_on_eq_metamethod_returning_a_type_other_than_boolean") { CheckResult result = check(R"( local tab = {a = 1} @@ -182,8 +182,6 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") // We'll need to not only report an error on `a == b`, but also to refine both operands as `never` in the `==` branch. TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: string, b: boolean?) if a == b then @@ -207,8 +205,6 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") // Just needs to fully support equality refinement. Which is annoying without type states. TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") { - ScopedFastFlag sff{"LuauDiscriminableUnions2", true}; - CheckResult result = check(R"( type T = {x: string, y: number} | {x: nil, y: nil} @@ -267,242 +263,6 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doct } } -TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doctest::timeout(1.0)) -{ - ScopedFastInt sffi{"LuauTarjanChildLimit", 400}; - - CheckResult result = check(R"LUA( - --!strict - local TS = _G[script] - local lazyGet = TS.import(script, script.Parent.Parent, "util", "lazyLoad").lazyGet - local unit = TS.import(script, script.Parent.Parent, "util", "Unit").unit - local Iterator - lazyGet("Iterator", function(c) - Iterator = c - end) - local Option - lazyGet("Option", function(c) - Option = c - end) - local Vec - lazyGet("Vec", function(c) - Vec = c - end) - local Result - do - Result = setmetatable({}, { - __tostring = function() - return "Result" - end, - }) - Result.__index = Result - function Result.new(...) - local self = setmetatable({}, Result) - self:constructor(...) - return self - end - function Result:constructor(okValue, errValue) - self.okValue = okValue - self.errValue = errValue - end - function Result:ok(val) - return Result.new(val, nil) - end - function Result:err(val) - return Result.new(nil, val) - end - function Result:fromCallback(c) - local _0 = c - local _1, _2 = pcall(_0) - local result = _1 and { - success = true, - value = _2, - } or { - success = false, - error = _2, - } - return result.success and Result:ok(result.value) or Result:err(Option:wrap(result.error)) - end - function Result:fromVoidCallback(c) - local _0 = c - local _1, _2 = pcall(_0) - local result = _1 and { - success = true, - value = _2, - } or { - success = false, - error = _2, - } - return result.success and Result:ok(unit()) or Result:err(Option:wrap(result.error)) - end - Result.fromPromise = TS.async(function(self, p) - local _0, _1 = TS.try(function() - return TS.TRY_RETURN, { Result:ok(TS.await(p)) } - end, function(e) - return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } - end) - if _0 then - return unpack(_1) - end - end) - Result.fromVoidPromise = TS.async(function(self, p) - local _0, _1 = TS.try(function() - TS.await(p) - return TS.TRY_RETURN, { Result:ok(unit()) } - end, function(e) - return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } - end) - if _0 then - return unpack(_1) - end - end) - function Result:isOk() - return self.okValue ~= nil - end - function Result:isErr() - return self.errValue ~= nil - end - function Result:contains(x) - return self.okValue == x - end - function Result:containsErr(x) - return self.errValue == x - end - function Result:okOption() - return Option:wrap(self.okValue) - end - function Result:errOption() - return Option:wrap(self.errValue) - end - function Result:map(func) - return self:isOk() and Result:ok(func(self.okValue)) or Result:err(self.errValue) - end - function Result:mapOr(def, func) - local _0 - if self:isOk() then - _0 = func(self.okValue) - else - _0 = def - end - return _0 - end - function Result:mapOrElse(def, func) - local _0 - if self:isOk() then - _0 = func(self.okValue) - else - _0 = def(self.errValue) - end - return _0 - end - function Result:mapErr(func) - return self:isErr() and Result:err(func(self.errValue)) or Result:ok(self.okValue) - end - Result["and"] = function(self, other) - return self:isErr() and Result:err(self.errValue) or other - end - function Result:andThen(func) - return self:isErr() and Result:err(self.errValue) or func(self.okValue) - end - Result["or"] = function(self, other) - return self:isOk() and Result:ok(self.okValue) or other - end - function Result:orElse(other) - return self:isOk() and Result:ok(self.okValue) or other(self.errValue) - end - function Result:expect(msg) - if self:isOk() then - return self.okValue - else - error(msg) - end - end - function Result:unwrap() - return self:expect("called `Result.unwrap()` on an `Err` value: " .. tostring(self.errValue)) - end - function Result:unwrapOr(def) - local _0 - if self:isOk() then - _0 = self.okValue - else - _0 = def - end - return _0 - end - function Result:unwrapOrElse(gen) - local _0 - if self:isOk() then - _0 = self.okValue - else - _0 = gen(self.errValue) - end - return _0 - end - function Result:expectErr(msg) - if self:isErr() then - return self.errValue - else - error(msg) - end - end - function Result:unwrapErr() - return self:expectErr("called `Result.unwrapErr()` on an `Ok` value: " .. tostring(self.okValue)) - end - function Result:transpose() - return self:isOk() and self.okValue:map(function(some) - return Result:ok(some) - end) or Option:some(Result:err(self.errValue)) - end - function Result:flatten() - return self:isOk() and Result.new(self.okValue.okValue, self.okValue.errValue) or Result:err(self.errValue) - end - function Result:match(ifOk, ifErr) - local _0 - if self:isOk() then - _0 = ifOk(self.okValue) - else - _0 = ifErr(self.errValue) - end - return _0 - end - function Result:asPtr() - local _0 = (self.okValue) - if _0 == nil then - _0 = (self.errValue) - end - return _0 - end - end - local resultMeta = Result - resultMeta.__eq = function(a, b) - return b:match(function(ok) - return a:contains(ok) - end, function(err) - return a:containsErr(err) - end) - end - resultMeta.__tostring = function(result) - return result:match(function(ok) - return "Result.ok(" .. tostring(ok) .. ")" - end, function(err) - return "Result.err(" .. tostring(err) .. ")" - end) - end - return { - Result = Result, - } - )LUA"); - - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& a) { - return nullptr != get(a); - }); - if (it == result.errors.end()) - { - dumpErrors(result); - FAIL("Expected a UnificationTooComplex error"); - } -} - // Should be in TypeInfer.tables.test.cpp // It's unsound to instantiate tables containing generic methods, // since mutating properties means table properties should be invariant. @@ -527,6 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_NO_ERRORS(result); } +// FIXME: Move this test to another source file when removing FFlag::LuauLowerBoundsCalculation TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { ScopedFastFlag sff[]{ @@ -556,10 +317,19 @@ TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type LUAU_REQUIRE_NO_ERRORS(result); - // f and g should have the type () -> () - CHECK_EQ("() -> (a...)", toString(requireType("f"))); - CHECK_EQ("() -> (a...)", toString(requireType("g"))); - CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now + if (FFlag::LuauLowerBoundsCalculation) + { + CHECK_EQ("() -> ()", toString(requireType("f"))); + CHECK_EQ("() -> ()", toString(requireType("g"))); + CHECK_EQ("nil", toString(requireType("x"))); + } + else + { + // f and g should have the type () -> () + CHECK_EQ("() -> (a...)", toString(requireType("f"))); + CHECK_EQ("() -> (a...)", toString(requireType("g"))); + CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now + } } TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") @@ -573,18 +343,12 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } -TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") -{ - CheckResult result = check(R"( - local function f() return end - local g = function() return f() end - )"); - - LUAU_REQUIRE_ERRORS(result); // Should not have any errors. -} - TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") { + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", false}, + }; + CheckResult result = check(R"( --!strict local function f(...) return ... end @@ -594,4 +358,145 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } +TEST_CASE_FIXTURE(Fixture, "lower_bounds_calculation_is_too_permissive_with_overloaded_higher_order_functions") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(f) + f(5, 'a') + f('b', 6) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // We incorrectly infer that the argument to foo could be called with (number, number) or (string, string) + // even though that is strictly more permissive than the actual source text shows. + CHECK("((number | string, number | string) -> (a...)) -> ()" == toString(requireType("foo"))); +} + +// Once fixed, move this to Normalize.test.cpp +TEST_CASE_FIXTURE(Fixture, "normalization_fails_on_certain_kinds_of_cyclic_tables") +{ +#if defined(_DEBUG) || defined(_NOOPT) + ScopedFastInt sfi("LuauNormalizeIterationLimit", 500); +#endif + + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y) + x.x = y + y.x = x + + type R = {x: typeof(x)} & {x: typeof(y)} + local r: R + + return r + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(nullptr != get(result.errors[0])); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(BuiltinsFixture, "pcall_returns_at_least_two_value_but_function_returns_nothing") +{ + CheckResult result = check(R"( + local function f(): () end + local ok, res = pcall(f) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Function only returns 1 value. 2 are required here", toString(result.errors[0])); + // LUAU_REQUIRE_NO_ERRORS(result); + // CHECK_EQ("boolean", toString(requireType("ok"))); + // CHECK_EQ("any", toString(requireType("res"))); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(BuiltinsFixture, "choose_the_right_overload_for_pcall") +{ + CheckResult result = check(R"( + local function f(): number + if math.random() > 0.5 then + return 5 + else + error("something") + end + end + + local ok, res = pcall(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("boolean", toString(requireType("ok"))); + CHECK_EQ("number", toString(requireType("res"))); + // CHECK_EQ("any", toString(requireType("res"))); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_many_things_but_first_of_it_is_forgotten") +{ + CheckResult result = check(R"( + local function f(): (number, string, boolean) + if math.random() > 0.5 then + return 5, "hello", true + else + error("something") + end + end + + local ok, res, s, b = pcall(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("boolean", toString(requireType("ok"))); + CHECK_EQ("number", toString(requireType("res"))); + // CHECK_EQ("any", toString(requireType("res"))); + CHECK_EQ("string", toString(requireType("s"))); + CHECK_EQ("boolean", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeFlagIsConservative", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + local function f(o) + local t = {} + t[o] = true + + local function foo(o) + o:m1() + t[o] = nil + end + + local function bar(o) + o:m2() + t[o] = true + end + + return t + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // TODO: We're missing generics b... + CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index cddeab6e..3f5dad3d 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Normalize.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" @@ -6,15 +7,14 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauDiscriminableUnions2) -LUAU_FASTFLAG(LuauWeakEqConstraint) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) using namespace Luau; namespace { -std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +std::optional> magicFunctionInstanceIsA( + TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { if (expr.args.size != 1) return std::nullopt; @@ -32,7 +32,7 @@ std::optional> magicFunctionInstanceIsA( unfreeze(typeChecker.globalTypes); TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); freeze(typeChecker.globalTypes); - return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; + return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } struct RefinementClassFixture : Fixture @@ -42,35 +42,44 @@ struct RefinementClassFixture : Fixture TypeArena& arena = typeChecker.globalTypes; unfreeze(arena); - TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr}); + TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); getMutable(vec3)->props = { {"X", Property{typeChecker.numberType}}, {"Y", Property{typeChecker.numberType}}, {"Z", Property{typeChecker.numberType}}, }; + normalize(vec3, arena, *typeChecker.iceHandler); - TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); + TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; + normalize(isA, arena, *typeChecker.iceHandler); getMutable(inst)->props = { {"Name", Property{typeChecker.stringType}}, {"IsA", Property{isA}}, }; + normalize(inst, arena, *typeChecker.iceHandler); - TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); - TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); + TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); + normalize(folder, arena, *typeChecker.iceHandler); + TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); getMutable(part)->props = { {"Position", Property{vec3}}, }; + normalize(part, arena, *typeChecker.iceHandler); typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; + + for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + persist(ty.type); + freeze(typeChecker.globalTypes); } }; @@ -233,7 +242,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); } -TEST_CASE_FIXTURE(Fixture, "typeguard_in_assert_position") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") { CheckResult result = check(R"( local a @@ -261,18 +270,10 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - if (FFlag::LuauDiscriminableUnions2) - { - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); - } + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -293,7 +294,7 @@ TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "impossible_type_narrow_is_not_an_error") +TEST_CASE_FIXTURE(BuiltinsFixture, "impossible_type_narrow_is_not_an_error") { // This unit test serves as a reminder to not implement this warning until Luau is intelligent enough. // For instance, getting a value out of the indexer and checking whether the value exists is not an error. @@ -326,7 +327,7 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") CHECK_EQ("number?", toString(requireType("bar"))); } -TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") +TEST_CASE_FIXTURE(BuiltinsFixture, "index_on_a_refined_property") { CheckResult result = check(R"( local t: {x: {y: string}?} = {x = {y = "hello!"}} @@ -339,7 +340,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_constraints") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_non_binary_expressions_actually_resolve_constraints") { CheckResult result = check(R"( local foo: string? = "hello" @@ -352,8 +353,6 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local t: {x: number?} = {x = nil} @@ -371,8 +370,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?, b: boolean?) if a == b then @@ -385,28 +382,15 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "nil"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "nil"); // a == b - - CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b - } + CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b } TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?) if a == 1 then @@ -419,24 +403,12 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1 - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number"); // a == 1 - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 - } + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1 + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 } TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local function f(a: (string | number)?) if "hello" == a then @@ -455,8 +427,6 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?) if a ~= nil then @@ -469,23 +439,12 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "nil"); // a == nil - } + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil } TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") { - ScopedFastFlag sff{"LuauDiscriminableUnions2", true}; - ScopedFastFlag sff2{"LuauWeakEqConstraint", true}; - CheckResult result = check(R"( local function f(a, b: string?) if a == b then @@ -502,8 +461,6 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: any, b: {x: number}?) if a ~= b then @@ -514,22 +471,12 @@ TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_e LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}"); // a ~= b - } + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b } TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local t: {string} = {"hello"} @@ -547,18 +494,8 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b - } - else - { - // This is technically not wrong, but it's also wrong at the same time. - // The refinement code is none the wiser about the fact we pulled a string out of an array, so it has no choice but to narrow as just string. - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string"); // a == b - } + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b } TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") @@ -587,16 +524,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") end )"); - if (FFlag::LuauDiscriminableUnions2) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - // This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type. - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0])); - } + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); } @@ -697,7 +625,10 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("{| x: number, y: number |}", toString(requireTypeAtPosition({4, 28}))); + else + CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } @@ -720,7 +651,7 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } -TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") { CheckResult result = check(R"( local function f(t: {x: number}) @@ -836,7 +767,7 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") CHECK_EQ("{| x: boolean |}?", toString(requireTypeAtPosition({3, 28}))); } -TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") { CheckResult result = check(R"( local a: (number | string)? @@ -852,7 +783,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") CHECK_EQ("number", toString(requireTypeAtPosition({5, 18}))); } -TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") +TEST_CASE_FIXTURE(BuiltinsFixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") { // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. CheckResult result = check(R"( @@ -889,7 +820,7 @@ TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_n CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); } -TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") +TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_ifelse_expression") { CheckResult result = check(R"( function f(v:string?) @@ -903,7 +834,7 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") CHECK_EQ("nil", toString(requireTypeAtPosition({2, 45}))); } -TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") +TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_ifelse_expression") { CheckResult result = check(R"( function f(v:string?) @@ -935,7 +866,7 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } -TEST_CASE_FIXTURE(Fixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") { CheckResult result = check(R"( local foo: string? = "hi" @@ -999,9 +930,7 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; + ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; CheckResult result = check(R"( type T = {tag: "missing", x: nil} | {tag: "exists", x: string} @@ -1018,15 +947,11 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ(R"({| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ(R"({| tag: "exists", x: string |} | {| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); } TEST_CASE_FIXTURE(Fixture, "discriminate_tag") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( type Cat = {tag: "Cat", name: string, catfood: string} type Dog = {tag: "Dog", name: string, dogfood: string} @@ -1060,11 +985,6 @@ TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauAssertStripsFalsyTypes", true}, - }; - CheckResult result = check(R"( local function is_true(b: true) end local function is_false(b: false) end @@ -1083,11 +1003,6 @@ TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauAssertStripsFalsyTypes", true}, - }; - CheckResult result = check(R"( type Ok = { ok: true, value: T } type Err = { ok: false, error: E } @@ -1107,8 +1022,6 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_ TEST_CASE_FIXTURE(Fixture, "refine_a_property_not_to_be_nil_through_an_intersection_table") { - ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; - CheckResult result = check(R"( type T = {} & {f: ((string) -> string)?} local function f(t: T, x) @@ -1123,10 +1036,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_not_to_be_nil_through_an_intersect TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( type T = {tag: "Part", x: Part} | {tag: "Folder", x: Folder} @@ -1161,14 +1070,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") end )"); - if (FFlag::LuauDiscriminableUnions2) - LUAU_REQUIRE_NO_ERRORS(result); - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); - } + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" @@ -1289,7 +1191,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") { - const std::string code = R"( + CheckResult result = check(R"( function f(a) if type(a) == "boolean" then local a1 = a @@ -1299,10 +1201,30 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") local a3 = a end end - )"; - CheckResult result = check(code); + )"); + LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") +{ + ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; + + CheckResult result = check(R"( + local function f(t: {number}) + local x = t[1] + if not x then + local foo = x + else + local bar = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("number", toString(requireTypeAtPosition({6, 28}))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index d39341ea..4a88abee 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -5,8 +5,6 @@ #include "doctest.h" #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(BetterDiagnosticCodesInStudio) - using namespace Luau; TEST_SUITE_BEGIN("TypeSingletons"); @@ -166,10 +164,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Dog = { tag: "Dog", howls: boolean } type Cat = { tag: "Cat", meows: boolean } @@ -261,22 +255,11 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::BetterDiagnosticCodesInStudio) - { - CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); - } - else - { - CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); - } + CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { - ScopedFastFlag sffs[]{ - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( --!strict local x: { ["<>"] : number } @@ -290,10 +273,6 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Cat = { tag: 'cat', catfood: string } type Dog = { tag: 'dog', dogfood: string } @@ -311,10 +290,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Good = { success: true, result: string } type Bad = { success: false, error: string } @@ -332,10 +307,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Cat = { tag: 'cat', catfood: string } type Dog = { tag: 'dog', dogfood: string } @@ -349,13 +320,6 @@ local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_singleton") { - ScopedFastFlag sff[]{ - {"LuauEqConstraint", true}, - {"LuauDiscriminableUnions2", true}, - {"LuauWidenIfSupertypeIsFree2", true}, - {"LuauWeakEqConstraint", false}, - }; - CheckResult result = check(R"( local function foo(f, x) if x == "hi" then @@ -374,14 +338,6 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauEqConstraint", true}, - {"LuauWidenIfSupertypeIsFree2", true}, - {"LuauWeakEqConstraint", false}, - {"LuauDoNotAccidentallyDependOnPointerOrdering", true}, - }; - CheckResult result = check(R"( local function foo(f, x): "hello"? -- anyone there? return if x == "hi" @@ -399,10 +355,6 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") { - ScopedFastFlag sff[]{ - {"LuauWidenIfSupertypeIsFree2", true}, - }; - CheckResult result = check(R"( local foo: "foo" = "foo" local copy = foo @@ -414,11 +366,6 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauWidenIfSupertypeIsFree2", true}, - }; - CheckResult result = check(R"( type Cat = {tag: "Cat", meows: boolean} type Dog = {tag: "Dog", barks: boolean} @@ -440,11 +387,9 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") { - ScopedFastFlag sff[]{ - {"LuauWidenIfSupertypeIsFree2", true}, - }; + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; CheckResult result = check(R"( local function foo(t, x) @@ -466,10 +411,6 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") TEST_CASE_FIXTURE(Fixture, "functions_are_not_to_be_widened") { - ScopedFastFlag sff[]{ - {"LuauWidenIfSupertypeIsFree2", true}, - }; - CheckResult result = check(R"( local function foo(my_enum: "A" | "B") end )"); @@ -481,10 +422,6 @@ TEST_CASE_FIXTURE(Fixture, "functions_are_not_to_be_widened") TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" then @@ -499,10 +436,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" or a == "bye" then @@ -517,10 +450,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" then @@ -535,10 +464,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" or a == "bye" then diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 0cc12d19..77a2928c 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -11,6 +11,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -199,7 +201,7 @@ TEST_CASE_FIXTURE(Fixture, "used_dot_instead_of_colon") REQUIRE(it != result.errors.end()); } -TEST_CASE_FIXTURE(Fixture, "used_colon_correctly") +TEST_CASE_FIXTURE(BuiltinsFixture, "used_colon_correctly") { CheckResult result = check(R"( --!nonstrict @@ -274,8 +276,6 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local a = {} a.x = 99 @@ -345,8 +345,6 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict function foo(o) @@ -368,8 +366,6 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local T = {} T.bar = 'hello' @@ -475,8 +471,6 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict local t = { u = {} } @@ -510,8 +504,6 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_ TEST_CASE_FIXTURE(Fixture, "width_subtyping") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict function f(x : { q : number }) @@ -640,7 +632,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") const TableTypeVar* argType = get(follow(argVec[0])); REQUIRE(argType != nullptr); - std::vector retVec = flatten(ftv->retType).first; + std::vector retVec = flatten(ftv->retTypes).first; const TableTypeVar* retType = get(follow(retVec[0])); REQUIRE(retType != nullptr); @@ -689,7 +681,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") const FunctionTypeVar* fType = get(requireType("f")); REQUIRE(fType != nullptr); - auto retType_ = first(fType->retType); + auto retType_ = first(fType->retTypes); REQUIRE(bool(retType_)); auto retType = get(follow(*retType_)); @@ -770,8 +762,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } )"); @@ -781,8 +771,6 @@ TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") TEST_CASE_FIXTURE(Fixture, "array_factory_function") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( function empty() return {} end local array: {string} = empty() @@ -881,7 +869,7 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s CHECK_EQ(*typeChecker.stringType, *propertyA); } -TEST_CASE_FIXTURE(Fixture, "oop_indexer_works") +TEST_CASE_FIXTURE(BuiltinsFixture, "oop_indexer_works") { CheckResult result = check(R"( local clazz = {} @@ -904,7 +892,7 @@ TEST_CASE_FIXTURE(Fixture, "oop_indexer_works") CHECK_EQ(*typeChecker.stringType, *requireType("words")); } -TEST_CASE_FIXTURE(Fixture, "indexer_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_table") { CheckResult result = check(R"( local clazz = {a="hello"} @@ -917,7 +905,7 @@ TEST_CASE_FIXTURE(Fixture, "indexer_table") CHECK_EQ(*typeChecker.stringType, *requireType("b")); } -TEST_CASE_FIXTURE(Fixture, "indexer_fn") +TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_fn") { CheckResult result = check(R"( local instanace = setmetatable({}, {__index=function() return 10 end}) @@ -928,7 +916,7 @@ TEST_CASE_FIXTURE(Fixture, "indexer_fn") CHECK_EQ(*typeChecker.numberType, *requireType("b")); } -TEST_CASE_FIXTURE(Fixture, "meta_add") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add") { // Note: meta_add_inferred and this unit test are currently the same exact thing. // We'll want to change this one in particular when we add real syntax for metatables. @@ -945,7 +933,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add") CHECK_EQ(follow(requireType("a")), follow(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "meta_add_inferred") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_inferred") { CheckResult result = check(R"( local a = {} @@ -958,7 +946,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add_inferred") CHECK_EQ(*requireType("a"), *requireType("c")); } -TEST_CASE_FIXTURE(Fixture, "meta_add_both_ways") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_both_ways") { CheckResult result = check(R"( type VectorMt = { __add: (Vector, number) -> Vector } @@ -978,7 +966,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add_both_ways") // This test exposed a bug where we let go of the "seen" stack while unifying table types // As a result, type inference crashed with a stack overflow. -TEST_CASE_FIXTURE(Fixture, "unification_of_unions_in_a_self_referential_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "unification_of_unions_in_a_self_referential_type") { CheckResult result = check(R"( type A = {} @@ -1007,7 +995,7 @@ TEST_CASE_FIXTURE(Fixture, "unification_of_unions_in_a_self_referential_type") CHECK_EQ(bmtv->metatable, requireType("bmt")); } -TEST_CASE_FIXTURE(Fixture, "oop_polymorphic") +TEST_CASE_FIXTURE(BuiltinsFixture, "oop_polymorphic") { CheckResult result = check(R"( local animal = {} @@ -1058,7 +1046,7 @@ TEST_CASE_FIXTURE(Fixture, "user_defined_table_types_are_named") CHECK_EQ("Vector3", toString(requireType("v"))); } -TEST_CASE_FIXTURE(Fixture, "result_is_always_any_if_lhs_is_any") +TEST_CASE_FIXTURE(BuiltinsFixture, "result_is_always_any_if_lhs_is_any") { CheckResult result = check(R"( type Vector3MT = { @@ -1131,7 +1119,7 @@ TEST_CASE_FIXTURE(Fixture, "nice_error_when_trying_to_fetch_property_of_boolean" CHECK_EQ("Type 'boolean' does not have key 'some_prop'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_builtin_sealed_table_must_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "defining_a_method_for_a_builtin_sealed_table_must_fail") { CheckResult result = check(R"( function string.m() end @@ -1140,7 +1128,7 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_builtin_sealed_table_must_fa LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_builtin_sealed_table_must_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "defining_a_self_method_for_a_builtin_sealed_table_must_fail") { CheckResult result = check(R"( function string:m() end @@ -1173,8 +1161,6 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_sealed_table_must TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( local t = {x = 1} function t.m() end @@ -1185,8 +1171,6 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok") TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_unsealed_table_is_ok") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( local t = {x = 1} function t:m() end @@ -1211,7 +1195,10 @@ TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_c )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(get(result.errors[0])); + if (FFlag::LuauLowerBoundsCalculation) + CHECK(get(result.errors[0])); + else + CHECK(get(result.errors[0])); } // This unit test could be flaky if the fix has regressed. @@ -1256,7 +1243,7 @@ TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_function_call") CHECK_EQ(toString(te), "Key 'fOo' not found in table 't'. Did you mean 'Foo'?"); } -TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_property_access") +TEST_CASE_FIXTURE(BuiltinsFixture, "found_like_key_in_table_property_access") { CheckResult result = check(R"( local t = {X = 1} @@ -1281,7 +1268,7 @@ TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_property_access") CHECK_EQ(toString(te), "Key 'x' not found in table 't'. Did you mean 'X'?"); } -TEST_CASE_FIXTURE(Fixture, "found_multiple_like_keys") +TEST_CASE_FIXTURE(BuiltinsFixture, "found_multiple_like_keys") { CheckResult result = check(R"( local t = {Foo = 1, foO = 2} @@ -1307,7 +1294,7 @@ TEST_CASE_FIXTURE(Fixture, "found_multiple_like_keys") CHECK_EQ(toString(te), "Key 'foo' not found in table 't'. Did you mean one of 'Foo', 'foO'?"); } -TEST_CASE_FIXTURE(Fixture, "dont_suggest_exact_match_keys") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_suggest_exact_match_keys") { CheckResult result = check(R"( local t = {} @@ -1334,7 +1321,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_exact_match_keys") CHECK_EQ(toString(te), "Key 'Foo' not found in table 't'. Did you mean 'foO'?"); } -TEST_CASE_FIXTURE(Fixture, "getmetatable_returns_pointer_to_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_pointer_to_metatable") { CheckResult result = check(R"( local t = {x = 1} @@ -1347,7 +1334,7 @@ TEST_CASE_FIXTURE(Fixture, "getmetatable_returns_pointer_to_metatable") CHECK_EQ(*requireType("mt"), *requireType("returnedMT")); } -TEST_CASE_FIXTURE(Fixture, "metatable_mismatch_should_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_mismatch_should_fail") { CheckResult result = check(R"( local t1 = {x = 1} @@ -1369,7 +1356,7 @@ TEST_CASE_FIXTURE(Fixture, "metatable_mismatch_should_fail") CHECK_EQ(*tm->givenType, *requireType("t2")); } -TEST_CASE_FIXTURE(Fixture, "property_lookup_through_tabletypevar_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "property_lookup_through_tabletypevar_metatable") { CheckResult result = check(R"( local t = {x = 1} @@ -1388,7 +1375,7 @@ TEST_CASE_FIXTURE(Fixture, "property_lookup_through_tabletypevar_metatable") CHECK_EQ(up->key, "z"); } -TEST_CASE_FIXTURE(Fixture, "missing_metatable_for_sealed_tables_do_not_get_inferred") +TEST_CASE_FIXTURE(BuiltinsFixture, "missing_metatable_for_sealed_tables_do_not_get_inferred") { CheckResult result = check(R"( local t = {x = 1} @@ -1463,11 +1450,6 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_indexer") { - ScopedFastFlag sff[]{ - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( type StringToStringMap = { [string]: string } local rt: StringToStringMap = { ["foo"] = 1 } @@ -1513,11 +1495,6 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { - ScopedFastFlag sff[]{ - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end foo({ a = 1 }) @@ -1604,8 +1581,6 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multipl TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local vec3 = {x = 1, y = 2, z = 3} local vec1 = {x = 1} @@ -1737,7 +1712,7 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") CHECK_EQ("Cannot add property 'b' to table '{| x: number |}'", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "builtin_table_names") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") { CheckResult result = check(R"( os.h = 2 @@ -1750,7 +1725,7 @@ TEST_CASE_FIXTURE(Fixture, "builtin_table_names") CHECK_EQ("Cannot add property 'k' to table 'string'", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "persistent_sealed_table_is_immutable") +TEST_CASE_FIXTURE(BuiltinsFixture, "persistent_sealed_table_is_immutable") { CheckResult result = check(R"( --!nonstrict @@ -1853,7 +1828,7 @@ local foos: {Foo} = { LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "quantifying_a_bound_var_works") +TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") { CheckResult result = check(R"( local clazz = {} @@ -1876,7 +1851,7 @@ TEST_CASE_FIXTURE(Fixture, "quantifying_a_bound_var_works") REQUIRE(prop.type); const FunctionTypeVar* ftv = get(follow(prop.type)); REQUIRE(ftv); - const TypePack* res = get(follow(ftv->retType)); + const TypePack* res = get(follow(ftv->retTypes)); REQUIRE(res); REQUIRE(res->head.size() == 1); const MetatableTypeVar* mtv = get(follow(res->head[0])); @@ -1886,7 +1861,7 @@ TEST_CASE_FIXTURE(Fixture, "quantifying_a_bound_var_works") REQUIRE_EQ(ttv->state, TableState::Sealed); } -TEST_CASE_FIXTURE(Fixture, "less_exponential_blowup_please") +TEST_CASE_FIXTURE(BuiltinsFixture, "less_exponential_blowup_please") { CheckResult result = check(R"( --!strict @@ -1915,7 +1890,7 @@ TEST_CASE_FIXTURE(Fixture, "less_exponential_blowup_please") newData:First() )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_ERROR_COUNT(1, result); } TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_call") @@ -1978,7 +1953,7 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") { CheckResult result = check(R"( --!nonstrict @@ -1991,10 +1966,8 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_strict") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict local buttons = {} @@ -2008,8 +1981,6 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - CheckResult result = check(R"( type A = { x: number, y: number } type B = { x: number, y: string } @@ -2026,8 +1997,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - CheckResult result = check(R"( type AS = { x: number, y: number } type BS = { x: number, y: string } @@ -2047,13 +2016,8 @@ caused by: Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); } -TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") +TEST_CASE_FIXTURE(BuiltinsFixture, "error_detailed_metatable_prop") { - ScopedFastFlag sff[]{ - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); local b1 = setmetatable({ x = 2, y = "hello" }, { __call = function(s) end }); @@ -2080,9 +2044,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") { - ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedIndexerError{"LuauExtendedIndexerError", true}; - CheckResult result = check(R"( type A = { [number]: string } type B = { [string]: string } @@ -2099,9 +2060,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") { - ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedIndexerError{"LuauExtendedIndexerError", true}; - CheckResult result = check(R"( type A = { [number]: number } type B = { [number]: string } @@ -2118,12 +2076,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") { - ScopedFastFlag sffs[]{ - {"LuauPropertiesGetExpectedType", true}, - {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance2", true}, - }; - CheckResult result = check(R"( --!strict type Super = { x : number } @@ -2139,13 +2091,6 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") { - ScopedFastFlag sffs[]{ - {"LuauPropertiesGetExpectedType", true}, - {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( --!strict type Super = { x : number } @@ -2167,12 +2112,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") { - ScopedFastFlag sffs[]{ - {"LuauPropertiesGetExpectedType", true}, - {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance2", true}, - }; - CheckResult result = check(R"( --!strict type Super = { x : number } @@ -2186,12 +2125,8 @@ a.p = { x = 9 } LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_type_call") { - ScopedFastFlag sff[]{ - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( local b b = setmetatable({}, {__call = b}) @@ -2204,11 +2139,6 @@ b() TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") { - ScopedFastFlag sffs[] = { - {"LuauTableSubtypingVariance2", true}, - {"LuauSubtypingAddOptPropsToUnsealedTables", true}, - }; - CheckResult result = check(R"( --!strict local function setNumber(t: { p: number? }, x:number) t.p = x end @@ -2280,10 +2210,8 @@ local y = #x LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") { - ScopedFastFlag sff{"LuauTerminateCyclicMetatableIndexLookup", true}; - // t :: t1 where t1 = {metatable {__index: t1, __tostring: (t1) -> string}} CheckResult result = check(R"( local mt = {} @@ -2299,7 +2227,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable CHECK_EQ("Type 't' does not have key 'p'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "give_up_after_one_metatable_index_look_up") +TEST_CASE_FIXTURE(BuiltinsFixture, "give_up_after_one_metatable_index_look_up") { CheckResult result = check(R"( local data = { x = 5 } @@ -2316,8 +2244,6 @@ TEST_CASE_FIXTURE(Fixture, "give_up_after_one_metatable_index_look_up") TEST_CASE_FIXTURE(Fixture, "confusing_indexing") { - ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; - CheckResult result = check(R"( type T = {} & {p: number | string} local function f(t: T) @@ -2334,8 +2260,6 @@ TEST_CASE_FIXTURE(Fixture, "confusing_indexing") TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") { - ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; - CheckResult result = check(R"( local a: {x: number, y: number, [any]: any} | {y: number} @@ -2354,8 +2278,6 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") { - ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; - CheckResult result = check(R"( local a: {y: number} | {x: number, y: number, [any]: any} @@ -2374,8 +2296,6 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf1") { - ScopedFastFlag sff{"LuauTxnLogCheckForInvalidation", true}; - CheckResult result = check(R"( -- This example produced a UAF at one point, caused by pointers to table types becoming -- invalidated by child unifiers. (Calling log.concat can cause pointers to become invalid.) @@ -2406,8 +2326,6 @@ end TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf2") { - ScopedFastFlag sff{"LuauTxnLogCheckForInvalidation", true}; - CheckResult result = check(R"( -- Another example that UAFd, this time found by fuzzing. local _ @@ -2485,7 +2403,7 @@ TEST_CASE_FIXTURE(Fixture, "free_rhs_table_can_also_be_bound") )"); } -TEST_CASE_FIXTURE(Fixture, "table_unifies_into_map") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_unifies_into_map") { CheckResult result = check(R"( local Instance: any @@ -2571,7 +2489,7 @@ TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") * the generalization process), then it loses the knowledge that its metatable will have an :incr() * method. */ -TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_scope") { CheckResult result = check(R"( local Counter = {} @@ -2599,7 +2517,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); REQUIRE(newType); - std::optional newRetType = *first(newType->retType); + std::optional newRetType = *first(newType->retTypes); REQUIRE(newRetType); const MetatableTypeVar* newRet = get(follow(*newRetType)); @@ -2613,7 +2531,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") } // TODO: CLI-39624 -TEST_CASE_FIXTURE(Fixture, "instantiate_tables_at_scope_level") +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_tables_at_scope_level") { CheckResult result = check(R"( --!strict @@ -2691,13 +2609,14 @@ do end )"); } -TEST_CASE_FIXTURE(Fixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") { CheckResult result = check("local x = setmetatable({})"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Argument count mismatch. Function expects 2 arguments, but only 1 is specified", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning") +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning") { CheckResult result = check(R"( --!nonstrict @@ -2718,10 +2637,8 @@ type t0 = any CHECK(ttv->instantiatedTypeParams.empty()); } -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_2") +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning_2") { - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - CheckResult result = check(R"( type X = T type K = X @@ -2739,8 +2656,6 @@ type K = X TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_3") { - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - CheckResult result = check(R"( type X = T local a = {} @@ -2774,7 +2689,7 @@ local baz = foo[bar] CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); } -TEST_CASE_FIXTURE(Fixture, "table_simple_call") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") { CheckResult result = check(R"( local a = setmetatable({ x = 2 }, { @@ -2790,7 +2705,7 @@ local c = a(2) -- too many arguments CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "access_index_metamethod_that_returns_variadic") +TEST_CASE_FIXTURE(BuiltinsFixture, "access_index_metamethod_that_returns_variadic") { CheckResult result = check(R"( type Foo = {x: string} @@ -2885,7 +2800,7 @@ TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") )"); } -TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_function_check_use_after_free") { CheckResult result = check(R"( local t = {} @@ -2922,4 +2837,120 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the LUAU_REQUIRE_NO_ERRORS(result); } +// The real bug here was that we weren't always uncondionally typechecking a trailing return statement last. +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") +{ + CheckResult result = check(R"( + local function a(state) + print(state.blah) + end + + local function b(state) -- The bug was that we inferred state: {blah: any, gwar: any} + print(state.gwar) + end + + return function() + return function(state) + a(state) + b(state) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({+ blah: a +}) -> ()", toString(requireType("a"))); + CHECK_EQ("({+ gwar: a +}) -> ()", toString(requireType("b"))); + CHECK_EQ("() -> ({+ blah: a, gwar: b +}) -> ()", toString(getMainModule()->getModuleScope()->returnType)); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + )"); + + CHECK_EQ("(t1) -> {| Byte: (b) -> (a...), PeekByte: (c) -> (a...) |} where t1 = {+ byte: (t1, number) -> (a...) +}", + toString(requireType("Base64FileReader"))); +} + +TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") +{ + CheckResult result = check(R"( + local t: { [string]: number } = { 5, 6, 7 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2])); +} + +TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") +{ + CheckResult result = check(R"( + type X = { { x: boolean?, y: boolean? } } + + local l1: {[string]: X} = { key = { { x = true }, { y = true } } } + local l2: {[any]: X} = { key = { { x = true }, { y = true } } } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra_2") +{ + CheckResult result = check(R"( + type X = {[any]: string | boolean} + + local x: X = { key = "str" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "prop_access_on_key_whose_types_mismatches") +{ + ScopedFastFlag sff{"LuauReportErrorsOnIndexerKeyMismatch", true}; + + CheckResult result = check(R"( + local t: {number} = {} + local x = t.x + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Key 'x' not found in table '{number}'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "prop_access_on_unions_of_indexers_where_key_whose_types_mismatches") +{ + ScopedFastFlag sff{"LuauReportErrorsOnIndexerKeyMismatch", true}; + + CheckResult result = check(R"( + local t: { [number]: number } | { [boolean]: number } = {} + local u = t.x + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '{number} | {| [boolean]: number |}' does not have key 'x'", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 660ddcfc..6a048b26 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -13,8 +13,9 @@ #include -LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) -LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); using namespace Luau; @@ -43,10 +44,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_error") CheckResult result = check("local a = 7 local b = 'hi' a = b"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{ - requireType("a"), - requireType("b"), - }})); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); } TEST_CASE_FIXTURE(Fixture, "tc_error_2") @@ -85,16 +83,22 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") { + ScopedFastFlag sff[]{ + {"DebugLuauDeferredConstraintResolution", false}, + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nocheck function f(x) - return x + return 5 end -- we get type information even if there's type errors f(1, 2) )"); - CHECK_EQ("(any) -> (...any)", toString(requireType("f"))); + CHECK_EQ("(any) -> number", toString(requireType("f"))); LUAU_REQUIRE_NO_ERRORS(result); } @@ -155,7 +159,7 @@ TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") +TEST_CASE_FIXTURE(BuiltinsFixture, "warn_on_lowercase_parent_property") { CheckResult result = check(R"( local M = require(script.parent.DoesNotMatter) @@ -169,7 +173,7 @@ TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") REQUIRE_EQ("parent", ed->symbol); } -TEST_CASE_FIXTURE(Fixture, "weird_case") +TEST_CASE_FIXTURE(BuiltinsFixture, "weird_case") { CheckResult result = check(R"( local function f() return 4 end @@ -177,7 +181,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_case") )"); LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); } TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") @@ -232,10 +235,14 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") CHECK_EQ("boolean", toString(err->table)); CHECK_EQ("x", err->key); - CHECK_EQ("*unknown*", toString(requireType("c"))); - CHECK_EQ("*unknown*", toString(requireType("d"))); - CHECK_EQ("*unknown*", toString(requireType("e"))); - CHECK_EQ("*unknown*", toString(requireType("f"))); + // TODO: Should we assert anything about these tests when DCR is being used? + if (!FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("*unknown*", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("*unknown*", toString(requireType("e"))); + CHECK_EQ("*unknown*", toString(requireType("f"))); + } } TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowing") @@ -293,7 +300,7 @@ TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") // In these tests, a successful parse is required, so we need the parser to return the AST and then we can test the recursion depth limit in type // checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead. -TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") +TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_count") { #if defined(LUAU_ENABLE_ASAN) int limit = 250; @@ -302,12 +309,13 @@ TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") #else int limit = 600; #endif - ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; - ScopedFastInt luauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", limit - 100}; - ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", 0}; - CHECK_NOTHROW(check("print('Hello!')")); - CHECK_THROWS_AS(check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"), std::runtime_error); + ScopedFastInt sfi{"LuauCheckRecursionLimit", limit}; + + CheckResult result = check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(nullptr != get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "check_block_recursion_limit") @@ -347,35 +355,6 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") CHECK(nullptr != get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "globals") -{ - CheckResult result = check(R"( - --!nonstrict - foo = true - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("foo"))); -} - -TEST_CASE_FIXTURE(Fixture, "globals2") -{ - CheckResult result = check(R"( - --!nonstrict - foo = function() return 1 end - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("() -> (...any)", toString(tm->wantedType)); - CHECK_EQ("string", toString(tm->givenType)); - CHECK_EQ("() -> (...any)", toString(requireType("foo"))); -} - TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") { CheckResult result = check(R"( @@ -390,24 +369,7 @@ TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") CHECK_EQ("foo", us->name); } -TEST_CASE_FIXTURE(Fixture, "globals_everywhere") -{ - CheckResult result = check(R"( - --!nonstrict - foo = 1 - - if true then - bar = 2 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); -} - -TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_do") { CheckResult result = check(R"( do @@ -437,21 +399,6 @@ TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") CHECK_EQ("any", toString(requireType("value"))); } -// TEST_CASE_FIXTURE(Fixture, "infer_method_signature_of_argument") -// { -// CheckResult result = check(R"( -// function f(a) -// if a.cond then -// return a.method() -// end -// end -// )"); - -// LUAU_REQUIRE_NO_ERRORS(result); - -// CHECK_EQ("A", toString(requireType("f"))); -// } - TEST_CASE_FIXTURE(Fixture, "cyclic_follow") { check(R"( @@ -522,7 +469,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_assert") LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error") +TEST_CASE_FIXTURE(BuiltinsFixture, "tc_after_error_recovery_no_replacement_name_in_error") { { CheckResult result = check(R"( @@ -575,7 +522,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error } } -TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") +TEST_CASE_FIXTURE(BuiltinsFixture, "index_expr_should_be_checked") { CheckResult result = check(R"( local foo: any @@ -671,7 +618,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") _(nil) )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERRORS(result); std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); @@ -683,7 +630,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") CHECK(it != result.errors.end()); } -TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional2") +TEST_CASE_FIXTURE(BuiltinsFixture, "no_stack_overflow_from_isoptional2") { CheckResult result = check(R"( function _(l0:({})|(t0)):((((typeof((xpcall)))|(t96))|(t13))&(t96),()->typeof(...)) @@ -710,10 +657,10 @@ TEST_CASE_FIXTURE(Fixture, "no_infinite_loop_when_trying_to_unify_uh_this") _() )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") +TEST_CASE_FIXTURE(BuiltinsFixture, "no_heap_use_after_free_error") { CheckResult result = check(R"( --!nonstrict @@ -721,13 +668,13 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") local l0 do end while _ do - function _:_() - _ += _(_._(_:n0(xpcall,_))) - end + function _:_() + _ += _(_._(_:n0(xpcall,_))) + end end )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_type_assertion_value_type") @@ -756,7 +703,7 @@ b, c = {2, "s"}, {"b", 4} LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types_mutable_lval") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_assignment_value_types_mutable_lval") { CheckResult result = check(R"( local a = {} @@ -824,7 +771,7 @@ local a: number? = if true then 1 else nil LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_3") +TEST_CASE_FIXTURE(BuiltinsFixture, "tc_if_else_expressions_expected_type_3") { CheckResult result = check(R"( local function times(n: any, f: () -> T) @@ -895,7 +842,7 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") )"); } -TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash") +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_crash") { CheckResult result = check(R"( local function getIt() @@ -940,8 +887,6 @@ end TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") { - ScopedFastFlag subtypingVariance{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict --!nolint @@ -978,4 +923,84 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") +{ + ScopedFastInt sfi("LuauTypeInferRecursionLimit", 2); + + CheckResult result = check(R"( + function complex() + function _(l0:t0): (any, ()->()) + return 0,_ + end + type t0 = t0 | {} + _(nil) + end + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") +{ + CheckResult result = check(R"( + local obj = {} + + function obj:Method() + self.fieldA = function(object) + if object.a then + self.arr[object] = true + elseif object.b then + self.fieldB[object] = object:Connect(function(arg) + self.arr[arg] = nil + end) + end + end + end + + return obj + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/** + * The problem we had here was that the type of q in B.h was initially inferring to {} | {prop: free} before we bound + * that second table to the enclosing union. + */ +TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_table") +{ + ScopedFastFlag flag[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + --!strict + + local A = {} + + function A:f() + local t = {} + + for key, value in pairs(self) do + t[key] = value + end + + return t + end + + local B = A:f() + + function B.g(t) + assert(type(t) == "table") + assert(t.prop ~= nil) + end + + function B.h(q) + q = q or {} + return q or {} + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index d8de2594..49deae71 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -126,8 +126,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( function f(arg: number) return arg end local a @@ -198,7 +196,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") CHECK_EQ(toString(tm->wantedType), "string"); } -TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unification") +TEST_CASE_FIXTURE(BuiltinsFixture, "cli_41095_concat_log_in_sealed_table_unification") { CheckResult result = check(R"( --!strict @@ -242,4 +240,26 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") state.tryUnify(&func, typeChecker.anyType); } +TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") +{ + TypeId a = arena.addType(TypeVar{FreeTypeVar{TypeLevel{}}}); + TypeId b = typeChecker.numberType; + + state.tryUnify(a, b); + state.log.commit(); + + CHECK_EQ(a->owningArena, &arena); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") +{ + TypePackId a = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); + TypePackId b = typeChecker.anyTypePack; + + state.tryUnify(a, b); + state.log.commit(); + + CHECK_EQ(a->owningArena, &arena); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 130f33d7..bcd30498 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TypePackTests"); TEST_CASE_FIXTURE(Fixture, "infer_multi_return") @@ -24,11 +26,11 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") const FunctionTypeVar* takeTwoType = get(requireType("take_two")); REQUIRE(takeTwoType != nullptr); - const auto& [returns, tail] = flatten(takeTwoType->retType); + const auto& [returns, tail] = flatten(takeTwoType->retTypes); CHECK_EQ(2, returns.size()); - CHECK_EQ(typeChecker.numberType, returns[0]); - CHECK_EQ(typeChecker.numberType, returns[1]); + CHECK_EQ(typeChecker.numberType, follow(returns[0])); + CHECK_EQ(typeChecker.numberType, follow(returns[1])); CHECK(!tail); } @@ -71,12 +73,12 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac const FunctionTypeVar* takeOneMoreType = get(requireType("take_three")); REQUIRE(takeOneMoreType != nullptr); - const auto& [rets, tail] = flatten(takeOneMoreType->retType); + const auto& [rets, tail] = flatten(takeOneMoreType->retTypes); REQUIRE_EQ(3, rets.size()); - CHECK_EQ(typeChecker.numberType, rets[0]); - CHECK_EQ(typeChecker.numberType, rets[1]); - CHECK_EQ(typeChecker.numberType, rets[2]); + CHECK_EQ(typeChecker.numberType, follow(rets[0])); + CHECK_EQ(typeChecker.numberType, follow(rets[1])); + CHECK_EQ(typeChecker.numberType, follow(rets[2])); CHECK(!tail); } @@ -91,26 +93,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* applyType = get(requireType("apply")); - REQUIRE(applyType != nullptr); - - std::vector applyArgs = flatten(applyType->argTypes).first; - REQUIRE_EQ(3, applyArgs.size()); - - const FunctionTypeVar* fType = get(follow(applyArgs[0])); - REQUIRE(fType != nullptr); - - const FunctionTypeVar* gType = get(follow(applyArgs[1])); - REQUIRE(gType != nullptr); - - std::vector gArgs = flatten(gType->argTypes).first; - REQUIRE_EQ(1, gArgs.size()); - - // function(function(t1, T2...): (t3, T4...), function(t5): (t1, T2...), t5): (t3, T4...) - - REQUIRE_EQ(*gArgs[0], *applyArgs[2]); - REQUIRE_EQ(toString(fType->argTypes), toString(gType->retType)); - REQUIRE_EQ(toString(fType->retType), toString(applyType->retType)); + CHECK_EQ("((b...) -> (c...), (a) -> (b...), a) -> (c...)", toString(requireType("apply"))); } TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") @@ -122,10 +105,10 @@ TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") LUAU_REQUIRE_NO_ERRORS(result); const FunctionTypeVar* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); - CHECK_EQ(0, size(fTy->retType)); + CHECK_EQ(0, size(fTy->retTypes)); const FunctionTypeVar* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); - CHECK_EQ(0, size(gTy->retType)); + CHECK_EQ(0, size(gTy->retTypes)); } TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") @@ -142,15 +125,15 @@ TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") const FunctionTypeVar* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); - CHECK_EQ(1, size(follow(fTy->retType))); + CHECK_EQ(1, size(follow(fTy->retTypes))); const FunctionTypeVar* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); - CHECK_EQ(0, size(gTy->retType)); + CHECK_EQ(0, size(gTy->retTypes)); const FunctionTypeVar* hTy = get(requireType("h")); REQUIRE(hTy != nullptr); - CHECK_EQ(0, size(hTy->retType)); + CHECK_EQ(0, size(hTy->retTypes)); } TEST_CASE_FIXTURE(Fixture, "varargs_inference_through_multiple_scopes") @@ -328,7 +311,10 @@ local c: Packed auto ttvA = get(requireType("a")); REQUIRE(ttvA); CHECK_EQ(toString(requireType("a")), "Packed"); - CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}"); + else + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); REQUIRE(ttvA->instantiatedTypeParams.size() == 1); REQUIRE(ttvA->instantiatedTypePackParams.size() == 1); CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); @@ -353,7 +339,7 @@ local c: Packed CHECK_EQ(toString(ttvC->instantiatedTypePackParams[0], {true}), "number, boolean"); } -TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_type_packs_import") { fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } @@ -383,7 +369,7 @@ local d: { a: typeof(c) } CHECK_EQ(toString(requireType("d")), "{| a: Packed |}"); } -TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_pack_type_parameters") { fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } @@ -798,7 +784,7 @@ local a: Y<...number> LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "type_alias_default_export") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_default_export") { fileResolver.source["Module/Types"] = R"( export type A = { a: T, b: U } @@ -953,7 +939,7 @@ until _ )"); } -TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks") +TEST_CASE_FIXTURE(BuiltinsFixture, "detect_cyclic_typepacks") { CheckResult result = check(R"( type ( ... ) ( ) ; @@ -963,10 +949,10 @@ TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks") ( ... ) "" )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks2") +TEST_CASE_FIXTURE(BuiltinsFixture, "detect_cyclic_typepacks2") { CheckResult result = check(R"( function _(l0:((typeof((pcall)))|((((t0)->())|(typeof(-67108864)))|(any)))|(any),...):(((typeof(0))|(any))|(any),typeof(-67108864),any) @@ -975,7 +961,7 @@ TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks2") end )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERRORS(result); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 68b7c4fb..2b48133d 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -6,7 +6,7 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) using namespace Luau; @@ -103,7 +103,7 @@ TEST_CASE_FIXTURE(Fixture, "optional_arguments_table2") REQUIRE(!result.errors.empty()); } -TEST_CASE_FIXTURE(Fixture, "error_takes_optional_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "error_takes_optional_arguments") { CheckResult result = check(R"( error("message") @@ -254,11 +254,11 @@ local c = bf.a.y TEST_CASE_FIXTURE(Fixture, "optional_union_functions") { CheckResult result = check(R"( -local a = {} -function a.foo(x:number, y:number) return x + y end -type A = typeof(a) -local b: A? = a -local c = b.foo(1, 2) + local a = {} + function a.foo(x:number, y:number) return x + y end + type A = typeof(a) + local b: A? = a + local c = b.foo(1, 2) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -356,7 +356,10 @@ a.x = 2 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("Value of type '{| x: number, y: number |}?' could be nil", toString(result.errors[0])); + else + CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "optional_length_error") @@ -425,12 +428,6 @@ y = x TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") { - ScopedFastFlag sffs[] = { - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - {"LuauSubtypingAddOptPropsToUnsealedTables", true}, - }; - CheckResult result = check(R"( -- the difference between this and unify_unsealed_table_union_check is the type annotation on x local t = { x = 3, y = true } @@ -513,4 +510,32 @@ TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") +{ + CheckResult result = check(R"( + type A = { x: number, y: (number) -> string } | { z: number, y: (number) -> string } + + local a:A = nil + + function a.y(x) + return tostring(x * 2) + end + + function a.y(x: string): number + return tonumber(x) or 0 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + // NOTE: union normalization will improve this message + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[0]), "Type '(string) -> number' could not be converted into '(number) -> string'\n" + "caused by:\n" + " Argument #1 type is not compatible. Type 'number' could not be converted into 'string'"); + else + CHECK_EQ(toString(result.errors[0]), + R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); +} + + TEST_SUITE_END(); diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index c4931578..8a5a65fe 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -197,4 +197,20 @@ TEST_CASE_FIXTURE(TypePackFixture, "std_distance") CHECK_EQ(4, std::distance(b, e)); } +TEST_CASE("content_reassignment") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + TypePackVar myError{Unifiable::Error{}, /*presistent*/ true}; + + TypeArena arena; + + TypePackId futureError = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); + asMutable(futureError)->reassign(myError); + + CHECK(get(futureError) != nullptr); + CHECK(!futureError->persistent); + CHECK(futureError->owningArena == &arena); +} + TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index fd5f4dbc..4f8fc502 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -184,8 +184,6 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_empty_union") TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { - ScopedFastFlag sff{"LuauSealExports", true}; - TypeVar ftv11{FreeTypeVar{TypeLevel{}}}; TypePackVar tp24{TypePack{{&ftv11}}}; @@ -275,7 +273,7 @@ TEST_CASE("tagging_tables") TEST_CASE("tagging_classes") { - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; CHECK(!Luau::hasTag(&base, "foo")); Luau::attachTag(&base, "foo"); CHECK(Luau::hasTag(&base, "foo")); @@ -283,8 +281,8 @@ TEST_CASE("tagging_classes") TEST_CASE("tagging_subclasses") { - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; - TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}}; + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; + TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr, "Test"}}; CHECK(!Luau::hasTag(&base, "foo")); CHECK(!Luau::hasTag(&derived, "foo")); @@ -315,23 +313,33 @@ TEST_CASE("tagging_props") CHECK(Luau::hasTag(prop, "foo")); } -struct VisitCountTracker +struct VisitCountTracker final : TypeVarOnceVisitor { std::unordered_map tyVisits; std::unordered_map tpVisits; - void cycle(TypeId) {} - void cycle(TypePackId) {} + void cycle(TypeId) override {} + void cycle(TypePackId) override {} template bool operator()(TypeId ty, const T& t) + { + return visit(ty); + } + + template + bool operator()(TypePackId tp, const T&) + { + return visit(tp); + } + + bool visit(TypeId ty) override { tyVisits[ty]++; return true; } - template - bool operator()(TypePackId tp, const T&) + bool visit(TypePackId tp) override { tpVisits[tp]++; return true; @@ -349,8 +357,7 @@ local b: (T, T, T) -> T TypeId bType = requireType("b"); VisitCountTracker tester; - DenseHashSet seen{nullptr}; - visitTypeVarOnce(bType, tester, seen); + tester.traverse(bType); for (auto [_, count] : tester.tyVisits) CHECK_EQ(count, 1); @@ -409,4 +416,24 @@ TEST_CASE("proof_that_isBoolean_uses_all_of") CHECK(!isBoolean(&union_)); } +TEST_CASE("content_reassignment") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; + myAny.normal = true; + myAny.documentationSymbol = "@global/any"; + + TypeArena arena; + + TypeId futureAny = arena.addType(FreeTypeVar{TypeLevel{}}); + asMutable(futureAny)->reassign(myAny); + + CHECK(get(futureAny) != nullptr); + CHECK(!futureAny->persistent); + CHECK(futureAny->normal); + CHECK(futureAny->documentationSymbol == "@global/any"); + CHECK(futureAny->owningArena == &arena); +} + TEST_SUITE_END(); diff --git a/tests/Variant.test.cpp b/tests/Variant.test.cpp index fcf37875..aa0731ca 100644 --- a/tests/Variant.test.cpp +++ b/tests/Variant.test.cpp @@ -13,6 +13,25 @@ struct Foo int x = 42; }; +struct Bar +{ + explicit Bar(int x) + : prop(x * 2) + { + ++count; + } + + ~Bar() + { + --count; + } + + int prop; + static int count; +}; + +int Bar::count = 0; + TEST_SUITE_BEGIN("Variant"); TEST_CASE("DefaultCtor") @@ -46,6 +65,29 @@ TEST_CASE("Create") CHECK(get_if(&v3)->x == 3); } +TEST_CASE("Emplace") +{ + { + Variant v1; + + CHECK(0 == Bar::count); + int& i = v1.emplace(5); + CHECK(5 == i); + + CHECK(0 == Bar::count); + + CHECK(get_if(&v1) == &i); + + Bar& bar = v1.emplace(11); + CHECK(22 == bar.prop); + CHECK(1 == Bar::count); + + CHECK(get_if(&v1) == &bar); + } + + CHECK(0 == Bar::count); +} + TEST_CASE("NonPOD") { // initialize (copy) diff --git a/tests/VisitTypeVar.test.cpp b/tests/VisitTypeVar.test.cpp new file mode 100644 index 00000000..4fba694a --- /dev/null +++ b/tests/VisitTypeVar.test.cpp @@ -0,0 +1,41 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "Luau/RecursionCounter.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTINT(LuauVisitRecursionLimit) + +TEST_SUITE_BEGIN("VisitTypeVar"); + +TEST_CASE_FIXTURE(Fixture, "throw_when_limit_is_exceeded") +{ + ScopedFastInt sfi{"LuauVisitRecursionLimit", 3}; + + CheckResult result = check(R"( + local t : {a: {b: {c: {d: {e: boolean}}}}} + )"); + + TypeId tType = requireType("t"); + + CHECK_THROWS_AS(toString(tType), RecursionLimitException); +} + +TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") +{ + ScopedFastInt sfi{"LuauVisitRecursionLimit", 8}; + + CheckResult result = check(R"( + local t : {a: {b: {c: {d: {e: boolean}}}}} + )"); + + TypeId tType = requireType("t"); + + (void)toString(tType); +} + +TEST_SUITE_END(); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 7a4058b5..27416623 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -11,4 +11,15 @@ function create_with_tm(x) return setmetatable({ a = x }, m) end +local gen = 0 +function incuv() + gen += 1 + return gen +end + +pi = 3.1415926 +function getpi() + return pi +end + return('OK') diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index 297cf011..0b5aafed 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -167,6 +167,81 @@ if not limitedstack then end end +-- C stack overflow +if not limitedstack then + local count = 1 + local cso = setmetatable({}, { + __index = function(self, i) + count = count + 1 + return self[i] + end, + __newindex = function(self, i, v) + count = count + 1 + self[i] = v + end, + __tostring = function(self) + count = count + 1 + return tostring(self) + end + }) + + local ehline + local function ehassert(cond) + if not cond then + ehline = debug.info(2, "l") + error() + end + end + + local userdata = newproxy(true) + getmetatable(userdata).__index = print + assert(debug.info(print, "s") == "[C]") + + local s, e = xpcall(tostring, function(e) + ehassert(string.find(e, "C stack overflow")) + print("after __tostring C stack overflow", count) -- 198: 1 resume + 1 xpcall + 198 luaB_tostring calls (which runs our __tostring successfully 197 times, erroring on the last attempt) + ehassert(count > 1) + + local ps, pe + + -- __tostring overflow (lua_call) + count = 1 + ps, pe = pcall(tostring, cso) + print("after __tostring overflow in handler", count) -- 23: xpcall error handler + pcall + 23 luaB_tostring calls + ehassert(not ps and string.find(pe, "error in error handling")) + ehassert(count > 1) + + -- __index overflow (callTMres) + count = 1 + ps, pe = pcall(function() return cso[cso] end) + print("after __index overflow in handler", count) -- 23: xpcall error handler + pcall + 23 __index calls + ehassert(not ps and string.find(pe, "error in error handling")) + ehassert(count > 1) + + -- __newindex overflow (callTM) + count = 1 + ps, pe = pcall(function() cso[cso] = "kohuke" end) + print("after __newindex overflow in handler", count) -- 23: xpcall error handler + pcall + 23 __newindex calls + ehassert(not ps and string.find(pe, "error in error handling")) + ehassert(count > 1) + + -- test various C __index invocations on userdata + ehassert(pcall(function() return userdata[userdata] end)) -- LOP_GETTABLE + ehassert(pcall(function() return userdata[1] end)) -- LOP_GETTABLEN + ehassert(pcall(function() return userdata.StringConstant end)) -- LOP_GETTABLEKS (luau_callTM) + + -- lua_resume test + local coro = coroutine.create(function() end) + ps, pe = coroutine.resume(coro) + ehassert(not ps and string.find(pe, "C stack overflow")) + + return true + end, cso) + + assert(not s) + assert(e == true, "error in xpcall eh, line " .. tostring(ehline)) +end + --[[ local i=1 while stack[i] ~= l1 do @@ -307,4 +382,9 @@ assert(ecall(function() return "a" <= 5 end) == "attempt to compare string <= nu assert(ecall(function() local t = {} setmetatable(t, { __newindex = function(t,i,v) end }) t[nil] = 2 end) == "table index is nil") +-- for loop type errors +assert(ecall(function() for i='a',2 do end end) == "invalid 'for' initial value (number expected, got string)") +assert(ecall(function() for i=1,'a' do end end) == "invalid 'for' limit (number expected, got string)") +assert(ecall(function() for i=1,2,'a' do end end) == "invalid 'for' step (number expected, got string)") + return('OK') diff --git a/tests/conformance/iter.lua b/tests/conformance/iter.lua new file mode 100644 index 00000000..468ffafb --- /dev/null +++ b/tests/conformance/iter.lua @@ -0,0 +1,196 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print('testing iteration') + +-- basic for loop tests +do + local a + for a,b in pairs{} do error("not here") end + for i=1,0 do error("not here") end + for i=0,1,-1 do error("not here") end + a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) + a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) + a = 0; for i=0, 1, 0.1 do a=a+1 end; assert(a==11) +end + +-- precision tests for for loops +do + local a + --a = 0; for i=1, 0, -0.01 do a=a+1 end; assert(a==101) + a = 0; for i=0, 0.999999999, 0.1 do a=a+1 end; assert(a==10) + a = 0; for i=1, 1, 1 do a=a+1 end; assert(a==1) + a = 0; for i=1e10, 1e10, -1 do a=a+1 end; assert(a==1) + a = 0; for i=1, 0.99999, 1 do a=a+1 end; assert(a==0) + a = 0; for i=99999, 1e5, -1 do a=a+1 end; assert(a==0) + a = 0; for i=1, 0.99999, -1 do a=a+1 end; assert(a==1) +end + +-- for loops do string->number coercion +do + local a = 0; for i="10","1","-2" do a=a+1 end; assert(a==5) +end + +-- generic for with function iterators +do + local function f (n, p) + local t = {}; for i=1,p do t[i] = i*10 end + return function (_,n) + if n > 0 then + n = n-1 + return n, unpack(t) + end + end, nil, n + end + + local x = 0 + for n,a,b,c,d in f(5,3) do + x = x+1 + assert(a == 10 and b == 20 and c == 30 and d == nil) + end + assert(x == 5) +end + +-- generic for with __call (tables) +do + local f = {} + setmetatable(f, { __call = function(_, _, n) if n > 0 then return n - 1 end end }) + + local x = 0 + for n in f, nil, 5 do + x += n + end + assert(x == 10) +end + +-- generic for with __call (userdata) +do + local f = newproxy(true) + getmetatable(f).__call = function(_, _, n) if n > 0 then return n - 1 end end + + local x = 0 + for n in f, nil, 5 do + x += n + end + assert(x == 10) +end + +-- generic for with pairs +do + local x = 0 + for k, v in pairs({a = 1, b = 2, c = 3}) do + x += v + end + assert(x == 6) +end + +-- generic for with pairs with holes +do + local x = 0 + for k, v in pairs({1, 2, 3, nil, 5}) do + x += v + end + assert(x == 11) +end + +-- generic for with ipairs +do + local x = 0 + for k, v in ipairs({1, 2, 3, nil, 5}) do + x += v + end + assert(x == 6) +end + +-- generic for with __iter (tables) +do + local f = {} + setmetatable(f, { __iter = function(x) + assert(f == x) + return next, {1, 2, 3, 4} + end }) + + local x = 0 + for n in f do + x += n + end + assert(x == 10) +end + +-- generic for with __iter (userdata) +do + local f = newproxy(true) + getmetatable(f).__iter = function(x) + assert(f == x) + return next, {1, 2, 3, 4} + end + + local x = 0 + for n in f do + x += n + end + assert(x == 10) +end + +-- generic for with tables (dictionary) +do + local x = 0 + for k, v in {a = 1, b = 2, c = 3} do + print(k, v) + x += v + end + assert(x == 6) +end + +-- generic for with tables (arrays) +do + local x = '' + for k, v in {1, 2, 3, nil, 5} do + x ..= tostring(v) + end + assert(x == "1235") +end + +-- generic for with tables (mixed) +do + local x = 0 + for k, v in {1, 2, 3, nil, 5, a = 1, b = 2, c = 3} do + x += v + end + assert(x == 17) +end + +-- generic for over a non-iterable object +do + local ok, err = pcall(function() for x in 42 do end end) + assert(not ok and err:match("attempt to iterate")) +end + +-- generic for over an iterable object that doesn't return a function +do + local obj = {} + setmetatable(obj, { __iter = function() end }) + + local ok, err = pcall(function() for x in obj do end end) + assert(not ok and err:match("attempt to call a nil value")) +end + +-- it's okay to iterate through a table with a single variable +do + local x = 0 + for k in {1, 2, 3, 4, 5} do + x += k + end + assert(x == 15) +end + +-- all extra variables should be set to nil during builtin traversal +do + local x = 0 + for k,v,a,b,c,d,e in {1, 2, 3, 4, 5} do + x += k + assert(a == nil and b == nil and c == nil and d == nil and e == nil) + end + assert(x == 15) +end + +return"OK" diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index e85fcbe8..93c4ddf7 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -368,48 +368,6 @@ assert(next(a,nil) == 1000 and next(a,1000) == nil) assert(next({}) == nil) assert(next({}, nil) == nil) -for a,b in pairs{} do error("not here") end -for i=1,0 do error("not here") end -for i=0,1,-1 do error("not here") end -a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) -a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) - -a = 0; for i=0, 1, 0.1 do a=a+1 end; assert(a==11) --- precision problems ---a = 0; for i=1, 0, -0.01 do a=a+1 end; assert(a==101) -a = 0; for i=0, 0.999999999, 0.1 do a=a+1 end; assert(a==10) -a = 0; for i=1, 1, 1 do a=a+1 end; assert(a==1) -a = 0; for i=1e10, 1e10, -1 do a=a+1 end; assert(a==1) -a = 0; for i=1, 0.99999, 1 do a=a+1 end; assert(a==0) -a = 0; for i=99999, 1e5, -1 do a=a+1 end; assert(a==0) -a = 0; for i=1, 0.99999, -1 do a=a+1 end; assert(a==1) - --- conversion -a = 0; for i="10","1","-2" do a=a+1 end; assert(a==5) - - -collectgarbage() - - --- testing generic 'for' - -local function f (n, p) - local t = {}; for i=1,p do t[i] = i*10 end - return function (_,n) - if n > 0 then - n = n-1 - return n, unpack(t) - end - end, nil, n -end - -local x = 0 -for n,a,b,c,d in f(5,3) do - x = x+1 - assert(a == 10 and b == 20 and c == 30 and d == nil) -end -assert(x == 5) - -- testing table.create and table.find do local t = table.create(5) @@ -550,4 +508,88 @@ do assert(not pcall(table.clone, 42)) end +-- test boundary invariant maintenance during rehash +do + local arr = table.create(5, 42) + + arr[1] = nil + arr.a = 'a' -- trigger rehash + + assert(#arr == 5) -- technically 0 is also valid, but it happens to be 5 because array capacity is 5 +end + +-- test boundary invariant maintenance when replacing hash keys +do + local arr = {} + arr.a = 'a' + arr.a = nil + arr[1] = 1 -- should rehash and resize array part, otherwise # won't find the boundary in array part + + assert(#arr == 1) +end + +-- test boundary invariant maintenance when table is filled from the end +do + local arr = {} + for i=5,2,-1 do + arr[i] = i + assert(#arr == 0) + end + arr[1] = 1 + assert(#arr == 5) +end + +-- test boundary invariant maintenance when table is filled using SETLIST opcode +do + local arr = {[2]=2,1} + assert(#arr == 2) +end + +-- test boundary invariant maintenance when table is filled using table.move +do + local t1 = {1, 2, 3, 4, 5} + local t2 = {[6] = 6} + + table.move(t1, 1, 5, 1, t2) + assert(#t2 == 6) +end + +-- test table.unpack fastcall for rejecting large unpacks +do + local ok, res = pcall(function() + local a = table.create(7999, 0) + local b = table.create(8000, 0) + + local at = { table.unpack(a) } + local bt = { table.unpack(b) } + end) + + assert(not ok) +end + +-- test iteration with lightuserdata keys +do + function countud() + local t = {} + t[makelud(1)] = 1 + t[makelud(2)] = 2 + + local count = 0 + for k,v in pairs(t) do + count += v + end + + return count + end + + assert(countud() == 3) +end + +-- test iteration with lightuserdata keys with a substituted environment +do + local env = { makelud = makelud, pairs = pairs } + setfenv(countud, env) + assert(countud() == 3) +end + return"OK" diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index 84ac2ba1..969209fc 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -144,4 +144,21 @@ coroutine.resume(co) resumeerror(co, "fail") checkresults({ true, false, "fail" }, coroutine.resume(co)) +-- stack overflow needs to happen at the call limit +local calllimit = 20000 +function recurse(n) return n <= 1 and 1 or recurse(n-1) + 1 end + +-- we use one frame for top-level function and one frame is the service frame for coroutines +assert(recurse(calllimit - 2) == calllimit - 2) + +-- note that when calling through pcall, pcall eats one more frame +checkresults({ true, calllimit - 3 }, pcall(recurse, calllimit - 3)) +checkerror(pcall(recurse, calllimit - 2)) + +-- xpcall handler runs in context of the stack frame, but this works just fine since we allow extra stack consumption past stack overflow +checkresults({ false, "ok" }, xpcall(recurse, function() return string.reverse("ko") end, calllimit - 2)) + +-- however, if xpcall handler itself runs out of extra stack space, we get "error in error handling" +checkresults({ false, "error in error handling" }, xpcall(recurse, function() return recurse(calllimit) end, calllimit - 2)) + return 'OK' diff --git a/tests/conformance/userdata.lua b/tests/conformance/userdata.lua new file mode 100644 index 00000000..98392e25 --- /dev/null +++ b/tests/conformance/userdata.lua @@ -0,0 +1,45 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print('testing userdata') + +-- int64 is a userdata type defined in C++ + +-- equality +assert(int64(1) == int64(1)) +assert(int64(1) ~= int64(2)) + +-- relational +assert(not (int64(1) < int64(1))) +assert(int64(1) < int64(2)) +assert(int64(1) <= int64(1)) +assert(int64(1) <= int64(2)) + +-- arithmetics +assert(-int64(2) == int64(-2)) + +assert(int64(1) + int64(2) == int64(3)) +assert(int64(1) - int64(2) == int64(-1)) +assert(int64(2) * int64(3) == int64(6)) +assert(int64(4) / int64(2) == int64(2)) +assert(int64(4) % int64(3) == int64(1)) +assert(int64(2) ^ int64(3) == int64(8)) + +assert(int64(1) + 2 == int64(3)) +assert(int64(1) - 2 == int64(-1)) +assert(int64(2) * 3 == int64(6)) +assert(int64(4) / 2 == int64(2)) +assert(int64(4) % 3 == int64(1)) +assert(int64(2) ^ 3 == int64(8)) + +-- tostring +assert(tostring(int64(2)) == "2") + +-- index/newindex; note, mutable userdatas aren't very idiomatic but we need to test this +do + local v = int64(42) + assert(v.value == 42) + v.value = 4 + assert(v.value == 4) + assert(v == int64(4)) +end + +return 'OK' diff --git a/tools/lldb_formatters.py b/tools/lldb_formatters.py index 40f8d6be..ff610d09 100644 --- a/tools/lldb_formatters.py +++ b/tools/lldb_formatters.py @@ -37,7 +37,7 @@ def getType(target, typeName): return ty def luau_variant_summary(valobj, internal_dict, options): - type_id = valobj.GetChildMemberWithName("typeid").GetValueAsUnsigned() + type_id = valobj.GetChildMemberWithName("typeId").GetValueAsUnsigned() storage = valobj.GetChildMemberWithName("storage") params = templateParams(valobj.GetType().GetCanonicalType().GetName()) stored_type = params[type_id] @@ -89,7 +89,7 @@ class LuauVariantSyntheticChildrenProvider: return None def update(self): - self.type_index = self.valobj.GetChildMemberWithName("typeid").GetValueAsSigned() + self.type_index = self.valobj.GetChildMemberWithName("typeId").GetValueAsSigned() self.type_params = templateParams(self.valobj.GetType().GetCanonicalType().GetName()) if len(self.type_params) > self.type_index: @@ -97,7 +97,7 @@ class LuauVariantSyntheticChildrenProvider: if self.current_type: storage = self.valobj.GetChildMemberWithName("storage") - self.stored_value = storage.Cast(self.current_type.GetPointerType()).Dereference() + self.stored_value = storage.Cast(self.current_type) else: self.stored_value = None else: diff --git a/tools/natvis/Analysis.natvis b/tools/natvis/Analysis.natvis index 5de0140e..b9ea3141 100644 --- a/tools/natvis/Analysis.natvis +++ b/tools/natvis/Analysis.natvis @@ -6,40 +6,40 @@ - {{ index=0, value={*($T1*)storage} }} - {{ index=1, value={*($T2*)storage} }} - {{ index=2, value={*($T3*)storage} }} - {{ index=3, value={*($T4*)storage} }} - {{ index=4, value={*($T5*)storage} }} - {{ index=5, value={*($T6*)storage} }} - {{ index=6, value={*($T7*)storage} }} - {{ index=7, value={*($T8*)storage} }} - {{ index=8, value={*($T9*)storage} }} - {{ index=9, value={*($T10*)storage} }} - {{ index=10, value={*($T11*)storage} }} - {{ index=11, value={*($T12*)storage} }} - {{ index=12, value={*($T13*)storage} }} - {{ index=13, value={*($T14*)storage} }} - {{ index=14, value={*($T15*)storage} }} - {{ index=15, value={*($T16*)storage} }} - {{ index=16, value={*($T17*)storage} }} - {{ index=17, value={*($T18*)storage} }} - {{ index=18, value={*($T19*)storage} }} - {{ index=19, value={*($T20*)storage} }} - {{ index=20, value={*($T21*)storage} }} - {{ index=21, value={*($T22*)storage} }} - {{ index=22, value={*($T23*)storage} }} - {{ index=23, value={*($T24*)storage} }} - {{ index=24, value={*($T25*)storage} }} - {{ index=25, value={*($T26*)storage} }} - {{ index=26, value={*($T27*)storage} }} - {{ index=27, value={*($T28*)storage} }} - {{ index=28, value={*($T29*)storage} }} - {{ index=29, value={*($T30*)storage} }} - {{ index=30, value={*($T31*)storage} }} - {{ index=31, value={*($T32*)storage} }} + {{ typeId=0, value={*($T1*)storage} }} + {{ typeId=1, value={*($T2*)storage} }} + {{ typeId=2, value={*($T3*)storage} }} + {{ typeId=3, value={*($T4*)storage} }} + {{ typeId=4, value={*($T5*)storage} }} + {{ typeId=5, value={*($T6*)storage} }} + {{ typeId=6, value={*($T7*)storage} }} + {{ typeId=7, value={*($T8*)storage} }} + {{ typeId=8, value={*($T9*)storage} }} + {{ typeId=9, value={*($T10*)storage} }} + {{ typeId=10, value={*($T11*)storage} }} + {{ typeId=11, value={*($T12*)storage} }} + {{ typeId=12, value={*($T13*)storage} }} + {{ typeId=13, value={*($T14*)storage} }} + {{ typeId=14, value={*($T15*)storage} }} + {{ typeId=15, value={*($T16*)storage} }} + {{ typeId=16, value={*($T17*)storage} }} + {{ typeId=17, value={*($T18*)storage} }} + {{ typeId=18, value={*($T19*)storage} }} + {{ typeId=19, value={*($T20*)storage} }} + {{ typeId=20, value={*($T21*)storage} }} + {{ typeId=21, value={*($T22*)storage} }} + {{ typeId=22, value={*($T23*)storage} }} + {{ typeId=23, value={*($T24*)storage} }} + {{ typeId=24, value={*($T25*)storage} }} + {{ typeId=25, value={*($T26*)storage} }} + {{ typeId=26, value={*($T27*)storage} }} + {{ typeId=27, value={*($T28*)storage} }} + {{ typeId=28, value={*($T29*)storage} }} + {{ typeId=29, value={*($T30*)storage} }} + {{ typeId=30, value={*($T31*)storage} }} + {{ typeId=31, value={*($T32*)storage} }} - typeId + typeId *($T1*)storage *($T2*)storage *($T3*)storage diff --git a/tools/natvis/CodeGen.natvis b/tools/natvis/CodeGen.natvis new file mode 100644 index 00000000..5ff6e143 --- /dev/null +++ b/tools/natvis/CodeGen.natvis @@ -0,0 +1,56 @@ + + + + + noreg + rip + + al + cl + dl + bl + + eax + ecx + edx + ebx + esp + ebp + esi + edi + e{(int)index,d}d + + rax + rcx + rdx + rbx + rsp + rbp + rsi + rdi + r{(int)index,d} + + xmm{(int)index,d} + + ymm{(int)index,d} + + + + {base} + {memSize,en} ptr[{base} + {index}*{(int)scale,d} + {imm}] + {memSize,en} ptr[{index}*{(int)scale,d} + {imm}] + {memSize,en} ptr[{base} + {imm}] + {memSize,en} ptr[{imm}] + {imm} + + base + imm + memSize + base + index + scale + imm + + + + diff --git a/tools/patchtests.py b/tools/patchtests.py new file mode 100644 index 00000000..dcaf6083 --- /dev/null +++ b/tools/patchtests.py @@ -0,0 +1,76 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# This code can be used to patch Compiler.test.cpp following bytecode changes, based on error output +import sys +import re + +(_, filename) = sys.argv +input = sys.stdin.readlines() + +errors = [] + +# 0: looking for error, 1: looking for replacement, 2: collecting replacement, 3: collecting initial text +state = 0 + +# parse input into errors[] with the state machine; this is using doctest output and expects multi-line match failures +for line in input: + if state == 0: + match = re.match("tests/[^:]+:(\d+): ERROR: CHECK_EQ", line) + if match: + error_line = int(match[1]) + state = 1 + elif state == 1: + if re.match("\s*values: CHECK_EQ\(\s*$", line): + error_repl = [] + state = 2 + elif re.match("\s*values: CHECK_EQ", line): + state = 0 # skipping single-line checks since we can't patch them + elif state == 2: + if line.strip() == ",": + error_orig = [] + state = 3 + else: + error_repl.append(line) + elif state == 3: + if line.strip() == ")": + errors.append((error_line, error_orig, error_repl)) + state = 0 + else: + error_orig.append(line) + +# make sure we fully process each individual check +assert(state == 0) + +errors.sort(key = lambda e: e[0]) + +with open(filename, "r") as fp: + source = fp.readlines() + +# patch source text into result[] using errors[]; we expect every match to appear at or after the line error was reported at +result = [] + +current = 0 +index = 0 + +while index < len(source): + line = source[index] + error = errors[current] if current < len(errors) else None + + if not error or index < error[0] or line != error[1][0]: + result.append(line) + index += 1 + else: + # validate that the patch has a complete match in source text + for v in range(len(error[1])): + assert(source[index + v] == error[1][v]) + + result += error[2] + index += len(error[1]) + current += 1 + +# make sure we patch all errors +assert(current == len(errors)) + +with open(filename, "w") as fp: + fp.writelines(result)