diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 37382f13..55f9db39 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,5 @@ blank_issues_enabled: false contact_links: - - name: Help and support + - name: Questions url: https://github.com/Roblox/luau/discussions about: Please use GitHub Discussions if you have questions or need help. diff --git a/.github/codecov.yml b/.github/codecov.yml new file mode 100644 index 00000000..69cb7601 --- /dev/null +++ b/.github/codecov.yml @@ -0,0 +1 @@ +comment: false diff --git a/.github/workflows/benchmark-dev.yml b/.github/workflows/benchmark-dev.yml new file mode 100644 index 00000000..d78cb0ba --- /dev/null +++ b/.github/workflows/benchmark-dev.yml @@ -0,0 +1,185 @@ +name: benchmark-dev + +on: + push: + branches: + - master + + paths-ignore: + - "docs/**" + - "papers/**" + - "rfcs/**" + - "*.md" + +jobs: + windows: + name: windows-${{matrix.arch}} + strategy: + fail-fast: false + matrix: + os: [windows-latest] + arch: [Win32, x64] + bench: + - { + script: "run-benchmarks", + timeout: 12, + title: "Luau Benchmarks", + } + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Luau repository + uses: actions/checkout@v3 + + - name: Build Luau + shell: bash # necessary for fail-fast + run: | + mkdir build && cd build + cmake .. -DCMAKE_BUILD_TYPE=Release + cmake --build . --target Luau.Repl.CLI --config Release + cmake --build . --target Luau.Analyze.CLI --config Release + + - name: Move build files to root + run: | + move build/Release/* . + + - uses: actions/setup-python@v3 + with: + python-version: "3.9" + architecture: "x64" + + - name: Install python dependencies + run: | + python -m pip install requests + python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose + + - name: Run benchmark + run: | + python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt + + - name: Push benchmark results + id: pushBenchmarkAttempt1 + continue-on-error: true + uses: ./.github/workflows/push-results + with: + repository: ${{ matrix.benchResultsRepo.name }} + branch: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})" + bench_tool: "benchmarkluau" + bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" + bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" + + - name: Push benchmark results (Attempt 2) + id: pushBenchmarkAttempt2 + continue-on-error: true + if: steps.pushBenchmarkAttempt1.outcome == 'failure' + uses: ./.github/workflows/push-results + with: + repository: ${{ matrix.benchResultsRepo.name }} + branch: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})" + bench_tool: "benchmarkluau" + bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" + bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" + + - name: Push benchmark results (Attempt 3) + id: pushBenchmarkAttempt3 + continue-on-error: true + if: steps.pushBenchmarkAttempt2.outcome == 'failure' + uses: ./.github/workflows/push-results + with: + repository: ${{ matrix.benchResultsRepo.name }} + branch: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})" + bench_tool: "benchmarkluau" + bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" + bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" + + unix: + name: ${{matrix.os}} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + bench: + - { + script: "run-benchmarks", + timeout: 12, + title: "Luau Benchmarks", + } + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Luau repository + uses: actions/checkout@v3 + + - name: Build Luau + run: make config=release luau luau-analyze + + - uses: actions/setup-python@v3 + with: + python-version: "3.9" + architecture: "x64" + + - name: Install python dependencies + run: | + python -m pip install requests + python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose + + - name: Run benchmark + run: | + python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt + + - name: Push benchmark results + id: pushBenchmarkAttempt1 + continue-on-error: true + uses: ./.github/workflows/push-results + with: + repository: ${{ matrix.benchResultsRepo.name }} + branch: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + bench_name: ${{ matrix.bench.title }} + bench_tool: "benchmarkluau" + bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" + bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" + + - name: Push benchmark results (Attempt 2) + id: pushBenchmarkAttempt2 + continue-on-error: true + if: steps.pushBenchmarkAttempt1.outcome == 'failure' + uses: ./.github/workflows/push-results + with: + repository: ${{ matrix.benchResultsRepo.name }} + branch: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + bench_name: ${{ matrix.bench.title }} + bench_tool: "benchmarkluau" + bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" + bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" + + - name: Push benchmark results (Attempt 3) + id: pushBenchmarkAttempt3 + continue-on-error: true + if: steps.pushBenchmarkAttempt2.outcome == 'failure' + uses: ./.github/workflows/push-results + with: + repository: ${{ matrix.benchResultsRepo.name }} + branch: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + bench_name: ${{ matrix.bench.title }} + bench_tool: "benchmarkluau" + bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" + bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 00000000..7fb88e21 --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,140 @@ +name: benchmark + +on: + push: + branches: + - master + paths-ignore: + - "docs/**" + - "papers/**" + - "rfcs/**" + - "*.md" + +jobs: + callgrind: + strategy: + matrix: + os: [ubuntu-22.04] + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Luau repository + uses: actions/checkout@v3 + + - name: Install valgrind + run: | + sudo apt-get install valgrind + + - name: Build Luau (gcc) + run: | + CXX=g++ make config=profile luau + cp luau luau-gcc + + - name: Build Luau (codegen) + run: | + make config=profile clean + CXX=clang++ make config=profile native=1 luau + cp luau luau-codegen + + - name: Build Luau (clang) + run: | + make config=profile clean + CXX=clang++ make config=profile luau luau-analyze + + - name: Run benchmark (bench-gcc) + run: | + python bench/bench.py --callgrind --vm "./luau-gcc -O2" | tee -a bench-gcc-output.txt + + - name: Run benchmark (bench) + run: | + python bench/bench.py --callgrind --vm "./luau -O2" | tee -a bench-output.txt + + - name: Run benchmark (bench-codegen) + run: | + python bench/bench.py --callgrind --vm "./luau-codegen --codegen -O2" | tee -a bench-codegen-output.txt + + - name: Run benchmark (analyze) + run: | + filter() { + awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau-analyze"}' + } + valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-nonstrict | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-strict | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/regex.lua 2>&1 | filter regex-nonstrict | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/regex.lua 2>&1 | filter regex-strict | tee -a analyze-output.txt + + - name: Run benchmark (compile) + run: | + filter() { + awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau --compile"}' + } + valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt + + - name: Checkout benchmark results + uses: actions/checkout@v3 + with: + repository: ${{ matrix.benchResultsRepo.name }} + ref: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + + - name: Store results (bench) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: callgrind clang + tool: "benchmarkluau" + output-file-path: ./bench-output.txt + external-data-json-path: ./gh-pages/bench.json + + - name: Store results (bench-codegen) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: callgrind codegen + tool: "benchmarkluau" + output-file-path: ./bench-codegen-output.txt + external-data-json-path: ./gh-pages/bench-codegen.json + + - name: Store results (bench-gcc) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: callgrind gcc + tool: "benchmarkluau" + output-file-path: ./bench-gcc-output.txt + external-data-json-path: ./gh-pages/bench-gcc.json + + - name: Store results (analyze) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: luau-analyze + tool: "benchmarkluau" + output-file-path: ./analyze-output.txt + external-data-json-path: ./gh-pages/analyze.json + + - name: Store results (compile) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: luau --compile + tool: "benchmarkluau" + output-file-path: ./compile-output.txt + external-data-json-path: ./gh-pages/compile.json + + - name: Push benchmark results + if: github.event_name == 'push' + run: | + echo "Pushing benchmark results..." + cd gh-pages + git config user.name github-actions + git config user.email github@users.noreply.github.com + git add *.json + git commit -m "Add benchmarks results for ${{ github.sha }}" + git push + cd .. diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a1c9db5b..5145e76b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,6 +4,11 @@ on: push: branches: - 'master' + paths-ignore: + - 'docs/**' + - 'papers/**' + - 'rfcs/**' + - '*.md' pull_request: paths-ignore: - 'docs/**' @@ -20,15 +25,24 @@ jobs: runs-on: ${{matrix.os}}-latest steps: - uses: actions/checkout@v1 - - name: make test + - name: make tests run: | - make -j2 config=sanitize test - - name: make test w/flags + make -j2 config=sanitize werror=1 native=1 luau-tests + - name: run tests run: | - make -j2 config=sanitize flags=true test + ./luau-tests + ./luau-tests --fflags=true + - name: run extra conformance tests + run: | + ./luau-tests -ts=Conformance -O2 + ./luau-tests -ts=Conformance -O2 --fflags=true + ./luau-tests -ts=Conformance --codegen + ./luau-tests -ts=Conformance --codegen --fflags=true + ./luau-tests -ts=Conformance --codegen -O2 + ./luau-tests -ts=Conformance --codegen -O2 --fflags=true - name: make cli run: | - make -j2 config=sanitize luau luau-analyze # match config with tests to improve build time + make -j2 config=sanitize werror=1 luau luau-analyze # match config with tests to improve build time ./luau tests/conformance/assert.lua ./luau-analyze tests/conformance/assert.lua @@ -40,18 +54,25 @@ jobs: steps: - uses: actions/checkout@v1 - name: cmake configure - run: cmake . -A ${{matrix.arch}} - - name: cmake test + run: cmake . -A ${{matrix.arch}} -DLUAU_WERROR=ON -DLUAU_NATIVE=ON + - name: cmake build + run: cmake --build . --target Luau.UnitTest Luau.Conformance --config Debug + - name: run tests shell: bash # necessary for fail-fast run: | - cmake --build . --target Luau.UnitTest Luau.Conformance --config Debug Debug/Luau.UnitTest.exe Debug/Luau.Conformance.exe - - name: cmake test w/flags - shell: bash # necessary for fail-fast - run: | Debug/Luau.UnitTest.exe --fflags=true Debug/Luau.Conformance.exe --fflags=true + - name: run extra conformance tests + shell: bash # necessary for fail-fast + run: | + Debug/Luau.Conformance.exe -O2 + Debug/Luau.Conformance.exe -O2 --fflags=true + Debug/Luau.Conformance.exe --codegen + Debug/Luau.Conformance.exe --codegen --fflags=true + Debug/Luau.Conformance.exe --codegen -O2 + Debug/Luau.Conformance.exe --codegen -O2 --fflags=true - name: cmake cli shell: bash # necessary for fail-fast run: | @@ -62,19 +83,34 @@ jobs: coverage: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v2 - name: install run: | sudo apt install llvm - name: make coverage run: | - CXX=clang++-10 make -j2 config=coverage coverage + CXX=clang++ make -j2 config=coverage native=1 coverage - name: upload coverage - uses: coverallsapp/github-action@master + uses: codecov/codecov-action@v3 with: - path-to-lcov: ./coverage.info - github-token: ${{ secrets.GITHUB_TOKEN }} - - uses: actions/upload-artifact@v2 + files: ./coverage.info + token: ${{ secrets.CODECOV_TOKEN }} + + web: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + - uses: actions/checkout@v2 with: - name: coverage - path: coverage + repository: emscripten-core/emsdk + path: emsdk + - name: emsdk install + run: | + cd emsdk + ./emsdk install latest + ./emsdk activate latest + - name: make + run: | + source emsdk/emsdk_env.sh + emcmake cmake . -DLUAU_BUILD_WEB=ON -DCMAKE_BUILD_TYPE=Release + make -j2 Luau.Web diff --git a/.github/workflows/new-release.yml b/.github/workflows/new-release.yml new file mode 100644 index 00000000..90143323 --- /dev/null +++ b/.github/workflows/new-release.yml @@ -0,0 +1,83 @@ +name: new-release + +on: + workflow_dispatch: + inputs: + version: + required: true + description: Release version including 0. + +permissions: + contents: write + +jobs: + create-release: + runs-on: ubuntu-latest + outputs: + upload_url: ${{ steps.create_release.outputs.upload_url }} + steps: + - name: create release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ github.event.inputs.version }} + release_name: ${{ github.event.inputs.version }} + draft: true + + build: + needs: ["create-release"] + strategy: + matrix: + os: [ubuntu, macos, windows] + name: ${{matrix.os}} + runs-on: ${{matrix.os}}-latest + steps: + - uses: actions/checkout@v1 + - name: configure + run: cmake . -DCMAKE_BUILD_TYPE=Release + - name: build + run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Release -j 2 + - name: pack + if: matrix.os != 'windows' + run: zip luau-${{matrix.os}}.zip luau* + - name: pack + if: matrix.os == 'windows' + run: 7z a luau-${{matrix.os}}.zip .\Release\luau*.exe + - uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ needs.create-release.outputs.upload_url }} + asset_path: luau-${{matrix.os}}.zip + asset_name: luau-${{matrix.os}}.zip + asset_content_type: application/octet-stream + + web: + needs: ["create-release"] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + - uses: actions/checkout@v2 + with: + repository: emscripten-core/emsdk + path: emsdk + - name: emsdk install + run: | + cd emsdk + ./emsdk install latest + ./emsdk activate latest + - name: make + run: | + source emsdk/emsdk_env.sh + emcmake cmake . -DLUAU_BUILD_WEB=ON -DCMAKE_BUILD_TYPE=Release + make -j2 Luau.Web + - uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ needs.create-release.outputs.upload_url }} + asset_path: Luau.Web.js + asset_name: Luau.Web.js + asset_content_type: application/octet-stream diff --git a/.github/workflows/push-results/action.yml b/.github/workflows/push-results/action.yml new file mode 100644 index 00000000..b5ffebee --- /dev/null +++ b/.github/workflows/push-results/action.yml @@ -0,0 +1,63 @@ +name: Checkout & push results +description: Checkout a given repo and push results to GitHub +inputs: + repository: + required: true + type: string + description: The benchmark results repository to check out + branch: + required: true + type: string + description: The benchmark results repository's branch to check out + token: + required: true + type: string + description: The GitHub token to use for pushing results + path: + required: true + type: string + description: The path to check out the results repository to + bench_name: + required: true + type: string + bench_tool: + required: true + type: string + bench_output_file_path: + required: true + type: string + bench_external_data_json_path: + required: true + type: string + +runs: + using: "composite" + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + repository: ${{ inputs.repository }} + ref: ${{ inputs.branch }} + token: ${{ inputs.token }} + path: ${{ inputs.path }} + + - name: Store results + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ inputs.bench_name }} + tool: ${{ inputs.bench_tool }} + gh-pages-branch: ${{ inputs.branch }} + output-file-path: ${{ inputs.bench_output_file_path }} + external-data-json-path: ${{ inputs.bench_external_data_json_path }} + + - name: Push benchmark results + shell: bash + run: | + echo "Pushing benchmark results..." + cd gh-pages + git config user.name github-actions + git config user.email github@users.noreply.github.com + git add *.json + git commit -m "Add benchmarks results for ${{ github.sha }}" + git push + cd .. diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8cfc27f6..36d3534c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -8,6 +8,7 @@ on: - 'docs/**' - 'papers/**' - 'rfcs/**' + - '*.md' jobs: build: @@ -21,7 +22,7 @@ jobs: - name: configure run: cmake . -DCMAKE_BUILD_TYPE=Release - name: build - run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Release + run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Release -j 2 - uses: actions/upload-artifact@v2 if: matrix.os != 'windows' with: @@ -32,3 +33,26 @@ jobs: with: name: luau-${{matrix.os}} path: Release\luau*.exe + + web: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + - uses: actions/checkout@v2 + with: + repository: emscripten-core/emsdk + path: emsdk + - name: emsdk install + run: | + cd emsdk + ./emsdk install latest + ./emsdk activate latest + - name: make + run: | + source emsdk/emsdk_env.sh + emcmake cmake . -DLUAU_BUILD_WEB=ON -DCMAKE_BUILD_TYPE=Release + make -j2 Luau.Web + - uses: actions/upload-artifact@v2 + with: + name: Luau.Web.js + path: Luau.Web.js diff --git a/.gitignore b/.gitignore index f92fb488..af77f73c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,13 @@ -build/ -^coverage/ -^fuzz/luau.pb.* -^crash-* -^default.prof* -^fuzz-* -^luau$ +/build/ +/build[.-]*/ +/coverage/ +/.vs/ +/.vscode/ +/fuzz/luau.pb.* +/crash-* +/default.prof* +/fuzz-* +/luau +/luau-tests +/luau-analyze +__pycache__ diff --git a/Analysis/include/Luau/Anyification.h b/Analysis/include/Luau/Anyification.h new file mode 100644 index 00000000..a6f3e2a9 --- /dev/null +++ b/Analysis/include/Luau/Anyification.h @@ -0,0 +1,42 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/NotNull.h" +#include "Luau/Substitution.h" +#include "Luau/TypeVar.h" + +#include + +namespace Luau +{ + +struct TypeArena; +struct Scope; +struct InternalErrorReporter; +using ScopePtr = std::shared_ptr; + +// A substitution which replaces free types by any +struct Anyification : Substitution +{ + Anyification(TypeArena* arena, NotNull scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, TypeId anyType, + TypePackId anyTypePack); + Anyification(TypeArena* arena, const ScopePtr& scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, TypeId anyType, + TypePackId anyTypePack); + NotNull scope; + NotNull singletonTypes; + InternalErrorReporter* iceHandler; + + TypeId anyType; + TypePackId anyTypePack; + bool normalizationTooComplex = false; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; + + bool ignoreChildren(TypeId ty) override; + bool ignoreChildren(TypePackId ty) override; +}; + +} // namespace Luau \ No newline at end of file diff --git a/Analysis/include/Luau/ApplyTypeFunction.h b/Analysis/include/Luau/ApplyTypeFunction.h new file mode 100644 index 00000000..8da3bc42 --- /dev/null +++ b/Analysis/include/Luau/ApplyTypeFunction.h @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Substitution.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeVar.h" + +namespace Luau +{ + +// A substitution which replaces the type parameters of a type function by arguments +struct ApplyTypeFunction : Substitution +{ + ApplyTypeFunction(TypeArena* arena) + : Substitution(TxnLog::empty(), arena) + , encounteredForwardedType(false) + { + } + + // Never set under deferred constraint resolution. + bool encounteredForwardedType; + std::unordered_map typeArguments; + std::unordered_map typePackArguments; + bool ignoreChildren(TypeId ty) override; + bool ignoreChildren(TypePackId tp) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/JsonEncoder.h b/Analysis/include/Luau/AstJsonEncoder.h similarity index 67% rename from Analysis/include/Luau/JsonEncoder.h rename to Analysis/include/Luau/AstJsonEncoder.h index aa00390b..e79d9e62 100644 --- a/Analysis/include/Luau/JsonEncoder.h +++ b/Analysis/include/Luau/AstJsonEncoder.h @@ -2,12 +2,15 @@ #pragma once #include +#include namespace Luau { class AstNode; +struct Comment; std::string toJson(AstNode* node); +std::string toJson(AstNode* node, const std::vector& commentLocations); } // namespace Luau diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h index d38976ef..950a19da 100644 --- a/Analysis/include/Luau/AstQuery.h +++ b/Analysis/include/Luau/AstQuery.h @@ -42,12 +42,28 @@ struct ExprOrLocal { return expr ? expr->location : (local ? local->location : std::optional{}); } + std::optional getName() + { + if (expr) + { + if (AstName name = getIdentifier(expr); name.value) + { + return name; + } + } + else if (local) + { + return local->name; + } + return std::nullopt; + } private: AstExpr* expr = nullptr; AstLocal* local = nullptr; }; +std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos); std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos); AstNode* findNodeAtPosition(const SourceModule& source, Position pos); AstExpr* findExprAtPosition(const SourceModule& source, Position pos); diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index 58534293..f40f8b49 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -19,6 +19,17 @@ struct TypeChecker; using ModulePtr = std::shared_ptr; +enum class AutocompleteContext +{ + Unknown, + Expression, + Statement, + Property, + Type, + Keyword, + String, +}; + enum class AutocompleteEntryKind { Property, @@ -66,11 +77,13 @@ struct AutocompleteResult { AutocompleteEntryMap entryMap; std::vector ancestry; + AutocompleteContext context = AutocompleteContext::Unknown; AutocompleteResult() = default; - AutocompleteResult(AutocompleteEntryMap entryMap, std::vector ancestry) + AutocompleteResult(AutocompleteEntryMap entryMap, std::vector ancestry, AutocompleteContext context) : entryMap(std::move(entryMap)) , ancestry(std::move(ancestry)) + , context(context) { } }; @@ -78,14 +91,6 @@ struct AutocompleteResult using ModuleName = std::string; using StringCompletionCallback = std::function(std::string tag, std::optional ctx)>; -struct OwningAutocompleteResult -{ - AutocompleteResult result; - ModulePtr module; - std::unique_ptr sourceModule; -}; - AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); -OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback); } // namespace Luau diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 8f17fff6..16cccafe 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -1,12 +1,22 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "TypeInfer.h" +#include "Luau/Scope.h" +#include "Luau/TypeVar.h" + +#include namespace Luau { -void registerBuiltinTypes(TypeChecker& typeChecker); +struct Frontend; +struct TypeChecker; +struct TypeArena; + +void registerBuiltinTypes(Frontend& frontend); + +void registerBuiltinGlobals(TypeChecker& typeChecker); +void registerBuiltinGlobals(Frontend& frontend); TypeId makeUnion(TypeArena& arena, std::vector&& types); TypeId makeIntersection(TypeArena& arena, std::vector&& types); @@ -14,6 +24,7 @@ TypeId makeIntersection(TypeArena& arena, std::vector&& types); /** Build an optional 't' */ TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t); +TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t); /** Small utility function for building up type definitions from C++. */ @@ -33,19 +44,25 @@ TypeId makeFunction( // Polymorphic std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); void attachMagicFunction(TypeId ty, MagicFunction fn); -void attachFunctionTag(TypeId ty, std::string constraint); +void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); +void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName); std::string getBuiltinDefinitionSource(); -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName); void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding); +void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName); void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding); -std::optional tryGetGlobalBinding(TypeChecker& typeChecker, const std::string& name); +void addGlobalBinding(Frontend& frontend, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(Frontend& frontend, const std::string& name, Binding binding); +void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, Binding binding); +std::optional tryGetGlobalBinding(Frontend& frontend, const std::string& name); Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name); +TypeId getGlobalBinding(Frontend& frontend, const std::string& name); TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name); } // namespace Luau diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h new file mode 100644 index 00000000..f003c242 --- /dev/null +++ b/Analysis/include/Luau/Clone.h @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include "Luau/TypeArena.h" +#include "Luau/TypeVar.h" + +#include + +namespace Luau +{ + +// Only exposed so they can be unit tested. +using SeenTypes = std::unordered_map; +using SeenTypePacks = std::unordered_map; + +struct CloneState +{ + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + int recursionCount = 0; +}; + +TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); +TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); + +TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false); +TypeId shallowClone(TypeId ty, NotNull dest); + +} // namespace Luau diff --git a/Analysis/include/Luau/Config.h b/Analysis/include/Luau/Config.h index 56cdfe78..8ba4ffa5 100644 --- a/Analysis/include/Luau/Config.h +++ b/Analysis/include/Luau/Config.h @@ -17,12 +17,9 @@ constexpr const char* kConfigName = ".luaurc"; struct Config { - Config() - { - enabledLint.setDefaults(); - } + Config(); - Mode mode = Mode::NoCheck; + Mode mode; ParseOptions parseOptions; diff --git a/Analysis/include/Luau/Connective.h b/Analysis/include/Luau/Connective.h new file mode 100644 index 00000000..4a6be93c --- /dev/null +++ b/Analysis/include/Luau/Connective.h @@ -0,0 +1,70 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Def.h" +#include "Luau/TypedAllocator.h" +#include "Luau/Variant.h" + +#include + +namespace Luau +{ + +struct TypeVar; +using TypeId = const TypeVar*; + +struct Negation; +struct Conjunction; +struct Disjunction; +struct Equivalence; +struct Proposition; +using Connective = Variant; +using ConnectiveId = Connective*; // Can and most likely is nullptr. + +struct Negation +{ + ConnectiveId connective; +}; + +struct Conjunction +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Disjunction +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Equivalence +{ + ConnectiveId lhs; + ConnectiveId rhs; +}; + +struct Proposition +{ + DefId def; + TypeId discriminantTy; +}; + +template +const T* get(ConnectiveId connective) +{ + return get_if(connective); +} + +struct ConnectiveArena +{ + TypedAllocator allocator; + + ConnectiveId negation(ConnectiveId connective); + ConnectiveId conjunction(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId disjunction(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId equivalence(ConnectiveId lhs, ConnectiveId rhs); + ConnectiveId proposition(DefId def, TypeId discriminantTy); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h new file mode 100644 index 00000000..9eea9c28 --- /dev/null +++ b/Analysis/include/Luau/Constraint.h @@ -0,0 +1,207 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" // Used for some of the enumerations +#include "Luau/Def.h" +#include "Luau/DenseHash.h" +#include "Luau/NotNull.h" +#include "Luau/TypeVar.h" +#include "Luau/Variant.h" + +#include +#include +#include + +namespace Luau +{ + +struct Scope; + +struct TypeVar; +using TypeId = const TypeVar*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +// subType <: superType +struct SubtypeConstraint +{ + TypeId subType; + TypeId superType; +}; + +// subPack <: superPack +struct PackSubtypeConstraint +{ + TypePackId subPack; + TypePackId superPack; +}; + +// generalizedType ~ gen sourceType +struct GeneralizationConstraint +{ + TypeId generalizedType; + TypeId sourceType; +}; + +// subType ~ inst superType +struct InstantiationConstraint +{ + TypeId subType; + TypeId superType; +}; + +struct UnaryConstraint +{ + AstExprUnary::Op op; + TypeId operandType; + TypeId resultType; +}; + +// let L : leftType +// let R : rightType +// in +// L op R : resultType +struct BinaryConstraint +{ + AstExprBinary::Op op; + TypeId leftType; + TypeId rightType; + TypeId resultType; + + // When we dispatch this constraint, we update the key at this map to record + // the overload that we selected. + AstExpr* expr; + DenseHashMap* astOriginalCallTypes; + DenseHashMap* astOverloadResolvedTypes; +}; + +// iteratee is iterable +// iterators is the iteration types. +struct IterableConstraint +{ + TypePackId iterator; + TypePackId variables; +}; + +// name(namedType) = name +struct NameConstraint +{ + TypeId namedType; + std::string name; +}; + +// target ~ inst target +struct TypeAliasExpansionConstraint +{ + // Must be a PendingExpansionTypeVar. + TypeId target; +}; + +struct FunctionCallConstraint +{ + std::vector> innerConstraints; + TypeId fn; + TypePackId argsPack; + TypePackId result; + class AstExprCall* callSite; +}; + +// result ~ prim ExpectedType SomeSingletonType MultitonType +// +// If ExpectedType is potentially a singleton (an actual singleton or a union +// that contains a singleton), then result ~ SomeSingletonType +// +// else result ~ MultitonType +struct PrimitiveTypeConstraint +{ + TypeId resultType; + TypeId expectedType; + TypeId singletonType; + TypeId multitonType; +}; + +// result ~ hasProp type "prop_name" +// +// If the subject is a table, bind the result to the named prop. If the table +// has an indexer, bind it to the index result type. If the subject is a union, +// bind the result to the union of its constituents' properties. +// +// It would be nice to get rid of this constraint and someday replace it with +// +// T <: {p: X} +// +// Where {} describes an inexact shape type. +struct HasPropConstraint +{ + TypeId resultType; + TypeId subjectType; + std::string prop; +}; + +// result ~ setProp subjectType ["prop", "prop2", ...] propType +// +// If the subject is a table or table-like thing that already has the named +// property chain, we unify propType with that existing property type. +// +// If the subject is a free table, we augment it in place. +// +// If the subject is an unsealed table, result is an augmented table that +// includes that new prop. +struct SetPropConstraint +{ + TypeId resultType; + TypeId subjectType; + std::vector path; + TypeId propType; +}; + +// if negation: +// result ~ if isSingleton D then ~D else unknown where D = discriminantType +// if not negation: +// result ~ if isSingleton D then D else unknown where D = discriminantType +struct SingletonOrTopTypeConstraint +{ + TypeId resultType; + TypeId discriminantType; + bool negated; +}; + +using ConstraintV = Variant; + +struct Constraint +{ + Constraint(NotNull scope, const Location& location, ConstraintV&& c); + + Constraint(const Constraint&) = delete; + Constraint& operator=(const Constraint&) = delete; + + NotNull scope; + Location location; // TODO: Extract this out into only the constraints that needs a location. Not all constraints needs locations. + ConstraintV c; + + std::vector> dependencies; +}; + +using ConstraintPtr = std::unique_ptr; + +inline Constraint& asMutable(const Constraint& c) +{ + return const_cast(c); +} + +template +T* getMutable(Constraint& c) +{ + return ::Luau::get_if(&c.c); +} + +template +const T* get(const Constraint& c) +{ + return getMutable(asMutable(c)); +} + +} // namespace Luau diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h new file mode 100644 index 00000000..c25a5537 --- /dev/null +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -0,0 +1,279 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Connective.h" +#include "Luau/Constraint.h" +#include "Luau/DataFlowGraph.h" +#include "Luau/Module.h" +#include "Luau/ModuleResolver.h" +#include "Luau/NotNull.h" +#include "Luau/Symbol.h" +#include "Luau/TypeVar.h" +#include "Luau/Variant.h" + +#include +#include +#include + +namespace Luau +{ + +struct Scope; +using ScopePtr = std::shared_ptr; + +struct DcrLogger; + +struct Inference +{ + TypeId ty = nullptr; + ConnectiveId connective = nullptr; + + Inference() = default; + + explicit Inference(TypeId ty, ConnectiveId connective = nullptr) + : ty(ty) + , connective(connective) + { + } +}; + +struct InferencePack +{ + TypePackId tp = nullptr; + std::vector connectives; + + InferencePack() = default; + + explicit InferencePack(TypePackId tp, const std::vector& connectives = {}) + : tp(tp) + , connectives(connectives) + { + } +}; + +struct ConstraintGraphBuilder +{ + // A list of all the scopes in the module. This vector holds ownership of the + // scope pointers; the scopes themselves borrow pointers to other scopes to + // define the scope hierarchy. + std::vector> scopes; + + ModuleName moduleName; + ModulePtr module; + NotNull singletonTypes; + const NotNull arena; + // The root scope of the module we're generating constraints for. + // This is null when the CGB is initially constructed. + Scope* rootScope; + + // Constraints that go straight to the solver. + std::vector constraints; + + // Constraints that do not go to the solver right away. Other constraints + // will enqueue them during solving. + std::vector unqueuedConstraints; + + // A mapping of AST node to TypeId. + DenseHashMap astTypes{nullptr}; + + // A mapping of AST node to TypePackId. + DenseHashMap astTypePacks{nullptr}; + + // If the node was applied as a function, this is the unspecialized type of + // that expression. + DenseHashMap astOriginalCallTypes{nullptr}; + + // If overload resolution was performed on this element, this is the + // overload that was selected. + DenseHashMap astOverloadResolvedTypes{nullptr}; + + // Types resolved from type annotations. Analogous to astTypes. + DenseHashMap astResolvedTypes{nullptr}; + + // Type packs resolved from type annotations. Analogous to astTypePacks. + DenseHashMap astResolvedTypePacks{nullptr}; + + // Defining scopes for AST nodes. + DenseHashMap astTypeAliasDefiningScopes{nullptr}; + + NotNull dfg; + ConnectiveArena connectiveArena; + + int recursionCount = 0; + + // It is pretty uncommon for constraint generation to itself produce errors, but it can happen. + std::vector errors; + + // Needed to resolve modules to make 'require' import types properly. + NotNull moduleResolver; + // Occasionally constraint generation needs to produce an ICE. + const NotNull ice; + + ScopePtr globalScope; + DcrLogger* logger; + + ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, + NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, + NotNull dfg); + + /** + * Fabricates a new free type belonging to a given scope. + * @param scope the scope the free type belongs to. + */ + TypeId freshType(const ScopePtr& scope); + + /** + * Fabricates a new free type pack belonging to a given scope. + * @param scope the scope the free type pack belongs to. + */ + TypePackId freshTypePack(const ScopePtr& scope); + + /** + * Fabricates a scope that is a child of another scope. + * @param node the lexical node that the scope belongs to. + * @param parent the parent scope of the new scope. Must not be null. + */ + ScopePtr childScope(AstNode* node, const ScopePtr& parent); + + /** + * Adds a new constraint with no dependencies to a given scope. + * @param scope the scope to add the constraint to. + * @param cv the constraint variant to add. + * @return the pointer to the inserted constraint + */ + NotNull addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv); + + /** + * Adds a constraint to a given scope. + * @param scope the scope to add the constraint to. Must not be null. + * @param c the constraint to add. + * @return the pointer to the inserted constraint + */ + NotNull addConstraint(const ScopePtr& scope, std::unique_ptr c); + + void applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective); + + /** + * The entry point to the ConstraintGraphBuilder. This will construct a set + * of scopes, constraints, and free types that can be solved later. + * @param block the root block to generate constraints for. + */ + void visit(AstStatBlock* block); + + void visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); + + void visit(const ScopePtr& scope, AstStat* stat); + void visit(const ScopePtr& scope, AstStatBlock* block); + void visit(const ScopePtr& scope, AstStatLocal* local); + void visit(const ScopePtr& scope, AstStatFor* for_); + void visit(const ScopePtr& scope, AstStatForIn* forIn); + void visit(const ScopePtr& scope, AstStatWhile* while_); + void visit(const ScopePtr& scope, AstStatRepeat* repeat); + void visit(const ScopePtr& scope, AstStatLocalFunction* function); + void visit(const ScopePtr& scope, AstStatFunction* function); + void visit(const ScopePtr& scope, AstStatReturn* ret); + void visit(const ScopePtr& scope, AstStatAssign* assign); + void visit(const ScopePtr& scope, AstStatCompoundAssign* assign); + void visit(const ScopePtr& scope, AstStatIf* ifStatement); + void visit(const ScopePtr& scope, AstStatTypeAlias* alias); + void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); + void visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); + void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); + void visit(const ScopePtr& scope, AstStatError* error); + + InferencePack checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes = {}); + InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes = {}); + + InferencePack checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes); + + /** + * Checks an expression that is expected to evaluate to one type. + * @param scope the scope the expression is contained within. + * @param expr the expression to check. + * @param expectedType the type of the expression that is expected from its + * surrounding context. Used to implement bidirectional type checking. + * @return the type of the expression. + */ + Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}, bool forceSingleton = false); + + Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton); + Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType, bool forceSingleton); + Inference check(const ScopePtr& scope, AstExprLocal* local); + Inference check(const ScopePtr& scope, AstExprGlobal* global); + Inference check(const ScopePtr& scope, AstExprIndexName* indexName); + Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); + Inference check(const ScopePtr& scope, AstExprUnary* unary); + Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); + Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); + Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); + std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); + + TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); + + TypeId checkLValue(const ScopePtr& scope, AstExpr* expr); + + struct FunctionSignature + { + // The type of the function. + TypeId signature; + // The scope that encompasses the function's signature. May be nullptr + // if there was no need for a signature scope (the function has no + // generics). + ScopePtr signatureScope; + // The scope that encompasses the function's body. Is a child scope of + // signatureScope, if present. + ScopePtr bodyScope; + }; + + FunctionSignature checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType = {}); + + /** + * Checks the body of a function expression. + * @param scope the interior scope of the body of the function. + * @param fn the function expression to check. + */ + void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn); + + /** + * Resolves a type from its AST annotation. + * @param scope the scope that the type annotation appears within. + * @param ty the AST annotation to resolve. + * @param topLevel whether the annotation is a "top-level" annotation. + * @return the type of the AST annotation. + **/ + TypeId resolveType(const ScopePtr& scope, AstType* ty, bool topLevel = false); + + /** + * Resolves a type pack from its AST annotation. + * @param scope the scope that the type annotation appears within. + * @param tp the AST annotation to resolve. + * @return the type pack of the AST annotation. + **/ + TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp); + + TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list); + + std::vector> createGenerics(const ScopePtr& scope, AstArray generics); + std::vector> createGenericPacks(const ScopePtr& scope, AstArray packs); + + Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); + + void reportError(Location location, TypeErrorData err); + void reportCodeTooComplex(Location location); + + /** Scan the program for global definitions. + * + * ConstraintGraphBuilder needs to differentiate between globals and accesses to undefined symbols. Doing this "for + * real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an + * initial scan of the AST and note what globals are defined. + */ + void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); +}; + +/** Borrow a vector of pointers from a vector of owning pointers to constraints. + */ +std::vector> borrowConstraints(const std::vector& constraints); + +} // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h new file mode 100644 index 00000000..c02cd4d5 --- /dev/null +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -0,0 +1,230 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Constraint.h" +#include "Luau/Error.h" +#include "Luau/Module.h" +#include "Luau/Normalize.h" +#include "Luau/ToString.h" +#include "Luau/TypeVar.h" +#include "Luau/Variant.h" + +#include + +namespace Luau +{ + +struct DcrLogger; + +// TypeId, TypePackId, or Constraint*. It is impossible to know which, but we +// never dereference this pointer. +using BlockedConstraintId = const void*; + +struct ModuleResolver; + +struct InstantiationSignature +{ + TypeFun fn; + std::vector arguments; + std::vector packArguments; + + bool operator==(const InstantiationSignature& rhs) const; + bool operator!=(const InstantiationSignature& rhs) const + { + return !((*this) == rhs); + } +}; + +struct HashInstantiationSignature +{ + size_t operator()(const InstantiationSignature& signature) const; +}; + +struct ConstraintSolver +{ + TypeArena* arena; + NotNull singletonTypes; + InternalErrorReporter iceReporter; + NotNull normalizer; + // The entire set of constraints that the solver is trying to resolve. + std::vector> constraints; + NotNull rootScope; + ModuleName currentModuleName; + + // Constraints that the solver has generated, rather than sourcing from the + // scope tree. + std::vector> solverConstraints; + + // This includes every constraint that has not been fully solved. + // A constraint can be both blocked and unsolved, for instance. + std::vector> unsolvedConstraints; + + // A mapping of constraint pointer to how many things the constraint is + // blocked on. Can be empty or 0 for constraints that are not blocked on + // anything. + std::unordered_map, size_t> blockedConstraints; + // A mapping of type/pack pointers to the constraints they block. + std::unordered_map>> blocked; + // Memoized instantiations of type aliases. + DenseHashMap instantiatedAliases{{}}; + + // Recorded errors that take place within the solver. + ErrorVec errors; + + NotNull moduleResolver; + std::vector requireCycles; + + DcrLogger* logger; + + explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); + + // Randomize the order in which to dispatch constraints + void randomize(unsigned seed); + + /** + * Attempts to dispatch all pending constraints and reach a type solution + * that satisfies all of the constraints. + **/ + void run(); + + bool isDone(); + + void finalizeModule(); + + /** Attempt to dispatch a constraint. Returns true if it was successful. If + * tryDispatch() returns false, the constraint remains in the unsolved set + * and will be retried later. + */ + bool tryDispatch(NotNull c, bool force); + + bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const IterableConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const NameConstraint& c, NotNull constraint); + bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint); + bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); + bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); + bool tryDispatch(const HasPropConstraint& c, NotNull constraint); + bool tryDispatch(const SetPropConstraint& c, NotNull constraint); + bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); + + // for a, ... in some_table do + // also handles __iter metamethod + bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); + + // for a, ... in next_function, t, ... do + bool tryDispatchIterableFunction( + TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); + + std::optional lookupTableProp(TypeId subjectType, const std::string& propName); + + void block(NotNull target, NotNull constraint); + /** + * Block a constraint on the resolution of a TypeVar. + * @returns false always. This is just to allow tryDispatch to return the result of block() + */ + bool block(TypeId target, NotNull constraint); + bool block(TypePackId target, NotNull constraint); + + // Traverse the type. If any blocked or pending typevars are found, block + // the constraint on them. + // + // Returns false if a type blocks the constraint. + // + // FIXME: This use of a boolean for the return result is an appalling + // interface. + bool recursiveBlock(TypeId target, NotNull constraint); + bool recursiveBlock(TypePackId target, NotNull constraint); + + void unblock(NotNull progressed); + void unblock(TypeId progressed); + void unblock(TypePackId progressed); + void unblock(const std::vector& types); + void unblock(const std::vector& packs); + + /** + * @returns true if the TypeId is in a blocked state. + */ + bool isBlocked(TypeId ty); + + /** + * @returns true if the TypePackId is in a blocked state. + */ + bool isBlocked(TypePackId tp); + + /** + * Returns whether the constraint is blocked on anything. + * @param constraint the constraint to check. + */ + bool isBlocked(NotNull constraint); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result. + * @param subType the sub-type to unify. + * @param superType the super-type to unify. + */ + void unify(TypeId subType, TypeId superType, NotNull scope); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result. + * @param subPack the sub-type pack to unify. + * @param superPack the super-type pack to unify. + */ + void unify(TypePackId subPack, TypePackId superPack, NotNull scope); + + /** Pushes a new solver constraint to the solver. + * @param cv the body of the constraint. + **/ + void pushConstraint(NotNull scope, const Location& location, ConstraintV cv); + + /** + * Attempts to resolve a module from its module information. Returns the + * module-level return type of the module, or the error type if one cannot + * be found. Reports errors to the solver if the module cannot be found or + * the require is illegal. + * @param module the module information to look up. + * @param location the location where the require is taking place; used for + * error locations. + **/ + TypeId resolveModule(const ModuleInfo& module, const Location& location); + + void reportError(TypeErrorData&& data, const Location& location); + void reportError(TypeError e); + +private: + /** + * Marks a constraint as being blocked on a type or type pack. The constraint + * solver will not attempt to dispatch blocked constraints until their + * dependencies have made progress. + * @param target the type or type pack pointer that the constraint is blocked on. + * @param constraint the constraint to block. + **/ + void block_(BlockedConstraintId target, NotNull constraint); + + /** + * Informs the solver that progress has been made on a type or type pack. The + * solver will wake up all constraints that are blocked on the type or type pack, + * and will resume attempting to dispatch them. + * @param progressed the type or type pack pointer that has progressed. + **/ + void unblock_(BlockedConstraintId progressed); + + TypeId errorRecoveryType() const; + TypePackId errorRecoveryTypePack() const; + + TypeId unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes); + + ToStringOptions opts; +}; + +void dump(NotNull rootScope, struct ToStringOptions& opts); + +} // namespace Luau diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h new file mode 100644 index 00000000..bd096ea9 --- /dev/null +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -0,0 +1,120 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +// Do not include LValue. It should never be used here. +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" +#include "Luau/Def.h" +#include "Luau/Symbol.h" + +#include + +namespace Luau +{ + +struct DataFlowGraph +{ + DataFlowGraph(DataFlowGraph&&) = default; + DataFlowGraph& operator=(DataFlowGraph&&) = default; + + // TODO: AstExprLocal, AstExprGlobal, and AstLocal* are guaranteed never to return nullopt. + // We leave them to return an optional as we build it out, but the end state is for them to return a non-optional DefId. + std::optional getDef(const AstExpr* expr) const; + std::optional getDef(const AstLocal* local) const; + + /// Retrieve the Def that corresponds to the given Symbol. + /// + /// We do not perform dataflow analysis on globals, so this function always + /// yields nullopt when passed a global Symbol. + std::optional getDef(const Symbol& symbol) const; + +private: + DataFlowGraph() = default; + + DataFlowGraph(const DataFlowGraph&) = delete; + DataFlowGraph& operator=(const DataFlowGraph&) = delete; + + DefArena arena; + DenseHashMap astDefs{nullptr}; + DenseHashMap localDefs{nullptr}; + + friend struct DataFlowGraphBuilder; +}; + +struct DfgScope +{ + DfgScope* parent; + DenseHashMap bindings{Symbol{}}; +}; + +struct ExpressionFlowGraph +{ + std::optional def; +}; + +// Currently unsound. We do not presently track the control flow of the program. +// Additionally, we do not presently track assignments. +struct DataFlowGraphBuilder +{ + static DataFlowGraph build(AstStatBlock* root, NotNull handle); + +private: + DataFlowGraphBuilder() = default; + + DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete; + DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; + + DataFlowGraph graph; + NotNull arena{&graph.arena}; + struct InternalErrorReporter* handle; + std::vector> scopes; + + // Does not belong in DataFlowGraphBuilder, but the old solver allows properties to escape the scope they were defined in, + // so we will need to be able to emulate this same behavior here too. We can kill this once we have better flow sensitivity. + DenseHashMap> props{nullptr}; + + DfgScope* childScope(DfgScope* scope); + + std::optional use(DfgScope* scope, Symbol symbol, AstExpr* e); + DefId use(DefId def, AstExprIndexName* e); + + void visit(DfgScope* scope, AstStatBlock* b); + void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); + + // TODO: visit type aliases + void visit(DfgScope* scope, AstStat* s); + void visit(DfgScope* scope, AstStatIf* i); + void visit(DfgScope* scope, AstStatWhile* w); + void visit(DfgScope* scope, AstStatRepeat* r); + void visit(DfgScope* scope, AstStatBreak* b); + void visit(DfgScope* scope, AstStatContinue* c); + void visit(DfgScope* scope, AstStatReturn* r); + void visit(DfgScope* scope, AstStatExpr* e); + void visit(DfgScope* scope, AstStatLocal* l); + void visit(DfgScope* scope, AstStatFor* f); + void visit(DfgScope* scope, AstStatForIn* f); + void visit(DfgScope* scope, AstStatAssign* a); + void visit(DfgScope* scope, AstStatCompoundAssign* c); + void visit(DfgScope* scope, AstStatFunction* f); + void visit(DfgScope* scope, AstStatLocalFunction* l); + + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExpr* e); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprLocal* l); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprGlobal* g); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprCall* c); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexName* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexExpr* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprFunction* f); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTable* t); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprUnary* u); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprBinary* b); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTypeAssertion* t); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIfElse* i); + ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprInterpString* i); + + // TODO: visitLValue + // TODO: visitTypes (because of typeof which has access to values namespace, needs unreachable scope) + // TODO: visitTypePacks (because of typeof which has access to values namespace, needs unreachable scope) +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/DcrLogger.h b/Analysis/include/Luau/DcrLogger.h new file mode 100644 index 00000000..30d2e15e --- /dev/null +++ b/Analysis/include/Luau/DcrLogger.h @@ -0,0 +1,133 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Constraint.h" +#include "Luau/NotNull.h" +#include "Luau/Scope.h" +#include "Luau/ToString.h" +#include "Luau/Error.h" +#include "Luau/Variant.h" + +#include +#include +#include + +namespace Luau +{ + +struct ErrorSnapshot +{ + std::string message; + Location location; +}; + +struct BindingSnapshot +{ + std::string typeId; + std::string typeString; + Location location; +}; + +struct TypeBindingSnapshot +{ + std::string typeId; + std::string typeString; +}; + +struct ConstraintGenerationLog +{ + std::string source; + std::unordered_map constraintLocations; + std::vector errors; +}; + +struct ScopeSnapshot +{ + std::unordered_map bindings; + std::unordered_map typeBindings; + std::unordered_map typePackBindings; + std::vector children; +}; + +enum class ConstraintBlockKind +{ + TypeId, + TypePackId, + ConstraintId, +}; + +struct ConstraintBlock +{ + ConstraintBlockKind kind; + std::string stringification; +}; + +struct ConstraintSnapshot +{ + std::string stringification; + std::vector blocks; +}; + +struct BoundarySnapshot +{ + std::unordered_map constraints; + ScopeSnapshot rootScope; +}; + +struct StepSnapshot +{ + std::string currentConstraint; + bool forced; + std::unordered_map unsolvedConstraints; + ScopeSnapshot rootScope; +}; + +struct TypeSolveLog +{ + BoundarySnapshot initialState; + std::vector stepStates; + BoundarySnapshot finalState; +}; + +struct TypeCheckLog +{ + std::vector errors; +}; + +using ConstraintBlockTarget = Variant>; + +struct DcrLogger +{ + std::string compileOutput(); + + void captureSource(std::string source); + void captureGenerationError(const TypeError& error); + void captureConstraintLocation(NotNull constraint, Location location); + + void pushBlock(NotNull constraint, TypeId block); + void pushBlock(NotNull constraint, TypePackId block); + void pushBlock(NotNull constraint, NotNull block); + void popBlock(TypeId block); + void popBlock(TypePackId block); + void popBlock(NotNull block); + + void captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints); + StepSnapshot prepareStepSnapshot( + const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints); + void commitStepSnapshot(StepSnapshot snapshot); + void captureFinalSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints); + + void captureTypeCheckError(const TypeError& error); + +private: + ConstraintGenerationLog generationLog; + std::unordered_map, std::vector> constraintBlocks; + TypeSolveLog solveLog; + TypeCheckLog checkLog; + + ToStringOptions opts; + + std::vector snapshotBlocks(NotNull constraint); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Def.h b/Analysis/include/Luau/Def.h new file mode 100644 index 00000000..1eef7dfd --- /dev/null +++ b/Analysis/include/Luau/Def.h @@ -0,0 +1,91 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/NotNull.h" +#include "Luau/TypedAllocator.h" +#include "Luau/Variant.h" + +#include +#include + +namespace Luau +{ + +struct Def; +using DefId = NotNull; + +struct FieldMetadata +{ + DefId parent; + std::string propName; +}; + +/** + * A cell is a "single-object" value. + * + * Leaky implementation note: sometimes "multiple-object" values, but none of which were interesting enough to warrant creating a phi node instead. + * That can happen because there's no point in creating a phi node that points to either resultant in `if math.random() > 0.5 then 5 else "hello"`. + * This might become of utmost importance if we wanted to do some backward reasoning, e.g. if `5` is taken, then `cond` must be `truthy`. + */ +struct Cell +{ + std::optional field; +}; + +/** + * A phi node is a union of cells. + * + * We need this because we're statically evaluating a program, and sometimes a place may be assigned with + * different cells, and when that happens, we need a special data type that merges in all the cells + * that will flow into that specific place. For example, consider this simple program: + * + * ``` + * x-1 + * if cond() then + * x-2 = 5 + * else + * x-3 = "hello" + * end + * x-4 : {x-2, x-3} + * ``` + * + * At x-4, we know for a fact statically that either `5` or `"hello"` can flow into the variable `x` after the branch, but + * we cannot make any definitive decisions about which one, so we just take in both. + */ +struct Phi +{ + std::vector operands; +}; + +/** + * We statically approximate a value at runtime using a symbolic value, which we call a Def. + * + * DataFlowGraphBuilder will allocate these defs as a stand-in for some Luau values, and bind them to places that + * can hold a Luau value, and then observes how those defs will commute as it statically evaluate the program. + * + * It must also be noted that defs are a cyclic graph, so it is not safe to recursively traverse into it expecting it to terminate. + */ +struct Def +{ + using V = Variant; + + V v; +}; + +template +const T* get(DefId def) +{ + return get_if(&def->v); +} + +struct DefArena +{ + TypedAllocator allocator; + + DefId freshCell(); + DefId freshCell(DefId parent, const std::string& prop); + // TODO: implement once we have cases where we need to merge in definitions + // DefId phi(const std::vector& defs); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Documentation.h b/Analysis/include/Luau/Documentation.h index 7b609b4f..67a9feb1 100644 --- a/Analysis/include/Luau/Documentation.h +++ b/Analysis/include/Luau/Documentation.h @@ -1,3 +1,4 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/DenseHash.h" @@ -12,10 +13,18 @@ namespace Luau struct FunctionDocumentation; struct TableDocumentation; struct OverloadedFunctionDocumentation; +struct BasicDocumentation; -using Documentation = Luau::Variant; +using Documentation = Luau::Variant; using DocumentationSymbol = std::string; +struct BasicDocumentation +{ + std::string documentation; + std::string learnMoreLink; + std::string codeSample; +}; + struct FunctionParameterDocumentation { std::string name; @@ -29,6 +38,8 @@ struct FunctionDocumentation std::string documentation; std::vector parameters; std::vector returns; + std::string learnMoreLink; + std::string codeSample; }; struct OverloadedFunctionDocumentation @@ -43,6 +54,8 @@ struct TableDocumentation { std::string documentation; Luau::DenseHashMap keys; + std::string learnMoreLink; + std::string codeSample; }; using DocumentationDatabase = Luau::DenseHashMap; diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 946bc928..739354f8 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -1,7 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/FileResolver.h" #include "Luau/Location.h" #include "Luau/TypeVar.h" #include "Luau/Variant.h" @@ -9,10 +8,33 @@ namespace Luau { +struct FileResolver; +struct TypeArena; +struct TypeError; + struct TypeMismatch { - TypeId wantedType; - TypeId givenType; + enum Context + { + CovariantContext, + InvariantContext + }; + + TypeMismatch() = default; + TypeMismatch(TypeId wantedType, TypeId givenType); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional error); + + TypeMismatch(TypeId wantedType, TypeId givenType, Context context); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, Context context); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional error, Context context); + + TypeId wantedType = nullptr; + TypeId givenType = nullptr; + Context context = CovariantContext; + + std::string reason; + std::shared_ptr error; bool operator==(const TypeMismatch& rhs) const; }; @@ -23,7 +45,6 @@ struct UnknownSymbol { Binding, Type, - Generic }; Name name; Context context; @@ -71,7 +92,7 @@ struct OnlyTablesCanHaveMethods struct DuplicateTypeDefinition { Name name; - Location previousLocation; + std::optional previousLocation; bool operator==(const DuplicateTypeDefinition& rhs) const; }; @@ -81,12 +102,16 @@ struct CountMismatch enum Context { Arg, - Result, + FunctionResult, + ExprListResult, Return, }; size_t expected; + std::optional maximum; size_t actual; Context context = Arg; + bool isVariadic = false; + std::string function; bool operator==(const CountMismatch& rhs) const; }; @@ -98,8 +123,6 @@ struct FunctionDoesNotTakeSelf struct FunctionRequiresSelf { - int requiredExtraNils = 0; - bool operator==(const FunctionRequiresSelf& rhs) const; }; @@ -120,6 +143,7 @@ struct IncorrectGenericParameterCount Name name; TypeFun typeFun; size_t actualParameters; + size_t actualPackParameters; bool operator==(const IncorrectGenericParameterCount& rhs) const; }; @@ -159,6 +183,13 @@ struct GenericError bool operator==(const GenericError& rhs) const; }; +struct InternalError +{ + std::string message; + + bool operator==(const InternalError& rhs) const; +}; + struct CannotCallNonFunction { TypeId ty; @@ -267,11 +298,57 @@ struct MissingUnionProperty bool operator==(const MissingUnionProperty& rhs) const; }; +struct TypesAreUnrelated +{ + TypeId left; + TypeId right; + + bool operator==(const TypesAreUnrelated& rhs) const; +}; + +struct NormalizationTooComplex +{ + bool operator==(const NormalizationTooComplex&) const + { + return true; + } +}; + +struct TypePackMismatch +{ + TypePackId wantedTp; + TypePackId givenTp; + + bool operator==(const TypePackMismatch& rhs) const; +}; + +struct DynamicPropertyLookupOnClassesUnsafe +{ + TypeId ty; + + bool operator==(const DynamicPropertyLookupOnClassesUnsafe& rhs) const; +}; + using TypeErrorData = Variant; + DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, + TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch, DynamicPropertyLookupOnClassesUnsafe>; + +struct TypeErrorSummary +{ + Location location; + ModuleName moduleName; + int code; + + TypeErrorSummary(const Location& location, const ModuleName& moduleName, int code) + : location(location) + , moduleName(moduleName) + , code(code) + { + } +}; struct TypeError { @@ -279,6 +356,7 @@ struct TypeError ModuleName moduleName; TypeErrorData data; + static int minCode(); int code() const; TypeError() = default; @@ -296,6 +374,8 @@ struct TypeError } bool operator==(const TypeError& rhs) const; + + TypeErrorSummary summary() const; }; template @@ -312,7 +392,13 @@ T* get(TypeError& e) using ErrorVec = std::vector; +struct TypeErrorToStringOptions +{ + FileResolver* fileResolver = nullptr; +}; + std::string toString(const TypeError& error); +std::string toString(const TypeError& error, TypeErrorToStringOptions options); bool containsParseErrorName(const TypeError& error); @@ -329,4 +415,29 @@ struct InternalErrorReporter [[noreturn]] void ice(const std::string& message); }; +class InternalCompilerError : public std::exception +{ +public: + explicit InternalCompilerError(const std::string& message) + : message(message) + { + } + explicit InternalCompilerError(const std::string& message, const std::string& moduleName) + : message(message) + , moduleName(moduleName) + { + } + explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location) + : message(message) + , moduleName(moduleName) + , location(location) + { + } + virtual const char* what() const throw(); + + const std::string message; + const std::optional moduleName; + const std::optional location; +}; + } // namespace Luau diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 71f9464b..0fdcce16 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -25,51 +25,32 @@ struct SourceCode Type type; }; +struct ModuleInfo +{ + ModuleName name; + bool optional = false; +}; + struct FileResolver { virtual ~FileResolver() {} - /** Fetch the source code associated with the provided ModuleName. - * - * FIXME: This requires a string copy! - * - * @returns The actual Lua code on success. - * @returns std::nullopt if no such file exists. When this occurs, type inference will report an UnknownRequire error. - */ virtual std::optional readSource(const ModuleName& name) = 0; - /** Does the module exist? - * - * Saves a string copy over reading the source and throwing it away. - */ - virtual bool moduleExists(const ModuleName& name) const = 0; + virtual std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) + { + return std::nullopt; + } - virtual std::optional fromAstFragment(AstExpr* expr) const = 0; - - /** Given a valid module name and a string of arbitrary data, figure out the concatenation. - */ - virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; - - /** Goes "up" a level in the hierarchy that the ModuleName represents. - * - * For instances, this is analogous to someInstance.Parent; for paths, this is equivalent to removing the last - * element of the path. Other ModuleName representations may have other ways of doing this. - * - * @returns The parent ModuleName, if one exists. - * @returns std::nullopt if there is no parent for this module name. - */ - virtual std::optional getParentModuleName(const ModuleName& name) const = 0; - - virtual std::optional getHumanReadableModuleName_(const ModuleName& name) const + virtual std::string getHumanReadableModuleName(const ModuleName& name) const { return name; } - virtual std::optional getEnvironmentForModule(const ModuleName& name) const = 0; - - /** LanguageService only: - * std::optional fromInstance(Instance* inst) - */ + virtual std::optional getEnvironmentForModule(const ModuleName& name) const + { + return std::nullopt; + } }; struct NullFileResolver : FileResolver @@ -78,26 +59,6 @@ struct NullFileResolver : FileResolver { return std::nullopt; } - bool moduleExists(const ModuleName& name) const override - { - return false; - } - std::optional fromAstFragment(AstExpr* expr) const override - { - return std::nullopt; - } - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override - { - return lhs; - } - std::optional getParentModuleName(const ModuleName& name) const override - { - return std::nullopt; - } - std::optional getEnvironmentForModule(const ModuleName& name) const override - { - return std::nullopt; - } }; } // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 07a0296a..b2662c68 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -5,6 +5,7 @@ #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" @@ -24,6 +25,7 @@ struct TypeChecker; struct FileResolver; struct ModuleResolver; struct ParseResult; +struct HotComment; struct LoadDefinitionFileResult { @@ -35,7 +37,7 @@ struct LoadDefinitionFileResult LoadDefinitionFileResult loadDefinitionFile( TypeChecker& typeChecker, ScopePtr targetScope, std::string_view definition, const std::string& packageName); -std::optional parseMode(const std::vector& hotcomments); +std::optional parseMode(const std::vector& hotcomments); std::vector parsePathExpr(const AstExpr& pathExpr); @@ -54,10 +56,23 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleNa struct SourceNode { + bool hasDirtySourceModule() const + { + return dirtySourceModule; + } + + bool hasDirtyModule(bool forAutocomplete) const + { + return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; + } + ModuleName name; - std::unordered_set requires; + std::unordered_set requireSet; std::vector> requireLocations; - bool dirty = true; + bool dirtySourceModule = true; + bool dirtyModule = true; + bool dirtyModuleForAutocomplete = true; + double autocompleteLimitsMult = 1.0; }; struct FrontendOptions @@ -68,14 +83,19 @@ struct FrontendOptions // is complete. bool retainFullTypeGraphs = false; - // When true, we run typechecking twice, one in the regular mode, ond once in strict mode - // in order to get more precise type information (e.g. for autocomplete). - bool typecheckTwice = false; + // Run typechecking only in mode required for autocomplete (strict mode in + // order to get more precise type information) + bool forAutocomplete = false; + + // If not empty, randomly shuffle the constraint set before attempting to + // solve. Use this value to seed the random number generator. + std::optional randomizeConstraintResolutionSeed; }; struct CheckResult { std::vector errors; + std::vector timeoutHits; }; struct FrontendModuleResolver : ModuleResolver @@ -109,20 +129,12 @@ struct Frontend Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {}); - CheckResult check(const ModuleName& name); // new shininess - LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); + CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess + LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); - /** Lint some code that has no associated DataModel object - * - * Since this source fragment has no name, we cannot cache its AST. Instead, - * we return it to the caller to use as they wish. - */ - std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); + LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); - CheckResult check(const SourceModule& module); // OLD. TODO KILL - LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); - - bool isDirty(const ModuleName& name) const; + bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); /** Borrow a pointer into the SourceModule cache. @@ -145,20 +157,31 @@ struct Frontend void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); + LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName); + + ScopePtr getGlobalScope(); + private: + ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope, std::vector requireCycles, + bool forAutocomplete = false); + std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); - bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root); + bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete); static LintResult classifyLints(const std::vector& warnings, const Config& config); - ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config); + ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete = false); std::unordered_map environments; std::unordered_map> builtinDefinitions; + SingletonTypes singletonTypes_; + public: + const NotNull singletonTypes; + FileResolver* fileResolver; FrontendModuleResolver moduleResolver; FrontendModuleResolver moduleResolverForAutocomplete; @@ -167,13 +190,16 @@ public: ConfigResolver* configResolver; FrontendOptions options; InternalErrorReporter iceHandler; - TypeArena arenaForAutocomplete; + TypeArena globalTypes; std::unordered_map sourceNodes; std::unordered_map sourceModules; - std::unordered_map requires; + std::unordered_map requireTrace; Stats stats = {}; + +private: + ScopePtr globalScope; }; } // namespace Luau diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h new file mode 100644 index 00000000..cd88d33a --- /dev/null +++ b/Analysis/include/Luau/Instantiation.h @@ -0,0 +1,57 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Substitution.h" +#include "Luau/TypeVar.h" +#include "Luau/Unifiable.h" + +namespace Luau +{ + +struct TypeArena; +struct TxnLog; + +// A substitution which replaces generic types in a given set by free types. +struct ReplaceGenerics : Substitution +{ + ReplaceGenerics(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope, const std::vector& generics, + const std::vector& genericPacks) + : Substitution(log, arena) + , level(level) + , scope(scope) + , generics(generics) + , genericPacks(genericPacks) + { + } + + TypeLevel level; + Scope* scope; + std::vector generics; + std::vector genericPacks; + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +// A substitution which replaces generic functions by monomorphic functions +struct Instantiation : Substitution +{ + Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope) + : Substitution(log, arena) + , level(level) + , scope(scope) + { + } + + TypeLevel level; + Scope* scope; + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/IostreamHelpers.h b/Analysis/include/Luau/IostreamHelpers.h index f9e9cd48..05b94516 100644 --- a/Analysis/include/Luau/IostreamHelpers.h +++ b/Analysis/include/Luau/IostreamHelpers.h @@ -30,12 +30,17 @@ std::ostream& operator<<(std::ostream& lhs, const OccursCheckFailed& error); std::ostream& operator<<(std::ostream& lhs, const UnknownRequire& error); std::ostream& operator<<(std::ostream& lhs, const UnknownPropButFoundLikeProp& e); std::ostream& operator<<(std::ostream& lhs, const GenericError& error); +std::ostream& operator<<(std::ostream& lhs, const InternalError& error); std::ostream& operator<<(std::ostream& lhs, const FunctionExitsWithoutReturning& error); std::ostream& operator<<(std::ostream& lhs, const MissingProperties& error); std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error); std::ostream& operator<<(std::ostream& lhs, const ModuleHasCyclicDependency& error); std::ostream& operator<<(std::ostream& lhs, const DuplicateGenericParameter& error); std::ostream& operator<<(std::ostream& lhs, const CannotInferBinaryOperation& error); +std::ostream& operator<<(std::ostream& lhs, const SwappedGenericTypeParameter& error); +std::ostream& operator<<(std::ostream& lhs, const OptionalValueAccess& error); +std::ostream& operator<<(std::ostream& lhs, const MissingUnionProperty& error); +std::ostream& operator<<(std::ostream& lhs, const TypesAreUnrelated& error); std::ostream& operator<<(std::ostream& lhs, const TableState& tv); std::ostream& operator<<(std::ostream& lhs, const TypeVar& tv); diff --git a/Analysis/include/Luau/JsonEmitter.h b/Analysis/include/Luau/JsonEmitter.h new file mode 100644 index 00000000..1a416586 --- /dev/null +++ b/Analysis/include/Luau/JsonEmitter.h @@ -0,0 +1,247 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include +#include +#include +#include + +#include "Luau/NotNull.h" + +namespace Luau::Json +{ + +struct JsonEmitter; + +/// Writes a value to the JsonEmitter. Note that this can produce invalid JSON +/// if you do not insert commas or appropriate object / array syntax. +template +void write(JsonEmitter&, T) = delete; + +/// Writes a boolean to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param b the boolean to write. +void write(JsonEmitter& emitter, bool b); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, int i); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, long i); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, long long i); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, unsigned int i); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, unsigned long i); + +/// Writes an integer to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param i the integer to write. +void write(JsonEmitter& emitter, unsigned long long i); + +/// Writes a double to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param d the double to write. +void write(JsonEmitter& emitter, double d); + +/// Writes a string to a JsonEmitter. The string will be escaped. +/// @param emitter the emitter to write to. +/// @param sv the string to write. +void write(JsonEmitter& emitter, std::string_view sv); + +/// Writes a character to a JsonEmitter as a single-character string. The +/// character will be escaped. +/// @param emitter the emitter to write to. +/// @param c the string to write. +void write(JsonEmitter& emitter, char c); + +/// Writes a string to a JsonEmitter. The string will be escaped. +/// @param emitter the emitter to write to. +/// @param str the string to write. +void write(JsonEmitter& emitter, const char* str); + +/// Writes a string to a JsonEmitter. The string will be escaped. +/// @param emitter the emitter to write to. +/// @param str the string to write. +void write(JsonEmitter& emitter, const std::string& str); + +/// Writes null to a JsonEmitter. +/// @param emitter the emitter to write to. +void write(JsonEmitter& emitter, std::nullptr_t); + +/// Writes null to a JsonEmitter. +/// @param emitter the emitter to write to. +void write(JsonEmitter& emitter, std::nullopt_t); + +struct ObjectEmitter; +struct ArrayEmitter; + +struct JsonEmitter +{ + JsonEmitter(); + + /// Converts the current contents of the JsonEmitter to a string value. This + /// does not invalidate the emitter, but it does not clear it either. + std::string str(); + + /// Returns the current comma state and resets it to false. Use popComma to + /// restore the old state. + /// @returns the previous comma state. + bool pushComma(); + + /// Restores a previous comma state. + /// @param c the comma state to restore. + void popComma(bool c); + + /// Writes a raw sequence of characters to the buffer, without escaping or + /// other processing. + /// @param sv the character sequence to write. + void writeRaw(std::string_view sv); + + /// Writes a character to the buffer, without escaping or other processing. + /// @param c the character to write. + void writeRaw(char c); + + /// Writes a comma if this wasn't the first time writeComma has been + /// invoked. Otherwise, sets the comma state to true. + /// @see pushComma + /// @see popComma + void writeComma(); + + /// Begins writing an object to the emitter. + /// @returns an ObjectEmitter that can be used to write key-value pairs. + ObjectEmitter writeObject(); + + /// Begins writing an array to the emitter. + /// @returns an ArrayEmitter that can be used to write values. + ArrayEmitter writeArray(); + +private: + bool comma = false; + std::vector chunks; + + void newChunk(); +}; + +/// An interface for writing an object into a JsonEmitter instance. +/// @see JsonEmitter::writeObject +struct ObjectEmitter +{ + ObjectEmitter(NotNull emitter); + ~ObjectEmitter(); + + NotNull emitter; + bool comma; + bool finished; + + /// Writes a key-value pair to the associated JsonEmitter. Keys will be escaped. + /// @param name the name of the key-value pair. + /// @param value the value to write. + template + void writePair(std::string_view name, T value) + { + if (finished) + { + return; + } + + emitter->writeComma(); + write(*emitter, name); + emitter->writeRaw(':'); + write(*emitter, value); + } + + /// Finishes writing the object, appending a closing `}` character and + /// resetting the comma state of the associated emitter. This can only be + /// called once, and once called will render the emitter unusable. This + /// method is also called when the ObjectEmitter is destructed. + void finish(); +}; + +/// An interface for writing an array into a JsonEmitter instance. Array values +/// do not need to be the same type. +/// @see JsonEmitter::writeArray +struct ArrayEmitter +{ + ArrayEmitter(NotNull emitter); + ~ArrayEmitter(); + + NotNull emitter; + bool comma; + bool finished; + + /// Writes a value to the array. + /// @param value the value to write. + template + void writeValue(T value) + { + if (finished) + { + return; + } + + emitter->writeComma(); + write(*emitter, value); + } + + /// Finishes writing the object, appending a closing `]` character and + /// resetting the comma state of the associated emitter. This can only be + /// called once, and once called will render the emitter unusable. This + /// method is also called when the ArrayEmitter is destructed. + void finish(); +}; + +/// Writes a vector as an array to a JsonEmitter. +/// @param emitter the emitter to write to. +/// @param vec the vector to write. +template +void write(JsonEmitter& emitter, const std::vector& vec) +{ + ArrayEmitter a = emitter.writeArray(); + + for (const T& value : vec) + a.writeValue(value); + + a.finish(); +} + +/// Writes an optional to a JsonEmitter. Will write the contained value, if +/// present, or null, if no value is present. +/// @param emitter the emitter to write to. +/// @param v the value to write. +template +void write(JsonEmitter& emitter, const std::optional& v) +{ + if (v.has_value()) + write(emitter, *v); + else + emitter.writeRaw("null"); +} + +template +void write(JsonEmitter& emitter, const std::unordered_map& map) +{ + ObjectEmitter o = emitter.writeObject(); + + for (const auto& [k, v] : map) + o.writePair(k, v); + + o.finish(); +} + +} // namespace Luau::Json diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h new file mode 100644 index 00000000..518cbfaf --- /dev/null +++ b/Analysis/include/Luau/LValue.h @@ -0,0 +1,53 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Variant.h" +#include "Luau/Symbol.h" + +#include +#include + +namespace Luau +{ + +struct TypeVar; +using TypeId = const TypeVar*; + +struct Field; + +// Deprecated. Do not use in new work. +using LValue = Variant; + +struct Field +{ + std::shared_ptr parent; + std::string key; + + bool operator==(const Field& rhs) const; + bool operator!=(const Field& rhs) const; +}; + +struct LValueHasher +{ + size_t operator()(const LValue& lvalue) const; +}; + +const LValue* baseof(const LValue& lvalue); + +std::optional tryGetLValue(const class AstExpr& expr); + +// Utility function: breaks down an LValue to get at the Symbol +Symbol getBaseSymbol(const LValue& lvalue); + +template +const T* get(const LValue& lvalue) +{ + return get_if(&lvalue); +} + +using RefinementMap = std::unordered_map; + +void merge(RefinementMap& l, const RefinementMap& r, std::function f); +void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty); + +} // namespace Luau diff --git a/Analysis/include/Luau/Linter.h b/Analysis/include/Luau/Linter.h index 1f7f7f9d..6bbc3d66 100644 --- a/Analysis/include/Luau/Linter.h +++ b/Analysis/include/Luau/Linter.h @@ -4,6 +4,7 @@ #include "Luau/Location.h" #include +#include #include namespace Luau @@ -14,6 +15,7 @@ class AstStat; class AstNameTable; struct TypeChecker; struct Module; +struct HotComment; using ScopePtr = std::shared_ptr; @@ -49,6 +51,10 @@ struct LintWarning Code_DeprecatedApi = 22, Code_TableOperations = 23, Code_DuplicateCondition = 24, + Code_MisleadingAndOr = 25, + Code_CommentDirective = 26, + Code_IntegerParsing = 27, + Code_ComparisonPrecedence = 28, Code__Count }; @@ -59,7 +65,7 @@ struct LintWarning static const char* getName(Code code); static Code parseName(const char* name); - static uint64_t parseMask(const std::vector& hotcomments); + static uint64_t parseMask(const std::vector& hotcomments); }; struct LintResult @@ -89,7 +95,8 @@ struct LintOptions void setDefaults(); }; -std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const LintOptions& options); +std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, + const std::vector& hotcomments, const LintOptions& options); std::vector getDeprecatedGlobals(const AstNameTable& names); diff --git a/Analysis/include/Luau/Metamethods.h b/Analysis/include/Luau/Metamethods.h new file mode 100644 index 00000000..84b0092f --- /dev/null +++ b/Analysis/include/Luau/Metamethods.h @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" + +#include + +namespace Luau +{ + +static const std::unordered_map kBinaryOpMetamethods{ + {AstExprBinary::Op::CompareEq, "__eq"}, + {AstExprBinary::Op::CompareNe, "__eq"}, + {AstExprBinary::Op::CompareGe, "__lt"}, + {AstExprBinary::Op::CompareGt, "__le"}, + {AstExprBinary::Op::CompareLe, "__le"}, + {AstExprBinary::Op::CompareLt, "__lt"}, + {AstExprBinary::Op::Add, "__add"}, + {AstExprBinary::Op::Sub, "__sub"}, + {AstExprBinary::Op::Mul, "__mul"}, + {AstExprBinary::Op::Div, "__div"}, + {AstExprBinary::Op::Pow, "__pow"}, + {AstExprBinary::Op::Mod, "__mod"}, + {AstExprBinary::Op::Concat, "__concat"}, +}; + +static const std::unordered_map kUnaryOpMetamethods{ + {AstExprUnary::Op::Minus, "__unm"}, + {AstExprUnary::Op::Len, "__len"}, +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 413b68f4..d22aad12 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -1,12 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/FileResolver.h" -#include "Luau/TypePack.h" -#include "Luau/TypedAllocator.h" -#include "Luau/ParseOptions.h" #include "Luau/Error.h" -#include "Luau/Parser.h" +#include "Luau/FileResolver.h" +#include "Luau/ParseOptions.h" +#include "Luau/ParseResult.h" +#include "Luau/Scope.h" +#include "Luau/TypeArena.h" #include #include @@ -21,6 +21,9 @@ struct Module; using ScopePtr = std::shared_ptr; using ModulePtr = std::shared_ptr; +class AstType; +class AstTypePack; + /// Root of the AST of a parsed source file struct SourceModule { @@ -29,14 +32,14 @@ struct SourceModule std::optional environmentName; bool cyclic = false; - std::unique_ptr allocator; - std::unique_ptr names; + std::shared_ptr allocator; + std::shared_ptr names; std::vector parseErrors; AstStatBlock* root = nullptr; std::optional mode; - uint64_t ignoreLints = 0; + std::vector hotcomments; std::vector commentLocations; SourceModule() @@ -48,40 +51,12 @@ struct SourceModule bool isWithinComment(const SourceModule& sourceModule, Position pos); -struct TypeArena +struct RequireCycle { - TypedAllocator typeVars; - TypedAllocator typePacks; - - void clear(); - - template - TypeId addType(T tv) - { - return addTV(TypeVar(std::move(tv))); - } - - TypeId addTV(TypeVar&& tv); - - TypeId freshType(TypeLevel level); - - TypePackId addTypePack(std::initializer_list types); - TypePackId addTypePack(std::vector types); - TypePackId addTypePack(TypePack pack); - TypePackId addTypePack(TypePackVar pack); + Location location; + std::vector path; // one of the paths for a require() to go all the way back to the originating module }; -void freeze(TypeArena& arena); -void unfreeze(TypeArena& arena); - -// Only exposed so they can be unit tested. -using SeenTypes = std::unordered_map; -using SeenTypePacks = std::unordered_map; - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); - struct Module { ~Module(); @@ -89,23 +64,33 @@ struct Module TypeArena interfaceTypes; TypeArena internalTypes; + // Scopes and AST types refer to parse data, so we need to keep that alive + std::shared_ptr allocator; + std::shared_ptr names; + std::vector> scopes; // never empty - std::unordered_map astTypes; - std::unordered_map astExpectedTypes; - std::unordered_map astOriginalCallTypes; - std::unordered_map astOverloadResolvedTypes; + + DenseHashMap astTypes{nullptr}; + DenseHashMap astTypePacks{nullptr}; + DenseHashMap astExpectedTypes{nullptr}; + DenseHashMap astOriginalCallTypes{nullptr}; + DenseHashMap astOverloadResolvedTypes{nullptr}; + DenseHashMap astResolvedTypes{nullptr}; + DenseHashMap astResolvedTypePacks{nullptr}; + // Map AST nodes to the scope they create. Cannot be NotNull because we need a sentinel value for the map. + DenseHashMap astScopes{nullptr}; + std::unordered_map declaredGlobals; ErrorVec errors; Mode mode; SourceCode::Type type; + bool timeout = false; ScopePtr getModuleScope() const; // Once a module has been typechecked, we clone its public interface into a separate arena. // This helps us to force TypeVar ownership into a DAG rather than a DCG. - // Returns true if there were any free types encountered in the public interface. This - // indicates a bug in the type checker that we want to surface. - bool clonePublicInterface(); + void clonePublicInterface(NotNull singletonTypes, InternalErrorReporter& ice); }; } // namespace Luau diff --git a/Analysis/include/Luau/ModuleResolver.h b/Analysis/include/Luau/ModuleResolver.h index a394a21b..d892ccd7 100644 --- a/Analysis/include/Luau/ModuleResolver.h +++ b/Analysis/include/Luau/ModuleResolver.h @@ -15,12 +15,6 @@ struct Module; using ModulePtr = std::shared_ptr; -struct ModuleInfo -{ - ModuleName name; - bool optional = false; -}; - struct ModuleResolver { virtual ~ModuleResolver() {} diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h new file mode 100644 index 00000000..39257315 --- /dev/null +++ b/Analysis/include/Luau/Normalize.h @@ -0,0 +1,329 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/NotNull.h" +#include "Luau/TypeVar.h" +#include "Luau/UnifierSharedState.h" + +#include + +namespace Luau +{ + +struct InternalErrorReporter; +struct Module; +struct Scope; +struct SingletonTypes; + +using ModulePtr = std::shared_ptr; + +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); +bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice); + +class TypeIds +{ +private: + std::unordered_set types; + std::vector order; + std::size_t hash = 0; + +public: + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + + TypeIds(const TypeIds&) = delete; + TypeIds(TypeIds&&) = default; + TypeIds() = default; + ~TypeIds() = default; + TypeIds& operator=(TypeIds&&) = default; + + void insert(TypeId ty); + /// Erase every element that does not also occur in tys + void retain(const TypeIds& tys); + void clear(); + + TypeId front() const; + iterator begin(); + iterator end(); + const_iterator begin() const; + const_iterator end() const; + iterator erase(const_iterator it); + + size_t size() const; + bool empty() const; + size_t count(TypeId ty) const; + + template + void insert(Iterator begin, Iterator end) + { + for (Iterator it = begin; it != end; ++it) + insert(*it); + } + + bool operator==(const TypeIds& there) const; + size_t getHash() const; +}; + +} // namespace Luau + +template<> +struct std::hash +{ + std::size_t operator()(const Luau::TypeIds& tys) const + { + return tys.getHash(); + } +}; + +template<> +struct std::hash +{ + std::size_t operator()(const Luau::TypeIds* tys) const + { + return tys->getHash(); + } +}; + +template<> +struct std::equal_to +{ + bool operator()(const Luau::TypeIds& here, const Luau::TypeIds& there) const + { + return here == there; + } +}; + +template<> +struct std::equal_to +{ + bool operator()(const Luau::TypeIds* here, const Luau::TypeIds* there) const + { + return *here == *there; + } +}; + +namespace Luau +{ + +/** A normalized string type is either `string` (represented by `nullopt`) or a + * union of string singletons. + * + * The representation is as follows: + * + * * A union of string singletons is finite and includes the singletons named by + * the `singletons` field. + * * An intersection of negated string singletons is cofinite and includes the + * singletons excluded by the `singletons` field. It is implied that cofinite + * values are exclusions from `string` itself. + * * The `string` data type is a cofinite set minus zero elements. + * * The `never` data type is a finite set plus zero elements. + */ +struct NormalizedStringType +{ + // When false, this type represents a union of singleton string types. + // eg "a" | "b" | "c" + // + // When true, this type represents string intersected with negated string + // singleton types. + // eg string & ~"a" & ~"b" & ... + bool isCofinite = false; + + std::map singletons; + + void resetToString(); + void resetToNever(); + + bool isNever() const; + bool isString() const; + + /// Returns true if the string has finite domain. + /// + /// Important subtlety: This method returns true for `never`. The empty set + /// is indeed an empty set. + bool isUnion() const; + + /// Returns true if the string has infinite domain. + bool isIntersection() const; + + bool includes(const std::string& str) const; + + static const NormalizedStringType never; + + NormalizedStringType(); + NormalizedStringType(bool isCofinite, std::map singletons); +}; + +bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr); + +// A normalized function type can be `never`, the top function type `function`, +// or an intersection of function types. +// +// NOTE: type normalization can fail on function types with generics (e.g. +// because we do not support unions and intersections of generic type packs), so +// this type may contain `error`. +struct NormalizedFunctionType +{ + NormalizedFunctionType(); + + bool isTop = false; + // TODO: Remove this wrapping optional when clipping + // FFlagLuauNegatedFunctionTypes. + std::optional parts; + + void resetToNever(); + void resetToTop(); + + bool isNever() const; +}; + +// A normalized generic/free type is a union, where each option is of the form (X & T) where +// * X is either a free type or a generic +// * T is a normalized type. +struct NormalizedType; +using NormalizedTyvars = std::unordered_map>; + +bool isInhabited_DEPRECATED(const NormalizedType& norm); + +// A normalized type is either any, unknown, or one of the form P | T | F | G where +// * P is a union of primitive types (including singletons, classes and the error type) +// * T is a union of table types +// * F is a union of an intersection of function types +// * G is a union of generic/free normalized types, intersected with a normalized type +struct NormalizedType +{ + // The top part of the type. + // This type is either never, unknown, or any. + // If this type is not never, all the other fields are null. + TypeId tops; + + // The boolean part of the type. + // This type is either never, boolean type, or a boolean singleton. + TypeId booleans; + + // The class part of the type. + // Each element of this set is a class, and none of the classes are subclasses of each other. + TypeIds classes; + + // The error part of the type. + // This type is either never or the error type. + TypeId errors; + + // The nil part of the type. + // This type is either never or nil. + TypeId nils; + + // The number part of the type. + // This type is either never or number. + TypeId numbers; + + // The string part of the type. + // This may be the `string` type, or a union of singletons. + NormalizedStringType strings; + + // The thread part of the type. + // This type is either never or thread. + TypeId threads; + + // The (meta)table part of the type. + // Each element of this set is a (meta)table type. + TypeIds tables; + + // The function part of the type. + NormalizedFunctionType functions; + + // The generic/free part of the type. + NormalizedTyvars tyvars; + + NormalizedType(NotNull singletonTypes); + + NormalizedType() = delete; + ~NormalizedType() = default; + + NormalizedType(const NormalizedType&) = delete; + NormalizedType& operator=(const NormalizedType&) = delete; + + NormalizedType(NormalizedType&&) = default; + NormalizedType& operator=(NormalizedType&&) = default; +}; + +class Normalizer +{ + std::unordered_map> cachedNormals; + std::unordered_map cachedIntersections; + std::unordered_map cachedUnions; + std::unordered_map> cachedTypeIds; + bool withinResourceLimits(); + +public: + TypeArena* arena; + NotNull singletonTypes; + NotNull sharedState; + + Normalizer(TypeArena* arena, NotNull singletonTypes, NotNull sharedState); + Normalizer(const Normalizer&) = delete; + Normalizer(Normalizer&&) = delete; + Normalizer() = delete; + ~Normalizer() = default; + Normalizer& operator=(Normalizer&&) = delete; + Normalizer& operator=(Normalizer&) = delete; + + // If this returns null, the typechecker should emit a "too complex" error + const NormalizedType* normalize(TypeId ty); + void clearNormal(NormalizedType& norm); + + // ------- Cached TypeIds + TypeId unionType(TypeId here, TypeId there); + TypeId intersectionType(TypeId here, TypeId there); + const TypeIds* cacheTypeIds(TypeIds tys); + void clearCaches(); + + // ------- Normalizing unions + void unionTysWithTy(TypeIds& here, TypeId there); + TypeId unionOfTops(TypeId here, TypeId there); + TypeId unionOfBools(TypeId here, TypeId there); + void unionClassesWithClass(TypeIds& heres, TypeId there); + void unionClasses(TypeIds& heres, const TypeIds& theres); + void unionStrings(NormalizedStringType& here, const NormalizedStringType& there); + std::optional unionOfTypePacks(TypePackId here, TypePackId there); + std::optional unionOfFunctions(TypeId here, TypeId there); + std::optional unionSaturatedFunctions(TypeId here, TypeId there); + void unionFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); + void unionFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); + void unionTablesWithTable(TypeIds& heres, TypeId there); + void unionTables(TypeIds& heres, const TypeIds& theres); + bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); + bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); + + // ------- Negations + std::optional negateNormal(const NormalizedType& here); + TypeIds negateAll(const TypeIds& theres); + TypeId negate(TypeId there); + void subtractPrimitive(NormalizedType& here, TypeId ty); + void subtractSingleton(NormalizedType& here, TypeId ty); + + // ------- Normalizing intersections + TypeId intersectionOfTops(TypeId here, TypeId there); + TypeId intersectionOfBools(TypeId here, TypeId there); + void intersectClasses(TypeIds& heres, const TypeIds& theres); + void intersectClassesWithClass(TypeIds& heres, TypeId there); + void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); + std::optional intersectionOfTypePacks(TypePackId here, TypePackId there); + std::optional intersectionOfTables(TypeId here, TypeId there); + void intersectTablesWithTable(TypeIds& heres, TypeId there); + void intersectTables(TypeIds& heres, const TypeIds& theres); + std::optional intersectionOfFunctions(TypeId here, TypeId there); + void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); + void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); + bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there); + bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); + bool intersectNormalWithTy(NormalizedType& here, TypeId there); + + // Check for inhabitance + bool isInhabited(TypeId ty, std::unordered_set seen = {}); + bool isInhabited(const NormalizedType* norm, std::unordered_set seen = {}); + + // -------- Convert back from a normalized type to a type + TypeId typeFromNormal(const NormalizedType& norm); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/NotNull.h b/Analysis/include/Luau/NotNull.h new file mode 100644 index 00000000..ecdcb476 --- /dev/null +++ b/Analysis/include/Luau/NotNull.h @@ -0,0 +1,104 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include + +namespace Luau +{ + +/** A non-owning, non-null pointer to a T. + * + * A NotNull is notionally identical to a T* with the added restriction that + * it can never store nullptr. + * + * The sole conversion rule from T* to NotNull is the single-argument + * constructor, which is intentionally marked explicit. This constructor + * performs a runtime test to verify that the passed pointer is never nullptr. + * + * Pointer arithmetic, increment, decrement, and array indexing are all + * forbidden. + * + * An implicit coersion from NotNull to T* is afforded, as are the pointer + * indirection and member access operators. (*p and p->prop) + * + * The explicit delete statement is permitted (but not recommended) on a + * NotNull through this implicit conversion. + */ +template +struct NotNull +{ + explicit NotNull(T* t) + : ptr(t) + { + LUAU_ASSERT(t); + } + + explicit NotNull(std::nullptr_t) = delete; + void operator=(std::nullptr_t) = delete; + + template + NotNull(NotNull other) + : ptr(other.get()) + { + } + + operator T*() const noexcept + { + return ptr; + } + + T& operator*() const noexcept + { + return *ptr; + } + + T* operator->() const noexcept + { + return ptr; + } + + template + bool operator==(NotNull other) const noexcept + { + return get() == other.get(); + } + + template + bool operator!=(NotNull other) const noexcept + { + return get() != other.get(); + } + + operator bool() const noexcept = delete; + + T& operator[](int) = delete; + + T& operator+(int) = delete; + T& operator-(int) = delete; + + T* get() const noexcept + { + return ptr; + } + +private: + T* ptr; +}; + +} // namespace Luau + +namespace std +{ + +template +struct hash> +{ + size_t operator()(const Luau::NotNull& p) const + { + return std::hash()(p.get()); + } +}; + +} // namespace std diff --git a/Analysis/include/Luau/Predicate.h b/Analysis/include/Luau/Predicate.h index a5e8b6ae..df93b4f4 100644 --- a/Analysis/include/Luau/Predicate.h +++ b/Analysis/include/Luau/Predicate.h @@ -1,12 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Variant.h" #include "Luau/Location.h" -#include "Luau/Symbol.h" +#include "Luau/LValue.h" +#include "Luau/Variant.h" -#include -#include #include namespace Luau @@ -15,34 +13,6 @@ namespace Luau struct TypeVar; using TypeId = const TypeVar*; -struct Field; -using LValue = Variant; - -struct Field -{ - std::shared_ptr parent; // TODO: Eventually use unique_ptr to enforce non-copyable trait. - std::string key; -}; - -std::optional tryGetLValue(const class AstExpr& expr); - -// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. -std::pair> getFullName(const LValue& lvalue); - -std::string toString(const LValue& lvalue); - -template -const T* get(const LValue& lvalue) -{ - return get_if(&lvalue); -} - -// Key is a stringified encoding of an LValue. -using RefinementMap = std::map; - -void merge(RefinementMap& l, const RefinementMap& r, std::function f); -void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty); - struct TruthyPredicate; struct IsAPredicate; struct TypeGuardPredicate; diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h new file mode 100644 index 00000000..7edf23b8 --- /dev/null +++ b/Analysis/include/Luau/Quantify.h @@ -0,0 +1,15 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeVar.h" + +namespace Luau +{ + +struct TypeArena; +struct Scope; + +void quantify(TypeId ty, TypeLevel level); +TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope); + +} // namespace Luau diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index 89632cea..77af10a0 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -2,12 +2,22 @@ #pragma once #include "Luau/Common.h" +#include "Luau/Error.h" #include +#include namespace Luau { +struct RecursionLimitException : public InternalCompilerError +{ + RecursionLimitException() + : InternalCompilerError("Internal recursion counter limit exceeded") + { + } +}; + struct RecursionCounter { RecursionCounter(int* count) @@ -32,7 +42,9 @@ struct RecursionLimiter : RecursionCounter : RecursionCounter(count) { if (limit > 0 && *count > limit) - throw std::runtime_error("Internal recursion counter limit exceeded"); + { + throw RecursionLimitException(); + } } }; diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h index e9778876..718a6cc1 100644 --- a/Analysis/include/Luau/RequireTracer.h +++ b/Analysis/include/Luau/RequireTracer.h @@ -6,6 +6,7 @@ #include "Luau/Location.h" #include +#include namespace Luau { @@ -17,12 +18,11 @@ struct AstLocal; struct RequireTraceResult { - DenseHashMap exprs{0}; - DenseHashMap optional{0}; + DenseHashMap exprs{nullptr}; - std::vector> requires; + std::vector> requireList; }; -RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName); +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName); } // namespace Luau diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h new file mode 100644 index 00000000..851ed1a7 --- /dev/null +++ b/Analysis/include/Luau/Scope.h @@ -0,0 +1,84 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" +#include "Luau/NotNull.h" +#include "Luau/TypeVar.h" + +#include +#include +#include + +namespace Luau +{ + +struct Scope; + +using ScopePtr = std::shared_ptr; + +struct Binding +{ + TypeId typeId; + Location location; + bool deprecated = false; + std::string deprecatedSuggestion; + std::optional documentationSymbol; +}; + +struct Scope +{ + explicit Scope(TypePackId returnType); // root scope + explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. + + const ScopePtr parent; // null for the root + + // All the children of this scope. + std::vector> children; + std::unordered_map bindings; + TypePackId returnType; + std::optional varargPack; + + TypeLevel level; + + std::unordered_map exportedTypeBindings; + std::unordered_map privateTypeBindings; + std::unordered_map typeAliasLocations; + std::unordered_map> importedTypeBindings; + + DenseHashSet builtinTypeNames{""}; + void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun); + + std::optional lookup(Symbol sym) const; + std::optional lookup(DefId def) const; + std::optional> lookupEx(Symbol sym); + + std::optional lookupType(const Name& name); + std::optional lookupImportedType(const Name& moduleAlias, const Name& name); + + std::unordered_map privateTypePackBindings; + std::optional lookupPack(const Name& name); + + // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) + std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const; + + RefinementMap refinements; + DenseHashMap dcrRefinements{nullptr}; + + // For mutually recursive type aliases, it's important that + // they use the same types for the same names. + // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` + // we need that the generic type `T` in both cases is the same, so we use a cache. + std::unordered_map typeAliasTypeParameters; + std::unordered_map typeAliasTypePackParameters; +}; + +// Returns true iff the left scope encloses the right scope. A Scope* equal to +// nullptr is considered to be the outermost-possible scope. +bool subsumesStrict(Scope* left, Scope* right); + +// Returns true if the left scope encloses the right scope, or if they are the +// same scope. As in subsumesStrict(), nullptr is considered to be the +// outermost-possible scope. +bool subsumes(Scope* left, Scope* right); + +} // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 6ac868f7..6ad38f9d 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -1,8 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Module.h" -#include "Luau/ModuleResolver.h" +#include "Luau/TypeArena.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/DenseHash.h" @@ -52,11 +51,11 @@ // `T`, and the type of `f` are in the same SCC, which is why `f` gets // replaced. -LUAU_FASTFLAG(DebugLuauTrackOwningArena) - namespace Luau { +struct TxnLog; + enum class TarjanResult { TooManyChildren, @@ -90,6 +89,11 @@ struct Tarjan std::vector lowlink; int childCount = 0; + int childLimit = 0; + + // This should never be null; ensure you initialize it before calling + // substitution methods. + const TxnLog* log = nullptr; std::vector edgesTy; std::vector edgesTp; @@ -97,9 +101,6 @@ struct Tarjan // This is hot code, so we optimize recursion to a stack. TarjanResult loop(); - // Clear the state - void clear(); - // Find or create the index for a vertex. // Return a boolean which is `true` if it's a freshly created index. std::pair indexify(TypeId ty); @@ -138,6 +139,8 @@ struct FindDirty : Tarjan { std::vector dirty; + void clearTarjan(); + // Get/set the dirty bit for an index (grows the vector if needed) bool getDirty(int index); void setDirty(int index, bool d); @@ -162,9 +165,21 @@ struct FindDirty : Tarjan // and replaces them with clean ones. struct Substitution : FindDirty { - ModulePtr currentModule; +protected: + Substitution(const TxnLog* log_, TypeArena* arena) + : arena(arena) + { + log = log_; + LUAU_ASSERT(log); + LUAU_ASSERT(arena); + } + +public: + TypeArena* arena; DenseHashMap newTypes{nullptr}; DenseHashMap newPacks{nullptr}; + DenseHashSet replacedTypes{nullptr}; + DenseHashSet replacedTypePacks{nullptr}; std::optional substitute(TypeId ty); std::optional substitute(TypePackId tp); @@ -188,20 +203,13 @@ struct Substitution : FindDirty template TypeId addType(const T& tv) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(tv); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return arena->addType(tv); } + template TypePackId addTypePack(const T& tp) { - TypePackId allocated = currentModule->internalTypes.typePacks.allocate(tp); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return arena->addTypePack(TypePackVar{tp}); } }; diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h index b5dd9c89..0432946c 100644 --- a/Analysis/include/Luau/Symbol.h +++ b/Analysis/include/Luau/Symbol.h @@ -6,10 +6,11 @@ #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + namespace Luau { -// TODO Rename this to Name once the old type alias is gone. struct Symbol { Symbol() @@ -30,6 +31,9 @@ struct Symbol { } + template + Symbol(const T&) = delete; + AstLocal* local; AstName global; @@ -37,9 +41,12 @@ struct Symbol { if (local) return local == rhs.local; - if (global.value) + else if (global.value) return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. - return false; + else if (FFlag::DebugLuauDeferredConstraintResolution) + return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. + else + return false; } bool operator!=(const Symbol& rhs) const @@ -55,8 +62,8 @@ struct Symbol return global < rhs.global; else if (local) return true; - else - return false; + + return false; } AstName astName() const diff --git a/Analysis/include/Luau/ToDot.h b/Analysis/include/Luau/ToDot.h new file mode 100644 index 00000000..ce518d3a --- /dev/null +++ b/Analysis/include/Luau/ToDot.h @@ -0,0 +1,31 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include + +namespace Luau +{ +struct TypeVar; +using TypeId = const TypeVar*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +struct ToDotOptions +{ + bool showPointers = true; // Show pointer value in the node label + bool duplicatePrimitives = true; // Display primitive types and 'any' as separate nodes +}; + +std::string toDot(TypeId ty, const ToDotOptions& opts); +std::string toDot(TypePackId tp, const ToDotOptions& opts); + +std::string toDot(TypeId ty); +std::string toDot(TypePackId tp); + +void dumpDot(TypeId ty); +void dumpDot(TypePackId tp); + +} // namespace Luau diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 0897ec85..71c0e359 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -2,12 +2,12 @@ #pragma once #include "Luau/Common.h" -#include "Luau/TypeVar.h" -#include -#include #include +#include #include +#include +#include LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength) @@ -15,6 +15,22 @@ LUAU_FASTINT(LuauTypeMaximumStringifierLength) namespace Luau { +class AstExpr; + +struct Scope; + +struct TypeVar; +using TypeId = const TypeVar*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +struct FunctionTypeVar; +struct Constraint; + +struct Position; +struct Location; + struct ToStringNameMap { std::unordered_map typeVars; @@ -23,20 +39,23 @@ struct ToStringNameMap struct ToStringOptions { - bool exhaustive = false; // If true, we produce complete output rather than comprehensible output - bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. - bool functionTypeArguments = false; // If true, output function type argument names when they are available - bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool exhaustive = false; // If true, we produce complete output rather than comprehensible output + bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. + bool functionTypeArguments = false; // If true, output function type argument names when they are available + bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. + bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self + bool DEPRECATED_indent = false; // TODO Deprecated field, prune when clipping flag FFlagLuauLineBreaksDeterminIndents size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); - std::optional nameMap; + ToStringNameMap nameMap; std::shared_ptr scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' + std::vector namedFunctionOverrideArgNames; // If present, named function argument names will be overridden }; struct ToStringResult { std::string name; - ToStringNameMap nameMap; bool invalid = false; bool error = false; @@ -44,11 +63,24 @@ struct ToStringResult bool truncated = false; }; -ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts = {}); -ToStringResult toStringDetailed(TypePackId ty, const ToStringOptions& opts = {}); +ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts); +ToStringResult toStringDetailed(TypePackId ty, ToStringOptions& opts); -std::string toString(TypeId ty, const ToStringOptions& opts); -std::string toString(TypePackId ty, const ToStringOptions& opts); +std::string toString(TypeId ty, ToStringOptions& opts); +std::string toString(TypePackId ty, ToStringOptions& opts); + +// These overloads are selected when a temporary ToStringOptions is passed. (eg +// via an initializer list) +inline std::string toString(TypePackId ty, ToStringOptions&& opts) +{ + // Delegate to the overload (TypePackId, ToStringOptions&) + return toString(ty, opts); +} +inline std::string toString(TypeId ty, ToStringOptions&& opts) +{ + // Delegate to the overload (TypeId, ToStringOptions&) + return toString(ty, opts); +} // These are offered as overloads rather than a default parameter so that they can be easily invoked from within the MSVC debugger. // You can use them in watch expressions! @@ -61,12 +93,54 @@ inline std::string toString(TypePackId ty) return toString(ty, ToStringOptions{}); } -std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); -std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); +std::string toString(const Constraint& c, ToStringOptions& opts); + +inline std::string toString(const Constraint& c, ToStringOptions&& opts) +{ + return toString(c, opts); +} + +inline std::string toString(const Constraint& c) +{ + return toString(c, ToStringOptions{}); +} + +std::string toString(const TypeVar& tv, ToStringOptions& opts); +std::string toString(const TypePackVar& tp, ToStringOptions& opts); + +inline std::string toString(const TypeVar& tv) +{ + ToStringOptions opts; + return toString(tv, opts); +} + +inline std::string toString(const TypePackVar& tp) +{ + ToStringOptions opts; + return toString(tp, opts); +} + +std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, ToStringOptions& opts); + +inline std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv) +{ + ToStringOptions opts; + return toStringNamedFunction(funcName, ftv, opts); +} + +std::optional getFunctionNameAsString(const AstExpr& expr); // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression -void dump(TypeId ty); -void dump(TypePackId ty); +std::string dump(TypeId ty); +std::string dump(TypePackId ty); +std::string dump(const Constraint& c); + +std::string dump(const std::shared_ptr& scope, const char* name); + +std::string generateName(size_t n); + +std::string toString(const Position& position); +std::string toString(const Location& location); } // namespace Luau diff --git a/Analysis/include/Luau/TopoSortStatements.h b/Analysis/include/Luau/TopoSortStatements.h index 751694f0..4a4acfa3 100644 --- a/Analysis/include/Luau/TopoSortStatements.h +++ b/Analysis/include/Luau/TopoSortStatements.h @@ -12,6 +12,7 @@ struct AstArray; class AstStat; bool containsFunctionCall(const AstStat& stat); +bool containsFunctionCallOrReturn(const AstStat& stat); bool isFunction(const AstStat& stat); void toposort(std::vector& stats); diff --git a/Analysis/include/Luau/Transpiler.h b/Analysis/include/Luau/Transpiler.h index 817459fe..df01008c 100644 --- a/Analysis/include/Luau/Transpiler.h +++ b/Analysis/include/Luau/Transpiler.h @@ -18,6 +18,7 @@ struct TranspileResult std::string parseError; // Nonempty if the transpile failed }; +std::string toString(AstNode* node); void dump(AstNode* node); // Never fails on a well-formed AST @@ -25,6 +26,6 @@ std::string transpile(AstStatBlock& ast); std::string transpileWithTypes(AstStatBlock& block); // Only fails when parsing fails -TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}); +TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}, bool withTypes = false); } // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 055441ce..82605bff 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -2,17 +2,92 @@ #pragma once #include "Luau/TypeVar.h" +#include "Luau/TypePack.h" + +#include +#include namespace Luau { -// Log of where what TypeIds we are rebinding and what they used to be +using TypeOrPackId = const void*; + +// Pending state for a TypeVar. Generated by a TxnLog and committed via +// TxnLog::commit. +struct PendingType +{ + // The pending TypeVar state. + TypeVar pending; + + explicit PendingType(TypeVar state) + : pending(std::move(state)) + { + } +}; + +std::string toString(PendingType* pending); +std::string dump(PendingType* pending); + +// Pending state for a TypePackVar. Generated by a TxnLog and committed via +// TxnLog::commit. +struct PendingTypePack +{ + // The pending TypePackVar state. + TypePackVar pending; + + explicit PendingTypePack(TypePackVar state) + : pending(std::move(state)) + { + } +}; + +std::string toString(PendingTypePack* pending); +std::string dump(PendingTypePack* pending); + +template +T* getMutable(PendingType* pending) +{ + // We use getMutable here because this state is intended to be mutated freely. + return getMutable(&pending->pending); +} + +template +T* getMutable(PendingTypePack* pending) +{ + // We use getMutable here because this state is intended to be mutated freely. + return getMutable(&pending->pending); +} + +// Log of what TypeIds we are rebinding, to be committed later. struct TxnLog { - TxnLog() = default; + TxnLog() + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , ownedSeen() + , sharedSeen(&ownedSeen) + { + } - explicit TxnLog(const std::vector>& seen) - : seen(seen) + explicit TxnLog(TxnLog* parent) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , parent(parent) + { + if (parent) + { + sharedSeen = parent->sharedSeen; + } + else + { + sharedSeen = &ownedSeen; + } + } + + explicit TxnLog(std::vector>* sharedSeen) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , sharedSeen(sharedSeen) { } @@ -22,25 +97,209 @@ struct TxnLog TxnLog(TxnLog&&) = default; TxnLog& operator=(TxnLog&&) = default; - void operator()(TypeId a); - void operator()(TypePackId a); - void operator()(TableTypeVar* a); - - void rollback(); + // Gets an empty TxnLog pointer. This is useful for constructs that + // take a TxnLog, like TypePackIterator - use the empty log if you + // don't have a TxnLog to give it. + static const TxnLog* empty(); + // Joins another TxnLog onto this one. You should use std::move to avoid + // copying the rhs TxnLog. + // + // If both logs talk about the same type, pack, or table, the rhs takes + // priority. void concat(TxnLog rhs); + void concatAsIntersections(TxnLog rhs, NotNull arena); + void concatAsUnion(TxnLog rhs, NotNull arena); - bool haveSeen(TypeId lhs, TypeId rhs); + // Commits the TxnLog, rebinding all type pointers to their pending states. + // Clears the TxnLog afterwards. + void commit(); + + // Clears the TxnLog without committing any pending changes. + void clear(); + + // Computes an inverse of this TxnLog at the current time. + // This method should be called before commit is called in order to give an + // accurate result. Committing the inverse of a TxnLog will undo the changes + // made by commit, assuming the inverse log is accurate. + TxnLog inverse(); + + bool haveSeen(TypeId lhs, TypeId rhs) const; void pushSeen(TypeId lhs, TypeId rhs); void popSeen(TypeId lhs, TypeId rhs); + bool haveSeen(TypePackId lhs, TypePackId rhs) const; + void pushSeen(TypePackId lhs, TypePackId rhs); + void popSeen(TypePackId lhs, TypePackId rhs); + + // Queues a type for modification. The original type will not change until commit + // is called. Use pending to get the pending state. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* queue(TypeId ty); + + // Queues a type pack for modification. The original type pack will not change + // until commit is called. Use pending to get the pending state. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* queue(TypePackId tp); + + // Returns the pending state of a type, or nullptr if there isn't any. It is important + // to note that this pending state is not transitive: the pending state may reference + // non-pending types freely, so you may need to call pending multiple times to view the + // entire pending state of a type graph. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* pending(TypeId ty) const; + + // Returns the pending state of a type pack, or nullptr if there isn't any. It is + // important to note that this pending state is not transitive: the pending state may + // reference non-pending types freely, so you may need to call pending multiple times + // to view the entire pending state of a type graph. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* pending(TypePackId tp) const; + + // Queues a replacement of a type with another type. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* replace(TypeId ty, TypeVar replacement); + + // Queues a replacement of a type pack with another type pack. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* replace(TypePackId tp, TypePackVar replacement); + + // Queues a replacement of a table type with another table type that is bound + // to a specific value. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* bindTable(TypeId ty, std::optional newBoundTo); + + // Queues a replacement of a type with a level with a duplicate of that type + // with a new type level. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* changeLevel(TypeId ty, TypeLevel newLevel); + + // Queues a replacement of a type pack with a level with a duplicate of that + // type pack with a new type level. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* changeLevel(TypePackId tp, TypeLevel newLevel); + + // Queues the replacement of a type's scope with the provided scope. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* changeScope(TypeId ty, NotNull scope); + + // Queues the replacement of a type pack's scope with the provided scope. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* changeScope(TypePackId tp, NotNull scope); + + // Queues a replacement of a table type with another table type with a new + // indexer. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* changeIndexer(TypeId ty, std::optional indexer); + + // Returns the type level of the pending state of the type, or the level of that + // type, if no pending state exists. If the type doesn't have a notion of a level, + // returns nullopt. If the pending state doesn't have a notion of a level, but the + // original state does, returns nullopt. + std::optional getLevel(TypeId ty) const; + + // Follows a type, accounting for pending type states. The returned type may have + // pending state; you should use `pending` or `get` to find out. + TypeId follow(TypeId ty) const; + + // Follows a type pack, accounting for pending type states. The returned type pack + // may have pending state; you should use `pending` or `get` to find out. + TypePackId follow(TypePackId tp) const; + + // Replaces a given type's state with a new variant. Returns the new pending state + // of that type. + // + // The pointer returned lives until `commit` or `clear` is called. + template + PendingType* replace(TypeId ty, T replacement) + { + return replace(ty, TypeVar(replacement)); + } + + // Replaces a given type pack's state with a new variant. Returns the new + // pending state of that type pack. + // + // The pointer returned lives until `commit` or `clear` is called. + template + PendingTypePack* replace(TypePackId tp, T replacement) + { + return replace(tp, TypePackVar(replacement)); + } + + // Returns T if a given type or type pack is this variant, respecting the + // log's pending state. + // + // Do not retain this pointer; it has the potential to be invalidated when + // commit or clear is called. + template + T* getMutable(TID ty) const + { + auto* pendingTy = pending(ty); + if (pendingTy) + return Luau::getMutable(pendingTy); + + return Luau::getMutable(ty); + } + + template + const T* get(TID ty) const + { + return this->getMutable(ty); + } + + // Returns whether a given type or type pack is a given state, respecting the + // log's pending state. + // + // This method will not assert if called on a BoundTypeVar or BoundTypePack. + template + bool is(TID ty) const + { + // We do not use getMutable here because this method can be called on + // BoundTypeVars, which triggers an assertion. + auto* pendingTy = pending(ty); + if (pendingTy) + return Luau::get_if(&pendingTy->pending.ty) != nullptr; + + return Luau::get_if(&ty->ty) != nullptr; + } + + std::pair, std::vector> getChanges() const; + private: - std::vector> typeVarChanges; - std::vector> typePackChanges; - std::vector>> tableChanges; + // unique_ptr is used to give us stable pointers across insertions into the + // map. Otherwise, it would be really easy to accidentally invalidate the + // pointers returned from queue/pending. + DenseHashMap> typeVarChanges; + DenseHashMap> typePackChanges; + + TxnLog* parent = nullptr; + + // Owned version of sharedSeen. This should not be accessed directly in + // TxnLogs; use sharedSeen instead. This field exists because in the tree + // of TxnLogs, the root must own its seen set. In all descendant TxnLogs, + // this is an empty vector. + std::vector> ownedSeen; + + bool haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const; + void pushSeen(TypeOrPackId lhs, TypeOrPackId rhs); + void popSeen(TypeOrPackId lhs, TypeOrPackId rhs); public: - std::vector> seen; // used to avoid infinite recursion when types are cyclic + // Used to avoid infinite recursion when types are cyclic. + // Shared with all the descendent TxnLogs. + std::vector>* sharedSeen; }; } // namespace Luau diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h new file mode 100644 index 00000000..c67f643b --- /dev/null +++ b/Analysis/include/Luau/TypeArena.h @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypedAllocator.h" +#include "Luau/TypeVar.h" +#include "Luau/TypePack.h" + +#include + +namespace Luau +{ + +struct TypeArena +{ + TypedAllocator typeVars; + TypedAllocator typePacks; + + void clear(); + + template + TypeId addType(T tv) + { + if constexpr (std::is_same_v) + LUAU_ASSERT(tv.options.size() >= 2); + + return addTV(TypeVar(std::move(tv))); + } + + TypeId addTV(TypeVar&& tv); + + TypeId freshType(TypeLevel level); + TypeId freshType(Scope* scope); + TypeId freshType(Scope* scope, TypeLevel level); + + TypePackId freshTypePack(Scope* scope); + + TypePackId addTypePack(std::initializer_list types); + TypePackId addTypePack(std::vector types, std::optional tail = {}); + TypePackId addTypePack(TypePack pack); + TypePackId addTypePack(TypePackVar pack); + + template + TypePackId addTypePack(T tp) + { + return addTypePack(TypePackVar(std::move(tp))); + } +}; + +void freeze(TypeArena& arena); +void unfreeze(TypeArena& arena); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h new file mode 100644 index 00000000..a9cd6ec8 --- /dev/null +++ b/Analysis/include/Luau/TypeChecker2.h @@ -0,0 +1,17 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Module.h" +#include "Luau/NotNull.h" + +namespace Luau +{ + +struct DcrLogger; +struct SingletonTypes; + +void check(NotNull singletonTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index ec2a1a26..c6f153d1 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -1,21 +1,24 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Anyification.h" #include "Luau/Predicate.h" #include "Luau/Error.h" #include "Luau/Module.h" #include "Luau/Symbol.h" -#include "Luau/Parser.h" #include "Luau/Substitution.h" #include "Luau/TxnLog.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" +#include "Luau/UnifierSharedState.h" #include #include #include +LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) + namespace Luau { @@ -31,77 +34,39 @@ bool doesCallError(const AstExprCall* call); bool hasBreak(AstStat* node); const AstStat* getFallthrough(const AstStat* node); +struct UnifierOptions; struct Unifier; -// A substitution which replaces generic types in a given set by free types. -struct ReplaceGenerics : Substitution +struct GenericTypeDefinitions { - TypeLevel level; - std::vector generics; - std::vector genericPacks; - bool ignoreChildren(TypeId ty) override; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; + std::vector genericTypes; + std::vector genericPacks; }; -// A substitution which replaces generic functions by monomorphic functions -struct Instantiation : Substitution +struct HashBoolNamePair { - TypeLevel level; - ReplaceGenerics replaceGenerics; - bool ignoreChildren(TypeId ty) override; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; + size_t operator()(const std::pair& pair) const; }; -// A substitution which replaces free types by generic types. -struct Quantification : Substitution +class TimeLimitError : public InternalCompilerError { - TypeLevel level; - std::vector generics; - std::vector genericPacks; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; -}; - -// A substitution which replaces free types by any -struct Anyification : Substitution -{ - TypeId anyType; - TypePackId anyTypePack; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; -}; - -// A substitution which replaces the type parameters of a type function by arguments -struct ApplyTypeFunction : Substitution -{ - TypeLevel level; - bool encounteredForwardedType; - std::unordered_map arguments; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; +public: + explicit TimeLimitError(const std::string& moduleName) + : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName) + { + } }; // All TypeVars are retained via Environment::typeVars. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker { - explicit TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler); + explicit TypeChecker(ModuleResolver* resolver, NotNull singletonTypes, InternalErrorReporter* iceHandler); TypeChecker(const TypeChecker&) = delete; TypeChecker& operator=(const TypeChecker&) = delete; ModulePtr check(const SourceModule& module, Mode mode, std::optional environmentScope = std::nullopt); + ModulePtr checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope = std::nullopt); std::vector> getScopes() const; @@ -118,31 +83,37 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatForIn& forin); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); - void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false); + void check(const ScopePtr& scope, const AstStatTypeAlias& typealias); void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); + void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0); + void prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); + void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); + void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); - ExprResult checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr(const ScopePtr& scope, const AstExprLocal& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprCall& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr(const ScopePtr& scope, const AstExprUnary& expr); + WithPredicate checkExpr( + const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt, bool forceSingleton = false); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprLocal& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); TypeId checkBinaryOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); - ExprResult checkExpr(const ScopePtr& scope, const AstExprBinary& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprError& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprInterpString& expr); TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType); @@ -150,35 +121,38 @@ struct TypeChecker // Returns the type of the lvalue. TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr); - // Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding). - // Note: the binding may be null. - std::pair checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); + // Returns the type of the lvalue. + TypeId checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); - TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName); + TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, - std::optional originalNameLoc, std::optional expectedType); + std::optional originalNameLoc, std::optional selfType, std::optional expectedType); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); - void checkArgumentList( - const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector& argLocations); + void checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId paramPack, TypePackId argPack, + const std::vector& argLocations); + + WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); + + WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr); + WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr); - ExprResult checkExprPack(const ScopePtr& scope, const AstExpr& expr); - ExprResult checkExprPack(const ScopePtr& scope, const AstExprCall& expr); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); - std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors); + + std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); - ExprResult reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, const std::vector& errors); - ExprResult checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, + WithPredicate checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil = false, const std::vector& lhsAnnotations = {}, const std::vector>& expectedTypes = {}); @@ -193,51 +167,44 @@ struct TypeChecker */ TypeId anyIfNonstrict(TypeId ty) const; - /** Attempt to unify the types left and right. Treat any failures as type errors - * in the final typecheck report. + /** Attempt to unify the types. + * Treat any failures as type errors in the final typecheck report. */ - bool unify(TypeId left, TypeId right, const Location& location); - bool unify(TypePackId left, TypePackId right, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); + bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); + bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options); + bool unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location, + CountMismatch::Context ctx = CountMismatch::Context::Arg); - /** Attempt to unify the types left and right. - * If this fails, and the right type can be instantiated, do so and try unification again. + /** Attempt to unify the types. + * If this fails, and the subTy type can be instantiated, do so and try unification again. */ - bool unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location); - void unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state); + bool unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); + void unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, Unifier& state); - /** Attempt to unify left with right. + /** Attempt to unify. * If there are errors, undo everything and return the errors. * If there are no errors, commit and return an empty error vector. */ - ErrorVec tryUnify(TypeId left, TypeId right, const Location& location); - ErrorVec tryUnify(TypePackId left, TypePackId right, const Location& location); + template + ErrorVec tryUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location); + ErrorVec tryUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); + ErrorVec tryUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location); // Test whether the two type vars unify. Never commits the result. - ErrorVec canUnify(TypeId superTy, TypeId subTy, const Location& location); - ErrorVec canUnify(TypePackId superTy, TypePackId subTy, const Location& location); + template + ErrorVec canUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location); + ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location); + ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location); - // Variant that takes a preexisting 'seen' set. We need this in certain cases to avoid infinitely recursing - // into cyclic types. - ErrorVec canUnify(const std::vector>& seen, TypeId left, TypeId right, const Location& location); - - std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); - std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); + std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors); + std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors); std::optional getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); - - // Reduces the union to its simplest possible shape. - // (A | B) | B | C yields A | B | C - std::vector reduceUnion(const std::vector& types); + std::optional getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); std::optional tryStripUnionFromNil(TypeId ty); TypeId stripFromNilAndReport(TypeId ty, const Location& location); - template - ErrorVec tryUnify_(Id left, Id right, const Location& location); - - template - ErrorVec canUnify_(Id left, Id right, const Location& location); - public: /* * Convert monotype into a a polytype, by replacing any metavariables in descendant scopes @@ -258,49 +225,66 @@ public: * {method: ({method: () -> a}) -> a} * */ - TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location); - // Removed by FFlag::LuauRankNTypes - TypePackId DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location); + TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log = TxnLog::empty()); // Replace any free types or type packs by `any`. // This is used when exporting types from modules, to make sure free types don't leak. TypeId anyify(const ScopePtr& scope, TypeId ty, Location location); TypePackId anyify(const ScopePtr& scope, TypePackId ty, Location location); + TypePackId anyifyModuleReturnTypePackGenerics(TypePackId ty); + void reportError(const TypeError& error); void reportError(const Location& location, TypeErrorData error); void reportErrors(const ErrorVec& errors); [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + [[noreturn]] void throwTimeLimitError(); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); - ScopePtr childScope(const ScopePtr& parent, const Location& location, int subLevel = 0); + ScopePtr childScope(const ScopePtr& parent, const Location& location); // Wrapper for merge(l, r, toUnion) but without the lambda junk. void merge(RefinementMap& l, const RefinementMap& r); + // Produce an "emergency backup type" for recovery from type errors. + // This comes in two flavours, depening on whether or not we can make a good guess + // for an error recovery type. + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(const ScopePtr& scope); + TypePackId errorRecoveryTypePack(const ScopePtr& scope); + private: void prepareErrorsForDisplay(ErrorVec& errVec); void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data); void reportErrorCodeTooComplex(const Location& location); private: - Unifier mkUnifier(const Location& location); - Unifier mkUnifier(const std::vector>& seen, const Location& location); + Unifier mkUnifier(const ScopePtr& scope, const Location& location); // These functions are only safe to call when we are in the process of typechecking a module. // Produce a new free type var. TypeId freshType(const ScopePtr& scope); TypeId freshType(TypeLevel level); - TypeId DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric = false); - TypeId DEPRECATED_freshType(TypeLevel level, bool canBeGeneric = false); - // Returns nullopt if the predicate filters down the TypeId to 0 options. - std::optional filterMap(TypeId type, TypeIdPredicate predicate); + // Produce a new singleton type var. + TypeId singletonType(bool value); + TypeId singletonType(std::string value); - TypeId unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes = true); + TypeIdPredicate mkTruthyPredicate(bool sense, TypeId emptySetTy); + + // TODO: Return TypeId only. + std::optional filterMapImpl(TypeId type, TypeIdPredicate predicate); + std::pair, bool> filterMap(TypeId type, TypeIdPredicate predicate); + +public: + std::pair, bool> pickTypesFromSense(TypeId type, bool sense, TypeId emptySetTy); + +private: + TypeId unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes = true); // ex // TypeId id = addType(FreeTypeVar()); @@ -310,8 +294,6 @@ private: return addTV(TypeVar(tv)); } - TypeId addType(const UnionTypeVar& utv); - TypeId addTV(TypeVar&& tv); TypePackId addTypePack(TypePackVar&& tp); @@ -322,36 +304,38 @@ private: TypePackId addTypePack(std::initializer_list&& ty); TypePackId freshTypePack(const ScopePtr& scope); TypePackId freshTypePack(TypeLevel level); - TypePackId DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric = false); - TypePackId DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric = false); - TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false); + TypeId resolveType(const ScopePtr& scope, const AstType& annotation); + TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); - TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location); + TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, + const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. - std::pair, std::vector> createGenericTypes( - const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, + const AstArray& genericNames, const AstArray& genericPackNames, bool useCache = false); public: - ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); + void resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); private: + void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate); + std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); - void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); - void resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); - void resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); - void resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const PredicateVec& predicates, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); + void resolve(const Predicate& predicate, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); + void resolve(const TruthyPredicate& truthyP, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); + void resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const OrPredicate& orP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const TypeGuardPredicate& typeguardP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; + bool useConstrainedIntersections() const; public: /** Extract the types in a type pack, given the assumption that the pack must have some exact length. @@ -371,14 +355,20 @@ public: ModulePtr currentModule; ModuleName currentModuleName; - Instantiation instantiation; - Quantification quantification; - Anyification anyification; - ApplyTypeFunction applyTypeFunction; - std::function prepareModuleScope; + NotNull singletonTypes; InternalErrorReporter* iceHandler; + UnifierSharedState unifierState; + Normalizer normalizer; + + std::vector requireCycles; + + // Type inference limits + std::optional finishTime; + std::optional instantiationChildLimit; + std::optional unifierIterationLimit; + public: const TypeId nilType; const TypeId numberType; @@ -386,68 +376,37 @@ public: const TypeId booleanType; const TypeId threadType; const TypeId anyType; - - const TypeId errorType; - const TypeId optionalNumberType; + const TypeId unknownType; + const TypeId neverType; const TypePackId anyTypePack; - const TypePackId errorTypePack; + const TypePackId neverTypePack; + const TypePackId uninhabitableTypePack; private: int checkRecursionCount = 0; int recursionCount = 0; + + /** + * We use this to avoid doing second-pass analysis of type aliases that are duplicates. We record a pair + * (exported, name) to properly deal with the case where the two duplicates do not have the same export status. + */ + DenseHashSet, HashBoolNamePair> duplicateTypeAliases; + + /** + * A set of incorrect class definitions which is used to avoid a second-pass analysis. + */ + DenseHashSet incorrectClassDefinitions{nullptr}; + + std::vector> deferredQuantification; }; -struct Binding -{ - TypeId typeId; - Location location; - bool deprecated = false; - std::string deprecatedSuggestion; - std::optional documentationSymbol; -}; +using PrintLineProc = void (*)(const std::string&); -struct Scope -{ - explicit Scope(TypePackId returnType); // root scope - explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. - - const ScopePtr parent; // null for the root - std::unordered_map bindings; - TypePackId returnType; - bool breakOk = false; - std::optional varargPack; - - TypeLevel level; - - std::unordered_map exportedTypeBindings; - std::unordered_map privateTypeBindings; - std::unordered_map typeAliasLocations; - - std::unordered_map> importedTypeBindings; - - std::optional lookup(const Symbol& name); - - std::optional lookupType(const Name& name); - std::optional lookupImportedType(const Name& moduleAlias, const Name& name); - - std::unordered_map privateTypePackBindings; - std::optional lookupPack(const Name& name); - - // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) - std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true); - - RefinementMap refinements; - - // For mutually recursive type aliases, it's important that - // they use the same types for the same names. - // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` - // we need that the generic type `T` in both cases is the same, so we use a cache. - std::unordered_map typeAliasParameters; -}; +extern PrintLineProc luauPrintLine; // Unit test hook -void setPrintLine(void (*pl)(const std::string& s)); +void setPrintLine(PrintLineProc pl); void resetPrintLine(); } // namespace Luau diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 0d0adce7..29688094 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -8,8 +8,6 @@ #include #include -LUAU_FASTFLAG(LuauAddMissingFollow) - namespace Luau { @@ -17,14 +15,17 @@ struct TypeArena; struct TypePack; struct VariadicTypePack; +struct BlockedTypePack; struct TypePackVar; +struct TxnLog; + using TypePackId = const TypePackVar*; using FreeTypePack = Unifiable::Free; using BoundTypePack = Unifiable::Bound; using GenericTypePack = Unifiable::Generic; -using TypePackVariant = Unifiable::Variant; +using TypePackVariant = Unifiable::Variant; /* A TypePack is a rope-like string of TypeIds. We use this structure to encode * notions like packs of unknown length and packs of any length, as well as more @@ -40,6 +41,18 @@ struct TypePack struct VariadicTypePack { TypeId ty; + bool hidden = false; // if true, we don't display this when toString()ing a pack with this variadic as its tail. +}; + +/** + * Analogous to a BlockedTypeVar. + */ +struct BlockedTypePack +{ + BlockedTypePack(); + size_t index; + + static size_t nextIndex; }; struct TypePackVar @@ -47,16 +60,24 @@ struct TypePackVar explicit TypePackVar(const TypePackVariant& ty); explicit TypePackVar(TypePackVariant&& ty); TypePackVar(TypePackVariant&& ty, bool persistent); + bool operator==(const TypePackVar& rhs) const; + TypePackVar& operator=(TypePackVariant&& tp); + TypePackVar& operator=(const TypePackVar& rhs); + + // Re-assignes the content of the pack, but doesn't change the owning arena and can't make pack persistent. + void reassign(const TypePackVar& rhs) + { + ty = rhs.ty; + } + TypePackVariant ty; + bool persistent = false; - // Pointer to the type arena that allocated this type. - // Do not depend on the value of this under any circumstances. This is for - // debugging purposes only. This is only set in debug builds; it is nullptr - // in all other environments. + // Pointer to the type arena that allocated this pack. TypeArena* owningArena = nullptr; }; @@ -86,6 +107,7 @@ struct TypePackIterator TypePackIterator() = default; explicit TypePackIterator(TypePackId tp); + TypePackIterator(TypePackId tp, const TxnLog* log); TypePackIterator& operator++(); TypePackIterator operator++(int); @@ -106,20 +128,25 @@ private: TypePackId currentTypePack = nullptr; const TypePack* tp = nullptr; size_t currentIndex = 0; + + const TxnLog* log; }; TypePackIterator begin(TypePackId tp); +TypePackIterator begin(TypePackId tp, const TxnLog* log); TypePackIterator end(TypePackId tp); -using SeenSet = std::set>; +using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); +TypePackId follow(TypePackId tp, std::function mapper); -size_t size(const TypePackId tp); -size_t size(const TypePack& tp); -std::optional first(TypePackId tp); +size_t size(TypePackId tp, TxnLog* log = nullptr); +bool finite(TypePackId tp, TxnLog* log = nullptr); +size_t size(const TypePack& tp, TxnLog* log = nullptr); +std::optional first(TypePackId tp, bool ignoreHiddenVariadics = true); TypePackVar* asMutable(TypePackId tp); TypePack* asMutable(const TypePack* tp); @@ -127,13 +154,10 @@ TypePack* asMutable(const TypePack* tp); template const T* get(TypePackId tp) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tp); + LUAU_ASSERT(tp); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tp->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tp->ty) == nullptr); return get_if(&(tp->ty)); } @@ -141,13 +165,10 @@ const T* get(TypePackId tp) template T* getMutable(TypePackId tp) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tp); + LUAU_ASSERT(tp); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tp->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tp->ty) == nullptr); return get_if(&(asMutable(tp)->ty)); } @@ -157,5 +178,16 @@ bool isEmpty(TypePackId tp); /// Flattens out a type pack. Also returns a valid TypePackId tail if the type pack's full size is not known std::pair, std::optional> flatten(TypePackId tp); +std::pair, std::optional> flatten(TypePackId tp, const TxnLog& log); + +/// Returs true if the type pack arose from a function that is declared to be variadic. +/// Returns *false* for function argument packs that are inferred to be safe to oversaturate! +bool isVariadic(TypePackId tp); +bool isVariadic(TypePackId tp, const TxnLog& log); + +// Returns true if the TypePack is Generic or Variadic. Does not walk TypePacks!! +bool isVariadicTail(TypePackId tp, const TxnLog& log, bool includeHiddenVariadics = false); + +bool containsNever(TypePackId tp); } // namespace Luau diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index ffddfe4b..6ed70f46 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -4,6 +4,7 @@ #include "Luau/Error.h" #include "Luau/Location.h" #include "Luau/TypeVar.h" +#include "Luau/TypePack.h" #include #include @@ -11,9 +12,40 @@ namespace Luau { +struct TxnLog; +struct TypeArena; + using ScopePtr = std::shared_ptr; -std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globalScope, TypeId type, std::string entry, Location location); -std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const ScopePtr& globalScope, TypeId ty, Name name, Location location); +std::optional findMetatableEntry( + NotNull singletonTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location); +std::optional findTablePropertyRespectingMeta( + NotNull singletonTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location); + +// Returns the minimum and maximum number of types the argument list can accept. +std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics = false); + +// Extend the provided pack to at least `length` types. +// Returns a temporary TypePack that contains those types plus a tail. +TypePack extendTypePack(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length); + +/** + * Reduces a union by decomposing to the any/error type if it appears in the + * type list, and by merging child unions. Also strips out duplicate (by pointer + * identity) types. + * @param types the input type list to reduce. + * @returns the reduced type list. + */ +std::vector reduceUnion(const std::vector& types); + +/** + * Tries to remove nil from a union type, if there's another option. T | nil + * reduces to T, but nil itself does not reduce. + * @param singletonTypes the singleton types to use + * @param arena the type arena to allocate the new type in, if necessary + * @param ty the type to remove nil from + * @returns a type with nil removed, or nil itself if that were the only option. + */ +TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty); } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 90a28b20..852a4054 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -1,29 +1,36 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Ast.h" +#include "Luau/Common.h" +#include "Luau/Connective.h" +#include "Luau/DataFlowGraph.h" +#include "Luau/DenseHash.h" +#include "Luau/Def.h" +#include "Luau/NotNull.h" #include "Luau/Predicate.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" -#include "Luau/Common.h" -#include -#include -#include -#include -#include -#include #include +#include #include #include +#include +#include +#include +#include +#include LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength) -LUAU_FASTFLAG(LuauAddMissingFollow) namespace Luau { struct TypeArena; +struct Scope; +using ScopePtr = std::shared_ptr; /** * There are three kinds of type variables: @@ -83,6 +90,24 @@ using Tags = std::vector; using ModuleName = std::string; +/** A TypeVar that cannot be computed. + * + * BlockedTypeVars essentially serve as a way to encode partial ordering on the + * constraint graph. Until a BlockedTypeVar is unblocked by its owning + * constraint, nothing at all can be said about it. Constraints that need to + * process a BlockedTypeVar cannot be dispatched. + * + * Whenever a BlockedTypeVar is added to the graph, we also record a constraint + * that will eventually unblock it. + */ +struct BlockedTypeVar +{ + BlockedTypeVar(); + int index; + + static int nextIndex; +}; + struct PrimitiveTypeVar { enum Type @@ -92,6 +117,7 @@ struct PrimitiveTypeVar Number, String, Thread, + Function, }; Type type; @@ -109,6 +135,95 @@ struct PrimitiveTypeVar } }; +// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md +// Types for true and false +struct BooleanSingleton +{ + bool value; + + bool operator==(const BooleanSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const BooleanSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// Types for "foo", "bar" etc. +struct StringSingleton +{ + std::string value; + + bool operator==(const StringSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const StringSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// No type for float singletons, partly because === isn't any equalivalence on floats +// (NaN != NaN). + +using SingletonVariant = Luau::Variant; + +struct SingletonTypeVar +{ + explicit SingletonTypeVar(const SingletonVariant& variant) + : variant(variant) + { + } + + explicit SingletonTypeVar(SingletonVariant&& variant) + : variant(std::move(variant)) + { + } + + // Default operator== is C++20. + bool operator==(const SingletonTypeVar& rhs) const + { + return variant == rhs.variant; + } + + bool operator!=(const SingletonTypeVar& rhs) const + { + return !(*this == rhs); + } + + SingletonVariant variant; +}; + +template +const T* get(const SingletonTypeVar* stv) +{ + if (stv) + return get_if(&stv->variant); + else + return nullptr; +} + +struct GenericTypeDefinition +{ + TypeId ty; + std::optional defaultValue; + + bool operator==(const GenericTypeDefinition& rhs) const; +}; + +struct GenericTypePackDefinition +{ + TypePackId tp; + std::optional defaultValue; + + bool operator==(const GenericTypePackDefinition& rhs) const; +}; + struct FunctionArgument { Name name; @@ -127,42 +242,72 @@ struct FunctionDefinition // TODO: Do we actually need this? We'll find out later if we can delete this. // Does not exactly belong in TypeVar.h, but this is the only way to appease the compiler. template -struct ExprResult +struct WithPredicate { T type; PredicateVec predicates; }; -using MagicFunction = std::function>( - struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, ExprResult)>; +using MagicFunction = std::function>( + struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; + +struct MagicFunctionCallContext +{ + NotNull solver; + const class AstExprCall* callSite; + TypePackId arguments; + TypePackId result; +}; + +using DcrMagicFunction = bool (*)(MagicFunctionCallContext); + +struct MagicRefinementContext +{ + ScopePtr scope; + NotNull cgb; + NotNull dfg; + NotNull connectiveArena; + std::vector argumentConnectives; + const class AstExprCall* callSite; +}; + +using DcrMagicRefinement = std::vector (*)(const MagicRefinementContext&); struct FunctionTypeVar { // Global monomorphic function - FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Global polymorphic function - FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local monomorphic function - FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar( + TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local polymorphic function - FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, TypePackId argTypes, + TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); - TypeLevel level; + std::optional definition; /// These should all be generic std::vector generics; std::vector genericPacks; - TypePackId argTypes; std::vector> argNames; - TypePackId retType; - std::optional definition; - MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. - bool hasSelf; Tags tags; + TypeLevel level; + Scope* scope = nullptr; + TypePackId argTypes; + TypePackId retTypes; + MagicFunction magicFunction = nullptr; + DcrMagicFunction dcrMagicFunction = nullptr; // Fired only while solving constraints + DcrMagicRefinement dcrMagicRefinement = nullptr; // Fired only while generating constraints + bool hasSelf; + bool hasNoGenerics = false; }; enum class TableState @@ -212,26 +357,31 @@ struct TableTypeVar using Props = std::map; TableTypeVar() = default; - explicit TableTypeVar(TableState state, TypeLevel level); - TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, TableState state = TableState::Unsealed); + explicit TableTypeVar(TableState state, TypeLevel level, Scope* scope = nullptr); + TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, TableState state); + TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, Scope* scope, TableState state); Props props; std::optional indexer; TableState state = TableState::Unsealed; TypeLevel level; + Scope* scope = nullptr; std::optional name; // Sometimes we throw a type on a name to make for nicer error messages, but without creating any entry in the type namespace // We need to know which is which when we stringify types. std::optional syntheticName; - std::map methodDefinitionLocations; std::vector instantiatedTypeParams; + std::vector instantiatedTypePackParams; ModuleName definitionModuleName; std::optional boundTo; Tags tags; + + // Methods of this table that have an untyped self will use the same shared self type. + std::optional selfTy; }; // Represents a metatable attached to a table typevar. Somewhat analogous to a bound typevar. @@ -269,23 +419,26 @@ struct ClassTypeVar std::optional metatable; // metaclass? Tags tags; std::shared_ptr userData; + ModuleName definitionModuleName; - ClassTypeVar( - Name name, Props props, std::optional parent, std::optional metatable, Tags tags, std::shared_ptr userData) + ClassTypeVar(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, + std::shared_ptr userData, ModuleName definitionModuleName) : name(name) , props(props) , parent(parent) , metatable(metatable) , tags(tags) , userData(userData) + , definitionModuleName(definitionModuleName) { } }; struct TypeFun { - /// These should all be generic - std::vector typeParams; + // These should all be generic + std::vector typeParams; + std::vector typePackParams; /** The underlying type. * @@ -293,6 +446,48 @@ struct TypeFun * You must first use TypeChecker::instantiateTypeFun to turn it into a real type. */ TypeId type; + + TypeFun() = default; + + explicit TypeFun(TypeId ty) + : type(ty) + { + } + + TypeFun(std::vector typeParams, TypeId type) + : typeParams(std::move(typeParams)) + , type(type) + { + } + + TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) + : typeParams(std::move(typeParams)) + , typePackParams(std::move(typePackParams)) + , type(type) + { + } + + bool operator==(const TypeFun& rhs) const; +}; + +/** Represents a pending type alias instantiation. + * + * In order to afford (co)recursive type aliases, we need to reason about a + * partially-complete instantiation. This requires encoding more information in + * a type variable than a BlockedTypeVar affords, hence this. Each + * PendingExpansionTypeVar has a corresponding TypeAliasExpansionConstraint + * enqueued in the solver to convert it to an actual instantiated type + */ +struct PendingExpansionTypeVar +{ + PendingExpansionTypeVar(std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments); + std::optional prefix; + AstName name; + std::vector typeArguments; + std::vector packArguments; + size_t index; + + static size_t nextIndex; }; // Anything! All static checking is off. @@ -300,11 +495,13 @@ struct AnyTypeVar { }; +// T | U struct UnionTypeVar { std::vector options; }; +// T & U struct IntersectionTypeVar { std::vector parts; @@ -315,10 +512,27 @@ struct LazyTypeVar std::function thunk; }; +struct UnknownTypeVar +{ +}; + +struct NeverTypeVar +{ +}; + +// ~T +// TODO: Some simplification step that overwrites the type graph to make sure negation +// types disappear from the user's view, and (?) a debug flag to disable that +struct NegationTypeVar +{ + TypeId ty; +}; + using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = + Unifiable::Variant; struct TypeVar final { @@ -338,6 +552,13 @@ struct TypeVar final { } + // Re-assignes the content of the type, but doesn't change the owning arena and can't make type persistent. + void reassign(const TypeVar& rhs) + { + ty = rhs.ty; + documentationSymbol = rhs.documentationSymbol; + } + TypeVariant ty; // Kludge: A persistent TypeVar is one that belongs to the global scope. @@ -348,9 +569,6 @@ struct TypeVar final std::optional documentationSymbol; // Pointer to the type arena that allocated this type. - // Do not depend on the value of this under any circumstances. This is for - // debugging purposes only. This is only set in debug builds; it is nullptr - // in all other environments. TypeArena* owningArena = nullptr; bool operator==(const TypeVar& rhs) const; @@ -358,13 +576,16 @@ struct TypeVar final TypeVar& operator=(const TypeVariant& rhs); TypeVar& operator=(TypeVariant&& rhs); + + TypeVar& operator=(const TypeVar& rhs); }; -using SeenSet = std::set>; +using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs); // Follow BoundTypeVars until we get to something real TypeId follow(TypeId t); +TypeId follow(TypeId t, std::function mapper); std::vector flattenIntersection(TypeId ty); @@ -378,7 +599,10 @@ bool isOptional(TypeId ty); bool isTableIntersection(TypeId ty); bool isOverloadedFunction(TypeId ty); -std::optional getMetatable(TypeId type); +// True when string is a subtype of ty +bool maybeString(TypeId ty); + +std::optional getMetatable(TypeId type, NotNull singletonTypes); TableTypeVar* getMutableTableType(TypeId type); const TableTypeVar* getTableType(TypeId type); @@ -386,83 +610,84 @@ const TableTypeVar* getTableType(TypeId type); // Returns nullptr if the type has no name. const std::string* getName(TypeId type); +// Returns name of the module where type was defined if type has that information +std::optional getDefinitionModuleName(TypeId type); + // Checks whether a union contains all types of another union. bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub); -// Checks if a type conains generic type binders +// Checks if a type contains generic type binders bool isGeneric(const TypeId ty); // Checks if a type may be instantiated to one containing generic type binders bool maybeGeneric(const TypeId ty); +// Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton +bool maybeSingleton(TypeId ty); + +// Checks if the length operator can be applied on the value of type +bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount); + struct SingletonTypes { - const TypeId nilType = &nilType_; - const TypeId numberType = &numberType_; - const TypeId stringType = &stringType_; - const TypeId booleanType = &booleanType_; - const TypeId threadType = &threadType_; - const TypeId anyType = &anyType_; - const TypeId errorType = &errorType_; - SingletonTypes(); + ~SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete; + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(); + TypePackId errorRecoveryTypePack(); + private: std::unique_ptr arena; - TypeVar nilType_; - TypeVar numberType_; - TypeVar stringType_; - TypeVar booleanType_; - TypeVar threadType_; - TypeVar anyType_; - TypeVar errorType_; + bool debugFreezeArena = false; TypeId makeStringMetatable(); -}; -extern SingletonTypes singletonTypes; +public: + const TypeId nilType; + const TypeId numberType; + const TypeId stringType; + const TypeId booleanType; + const TypeId threadType; + const TypeId functionType; + const TypeId trueType; + const TypeId falseType; + const TypeId anyType; + const TypeId unknownType; + const TypeId neverType; + const TypeId errorType; + const TypeId falsyType; // No type binding! + const TypeId truthyType; // No type binding! + + const TypePackId anyTypePack; + const TypePackId neverTypePack; + const TypePackId uninhabitableTypePack; + const TypePackId errorTypePack; +}; void persist(TypeId ty); void persist(TypePackId tp); -struct ToDotOptions -{ - bool showPointers = true; // Show pointer value in the node label - bool duplicatePrimitives = true; // Display primitive types and 'any' as separate nodes -}; - -std::string toDot(TypeId ty, const ToDotOptions& opts); -std::string toDot(TypePackId tp, const ToDotOptions& opts); - -std::string toDot(TypeId ty); -std::string toDot(TypePackId tp); - -void dumpDot(TypeId ty); -void dumpDot(TypePackId tp); - const TypeLevel* getLevel(TypeId ty); TypeLevel* getMutableLevel(TypeId ty); +std::optional getLevel(TypePackId tp); + const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); -bool hasGeneric(TypeId ty); -bool hasGeneric(TypePackId tp); - TypeVar* asMutable(TypeId ty); template const T* get(TypeId tv) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tv); + LUAU_ASSERT(tv); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tv->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); return get_if(&tv->ty); } @@ -470,23 +695,33 @@ const T* get(TypeId tv) template T* getMutable(TypeId tv) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tv); + LUAU_ASSERT(tv); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tv->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); return get_if(&asMutable(tv)->ty); } -/* Traverses the UnionTypeVar yielding each TypeId. - * If the iterator encounters a nested UnionTypeVar, it will instead yield each TypeId within. - * - * Beware: the iterator does not currently filter for unique TypeIds. This may change in the future. +const std::vector& getTypes(const UnionTypeVar* utv); +const std::vector& getTypes(const IntersectionTypeVar* itv); + +template +struct TypeIterator; + +using UnionTypeVarIterator = TypeIterator; +UnionTypeVarIterator begin(const UnionTypeVar* utv); +UnionTypeVarIterator end(const UnionTypeVar* utv); + +using IntersectionTypeVarIterator = TypeIterator; +IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv); +IntersectionTypeVarIterator end(const IntersectionTypeVar* itv); + +/* Traverses the type T yielding each TypeId. + * If the iterator encounters a nested type T, it will instead yield each TypeId within. */ -struct UnionTypeVarIterator +template +struct TypeIterator { using value_type = Luau::TypeId; using pointer = value_type*; @@ -494,38 +729,126 @@ struct UnionTypeVarIterator using difference_type = size_t; using iterator_category = std::input_iterator_tag; - explicit UnionTypeVarIterator(const UnionTypeVar* utv); + explicit TypeIterator(const T* t) + { + LUAU_ASSERT(t); - UnionTypeVarIterator& operator++(); - UnionTypeVarIterator operator++(int); - bool operator!=(const UnionTypeVarIterator& rhs); - bool operator==(const UnionTypeVarIterator& rhs); + const std::vector& types = getTypes(t); + if (!types.empty()) + stack.push_front({t, 0}); - const TypeId& operator*(); + seen.insert(t); + descend(); + } - friend UnionTypeVarIterator end(const UnionTypeVar* utv); + TypeIterator& operator++() + { + advance(); + descend(); + return *this; + } + + TypeIterator operator++(int) + { + TypeIterator copy = *this; + ++copy; + return copy; + } + + bool operator==(const TypeIterator& rhs) const + { + if (!stack.empty() && !rhs.stack.empty()) + return stack.front() == rhs.stack.front(); + + return stack.empty() && rhs.stack.empty(); + } + + bool operator!=(const TypeIterator& rhs) const + { + return !(*this == rhs); + } + + const TypeId& operator*() + { + descend(); + + LUAU_ASSERT(!stack.empty()); + + auto [t, currentIndex] = stack.front(); + LUAU_ASSERT(t); + + const std::vector& types = getTypes(t); + LUAU_ASSERT(currentIndex < types.size()); + + const TypeId& ty = types[currentIndex]; + LUAU_ASSERT(!get(follow(ty))); + + return ty; + } + + // Normally, we'd have `begin` and `end` be a template but there's too much trouble + // with templates portability in this area, so not worth it. Thanks MSVC. + friend UnionTypeVarIterator end(const UnionTypeVar*); + friend IntersectionTypeVarIterator end(const IntersectionTypeVar*); private: - UnionTypeVarIterator() = default; + TypeIterator() = default; - // (UnionTypeVar* utv, size_t currentIndex) - using SavedIterInfo = std::pair; + // (T* t, size_t currentIndex) + using SavedIterInfo = std::pair; std::deque stack; - std::unordered_set seen; // Only needed to protect the iterator from hanging the thread. + std::unordered_set seen; // Only needed to protect the iterator from hanging the thread. - void advance(); - void descend(); + void advance() + { + while (!stack.empty()) + { + auto& [t, currentIndex] = stack.front(); + ++currentIndex; + + const std::vector& types = getTypes(t); + if (currentIndex >= types.size()) + stack.pop_front(); + else + break; + } + } + + void descend() + { + while (!stack.empty()) + { + auto [current, currentIndex] = stack.front(); + const std::vector& types = getTypes(current); + if (auto inner = get(follow(types[currentIndex]))) + { + // If we're about to descend into a cyclic type, we should skip over this. + // Ideally this should never happen, but alas it does from time to time. :( + if (seen.find(inner) != seen.end()) + advance(); + else + { + seen.insert(inner); + stack.push_front({inner, 0}); + } + + continue; + } + + break; + } + } }; -UnionTypeVarIterator begin(const UnionTypeVar* utv); -UnionTypeVarIterator end(const UnionTypeVar* utv); - using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); -// TEMP: Clip this prototype with FFlag::LuauStringMetatable -std::optional> magicFunctionFormat( - struct TypeChecker& typechecker, const std::shared_ptr& scope, const AstExprCall& expr, ExprResult exprResult); +void attachTag(TypeId ty, const std::string& tagName); +void attachTag(Property& prop, const std::string& tagName); + +bool hasTag(TypeId ty, const std::string& tagName); +bool hasTag(const Property& prop, const std::string& tagName); +bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work. } // namespace Luau diff --git a/Analysis/include/Luau/TypedAllocator.h b/Analysis/include/Luau/TypedAllocator.h index 0ded1489..a5dd17b6 100644 --- a/Analysis/include/Luau/TypedAllocator.h +++ b/Analysis/include/Luau/TypedAllocator.h @@ -10,7 +10,7 @@ namespace Luau { void* pagedAllocate(size_t size); -void pagedDeallocate(void* ptr); +void pagedDeallocate(void* ptr, size_t size); void pagedFreeze(void* ptr, size_t size); void pagedUnfreeze(void* ptr, size_t size); @@ -20,9 +20,15 @@ class TypedAllocator public: TypedAllocator() { - appendBlock(); + currentBlockSize = kBlockSize; } + TypedAllocator(const TypedAllocator&) = delete; + TypedAllocator& operator=(const TypedAllocator&) = delete; + + TypedAllocator(TypedAllocator&&) = default; + TypedAllocator& operator=(TypedAllocator&&) = default; + ~TypedAllocator() { if (frozen) @@ -59,12 +65,12 @@ public: bool empty() const { - return stuff.size() == 1 && currentBlockSize == 0; + return stuff.empty(); } size_t size() const { - return kBlockSize * (stuff.size() - 1) + currentBlockSize; + return stuff.empty() ? 0 : kBlockSize * (stuff.size() - 1) + currentBlockSize; } void clear() @@ -72,7 +78,8 @@ public: if (frozen) unfreeze(); free(); - appendBlock(); + + currentBlockSize = kBlockSize; } void freeze() @@ -106,7 +113,7 @@ private: for (size_t i = 0; i < blockSize; ++i) block[i].~T(); - pagedDeallocate(block); + pagedDeallocate(block, kBlockSizeBytes); } stuff.clear(); diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 10dbf333..c43daa21 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -8,6 +8,8 @@ namespace Luau { +struct Scope; + /** * The 'level' of a TypeVar is an indirect way to talk about the scope that it 'belongs' too. * To start, read http://okmij.org/ftp/ML/generalization.html @@ -24,7 +26,7 @@ struct TypeLevel int level = 0; int subLevel = 0; - // Returns true if the typelevel "this" is "bigger" than rhs + // Returns true if the level of "this" belongs to an equal or larger scope than that of rhs bool subsumes(const TypeLevel& rhs) const { if (level < rhs.level) @@ -38,6 +40,15 @@ struct TypeLevel return false; } + // Returns true if the level of "this" belongs to a larger (not equal) scope than that of rhs + bool subsumesStrict(const TypeLevel& rhs) const + { + if (level == rhs.level && subLevel == rhs.subLevel) + return false; + else + return subsumes(rhs); + } + TypeLevel incr() const { TypeLevel result; @@ -47,6 +58,14 @@ struct TypeLevel } }; +inline TypeLevel max(const TypeLevel& a, const TypeLevel& b) +{ + if (a.subsumes(b)) + return b; + else + return a; +} + inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) { if (a.subsumes(b)) @@ -55,7 +74,9 @@ inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) return b; } -namespace Unifiable +} // namespace Luau + +namespace Luau::Unifiable { using Name = std::string; @@ -63,19 +84,19 @@ using Name = std::string; struct Free { explicit Free(TypeLevel level); - Free(TypeLevel level, bool DEPRECATED_canBeGeneric); + explicit Free(Scope* scope); + explicit Free(Scope* scope, TypeLevel level); int index; TypeLevel level; - // Removed by FFlag::LuauRankNTypes - bool DEPRECATED_canBeGeneric = false; + Scope* scope = nullptr; // True if this free type variable is part of a mutually // recursive type alias whose definitions haven't been // resolved yet. bool forwardedTypeAlias = false; private: - static int nextIndex; + static int DEPRECATED_nextIndex; }; template @@ -95,19 +116,24 @@ struct Generic Generic(); explicit Generic(TypeLevel level); explicit Generic(const Name& name); + explicit Generic(Scope* scope); Generic(TypeLevel level, const Name& name); + Generic(Scope* scope, const Name& name); int index; TypeLevel level; + Scope* scope = nullptr; Name name; - bool explicitName; + bool explicitName = false; private: - static int nextIndex; + static int DEPRECATED_nextIndex; }; struct Error { + // This constructor has to be public, since it's used in TypeVar and TypePack, + // but shouldn't be called directly. Please use errorRecoveryType() instead. Error(); int index; @@ -117,7 +143,6 @@ private: }; template -using Variant = Variant, Generic, Error, Value...>; +using Variant = Luau::Variant, Generic, Error, Value...>; -} // namespace Unifiable -} // namespace Luau +} // namespace Luau::Unifiable diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 0ddc3cc0..af3864ea 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -3,9 +3,13 @@ #include "Luau/Error.h" #include "Luau/Location.h" +#include "Luau/ParseOptions.h" +#include "Luau/Scope.h" +#include "Luau/Substitution.h" #include "Luau/TxnLog.h" -#include "Luau/TypeInfer.h" -#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header. +#include "Luau/TypeArena.h" +#include "Luau/UnifierSharedState.h" +#include "Normalize.h" #include @@ -18,81 +22,133 @@ enum Variance Invariant }; -struct UnifierCounters +// A substitution which replaces singleton types by their wider types +struct Widen : Substitution { - int recursionCount = 0; - int iterationCount = 0; + Widen(TypeArena* arena, NotNull singletonTypes) + : Substitution(TxnLog::empty(), arena) + , singletonTypes(singletonTypes) + { + } + + NotNull singletonTypes; + + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId ty) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId ty) override; + bool ignoreChildren(TypeId ty) override; + + TypeId operator()(TypeId ty); + TypePackId operator()(TypePackId ty); +}; + +// TODO: Use this more widely. +struct UnifierOptions +{ + bool isFunctionCall = false; }; struct Unifier { TypeArena* const types; + NotNull singletonTypes; + NotNull normalizer; Mode mode; - ScopePtr globalScope; // sigh. Needed solely to get at string's metatable. + NotNull scope; // const Scope maybe TxnLog log; ErrorVec errors; Location location; Variance variance = Covariant; + bool normalize; // Normalize unions and intersections if necessary + bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; - std::shared_ptr counters; - InternalErrorReporter* iceHandler; + UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler); - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters = nullptr); + Unifier( + NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. - ErrorVec canUnify(TypeId superTy, TypeId subTy); - ErrorVec canUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); + ErrorVec canUnify(TypeId subTy, TypeId superTy); + ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); - /** Attempt to unify left with right. + /** Attempt to unify. * Populate the vector errors with any type errors that may arise. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. */ - void tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); private: - void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); - void tryUnifyPrimitives(TypeId superTy, TypeId subTy); - void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); - void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); - void tryUnifyFreeTable(TypeId free, TypeId other); - void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection); - void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); - void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); - void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); + void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId superTy); + void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall); + void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv); + void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); + void tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, + std::optional error = std::nullopt); + void tryUnifyPrimitives(TypeId subTy, TypeId superTy); + void tryUnifySingletons(TypeId subTy, TypeId superTy); + void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); + void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); + void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy); + void tryUnifyNegationWithType(TypeId subTy, TypeId superTy); + + TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args); + + TypeId widen(TypeId ty); + TypePackId widen(TypePackId tp); + + TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); + + bool canCacheResult(TypeId subTy, TypeId superTy); + void cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount); public: - void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); + void tryUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); private: - void tryUnify_(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); - void tryUnifyVariadics(TypePackId superTy, TypePackId subTy, bool reversed, int subOffset = 0); + void tryUnify_(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); + void tryUnifyVariadics(TypePackId subTy, TypePackId superTy, bool reversed, int subOffset = 0); - void tryUnifyWithAny(TypeId any, TypeId ty); - void tryUnifyWithAny(TypePackId any, TypePackId ty); + void tryUnifyWithAny(TypeId subTy, TypeId anyTy); + void tryUnifyWithAny(TypePackId subTy, TypePackId anyTp); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); - std::optional findMetatableEntry(TypeId type, std::string entry); + + TxnLog combineLogsIntoIntersection(std::vector logs); + TxnLog combineLogsIntoUnion(std::vector logs); public: - // Report an "infinite type error" if the type "needle" already occurs within "haystack" - void occursCheck(TypeId needle, TypeId haystack); - void occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack); - void occursCheck(TypePackId needle, TypePackId haystack); - void occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack); + // Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error" + bool occursCheck(TypeId needle, TypeId haystack); + bool occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); + bool occursCheck(TypePackId needle, TypePackId haystack); + bool occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); Unifier makeChildUnifier(); + void reportError(TypeError err); + LUAU_NOINLINE void reportError(Location location, TypeErrorData data); + private: bool isNonstrictMode() const; + TypeMismatch::Context mismatchContext(); void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType); + void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType); [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + + // Available after regular type pack unification errors + std::optional firstPackErrorPos; }; +void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); + } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h new file mode 100644 index 00000000..d4315d47 --- /dev/null +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -0,0 +1,55 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/Error.h" +#include "Luau/TypeVar.h" +#include "Luau/TypePack.h" + +#include + +namespace Luau +{ +struct InternalErrorReporter; + +struct TypeIdPairHash +{ + size_t hashOne(Luau::TypeId key) const + { + return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9); + } + + size_t operator()(const std::pair& x) const + { + return hashOne(x.first) ^ (hashOne(x.second) << 1); + } +}; + +struct UnifierCounters +{ + int recursionCount = 0; + int recursionLimit = 0; + int iterationCount = 0; + int iterationLimit = 0; +}; + +struct UnifierSharedState +{ + UnifierSharedState(InternalErrorReporter* iceHandler) + : iceHandler(iceHandler) + { + } + + InternalErrorReporter* iceHandler; + + DenseHashMap skipCacheForType{nullptr}; + DenseHashSet, TypeIdPairHash> cachedUnify{{nullptr, nullptr}}; + DenseHashMap, TypeErrorData, TypeIdPairHash> cachedUnifyError{{nullptr, nullptr}}; + + DenseHashSet tempSeenTy{nullptr}; + DenseHashSet tempSeenTp{nullptr}; + + UnifierCounters counters; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index 63d5a65c..016c51f6 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -2,45 +2,15 @@ #pragma once #include "Luau/Common.h" - -#ifndef LUAU_USE_STD_VARIANT -#define LUAU_USE_STD_VARIANT 0 -#endif - -#if LUAU_USE_STD_VARIANT -#include -#else #include #include #include #include -#endif +#include namespace Luau { -#if LUAU_USE_STD_VARIANT -template -using Variant = std::variant; - -template -auto visit(Visitor&& vis, Variant&& var) -{ - // This change resolves the ABI issues with std::variant on libc++; std::visit normally throws bad_variant_access - // but it requires an update to libc++.dylib which ships with macOS 10.14. To work around this, we assert on valueless - // variants since we will never generate them and call into a libc++ function that doesn't throw. - LUAU_ASSERT(!var.valueless_by_exception()); - -#ifdef __APPLE__ - // See https://stackoverflow.com/a/53868971/503215 - return std::__variant_detail::__visitation::__variant::__visit_value(vis, var); -#else - return std::visit(vis, var); -#endif -} - -using std::get_if; -#else template class Variant { @@ -88,13 +58,15 @@ public: constexpr int tid = getTypeId(); typeId = tid; - new (&storage) TT(value); + new (&storage) TT(std::forward(value)); } Variant(const Variant& other) { + static constexpr FnCopy table[sizeof...(Ts)] = {&fnCopy...}; + typeId = other.typeId; - tableCopy[typeId](&storage, &other.storage); + table[typeId](&storage, &other.storage); } Variant(Variant&& other) @@ -126,6 +98,20 @@ public: return *this; } + template + T& emplace(Args&&... args) + { + using TT = std::decay_t; + constexpr int tid = getTypeId(); + static_assert(tid >= 0, "unsupported T"); + + tableDtor[typeId](&storage); + typeId = tid; + new (&storage) TT{std::forward(args)...}; + + return *reinterpret_cast(&storage); + } + template const T* get_if() const { @@ -208,7 +194,6 @@ private: return *static_cast(lhs) == *static_cast(rhs); } - static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy...}; static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove...}; static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor...}; @@ -248,6 +233,8 @@ static void fnVisitV(Visitor& vis, std::conditional_t, const template auto visit(Visitor&& vis, const Variant& var) { + static_assert(std::conjunction_v...>, "visitor must accept every alternative as an argument"); + using Result = std::invoke_result_t::first_alternative>; static_assert(std::conjunction_v>...>, "visitor result type must be consistent between alternatives"); @@ -273,6 +260,8 @@ auto visit(Visitor&& vis, const Variant& var) template auto visit(Visitor&& vis, Variant& var) { + static_assert(std::conjunction_v...>, "visitor must accept every alternative as an argument"); + using Result = std::invoke_result_t::first_alternative&>; static_assert(std::conjunction_v>...>, "visitor result type must be consistent between alternatives"); @@ -294,7 +283,6 @@ auto visit(Visitor&& vis, Variant& var) return res; } } -#endif template inline constexpr bool always_false_v = false; diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index df0bd420..3dcddba1 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -1,8 +1,15 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/TypeVar.h" +#include + +#include "Luau/DenseHash.h" +#include "Luau/RecursionCounter.h" #include "Luau/TypePack.h" +#include "Luau/TypeVar.h" + +LUAU_FASTINT(LuauVisitRecursionLimit) +LUAU_FASTFLAG(LuauCompleteVisitor); namespace Luau { @@ -32,169 +39,365 @@ inline bool hasSeen(std::unordered_set& seen, const void* tv) return !seen.insert(ttv).second; } +inline bool hasSeen(DenseHashSet& seen, const void* tv) +{ + void* ttv = const_cast(tv); + + if (seen.contains(ttv)) + return true; + + seen.insert(ttv); + return false; +} + inline void unsee(std::unordered_set& seen, const void* tv) { void* ttv = const_cast(tv); seen.erase(ttv); } -template -void visit(TypePackId tp, F& f, std::unordered_set& seen); - -template -void visit(TypeId ty, F& f, std::unordered_set& seen) +inline void unsee(DenseHashSet& seen, const void* tv) { - if (visit_detail::hasSeen(seen, ty)) - { - f.cycle(ty); - return; - } - - if (auto btv = get(ty)) - { - if (apply(ty, *btv, seen, f)) - visit(btv->boundTo, f, seen); - } - - else if (auto ftv = get(ty)) - apply(ty, *ftv, seen, f); - - else if (auto gtv = get(ty)) - apply(ty, *gtv, seen, f); - - else if (auto etv = get(ty)) - apply(ty, *etv, seen, f); - - else if (auto ptv = get(ty)) - apply(ty, *ptv, seen, f); - - else if (auto ftv = get(ty)) - { - if (apply(ty, *ftv, seen, f)) - { - visit(ftv->argTypes, f, seen); - visit(ftv->retType, f, seen); - } - } - - else if (auto ttv = get(ty)) - { - if (apply(ty, *ttv, seen, f)) - { - for (auto& [_name, prop] : ttv->props) - visit(prop.type, f, seen); - - if (ttv->indexer) - { - visit(ttv->indexer->indexType, f, seen); - visit(ttv->indexer->indexResultType, f, seen); - } - } - } - - else if (auto mtv = get(ty)) - { - if (apply(ty, *mtv, seen, f)) - { - visit(mtv->table, f, seen); - visit(mtv->metatable, f, seen); - } - } - - else if (auto ctv = get(ty)) - { - if (apply(ty, *ctv, seen, f)) - { - for (const auto& [name, prop] : ctv->props) - visit(prop.type, f, seen); - - if (ctv->parent) - visit(*ctv->parent, f, seen); - - if (ctv->metatable) - visit(*ctv->metatable, f, seen); - } - } - - else if (auto atv = get(ty)) - apply(ty, *atv, seen, f); - - else if (auto utv = get(ty)) - { - if (apply(ty, *utv, seen, f)) - { - for (TypeId optTy : utv->options) - visit(optTy, f, seen); - } - } - - else if (auto itv = get(ty)) - { - if (apply(ty, *itv, seen, f)) - { - for (TypeId partTy : itv->parts) - visit(partTy, f, seen); - } - } - - visit_detail::unsee(seen, ty); + // When DenseHashSet is used for 'visitTypeVarOnce', where don't forget visited elements } -template -void visit(TypePackId tp, F& f, std::unordered_set& seen) -{ - if (visit_detail::hasSeen(seen, tp)) - { - f.cycle(tp); - return; - } - - if (auto btv = get(tp)) - { - if (apply(tp, *btv, seen, f)) - visit(btv->boundTo, f, seen); - } - - else if (auto ftv = get(tp)) - apply(tp, *ftv, seen, f); - - else if (auto gtv = get(tp)) - apply(tp, *gtv, seen, f); - - else if (auto etv = get(tp)) - apply(tp, *etv, seen, f); - - else if (auto pack = get(tp)) - { - apply(tp, *pack, seen, f); - - for (TypeId ty : pack->head) - visit(ty, f, seen); - - if (pack->tail) - visit(*pack->tail, f, seen); - } - else if (auto pack = get(tp)) - { - apply(tp, *pack, seen, f); - visit(pack->ty, f, seen); - } - - visit_detail::unsee(seen, tp); -} } // namespace visit_detail -template -void visitTypeVar(TID ty, F& f, std::unordered_set& seen) +template +struct GenericTypeVarVisitor { - visit_detail::visit(ty, f, seen); -} + using Set = S; -template -void visitTypeVar(TID ty, F& f) + Set seen; + bool skipBoundTypes = false; + int recursionCounter = 0; + + GenericTypeVarVisitor() = default; + + explicit GenericTypeVarVisitor(Set seen, bool skipBoundTypes = false) + : seen(std::move(seen)) + , skipBoundTypes(skipBoundTypes) + { + } + + virtual void cycle(TypeId) {} + virtual void cycle(TypePackId) {} + + virtual bool visit(TypeId ty) + { + return true; + } + virtual bool visit(TypeId ty, const BoundTypeVar& btv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const FreeTypeVar& ftv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const GenericTypeVar& gtv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const ErrorTypeVar& etv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const FunctionTypeVar& ftv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const TableTypeVar& ttv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const MetatableTypeVar& mtv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const ClassTypeVar& ctv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const AnyTypeVar& atv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const UnknownTypeVar& utv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const NeverTypeVar& ntv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const UnionTypeVar& utv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const IntersectionTypeVar& itv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const BlockedTypeVar& btv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const PendingExpansionTypeVar& petv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const SingletonTypeVar& stv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const NegationTypeVar& ntv) + { + return visit(ty); + } + + virtual bool visit(TypePackId tp) + { + return true; + } + virtual bool visit(TypePackId tp, const BoundTypePack& btp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const FreeTypePack& ftp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const GenericTypePack& gtp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const Unifiable::Error& etp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const TypePack& pack) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const VariadicTypePack& vtp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const BlockedTypePack& btp) + { + return visit(tp); + } + + void traverse(TypeId ty) + { + RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit}; + + if (visit_detail::hasSeen(seen, ty)) + { + cycle(ty); + return; + } + + if (auto btv = get(ty)) + { + if (skipBoundTypes) + traverse(btv->boundTo); + else if (visit(ty, *btv)) + traverse(btv->boundTo); + } + else if (auto ftv = get(ty)) + visit(ty, *ftv); + else if (auto gtv = get(ty)) + visit(ty, *gtv); + else if (auto etv = get(ty)) + visit(ty, *etv); + else if (auto ptv = get(ty)) + visit(ty, *ptv); + else if (auto ftv = get(ty)) + { + if (visit(ty, *ftv)) + { + traverse(ftv->argTypes); + traverse(ftv->retTypes); + } + } + else if (auto ttv = get(ty)) + { + // Some visitors want to see bound tables, that's why we traverse the original type + if (skipBoundTypes && ttv->boundTo) + { + traverse(*ttv->boundTo); + } + else if (visit(ty, *ttv)) + { + if (ttv->boundTo) + { + traverse(*ttv->boundTo); + } + else + { + for (auto& [_name, prop] : ttv->props) + traverse(prop.type); + + if (ttv->indexer) + { + traverse(ttv->indexer->indexType); + traverse(ttv->indexer->indexResultType); + } + } + } + } + else if (auto mtv = get(ty)) + { + if (visit(ty, *mtv)) + { + traverse(mtv->table); + traverse(mtv->metatable); + } + } + else if (auto ctv = get(ty)) + { + if (visit(ty, *ctv)) + { + for (const auto& [name, prop] : ctv->props) + traverse(prop.type); + + if (ctv->parent) + traverse(*ctv->parent); + + if (ctv->metatable) + traverse(*ctv->metatable); + } + } + else if (auto atv = get(ty)) + visit(ty, *atv); + else if (auto utv = get(ty)) + { + if (visit(ty, *utv)) + { + for (TypeId optTy : utv->options) + traverse(optTy); + } + } + else if (auto itv = get(ty)) + { + if (visit(ty, *itv)) + { + for (TypeId partTy : itv->parts) + traverse(partTy); + } + } + else if (get(ty)) + { + // Visiting into LazyTypeVar may necessarily cause infinite expansion, so we don't do that on purpose. + // Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassTypeVar + // that doesn't need to be expanded. + } + else if (auto stv = get(ty)) + visit(ty, *stv); + else if (auto btv = get(ty)) + visit(ty, *btv); + else if (auto utv = get(ty)) + visit(ty, *utv); + else if (auto ntv = get(ty)) + visit(ty, *ntv); + else if (auto petv = get(ty)) + { + if (visit(ty, *petv)) + { + for (TypeId a : petv->typeArguments) + traverse(a); + + for (TypePackId a : petv->packArguments) + traverse(a); + } + } + else if (auto ntv = get(ty)) + visit(ty, *ntv); + else if (!FFlag::LuauCompleteVisitor) + return visit_detail::unsee(seen, ty); + else + LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypeId) is not exhaustive!"); + + visit_detail::unsee(seen, ty); + } + + void traverse(TypePackId tp) + { + if (visit_detail::hasSeen(seen, tp)) + { + cycle(tp); + return; + } + + if (auto btv = get(tp)) + { + if (visit(tp, *btv)) + traverse(btv->boundTo); + } + + else if (auto ftv = get(tp)) + visit(tp, *ftv); + + else if (auto gtv = get(tp)) + visit(tp, *gtv); + + else if (auto etv = get(tp)) + visit(tp, *etv); + + else if (auto pack = get(tp)) + { + bool res = visit(tp, *pack); + if (res) + { + for (TypeId ty : pack->head) + traverse(ty); + + if (pack->tail) + traverse(*pack->tail); + } + } + else if (auto pack = get(tp)) + { + bool res = visit(tp, *pack); + if (res) + traverse(pack->ty); + } + else if (auto btp = get(tp)) + visit(tp, *btp); + + else + LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!"); + + visit_detail::unsee(seen, tp); + } +}; + +/** Visit each type under a given type. Skips over cycles and keeps recursion depth under control. + * + * The same type may be visited multiple times if there are multiple distinct paths to it. If this is undesirable, use + * TypeVarOnceVisitor. + */ +struct TypeVarVisitor : GenericTypeVarVisitor> { - std::unordered_set seen; - visit_detail::visit(ty, f, seen); -} + explicit TypeVarVisitor(bool skipBoundTypes = false) + : GenericTypeVarVisitor{{}, skipBoundTypes} + { + } +}; + +/// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it. +struct TypeVarOnceVisitor : GenericTypeVarVisitor> +{ + explicit TypeVarOnceVisitor(bool skipBoundTypes = false) + : GenericTypeVarVisitor{DenseHashSet{nullptr}, skipBoundTypes} + { + } +}; } // namespace Luau diff --git a/Analysis/src/Anyification.cpp b/Analysis/src/Anyification.cpp new file mode 100644 index 00000000..5dd761c2 --- /dev/null +++ b/Analysis/src/Anyification.cpp @@ -0,0 +1,90 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Anyification.h" + +#include "Luau/Common.h" +#include "Luau/Normalize.h" +#include "Luau/TxnLog.h" + +LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) + +namespace Luau +{ + +Anyification::Anyification(TypeArena* arena, NotNull scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, + TypeId anyType, TypePackId anyTypePack) + : Substitution(TxnLog::empty(), arena) + , scope(scope) + , singletonTypes(singletonTypes) + , iceHandler(iceHandler) + , anyType(anyType) + , anyTypePack(anyTypePack) +{ +} + +Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, NotNull singletonTypes, InternalErrorReporter* iceHandler, + TypeId anyType, TypePackId anyTypePack) + : Anyification(arena, NotNull{scope.get()}, singletonTypes, iceHandler, anyType, anyTypePack) +{ +} + +bool Anyification::isDirty(TypeId ty) +{ + if (ty->persistent) + return false; + + if (const TableTypeVar* ttv = log->getMutable(ty)) + return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); + else if (log->getMutable(ty)) + return true; + else + return false; +} + +bool Anyification::isDirty(TypePackId tp) +{ + if (tp->persistent) + return false; + + if (log->getMutable(tp)) + return true; + else + return false; +} + +TypeId Anyification::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + if (const TableTypeVar* ttv = log->getMutable(ty)) + { + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; + clone.definitionModuleName = ttv->definitionModuleName; + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.tags = ttv->tags; + TypeId res = addType(std::move(clone)); + return res; + } + else + return anyType; +} + +TypePackId Anyification::clean(TypePackId tp) +{ + LUAU_ASSERT(isDirty(tp)); + return anyTypePack; +} + +bool Anyification::ignoreChildren(TypeId ty) +{ + if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + return true; + + return ty->persistent; +} +bool Anyification::ignoreChildren(TypePackId ty) +{ + return ty->persistent; +} + +} // namespace Luau diff --git a/Analysis/src/ApplyTypeFunction.cpp b/Analysis/src/ApplyTypeFunction.cpp new file mode 100644 index 00000000..b293ed3d --- /dev/null +++ b/Analysis/src/ApplyTypeFunction.cpp @@ -0,0 +1,64 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ApplyTypeFunction.h" + +LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) + +namespace Luau +{ + +bool ApplyTypeFunction::isDirty(TypeId ty) +{ + if (typeArguments.count(ty)) + return true; + else if (const FreeTypeVar* ftv = get(ty)) + { + if (ftv->forwardedTypeAlias) + encounteredForwardedType = true; + return false; + } + else + return false; +} + +bool ApplyTypeFunction::isDirty(TypePackId tp) +{ + if (typePackArguments.count(tp)) + return true; + else + return false; +} + +bool ApplyTypeFunction::ignoreChildren(TypeId ty) +{ + if (get(ty)) + return true; + else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + return true; + else + return false; +} + +bool ApplyTypeFunction::ignoreChildren(TypePackId tp) +{ + if (get(tp)) + return true; + else + return false; +} + +TypeId ApplyTypeFunction::clean(TypeId ty) +{ + TypeId& arg = typeArguments[ty]; + LUAU_ASSERT(arg); + return arg; +} + +TypePackId ApplyTypeFunction::clean(TypePackId tp) +{ + TypePackId& arg = typePackArguments[tp]; + LUAU_ASSERT(arg); + return arg; +} + +} // namespace Luau diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp similarity index 78% rename from Analysis/src/JsonEncoder.cpp rename to Analysis/src/AstJsonEncoder.cpp index a1018297..57c8c90b 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -1,8 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/JsonEncoder.h" +#include "Luau/AstJsonEncoder.h" #include "Luau/Ast.h" +#include "Luau/ParseResult.h" #include "Luau/StringUtils.h" +#include "Luau/Common.h" + +#include namespace Luau { @@ -74,6 +78,11 @@ struct AstJsonEncoder : public AstVisitor writeRaw(std::string_view{&c, 1}); } + void writeType(std::string_view propValue) + { + write("type", propValue); + } + template void write(std::string_view propName, const T& value) { @@ -96,9 +105,28 @@ struct AstJsonEncoder : public AstVisitor void write(double d) { - char b[256]; - sprintf(b, "%g", d); - writeRaw(b); + switch (fpclassify(d)) + { + case FP_INFINITE: + if (d < 0) + writeRaw("-Infinity"); + else + writeRaw("Infinity"); + break; + + case FP_NAN: + writeRaw("NaN"); + break; + + case FP_NORMAL: + case FP_SUBNORMAL: + case FP_ZERO: + default: + char b[32]; + snprintf(b, sizeof(b), "%.17g", d); + writeRaw(b); + break; + } } void writeString(std::string_view sv) @@ -110,8 +138,12 @@ struct AstJsonEncoder : public AstVisitor { if (c == '"') writeRaw("\\\""); - else if (c == '\0') - writeRaw("\\\0"); + else if (c == '\\') + writeRaw("\\\\"); + else if (c < ' ') + writeRaw(format("\\u%04x", c)); + else if (c == '\n') + writeRaw("\\n"); else writeRaw(c); } @@ -147,10 +179,21 @@ struct AstJsonEncoder : public AstVisitor { writeRaw(std::to_string(i)); } + void write(std::nullptr_t) + { + writeRaw("null"); + } void write(std::string_view str) { writeString(str); } + void write(std::optional name) + { + if (name) + write(*name); + else + writeRaw("null"); + } void write(AstName name) { writeString(name.value ? name.value : ""); @@ -174,7 +217,17 @@ struct AstJsonEncoder : public AstVisitor void write(AstLocal* local) { - write(local->name); + writeRaw("{"); + bool c = pushComma(); + if (local->annotation != nullptr) + write("luauType", local->annotation); + else + write("luauType", nullptr); + write("name", local->name); + writeType("AstLocal"); + write("location", local->location); + popComma(c); + writeRaw("}"); } void writeNode(AstNode* node) @@ -187,7 +240,7 @@ struct AstJsonEncoder : public AstVisitor { writeRaw("{"); bool c = pushComma(); - write("type", name); + writeType(name); writeNode(node); f(); popComma(c); @@ -261,7 +314,7 @@ struct AstJsonEncoder : public AstVisitor if (comma) writeRaw(","); else - comma = false; + comma = true; write(a); } @@ -311,7 +364,7 @@ struct AstJsonEncoder : public AstVisitor if (node->self) PROP(self); PROP(args); - if (node->hasReturnAnnotation) + if (node->returnAnnotation) PROP(returnAnnotation); PROP(vararg); PROP(varargLocation); @@ -325,10 +378,19 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(const std::optional& typeList) + { + if (typeList) + write(*typeList); + else + writeRaw("null"); + } + void write(const AstTypeList& typeList) { writeRaw("{"); bool c = pushComma(); + writeType("AstTypeList"); write("types", typeList.types); if (typeList.tailType) write("tailType", typeList.tailType); @@ -336,6 +398,30 @@ struct AstJsonEncoder : public AstVisitor writeRaw("}"); } + void write(const AstGenericType& genericType) + { + writeRaw("{"); + bool c = pushComma(); + writeType("AstGenericType"); + write("name", genericType.name); + if (genericType.defaultValue) + write("luauType", genericType.defaultValue); + popComma(c); + writeRaw("}"); + } + + void write(const AstGenericTypePack& genericTypePack) + { + writeRaw("{"); + bool c = pushComma(); + writeType("AstGenericTypePack"); + write("name", genericTypePack.name); + if (genericTypePack.defaultValue) + write("luauType", genericTypePack.defaultValue); + popComma(c); + writeRaw("}"); + } + void write(AstExprTable::Item::Kind kind) { switch (kind) @@ -352,35 +438,46 @@ struct AstJsonEncoder : public AstVisitor void write(const AstExprTable::Item& item) { writeRaw("{"); - bool comma = pushComma(); + bool c = pushComma(); + writeType("AstExprTableItem"); write("kind", item.kind); switch (item.kind) { case AstExprTable::Item::List: - write(item.value); + write("value", item.value); break; default: - write(item.key); - writeRaw(","); - write(item.value); + write("key", item.key); + write("value", item.value); break; } - popComma(comma); + popComma(c); writeRaw("}"); } + void write(class AstExprIfElse* node) + { + writeNode(node, "AstExprIfElse", [&]() { + PROP(condition); + PROP(hasThen); + PROP(trueExpr); + PROP(hasElse); + PROP(falseExpr); + }); + } + + void write(class AstExprInterpString* node) + { + writeNode(node, "AstExprInterpString", [&]() { + PROP(strings); + PROP(expressions); + }); + } + void write(class AstExprTable* node) { writeNode(node, "AstExprTable", [&]() { - bool comma = false; - for (const auto& prop : node->items) - { - if (comma) - writeRaw(","); - else - comma = false; - write(prop); - } + PROP(items); }); } @@ -389,11 +486,11 @@ struct AstJsonEncoder : public AstVisitor switch (op) { case AstExprUnary::Not: - return writeString("not"); + return writeString("Not"); case AstExprUnary::Minus: - return writeString("minus"); + return writeString("Minus"); case AstExprUnary::Len: - return writeString("len"); + return writeString("Len"); } } @@ -492,14 +589,14 @@ struct AstJsonEncoder : public AstVisitor PROP(thenbody); if (node->elsebody) PROP(elsebody); - PROP(hasThen); + write("hasThen", node->thenLocation.has_value()); PROP(hasEnd); }); } void write(class AstStatWhile* node) { - writeNode(node, "AtStatWhile", [&]() { + writeNode(node, "AstStatWhile", [&]() { PROP(condition); PROP(body); PROP(hasDo); @@ -612,6 +709,7 @@ struct AstJsonEncoder : public AstVisitor writeNode(node, "AstStatTypeAlias", [&]() { PROP(name); PROP(generics); + PROP(genericPacks); PROP(type); PROP(exported); }); @@ -641,7 +739,8 @@ struct AstJsonEncoder : public AstVisitor writeRaw("{"); bool c = pushComma(); write("name", prop.name); - write("type", prop.ty); + writeType("AstDeclaredClassProp"); + write("luauType", prop.ty); popComma(c); writeRaw("}"); } @@ -664,13 +763,21 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(struct AstTypeOrPack node) + { + if (node.type) + write(node.type); + else + write(node.typePack); + } + void write(class AstTypeReference* node) { writeNode(node, "AstTypeReference", [&]() { - if (node->hasPrefix) + if (node->prefix) PROP(prefix); PROP(name); - PROP(generics); + PROP(parameters); }); } @@ -680,8 +787,9 @@ struct AstJsonEncoder : public AstVisitor bool c = pushComma(); write("name", prop.name); + writeType("AstTableProp"); write("location", prop.location); - write("type", prop.type); + write("propType", prop.type); popComma(c); writeRaw("}"); @@ -695,6 +803,24 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(struct AstTableIndexer* indexer) + { + if (indexer) + { + writeRaw("{"); + bool c = pushComma(); + write("location", indexer->location); + write("indexType", indexer->indexType); + write("resultType", indexer->resultType); + popComma(c); + writeRaw("}"); + } + else + { + writeRaw("null"); + } + } + void write(class AstTypeFunction* node) { writeNode(node, "AstTypeFunction", [&]() { @@ -734,6 +860,13 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(class AstTypePackExplicit* node) + { + writeNode(node, "AstTypePackExplicit", [&]() { + PROP(typeList); + }); + } + void write(class AstTypePackVariadic* node) { writeNode(node, "AstTypePackVariadic", [&]() { @@ -778,6 +911,18 @@ struct AstJsonEncoder : public AstVisitor return false; } + bool visit(class AstExprIfElse* node) override + { + write(node); + return false; + } + + bool visit(class AstExprInterpString* node) override + { + write(node); + return false; + } + bool visit(class AstExprLocal* node) override { write(node); @@ -1018,6 +1163,12 @@ struct AstJsonEncoder : public AstVisitor return false; } + bool visit(class AstTypePackExplicit* node) override + { + write(node); + return false; + } + bool visit(class AstTypePackVariadic* node) override { write(node); @@ -1029,6 +1180,41 @@ struct AstJsonEncoder : public AstVisitor write(node); return false; } + + void writeComments(std::vector commentLocations) + { + bool commentComma = false; + for (Comment comment : commentLocations) + { + if (commentComma) + { + writeRaw(","); + } + else + { + commentComma = true; + } + writeRaw("{"); + bool c = pushComma(); + switch (comment.type) + { + case Lexeme::Comment: + writeType("Comment"); + break; + case Lexeme::BlockComment: + writeType("BlockComment"); + break; + case Lexeme::BrokenComment: + writeType("BrokenComment"); + break; + default: + break; + } + write("location", comment.location); + popComma(c); + writeRaw("}"); + } + } }; std::string toJson(AstNode* node) @@ -1038,4 +1224,15 @@ std::string toJson(AstNode* node) return encoder.str(); } +std::string toJson(AstNode* node, const std::vector& commentLocations) +{ + AstJsonEncoder encoder; + encoder.writeRaw(R"({"root":)"); + node->visit(&encoder); + encoder.writeRaw(R"(,"commentLocations":[)"); + encoder.writeComments(commentLocations); + encoder.writeRaw("]}"); + return encoder.str(); +} + } // namespace Luau diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index d3de1754..e6e7f3d9 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -2,6 +2,7 @@ #include "Luau/AstQuery.h" #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/ToString.h" @@ -10,12 +11,123 @@ #include +LUAU_FASTFLAG(LuauCompleteTableKeysBetter); + namespace Luau { namespace { +struct AutocompleteNodeFinder : public AstVisitor +{ + const Position pos; + std::vector ancestry; + + explicit AutocompleteNodeFinder(Position pos, AstNode* root) + : pos(pos) + { + } + + bool visit(AstExpr* expr) override + { + if (FFlag::LuauCompleteTableKeysBetter) + { + if (expr->location.begin <= pos && pos <= expr->location.end) + { + ancestry.push_back(expr); + return true; + } + return false; + } + else + { + if (expr->location.begin < pos && pos <= expr->location.end) + { + ancestry.push_back(expr); + return true; + } + return false; + } + } + + bool visit(AstStat* stat) override + { + if (stat->location.begin < pos && pos <= stat->location.end) + { + ancestry.push_back(stat); + return true; + } + return false; + } + + bool visit(AstType* type) override + { + if (type->location.begin < pos && pos <= type->location.end) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(AstTypeError* type) override + { + // For a missing type, match the whole range including the start position + if (type->isMissing && type->location.containsClosed(pos)) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(class AstTypePack* typePack) override + { + return true; + } + + bool visit(AstStatBlock* block) override + { + // If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite. + if (ancestry.empty()) + { + ancestry.push_back(block); + return true; + } + + // AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes. + // ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz} + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // Type annotation error might intersect the block statement when the function header is being written, + // annotation takes priority + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // If the cursor is at the end of an expression or type and simultaneously at the beginning of a block, + // the expression or type wins out. + // The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to + // be within the block. + if (block->location.begin == pos && !ancestry.empty()) + { + if (ancestry.back()->asExpr() && !ancestry.back()->is()) + return false; + + if (ancestry.back()->asType()) + return false; + } + + if (block->location.begin <= pos && pos <= block->location.end) + { + ancestry.push_back(block); + return true; + } + return false; + } +}; + struct FindNode : public AstVisitor { const Position pos; @@ -70,9 +182,11 @@ struct FindFullAncestry final : public AstVisitor { std::vector nodes; Position pos; + Position documentEnd; - explicit FindFullAncestry(Position pos) + explicit FindFullAncestry(Position pos, Position documentEnd) : pos(pos) + , documentEnd(documentEnd) { } @@ -83,17 +197,38 @@ struct FindFullAncestry final : public AstVisitor nodes.push_back(node); return true; } + + // Edge case: If we ask for the node at the position that is the very end of the document + // return the innermost AST element that ends at that position. + + if (node->location.end == documentEnd && pos >= documentEnd) + { + nodes.push_back(node); + return true; + } + return false; } }; } // namespace +std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos) +{ + AutocompleteNodeFinder finder{pos, source.root}; + source.root->visit(&finder); + return finder.ancestry; +} + std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos) { - FindFullAncestry finder(pos); + const Position end = source.root->location.end; + if (pos > end) + pos = end; + + FindFullAncestry finder(pos, end); source.root->visit(&finder); - return std::move(finder.nodes); + return finder.nodes; } AstNode* findNodeAtPosition(const SourceModule& source, Position pos) @@ -143,8 +278,8 @@ std::optional findTypeAtPosition(const Module& module, const SourceModul { if (auto expr = findExprAtPosition(sourceModule, pos)) { - if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) - return it->second; + if (auto it = module.astTypes.find(expr)) + return *it; } return std::nullopt; @@ -154,8 +289,8 @@ std::optional findExpectedTypeAtPosition(const Module& module, const Sou { if (auto expr = findExprAtPosition(sourceModule, pos)) { - if (auto it = module.astExpectedTypes.find(expr); it != module.astExpectedTypes.end()) - return it->second; + if (auto it = module.astExpectedTypes.find(expr)) + return *it; } return std::nullopt; @@ -192,7 +327,7 @@ std::optional findBindingAtPosition(const Module& module, const SourceM auto iter = currentScope->bindings.find(name); if (iter != currentScope->bindings.end() && iter->second.location.begin <= pos) { - /* Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope */ + // Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope std::optional bindingStatement = findBindingLocalStatement(source, iter->second); if (!bindingStatement || !(*bindingStatement)->location.contains(pos)) return iter->second; @@ -305,6 +440,36 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) return findVisitor.result; } +static std::optional checkOverloadedDocumentationSymbol( + const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional documentationSymbol) +{ + if (!documentationSymbol) + return std::nullopt; + + // This might be an overloaded function. + if (get(follow(ty))) + { + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) + { + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + { + matchingOverload = *it; + } + } + + if (matchingOverload) + { + std::string overloadSymbol = *documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; + } + } + + return documentationSymbol; +} + std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) { std::vector ancestry = findAstAncestryOfPosition(source, position); @@ -313,54 +478,24 @@ std::optional getDocumentationSymbolAtPosition(const Source AstExpr* parentExpr = ancestry.size() >= 2 ? ancestry[ancestry.size() - 2]->asExpr() : nullptr; if (std::optional binding = findBindingAtPosition(module, source, position)) - { - if (binding->documentationSymbol) - { - // This might be an overloaded function binding. - if (get(follow(binding->typeId))) - { - TypeId matchingOverload = nullptr; - if (parentExpr && parentExpr->is()) - { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr); it != module.astOverloadResolvedTypes.end()) - { - matchingOverload = it->second; - } - } - - if (matchingOverload) - { - std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; - // Default toString options are fine for this purpose. - overloadSymbol += toString(matchingOverload); - return overloadSymbol; - } - } - } - - return binding->documentationSymbol; - } + return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol); if (targetExpr) { if (AstExprIndexName* indexName = targetExpr->as()) { - if (auto it = module.astTypes.find(indexName->expr); it != module.astTypes.end()) + if (auto it = module.astTypes.find(indexName->expr)) { - TypeId parentTy = follow(it->second); + TypeId parentTy = follow(*it); if (const TableTypeVar* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) - { - return propIt->second.documentationSymbol; - } + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); } else if (const ClassTypeVar* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) - { - return propIt->second.documentationSymbol; - } + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); } } } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index dce92a0c..83a6f021 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,9 +12,7 @@ #include #include -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) -LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); -LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) +LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -22,102 +20,6 @@ static const std::unordered_set kStatementStartingKeywords = { namespace Luau { -struct NodeFinder : public AstVisitor -{ - const Position pos; - std::vector ancestry; - - explicit NodeFinder(Position pos, AstNode* root) - : pos(pos) - { - } - - bool visit(AstExpr* expr) override - { - if (expr->location.begin < pos && pos <= expr->location.end) - { - ancestry.push_back(expr); - return true; - } - return false; - } - - bool visit(AstStat* stat) override - { - if (stat->location.begin < pos && pos <= stat->location.end) - { - ancestry.push_back(stat); - return true; - } - return false; - } - - bool visit(AstType* type) override - { - if (type->location.begin < pos && pos <= type->location.end) - { - ancestry.push_back(type); - return true; - } - return false; - } - - bool visit(AstTypeError* type) override - { - // For a missing type, match the whole range including the start position - if (type->isMissing && type->location.containsClosed(pos)) - { - ancestry.push_back(type); - return true; - } - return false; - } - - bool visit(class AstTypePack* typePack) override - { - return true; - } - - bool visit(AstStatBlock* block) override - { - // If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite. - if (ancestry.empty()) - { - ancestry.push_back(block); - return true; - } - - // AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes. - // ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz} - if (!ancestry.empty() && ancestry.back()->is()) - return false; - - // Type annotation error might intersect the block statement when the function header is being written, - // annotation takes priority - if (!ancestry.empty() && ancestry.back()->is()) - return false; - - // If the cursor is at the end of an expression or type and simultaneously at the beginning of a block, - // the expression or type wins out. - // The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to - // be within the block. - if (block->location.begin == pos && !ancestry.empty()) - { - if (ancestry.back()->asExpr() && !ancestry.back()->is()) - return false; - - if (ancestry.back()->asType()) - return false; - } - - if (block->location.begin <= pos && pos <= block->location.end) - { - ancestry.push_back(block); - return true; - } - return false; - } -}; static bool alreadyHasParens(const std::vector& nodes) { @@ -150,8 +52,12 @@ static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionTyp auto idxExpr = nodes.back()->as(); bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; - auto args = Luau::flatten(func->argTypes); - bool noArgFunction = (args.first.empty() || (hasImplicitSelf && args.first.size() == 1)) && !args.second.has_value(); + auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes); + + if (argVariadicPack.has_value() && isVariadic(*argVariadicPack)) + return ParenthesesRecommendation::CursorInside; + + bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1); return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; } @@ -190,53 +96,94 @@ static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::ve return ParenthesesRecommendation::None; } -static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, TypeId ty) +static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) +{ + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + // Extra care for first function call argument location + // When we don't have anything inside () yet, we also don't have an AST node to base our lookup + if (AstExprCall* exprCall = expr->as()) + { + if (exprCall->args.size == 0 && exprCall->argLocation.contains(position)) + { + auto it = module.astTypes.find(exprCall->func); + + if (!it) + return std::nullopt; + + const FunctionTypeVar* ftv = get(follow(*it)); + + if (!ftv) + return std::nullopt; + + auto [head, tail] = flatten(ftv->argTypes); + unsigned index = exprCall->self ? 1 : 0; + + if (index < head.size()) + return head[index]; + + return std::nullopt; + } + } + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + return *it; +} + +static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, TypeArena* typeArena, NotNull singletonTypes) +{ + InternalErrorReporter iceReporter; + UnifierSharedState unifierState(&iceReporter); + Normalizer normalizer{typeArena, singletonTypes, NotNull{&unifierState}}; + Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); + + return unifier.canUnify(subTy, superTy).empty(); +} + +static TypeCorrectKind checkTypeCorrectKind( + const Module& module, TypeArena* typeArena, NotNull singletonTypes, AstNode* node, Position position, TypeId ty) { ty = follow(ty); - auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) { - InternalErrorReporter iceReporter; - Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, &iceReporter); + NotNull moduleScope{module.getModuleScope().get()}; - unifier.tryUnify(expectedType, actualType); + auto typeAtPosition = findExpectedTypeAt(module, node, position); - bool ok = unifier.errors.empty(); - unifier.log.rollback(); - return ok; + if (!typeAtPosition) + return TypeCorrectKind::None; + + TypeId expectedType = follow(*typeAtPosition); + + auto checkFunctionType = [typeArena, singletonTypes, moduleScope, &expectedType](const FunctionTypeVar* ftv) { + if (std::optional firstRetTy = first(ftv->retTypes)) + return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, singletonTypes); + + return false; }; - auto expr = node->asExpr(); - if (!expr) - return TypeCorrectKind::None; - - auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) - return TypeCorrectKind::None; - - TypeId expectedType = follow(it->second); - - if (canUnify(expectedType, ty)) - return TypeCorrectKind::Correct; - // We also want to suggest functions that return compatible result - const FunctionTypeVar* ftv = get(ty); - - if (!ftv) - return TypeCorrectKind::None; - - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty()) - return canUnify(expectedType, retHead.front()) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) + if (const FunctionTypeVar* ftv = get(ty); ftv && checkFunctionType(ftv)) { - if (const VariadicTypePack* vtp = get(follow(*retTail))) - return canUnify(expectedType, vtp->ty) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + return TypeCorrectKind::CorrectFunctionResult; + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + for (TypeId id : itv->parts) + { + if (const FunctionTypeVar* ftv = get(id); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + } } - return TypeCorrectKind::None; + return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena, singletonTypes) ? TypeCorrectKind::Correct + : TypeCorrectKind::None; } enum class PropIndexType @@ -246,45 +193,60 @@ enum class PropIndexType Key, }; -static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes, - AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) +static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull singletonTypes, TypeId rootTy, TypeId ty, + PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, + std::optional containingClass = std::nullopt) { + rootTy = follow(rootTy); ty = follow(ty); if (seen.count(ty)) return; seen.insert(ty); - auto isWrongIndexer = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { + auto isWrongIndexer = [typeArena, singletonTypes, &module, rootTy, indexType](Luau::TypeId type) { if (indexType == PropIndexType::Key) return false; - bool colonIndex = indexType == PropIndexType::Colon; + bool calledWithSelf = indexType == PropIndexType::Colon; + + auto isCompatibleCall = [typeArena, singletonTypes, &module, rootTy, calledWithSelf](const FunctionTypeVar* ftv) { + // Strong match with definition is a success + if (calledWithSelf == ftv->hasSelf) + return true; + + // Calls on classes require strict match between how function is declared and how it's called + if (get(rootTy)) + return false; + + // When called with ':', but declared without 'self', it is invalid if a function has incompatible first argument or no arguments at all + // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible + if (std::optional firstArgTy = first(ftv->argTypes)) + { + if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, singletonTypes)) + return calledWithSelf; + } + + return !calledWithSelf; + }; if (const FunctionTypeVar* ftv = get(type)) + return !isCompatibleCall(ftv); + + // For intersections, any part that is successful makes the whole call successful + if (const IntersectionTypeVar* itv = get(type)) { - return useStrictFunctionIndexers ? colonIndex != ftv->hasSelf : false; - } - else if (const IntersectionTypeVar* itv = get(type)) - { - bool allHaveSelf = true; for (auto subType : itv->parts) { if (const FunctionTypeVar* ftv = get(Luau::follow(subType))) { - allHaveSelf &= ftv->hasSelf; - } - else - { - return colonIndex; + if (isCompatibleCall(ftv)) + return false; } } - return useStrictFunctionIndexers ? colonIndex != allHaveSelf : false; - } - else - { - return colonIndex; } + + return calledWithSelf; }; auto fillProps = [&](const ClassTypeVar::Props& props) { @@ -292,11 +254,12 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId { // We are walking up the class hierarchy, so if we encounter a property that we have // already populated, it takes precedence over the property we found just now. - if (result.count(name) == 0 && name != Parser::errorName) + if (result.count(name) == 0 && name != kParseNameError) { Luau::TypeId type = Luau::follow(prop.type); - TypeCorrectKind typeCorrect = - indexType == PropIndexType::Key ? TypeCorrectKind::Correct : checkTypeCorrectKind(module, typeArena, nodes.back(), type); + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key + ? TypeCorrectKind::Correct + : checkTypeCorrectKind(module, typeArena, singletonTypes, nodes.back(), {{}, {}}, type); ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); @@ -316,35 +279,39 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } }; + auto fillMetatableProps = [&](const TableTypeVar* mtable) { + auto indexIt = mtable->props.find("__index"); + if (indexIt != mtable->props.end()) + { + TypeId followed = follow(indexIt->second.type); + if (get(followed) || get(followed)) + { + autocompleteProps(module, typeArena, singletonTypes, rootTy, followed, indexType, nodes, result, seen); + } + else if (auto indexFunction = get(followed)) + { + std::optional indexFunctionResult = first(indexFunction->retTypes); + if (indexFunctionResult) + autocompleteProps(module, typeArena, singletonTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); + } + } + }; + if (auto cls = get(ty)) { containingClass = containingClass.value_or(cls); fillProps(cls->props); if (cls->parent) - autocompleteProps(module, typeArena, *cls->parent, indexType, nodes, result, seen, cls); + autocompleteProps(module, typeArena, singletonTypes, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); } else if (auto tbl = get(ty)) fillProps(tbl->props); else if (auto mt = get(ty)) { - autocompleteProps(module, typeArena, mt->table, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, mt->table, indexType, nodes, result, seen); - auto mtable = get(mt->metatable); - if (!mtable) - return; - - auto indexIt = mtable->props.find("__index"); - if (indexIt != mtable->props.end()) - { - if (get(indexIt->second.type) || get(indexIt->second.type)) - autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); - else if (auto indexFunction = get(indexIt->second.type)) - { - std::optional indexFunctionResult = first(indexFunction->retType); - if (indexFunctionResult) - autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); - } - } + if (auto mtable = get(mt->metatable)) + fillMetatableProps(mtable); } else if (auto i = get(ty)) { @@ -354,7 +321,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen = seen; - autocompleteProps(module, typeArena, ty, indexType, nodes, inner, innerSeen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, ty, indexType, nodes, inner, innerSeen); for (auto& pair : inner) result.insert(pair); @@ -368,52 +335,31 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId while (iter != endIter) { - if (FFlag::LuauAddMissingFollow) - { - if (isNil(*iter)) - ++iter; - else - break; - } + if (isNil(*iter)) + ++iter; else - { - if (auto primTy = Luau::get(*iter); primTy && primTy->type == PrimitiveTypeVar::NilType) - ++iter; - else - break; - } + break; } if (iter == endIter) return; - autocompleteProps(module, typeArena, *iter, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, *iter, indexType, nodes, result, seen); ++iter; while (iter != endIter) { AutocompleteEntryMap inner; - std::unordered_set innerSeen = seen; + std::unordered_set innerSeen; - if (FFlag::LuauAddMissingFollow) + if (isNil(*iter)) { - if (isNil(*iter)) - { - ++iter; - continue; - } - } - else - { - if (auto innerPrimTy = Luau::get(*iter); innerPrimTy && innerPrimTy->type == PrimitiveTypeVar::NilType) - { - ++iter; - continue; - } + ++iter; + continue; } - autocompleteProps(module, typeArena, *iter, indexType, nodes, inner, innerSeen); + autocompleteProps(module, typeArena, singletonTypes, rootTy, *iter, indexType, nodes, inner, innerSeen); std::unordered_set toRemove; @@ -430,6 +376,18 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ++iter; } } + else if (auto pt = get(ty)) + { + if (pt->metatable) + { + if (auto mtable = get(*pt->metatable)) + fillMetatableProps(mtable); + } + } + else if (get(get(ty))) + { + autocompleteProps(module, typeArena, singletonTypes, rootTy, singletonTypes->stringType, indexType, nodes, result, seen); + } } static void autocompleteKeywords( @@ -453,18 +411,18 @@ static void autocompleteKeywords( } } -static void autocompleteProps( - const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result) +static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNull singletonTypes, TypeId ty, PropIndexType indexType, + const std::vector& nodes, AutocompleteEntryMap& result) { std::unordered_set seen; - autocompleteProps(module, typeArena, ty, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, singletonTypes, ty, ty, indexType, nodes, result, seen); } -AutocompleteEntryMap autocompleteProps( - const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes) +AutocompleteEntryMap autocompleteProps(const Module& module, TypeArena* typeArena, NotNull singletonTypes, TypeId ty, + PropIndexType indexType, const std::vector& nodes) { AutocompleteEntryMap result; - autocompleteProps(module, typeArena, ty, indexType, nodes, result); + autocompleteProps(module, typeArena, singletonTypes, ty, indexType, nodes, result); return result; } @@ -486,6 +444,31 @@ AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position posi return result; } +static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteEntryMap& result) +{ + auto formatKey = [addQuotes](const std::string& key) { + if (addQuotes) + return "\"" + escape(key) + "\""; + + return escape(key); + }; + + ty = follow(ty); + + if (auto ss = get(get(ty))) + { + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + else if (auto uty = get(ty)) + { + for (auto el : uty) + { + if (auto ss = get(get(el))) + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + } +}; + static bool canSuggestInferredType(ScopePtr scope, TypeId ty) { ty = follow(ty); @@ -495,7 +478,7 @@ static bool canSuggestInferredType(ScopePtr scope, TypeId ty) return false; // No syntax for unnamed tables with a metatable - if (const MetatableTypeVar* mtv = get(ty)) + if (get(ty)) return false; if (const TableTypeVar* ttv = get(ty)) @@ -569,7 +552,7 @@ static std::optional findTypeElementAt(AstType* astType, TypeId ty, Posi if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) return element; - if (auto element = findTypeElementAt(type->returnTypes, ftv->retType, position)) + if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position)) return element; } @@ -675,19 +658,16 @@ std::optional returnFirstNonnullOptionOfType(const UnionTypeVar* utv) return ret; } -static std::optional functionIsExpectedAt(const Module& module, AstNode* node) +static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) { - auto expr = node->asExpr(); - if (!expr) + auto typeAtPosition = findExpectedTypeAt(module, node, position); + + if (!typeAtPosition) return std::nullopt; - auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) - return std::nullopt; + TypeId expectedType = follow(*typeAtPosition); - TypeId expectedType = follow(it->second); - - if (const FunctionTypeVar* ftv = get(expectedType)) + if (get(expectedType)) return true; if (const IntersectionTypeVar* itv = get(expectedType)) @@ -736,7 +716,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi } AstNode* parent = nullptr; - AstType* topType = nullptr; + AstType* topType = nullptr; // TODO: rename? for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it) { @@ -784,11 +764,11 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (AstExprCall* exprCall = expr->as()) { - if (auto it = module.astTypes.find(exprCall->func); it != module.astTypes.end()) + if (auto it = module.astTypes.find(exprCall->func)) { - if (const FunctionTypeVar* ftv = get(follow(it->second))) + if (const FunctionTypeVar* ftv = get(follow(*it))) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) inferredType = *ty; } } @@ -798,8 +778,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (tailPos != 0) break; - if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) - inferredType = it->second; + if (auto it = module.astTypes.find(expr)) + inferredType = *it; } if (inferredType) @@ -815,10 +795,10 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionTypeVar* { auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return nullptr; - TypeId ty = follow(it->second); + TypeId ty = follow(*it); if (const FunctionTypeVar* ftv = get(ty)) return ftv; @@ -869,15 +849,18 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi } } - for (size_t i = 0; i < node->returnAnnotation.types.size; i++) + if (!node->returnAnnotation) + return result; + + for (size_t i = 0; i < node->returnAnnotation->types.size; i++) { - AstType* ret = node->returnAnnotation.types.data[i]; + AstType* ret = node->returnAnnotation->types.data[i]; if (ret->location.containsClosed(position)) { if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, i)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); } @@ -886,7 +869,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi } } - if (AstTypePack* retTp = node->returnAnnotation.tailType) + if (AstTypePack* retTp = node->returnAnnotation->tailType) { if (auto variadic = retTp->as()) { @@ -894,7 +877,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, ~0u)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); } } @@ -985,13 +968,28 @@ T* extractStat(const std::vector& ancestry) if (!parent) return nullptr; - if (T* t = parent->as(); t && parent->is()) - return t; - AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; - if (!grandParent || !greatGrandParent) - return nullptr; + + if (FFlag::LuauCompleteTableKeysBetter) + { + if (!grandParent) + return nullptr; + + if (T* t = parent->as(); t && grandParent->is()) + return t; + + if (!greatGrandParent) + return nullptr; + } + else + { + if (T* t = parent->as(); t && parent->is()) + return t; + + if (!grandParent || !greatGrandParent) + return nullptr; + } if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) return t; @@ -999,10 +997,13 @@ T* extractStat(const std::vector& ancestry) return nullptr; } -static bool isBindingLegalAtCurrentPosition(const Binding& binding, Position pos) +static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& binding, Position pos) { - // Default Location used for global bindings, which are always legal. - return binding.location == Location() || binding.location.end < pos; + if (symbol.local) + return binding.location.end < pos; + + // Builtin globals have an empty location; for defined globals, we want pos to be outside of the definition range to suggest it + return binding.location == Location() || !binding.location.containsClosed(pos); } static AutocompleteEntryMap autocompleteStatement( @@ -1023,7 +1024,7 @@ static AutocompleteEntryMap autocompleteStatement( { for (const auto& [name, binding] : scope->bindings) { - if (!isBindingLegalAtCurrentPosition(binding, position)) + if (!isBindingLegalAtCurrentPosition(name, binding, position)) continue; std::string n = toString(name); @@ -1057,7 +1058,7 @@ static AutocompleteEntryMap autocompleteStatement( AstNode* parent = ancestry.rbegin()[1]; if (AstStatIf* statIf = parent->as()) { - if (!statIf->elsebody || (statIf->hasElse && statIf->elseLocation.containsClosed(position))) + if (!statIf->elsebody || (statIf->elseLocation && statIf->elseLocation->containsClosed(position))) { result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); @@ -1085,8 +1086,7 @@ static AutocompleteEntryMap autocompleteStatement( return result; } -// Returns true if completions were generated (completions will be inserted into 'outResult') -// Returns false if no completions were generated +// Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) static bool autocompleteIfElseExpression( const AstNode* node, const std::vector& ancestry, const Position& position, AutocompleteEntryMap& outResult) { @@ -1094,6 +1094,13 @@ static bool autocompleteIfElseExpression( if (!parent) return false; + if (node->is()) + { + // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else + // expression. + return true; + } + AstExprIfElse* ifElseExpr = parent->as(); if (!ifElseExpr || ifElseExpr->condition->location.containsClosed(position)) { @@ -1120,8 +1127,8 @@ static bool autocompleteIfElseExpression( } } -static void autocompleteExpression(const SourceModule& sourceModule, const Module& module, const TypeChecker& typeChecker, TypeArena* typeArena, - const std::vector& ancestry, Position position, AutocompleteEntryMap& result) +static AutocompleteContext autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull singletonTypes, + TypeArena* typeArena, const std::vector& ancestry, Position position, AutocompleteEntryMap& result) { LUAU_ASSERT(!ancestry.empty()); @@ -1129,14 +1136,13 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul if (node->is()) { - auto it = module.astTypes.find(node->asExpr()); - if (it != module.astTypes.end()) - autocompleteProps(module, typeArena, it->second, PropIndexType::Point, ancestry, result); + if (auto it = module.astTypes.find(node->asExpr())) + autocompleteProps(module, typeArena, singletonTypes, *it, PropIndexType::Point, ancestry, result); } - else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result)) - return; + else if (autocompleteIfElseExpression(node, ancestry, position, result)) + return AutocompleteContext::Keyword; else if (node->is()) - return; + return AutocompleteContext::Unknown; else { // This is inefficient. :( @@ -1146,7 +1152,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul { for (const auto& [name, binding] : scope->bindings) { - if (!isBindingLegalAtCurrentPosition(binding, position)) + if (!isBindingLegalAtCurrentPosition(name, binding, position)) continue; if (isBeingDefined(ancestry, name)) @@ -1155,7 +1161,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul std::string n = toString(name); if (!result.count(n)) { - TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, node, binding.typeId); + TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, binding.typeId); result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)}; @@ -1165,26 +1171,32 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul scope = scope->parent; } - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, typeChecker.nilType); - TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, typeChecker.booleanType); - TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, singletonTypes->nilType); + TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, singletonTypes->trueType); + TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, singletonTypes, node, position, singletonTypes->falseType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - if (FFlag::LuauIfElseExpressionAnalysisSupport) - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; - result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; - result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["true"] = {AutocompleteEntryKind::Keyword, singletonTypes->booleanType, false, false, correctForTrue}; + result["false"] = {AutocompleteEntryKind::Keyword, singletonTypes->booleanType, false, false, correctForFalse}; + result["nil"] = {AutocompleteEntryKind::Keyword, singletonTypes->nilType, false, false, correctForNil}; result["not"] = {AutocompleteEntryKind::Keyword}; result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + + if (auto ty = findExpectedTypeAt(module, node, position)) + autocompleteStringSingleton(*ty, true, result); } + + return AutocompleteContext::Expression; } -static AutocompleteEntryMap autocompleteExpression(const SourceModule& sourceModule, const Module& module, const TypeChecker& typeChecker, +static AutocompleteResult autocompleteExpression(const SourceModule& sourceModule, const Module& module, NotNull singletonTypes, TypeArena* typeArena, const std::vector& ancestry, Position position) { AutocompleteEntryMap result; - autocompleteExpression(sourceModule, module, typeChecker, typeArena, ancestry, position, result); - return result; + AutocompleteContext context = autocompleteExpression(sourceModule, module, singletonTypes, typeArena, ancestry, position, result); + return {result, ancestry, context}; } static std::optional getMethodContainingClass(const ModulePtr& module, AstExpr* funcExpr) @@ -1203,13 +1215,13 @@ static std::optional getMethodContainingClass(const ModuleP return std::nullopt; } - auto parentIter = module->astTypes.find(parentExpr); - if (parentIter == module->astTypes.end()) + auto parentIt = module->astTypes.find(parentExpr); + if (!parentIt) { return std::nullopt; } - Luau::TypeId parentType = Luau::follow(parentIter->second); + Luau::TypeId parentType = Luau::follow(*parentIt); if (auto parentClass = Luau::get(parentType)) { @@ -1224,6 +1236,31 @@ static std::optional getMethodContainingClass(const ModuleP return std::nullopt; } +static bool stringPartOfInterpString(const AstNode* node, Position position) +{ + const AstExprInterpString* interpString = node->as(); + if (!interpString) + { + return false; + } + + for (const AstExpr* expression : interpString->expressions) + { + if (expression->location.containsClosed(position)) + { + return false; + } + } + + return true; +} + +static bool isSimpleInterpolatedString(const AstNode* node) +{ + const AstExprInterpString* interpString = node->as(); + return interpString != nullptr && interpString->expressions.size == 0; +} + static std::optional autocompleteStringParams(const SourceModule& sourceModule, const ModulePtr& module, const std::vector& nodes, Position position, StringCompletionCallback callback) { @@ -1232,7 +1269,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - if (!nodes.back()->is()) + if (!nodes.back()->is() && !isSimpleInterpolatedString(nodes.back()) && !nodes.back()->is()) { return std::nullopt; } @@ -1250,8 +1287,8 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - auto iter = module->astTypes.find(candidate->func); - if (iter == module->astTypes.end()) + auto it = module->astTypes.find(candidate->func); + if (!it) { return std::nullopt; } @@ -1267,7 +1304,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; }; - auto followedId = Luau::follow(iter->second); + auto followedId = Luau::follow(*it); if (auto functionType = Luau::get(followedId)) { return performCallback(functionType); @@ -1290,80 +1327,77 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } -static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, const TypeChecker& typeChecker, - TypeArena* typeArena, Position position, StringCompletionCallback callback) +static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull singletonTypes, + Scope* globalScope, Position position, StringCompletionCallback callback) { if (isWithinComment(sourceModule, position)) return {}; - NodeFinder finder{position, sourceModule.root}; - sourceModule.root->visit(&finder); - LUAU_ASSERT(!finder.ancestry.empty()); - AstNode* node = finder.ancestry.back(); + TypeArena typeArena; + + std::vector ancestry = findAncestryAtPositionForAutocomplete(sourceModule, position); + LUAU_ASSERT(!ancestry.empty()); + AstNode* node = ancestry.back(); AstExprConstantNil dummy{Location{}}; - AstNode* parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy; + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; // If we are inside a body of a function that doesn't have a completed argument list, ignore the body node if (auto exprFunction = parent->as(); exprFunction && !exprFunction->argLocation && node == exprFunction->body) { - finder.ancestry.pop_back(); + ancestry.pop_back(); - node = finder.ancestry.back(); - parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy; + node = ancestry.back(); + parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; } if (auto indexName = node->as()) { auto it = module->astTypes.find(indexName->expr); - if (it == module->astTypes.end()) + if (!it) return {}; - TypeId ty = follow(it->second); + TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - if (isString(ty)) - return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry), - finder.ancestry}; - else - return {autocompleteProps(*module, typeArena, ty, indexType, finder.ancestry), finder.ancestry}; + return {autocompleteProps(*module, &typeArena, singletonTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; } else if (auto typeReference = node->as()) { - if (typeReference->hasPrefix) - return {autocompleteModuleTypes(*module, position, typeReference->prefix.value), finder.ancestry}; + if (typeReference->prefix) + return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), ancestry, AutocompleteContext::Type}; else - return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; + return {autocompleteTypeNames(*module, position, ancestry), ancestry, AutocompleteContext::Type}; } else if (node->is()) { - return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; + return {autocompleteTypeNames(*module, position, ancestry), ancestry, AutocompleteContext::Type}; } else if (AstStatLocal* statLocal = node->as()) { - if (statLocal->vars.size == 1 && (!statLocal->hasEqualsSign || position < statLocal->equalsSignLocation.begin)) - return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; - else if (statLocal->hasEqualsSign && position >= statLocal->equalsSignLocation.end) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) + return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Unknown}; + else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); else return {}; } - else if (AstStatFor* statFor = extractStat(finder.ancestry)) + else if (AstStatFor* statFor = extractStat(ancestry)) { if (!statFor->hasDo || position < statFor->doLocation.begin) { if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || (statFor->step && statFor->step->location.containsClosed(position))) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); return {}; } - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; } else if (AstStatForIn* statForIn = parent->as(); statForIn && (node->is() || isIdentifier(node))) @@ -1371,7 +1405,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (!statForIn->hasIn || position <= statForIn->inLocation.begin) { AstLocal* lastName = statForIn->vars.data[statForIn->vars.size - 1]; - if (lastName->name == Parser::errorName || lastName->location.containsClosed(position)) + if (lastName->name == kParseNameError || lastName->location.containsClosed(position)) { // Here we are either working with a missing binding (as would be the case in a bare "for" keyword) or // the cursor is still touching a binding name. The user is still typing a new name, so we should not offer @@ -1379,7 +1413,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {}; } - return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; } if (!statForIn->hasDo || position <= statForIn->doLocation.begin) @@ -1388,68 +1422,89 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; if (lastExpr->location.containsClosed(position)) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); if (position > lastExpr->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; return {}; // Not sure what this means } } - else if (AstStatForIn* statForIn = extractStat(finder.ancestry)) + else if (AstStatForIn* statForIn = extractStat(ancestry)) { // The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed. // ex "for f in f do" if (!statForIn->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; } else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) { if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; if (!statWhile->hasDo || position < statWhile->doLocation.begin) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); if (statWhile->hasDo && position > statWhile->doLocation.end) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; } - else if (AstStatWhile* statWhile = extractStat(finder.ancestry); statWhile && !statWhile->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + else if (AstStatWhile* statWhile = extractStat(ancestry); statWhile && !statWhile->hasDo) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - else if (AstStatIf* statIf = node->as(); FFlag::ElseElseIfCompletionImprovements && statIf && !statIf->hasElse) + else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, - finder.ancestry}; + ancestry, AutocompleteContext::Keyword}; } else if (AstStatIf* statIf = parent->as(); statIf && node->is()) { if (statIf->condition->is()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; - else if (!statIf->hasThen || statIf->thenLocation.containsClosed(position)) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); + else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; } - else if (AstStatIf* statIf = extractStat(finder.ancestry); - statIf && (!statIf->hasThen || statIf->thenLocation.containsClosed(position))) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + else if (AstStatIf* statIf = extractStat(ancestry); + statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position))) + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; - else if (AstStatRepeat* statRepeat = extractStat(finder.ancestry); statRepeat) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; - else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is())) + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); + else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; + else if (AstExprTable* exprTable = parent->as(); + exprTable && (node->is() || node->is() || node->is())) { for (const auto& [kind, key, value] : exprTable->items) { // If item doesn't have a key, maybe the value is actually the key if (key ? key == node : node->is() && value == node) { - if (auto it = module->astExpectedTypes.find(exprTable); it != module->astExpectedTypes.end()) + if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, typeArena, it->second, PropIndexType::Key, finder.ancestry); + auto result = autocompleteProps(*module, &typeArena, singletonTypes, *it, PropIndexType::Key, ancestry); + + if (FFlag::LuauCompleteTableKeysBetter) + { + if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*nodeIt, !node->is(), result); + + if (!key) + { + // If there is "no key," it may be that the user + // intends for the current token to be the key, but + // has yet to type the `=` sign. + // + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) + { + autocompleteStringSingleton(ttv->indexer->indexType, false, result); + } + } + } // Remove keys that are already completed for (const auto& item : exprTable->items) @@ -1461,11 +1516,11 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M result.erase(std::string(stringKey->value.data, stringKey->value.size)); } - // If we know for sure that a key is being written, do not offer general epxression suggestions + // If we know for sure that a key is being written, do not offer general expression suggestions if (!key) - autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position, result); + autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position, result); - return {result, finder.ancestry}; + return {result, ancestry, AutocompleteContext::Property}; } break; @@ -1473,36 +1528,53 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } } else if (isIdentifier(node) && (parent->is() || parent->is())) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - if (std::optional ret = autocompleteStringParams(sourceModule, module, finder.ancestry, position, callback)) + if (std::optional ret = autocompleteStringParams(sourceModule, module, ancestry, position, callback)) { - return {*ret, finder.ancestry}; + return {*ret, ancestry, AutocompleteContext::String}; } - else if (node->is()) + else if (node->is() || isSimpleInterpolatedString(node)) { - if (finder.ancestry.size() >= 2) + AutocompleteEntryMap result; + + if (auto it = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*it, false, result); + + if (ancestry.size() >= 2) { - if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) + if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) { - if (auto it = module->astTypes.find(idxExpr->expr); it != module->astTypes.end()) + if (auto it = module->astTypes.find(idxExpr->expr)) + autocompleteProps(*module, &typeArena, singletonTypes, follow(*it), PropIndexType::Point, ancestry, result); + } + else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) + { + if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { - return {autocompleteProps(*module, typeArena, follow(it->second), PropIndexType::Point, finder.ancestry), finder.ancestry}; + if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) + autocompleteStringSingleton(*it, false, result); } } } - return {}; + + return {result, ancestry, AutocompleteContext::String}; + } + else if (stringPartOfInterpString(node, position)) + { + // We're not a simple interpolated string, we're something like `a{"b"}@1`, and we + // can't know what to format to + AutocompleteEntryMap map; + return {map, ancestry, AutocompleteContext::String}; } if (node->is()) - { return {}; - } if (node->asExpr()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); else if (node->asStat()) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; return {}; } @@ -1511,54 +1583,23 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName { // FIXME: We can improve performance here by parsing without checking. // The old type graph is probably fine. (famous last words!) - // FIXME: We don't need to typecheck for script analysis here, just for autocomplete. - frontend.check(moduleName); + FrontendOptions opts; + opts.forAutocomplete = true; + frontend.check(moduleName, opts); const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) return {}; - TypeChecker& typeChecker = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - ModulePtr module = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName); if (!module) return {}; - AutocompleteResult autocompleteResult = autocomplete(*sourceModule, module, typeChecker, &frontend.arenaForAutocomplete, position, callback); + NotNull singletonTypes = frontend.singletonTypes; + Scope* globalScope = frontend.typeCheckerForAutocomplete.globalScope.get(); - frontend.arenaForAutocomplete.clear(); - - return autocompleteResult; -} - -OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback) -{ - auto sourceModule = std::make_unique(); - ParseOptions parseOptions; - parseOptions.captureComments = true; - ParseResult result = Parser::parse(source.data(), source.size(), *sourceModule->names, *sourceModule->allocator, parseOptions); - - if (!result.root) - return {AutocompleteResult{}, {}, nullptr}; - - sourceModule->name = "FRAGMENT_SCRIPT"; - sourceModule->root = result.root; - sourceModule->mode = Mode::Strict; - sourceModule->commentLocations = std::move(result.commentLocations); - - TypeChecker& typeChecker = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - - ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); - - OwningAutocompleteResult autocompleteResult = { - autocomplete(*sourceModule, module, typeChecker, &frontend.arenaForAutocomplete, position, callback), std::move(module), - std::move(sourceModule)}; - - frontend.arenaForAutocomplete.clear(); + AutocompleteResult autocompleteResult = autocomplete(*sourceModule, module, singletonTypes, globalScope, position, callback); return autocompleteResult; } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 68ad5ac9..612812c5 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -1,17 +1,26 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +#include "Luau/Ast.h" #include "Luau/Frontend.h" #include "Luau/Symbol.h" #include "Luau/Common.h" #include "Luau/ToString.h" +#include "Luau/ConstraintSolver.h" +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypePack.h" +#include "Luau/TypeVar.h" +#include "Luau/TypeUtils.h" #include -LUAU_FASTFLAG(LuauParseGenericFunctions) -LUAU_FASTFLAG(LuauGenericFunctions) -LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauStringMetatable) +LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauBuiltInMetatableNoBadSynthetic, false) +LUAU_FASTFLAG(LuauOptionalNextKey) +LUAU_FASTFLAG(LuauReportShadowedTypeAlias) +LUAU_FASTFLAG(LuauNewLibraryTypeNames) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -22,16 +31,23 @@ LUAU_FASTFLAG(LuauStringMetatable) namespace Luau { -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + + +static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); +static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); +static bool dcrMagicFunctionPack(MagicFunctionCallContext context); + +static std::vector dcrMagicRefinementAssert(const MagicRefinementContext& context); TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -43,6 +59,11 @@ TypeId makeIntersection(TypeArena& arena, std::vector&& types) return arena.addType(IntersectionTypeVar{std::move(types)}); } +TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t) +{ + return makeUnion(arena, {frontend.typeChecker.nilType, t}); +} + TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t) { return makeUnion(arena, {typeChecker.nilType, t}); @@ -106,16 +127,20 @@ void attachMagicFunction(TypeId ty, MagicFunction fn) LUAU_ASSERT(!"Got a non functional type"); } -void attachFunctionTag(TypeId ty, std::string tag) +void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn) { if (auto ftv = getMutable(ty)) - { - ftv->tags.emplace_back(std::move(tag)); - } + ftv->dcrMagicFunction = fn; + else + LUAU_ASSERT(!"Got a non functional type"); +} + +void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn) +{ + if (auto ftv = getMutable(ty)) + ftv->dcrMagicRefinement = fn; else - { LUAU_ASSERT(!"Got a non functional type"); - } } Property makeProperty(TypeId ty, std::optional documentationSymbol) @@ -130,34 +155,50 @@ Property makeProperty(TypeId ty, std::optional documentationSymbol) }; } +void addGlobalBinding(Frontend& frontend, const std::string& name, TypeId ty, const std::string& packageName) +{ + addGlobalBinding(frontend, frontend.getGlobalScope(), name, ty, packageName); +} + +void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); + void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName) { addGlobalBinding(typeChecker, typeChecker.globalScope, name, ty, packageName); } +void addGlobalBinding(Frontend& frontend, const std::string& name, Binding binding) +{ + addGlobalBinding(frontend, frontend.getGlobalScope(), name, binding); +} + void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding) { addGlobalBinding(typeChecker, typeChecker.globalScope, name, binding); } +void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) +{ + std::string documentationSymbol = packageName + "/global/" + name; + addGlobalBinding(frontend, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); +} + void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) { std::string documentationSymbol = packageName + "/global/" + name; addGlobalBinding(typeChecker, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); } +void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, Binding binding) +{ + addGlobalBinding(frontend.typeChecker, scope, name, binding); +} + void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding) { scope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = binding; } -TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name) -{ - auto t = tryGetGlobalBinding(typeChecker, name); - LUAU_ASSERT(t.has_value()); - return t->typeId; -} - std::optional tryGetGlobalBinding(TypeChecker& typeChecker, const std::string& name) { AstName astName = typeChecker.globalNames.names->getOrAdd(name.c_str()); @@ -168,6 +209,23 @@ std::optional tryGetGlobalBinding(TypeChecker& typeChecker, const std:: return std::nullopt; } +TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name) +{ + auto t = tryGetGlobalBinding(typeChecker, name); + LUAU_ASSERT(t.has_value()); + return t->typeId; +} + +TypeId getGlobalBinding(Frontend& frontend, const std::string& name) +{ + return getGlobalBinding(frontend.typeChecker, name); +} + +std::optional tryGetGlobalBinding(Frontend& frontend, const std::string& name) +{ + return tryGetGlobalBinding(frontend.typeChecker, name); +} + Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name) { AstName astName = typeChecker.globalNames.names->get(name.c_str()); @@ -189,380 +247,81 @@ void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::strin } } -void registerBuiltinTypes(TypeChecker& typeChecker) +void registerBuiltinTypes(Frontend& frontend) +{ + frontend.getGlobalScope()->addBuiltinTypeBinding("any", TypeFun{{}, frontend.singletonTypes->anyType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("nil", TypeFun{{}, frontend.singletonTypes->nilType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("number", TypeFun{{}, frontend.singletonTypes->numberType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("string", TypeFun{{}, frontend.singletonTypes->stringType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("boolean", TypeFun{{}, frontend.singletonTypes->booleanType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("thread", TypeFun{{}, frontend.singletonTypes->threadType}); + if (FFlag::LuauUnknownAndNeverType) + { + frontend.getGlobalScope()->addBuiltinTypeBinding("unknown", TypeFun{{}, frontend.singletonTypes->unknownType}); + frontend.getGlobalScope()->addBuiltinTypeBinding("never", TypeFun{{}, frontend.singletonTypes->neverType}); + } +} + +void registerBuiltinGlobals(TypeChecker& typeChecker) { LUAU_ASSERT(!typeChecker.globalTypes.typeVars.isFrozen()); LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen()); - TypeId numberType = typeChecker.numberType; - TypeId booleanType = typeChecker.booleanType; TypeId nilType = typeChecker.nilType; - TypeId stringType = typeChecker.stringType; - TypeId threadType = typeChecker.threadType; - TypeId anyType = typeChecker.anyType; TypeArena& arena = typeChecker.globalTypes; - - TypeId optionalNumber = makeOption(typeChecker, arena, numberType); - TypeId optionalString = makeOption(typeChecker, arena, stringType); - TypeId optionalBoolean = makeOption(typeChecker, arena, booleanType); - - TypeId stringOrNumber = makeUnion(arena, {stringType, numberType}); - - TypePackId emptyPack = arena.addTypePack({}); - TypePackId oneNumberPack = arena.addTypePack({numberType}); - TypePackId oneStringPack = arena.addTypePack({stringType}); - TypePackId oneBooleanPack = arena.addTypePack({booleanType}); - TypePackId oneAnyPack = arena.addTypePack({anyType}); - - TypePackId anyTypePack = typeChecker.anyTypePack; - - TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); - TypePackId stringVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{stringType}}); - TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); - - TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{ - listOfAtLeastOneNumber, - oneNumberPack, - }); - - TypeId listOfAtLeastZeroNumbersToNumberType = arena.addType(FunctionTypeVar{numberVariadicList, oneNumberPack}); - - TypeId stringToAnyMap = arena.addType(TableTypeVar{{}, TableIndexer(stringType, anyType), typeChecker.globalScope->level}); + NotNull singletonTypes = typeChecker.singletonTypes; LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); LUAU_ASSERT(loadResult.success); - TypeId mathLibType = getGlobalBinding(typeChecker, "math"); - if (TableTypeVar* ttv = getMutable(mathLibType)) - { - ttv->props["min"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.min"); - ttv->props["max"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.max"); - } - - TypeId bit32LibType = getGlobalBinding(typeChecker, "bit32"); - if (TableTypeVar* ttv = getMutable(bit32LibType)) - { - ttv->props["band"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.band"); - ttv->props["bor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bor"); - ttv->props["bxor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bxor"); - ttv->props["btest"] = makeProperty(arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneBooleanPack}), "@luau/global/bit32.btest"); - } - - TypeId anyFunction = arena.addType(FunctionTypeVar{anyTypePack, anyTypePack}); - TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); - TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); + TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); - if (FFlag::LuauStringMetatable) + std::optional stringMetatableTy = getMetatable(singletonTypes->stringType, singletonTypes); + LUAU_ASSERT(stringMetatableTy); + const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); + LUAU_ASSERT(stringMetatableTable); + + auto it = stringMetatableTable->props.find("__index"); + LUAU_ASSERT(it != stringMetatableTable->props.end()); + + addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); + + if (FFlag::LuauOptionalNextKey) { - std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); - LUAU_ASSERT(stringMetatableTy); - const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); - LUAU_ASSERT(stringMetatableTable); + // next(t: Table, i: K?) -> (K?, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(typeChecker, arena, genericK), genericV}}); + addGlobalBinding(typeChecker, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - auto it = stringMetatableTable->props.find("__index"); - LUAU_ASSERT(it != stringMetatableTable->props.end()); + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId stringLib = it->second.type; - addGlobalBinding(typeChecker, "string", stringLib, "@luau"); - } + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, nextRetsTypePack}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) - { - if (!FFlag::LuauStringMetatable) - { - TypeId stringLibTy = getGlobalBinding(typeChecker, "string"); - TableTypeVar* stringLib = getMutable(stringLibTy); - TypeId replArgType = makeUnion( - arena, {stringType, - arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), - makeFunction(arena, std::nullopt, {stringType}, {stringType})}); - TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); - - stringLib->props["gsub"] = makeProperty(gsubFunc, "@luau/global/string.gsub"); - } + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding( + typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); } else { - if (!FFlag::LuauStringMetatable) - { - TypeId stringToStringType = makeFunction(arena, std::nullopt, {stringType}, {stringType}); + // next(t: Table, i: K?) -> (K, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); + addGlobalBinding(typeChecker, "next", + arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); - TypeId gmatchFunc = makeFunction(arena, stringType, {stringType}, {arena.addType(FunctionTypeVar{emptyPack, stringVariadicList})}); + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId replArgType = makeUnion( - arena, {stringType, - arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), - makeFunction(arena, std::nullopt, {stringType}, {stringType})}); - TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - TypeId formatFn = arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}); - - TableTypeVar::Props stringLib = { - // FIXME string.byte "can" return a pack of numbers, but only if 2nd or 3rd arguments were supplied - {"byte", {makeFunction(arena, stringType, {optionalNumber, optionalNumber}, {optionalNumber})}}, - // FIXME char takes a variadic pack of numbers - {"char", {makeFunction(arena, std::nullopt, {numberType, optionalNumber, optionalNumber, optionalNumber}, {stringType})}}, - {"find", {makeFunction(arena, stringType, {stringType, optionalNumber, optionalBoolean}, {optionalNumber, optionalNumber})}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(arena, stringType, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {makeFunction(arena, stringType, {stringType, optionalNumber}, {optionalString})}}, - {"rep", {makeFunction(arena, stringType, {numberType}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(arena, stringType, {numberType, optionalNumber}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(arena, stringType, {stringType, optionalString}, - {arena.addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, typeChecker.globalScope->level})})}}, - {"pack", {arena.addType(FunctionTypeVar{ - arena.addTypePack(TypePack{{stringType}, anyTypePack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(arena, stringType, {}, {numberType})}}, - {"unpack", {arena.addType(FunctionTypeVar{ - arena.addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - anyTypePack, - })}}, - }; - - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - addGlobalBinding(typeChecker, "string", - arena.addType(TableTypeVar{stringLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - } - - TableTypeVar::Props debugLib{ - {"info", {makeIntersection(arena, - { - arena.addType(FunctionTypeVar{arena.addTypePack({typeChecker.threadType, numberType, stringType}), anyTypePack}), - arena.addType(FunctionTypeVar{arena.addTypePack({numberType, stringType}), anyTypePack}), - arena.addType(FunctionTypeVar{arena.addTypePack({anyFunction, stringType}), anyTypePack}), - })}}, - {"traceback", {makeIntersection(arena, - { - makeFunction(arena, std::nullopt, {optionalString, optionalNumber}, {stringType}), - makeFunction(arena, std::nullopt, {typeChecker.threadType, optionalString, optionalNumber}, {stringType}), - })}}, - }; - - assignPropDocumentationSymbols(debugLib, "@luau/global/debug"); - addGlobalBinding(typeChecker, "debug", - arena.addType(TableTypeVar{debugLib, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}), "@luau"); - - TableTypeVar::Props utf8Lib = { - {"char", {arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneStringPack})}}, // FIXME - {"charpattern", {stringType}}, - {"codes", {makeFunction(arena, std::nullopt, {stringType}, - {makeFunction(arena, std::nullopt, {stringType, numberType}, {numberType, numberType}), stringType, numberType})}}, - {"codepoint", - {arena.addType(FunctionTypeVar{arena.addTypePack({stringType, optionalNumber, optionalNumber}), listOfAtLeastOneNumber})}}, // FIXME - {"len", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {optionalNumber, numberType})}}, - {"offset", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {numberType})}}, - {"nfdnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}}, - {"graphemes", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, - {makeFunction(arena, std::nullopt, {}, {numberType, numberType})})}}, - {"nfcnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}}, - }; - - assignPropDocumentationSymbols(utf8Lib, "@luau/global/utf8"); + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) addGlobalBinding( - typeChecker, "utf8", arena.addType(TableTypeVar{utf8Lib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TypeId optionalV = makeOption(typeChecker, arena, genericV); - - TypeId arrayOfV = arena.addType(TableTypeVar{{}, TableIndexer(numberType, genericV), typeChecker.globalScope->level}); - - TypePackId unpackArgsPack = arena.addTypePack(TypePack{{arrayOfV, optionalNumber, optionalNumber}}); - TypePackId unpackReturnPack = arena.addTypePack(TypePack{{}, anyTypePack}); - TypeId unpackFunc = arena.addType(FunctionTypeVar{{genericV}, {}, unpackArgsPack, unpackReturnPack}); - - TypeId packResult = arena.addType(TableTypeVar{ - TableTypeVar::Props{{"n", {numberType}}}, TableIndexer{numberType, numberType}, typeChecker.globalScope->level, TableState::Sealed}); - TypePackId packArgsPack = arena.addTypePack(TypePack{{}, anyTypePack}); - TypePackId packReturnPack = arena.addTypePack(TypePack{{packResult}}); - - TypeId comparator = makeFunction(arena, std::nullopt, {genericV, genericV}, {booleanType}); - TypeId optionalComparator = makeOption(typeChecker, arena, comparator); - - TypeId packFn = arena.addType(FunctionTypeVar(packArgsPack, packReturnPack)); - - TableTypeVar::Props tableLib = { - {"concat", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalString, optionalNumber, optionalNumber}, {stringType})}}, - {"insert", {makeIntersection(arena, {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV}, {}), - makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, genericV}, {})})}}, - {"maxn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}}, - {"remove", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalNumber}, {optionalV})}}, - {"sort", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalComparator}, {})}}, - {"create", {makeFunction(arena, std::nullopt, {genericV}, {}, {numberType, optionalV}, {arrayOfV})}}, - {"find", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV, optionalNumber}, {optionalNumber})}}, - - {"unpack", {unpackFunc}}, // FIXME - {"pack", {packFn}}, - - // Lua 5.0 compat - {"getn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}}, - {"foreach", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, - {mapOfKtoV, makeFunction(arena, std::nullopt, {genericK, genericV}, {})}, {})}}, - {"foreachi", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, makeFunction(arena, std::nullopt, {genericV}, {})}, {})}}, - - // backported from Lua 5.3 - {"move", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, numberType, numberType, arrayOfV}, {})}}, - - // added in Luau (borrowed from LuaJIT) - {"clear", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {})}}, - - {"freeze", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {mapOfKtoV})}}, - {"isfrozen", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {booleanType})}}, - }; - - assignPropDocumentationSymbols(tableLib, "@luau/global/table"); - addGlobalBinding( - typeChecker, "table", arena.addType(TableTypeVar{tableLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TableTypeVar::Props coroutineLib = { - {"create", {makeFunction(arena, std::nullopt, {anyFunction}, {threadType})}}, - {"resume", {arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{threadType}, anyTypePack}), anyTypePack})}}, - {"running", {makeFunction(arena, std::nullopt, {}, {threadType})}}, - {"status", {makeFunction(arena, std::nullopt, {threadType}, {stringType})}}, - {"wrap", {makeFunction( - arena, std::nullopt, {anyFunction}, {anyType})}}, // FIXME this technically returns a function, but we can't represent this - // atm since it can be called with different arg types at different times - {"yield", {arena.addType(FunctionTypeVar{anyTypePack, anyTypePack})}}, - {"isyieldable", {makeFunction(arena, std::nullopt, {}, {booleanType})}}, - }; - - assignPropDocumentationSymbols(coroutineLib, "@luau/global/coroutine"); - addGlobalBinding(typeChecker, "coroutine", - arena.addType(TableTypeVar{coroutineLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TypeId genericT = arena.addType(GenericTypeVar{"T"}); - TypeId genericR = arena.addType(GenericTypeVar{"R"}); - - // assert returns all arguments - TypePackId assertArgs = arena.addTypePack({genericT, optionalString}); - TypePackId assertRets = arena.addTypePack({genericT}); - addGlobalBinding(typeChecker, "assert", arena.addType(FunctionTypeVar{assertArgs, assertRets}), "@luau"); - - addGlobalBinding(typeChecker, "print", arena.addType(FunctionTypeVar{anyTypePack, emptyPack}), "@luau"); - - addGlobalBinding(typeChecker, "type", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - addGlobalBinding(typeChecker, "typeof", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - - addGlobalBinding(typeChecker, "error", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {}), "@luau"); - - addGlobalBinding(typeChecker, "tostring", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - addGlobalBinding( - typeChecker, "tonumber", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {numberType}), "@luau"); - - addGlobalBinding( - typeChecker, "rawequal", makeFunction(arena, std::nullopt, {genericT, genericR}, {}, {genericT, genericR}, {booleanType}), "@luau"); - addGlobalBinding( - typeChecker, "rawget", makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK}, {genericV}), "@luau"); - addGlobalBinding(typeChecker, "rawset", - makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK, genericV}, {mapOfKtoV}), "@luau"); - - TypePackId genericTPack = arena.addTypePack({genericT}); - TypePackId genericRPack = arena.addTypePack({genericR}); - TypeId genericArgsToReturnFunction = arena.addType( - FunctionTypeVar{{genericT, genericR}, {}, arena.addTypePack(TypePack{{}, genericTPack}), arena.addTypePack(TypePack{{}, genericRPack})}); - - TypeId setfenvArgType = makeUnion(arena, {numberType, genericArgsToReturnFunction}); - TypeId setfenvReturnType = makeOption(typeChecker, arena, genericArgsToReturnFunction); - addGlobalBinding(typeChecker, "setfenv", makeFunction(arena, std::nullopt, {setfenvArgType, stringToAnyMap}, {setfenvReturnType}), "@luau"); - - TypePackId ipairsArgsTypePack = arena.addTypePack({arrayOfV}); - - TypeId ipairsNextFunctionType = arena.addType( - FunctionTypeVar{{genericK, genericV}, {}, arena.addTypePack({arrayOfV, numberType}), arena.addTypePack({numberType, genericV})}); - - // ipairs returns 'next, Array, 0' so we would need type-level primitives and change to - // again, we have a direct reference to 'next' because ipairs returns it - // ipairs(t: Array) -> ((Array) -> (number, V), Array, 0) - TypePackId ipairsReturnTypePack = arena.addTypePack(TypePack{{ipairsNextFunctionType, arrayOfV, numberType}}); - - // ipairs(t: Array) -> ((Array) -> (number, V), Array, number) - addGlobalBinding(typeChecker, "ipairs", arena.addType(FunctionTypeVar{{genericV}, {}, ipairsArgsTypePack, ipairsReturnTypePack}), "@luau"); - - TypePackId pcallArg0FnArgs = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}}); - TypePackId pcallArg0FnRet = arena.addTypePack(TypePackVar{GenericTypeVar{"R"}}); - TypeId pcallArg0 = arena.addType(FunctionTypeVar{pcallArg0FnArgs, pcallArg0FnRet}); - TypePackId pcallArgsTypePack = arena.addTypePack(TypePack{{pcallArg0}, pcallArg0FnArgs}); - - TypePackId pcallReturnTypePack = arena.addTypePack(TypePack{{booleanType}, pcallArg0FnRet}); - - // pcall(f: (A...) -> R..., args: A...) -> boolean, R... - addGlobalBinding(typeChecker, "pcall", - arena.addType(FunctionTypeVar{{}, {pcallArg0FnArgs, pcallArg0FnRet}, pcallArgsTypePack, pcallReturnTypePack}), "@luau"); - - // errors thrown by the function 'f' are propagated onto the function 'err' that accepts it. - // and either 'f' or 'err' are valid results of this xpcall - // if 'err' did throw an error, then it returns: false, "error in error handling" - // TODO: the above is not represented (nor representable) in the type annotation below. - // - // The real type of xpcall is as such: (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, - // R2...) - TypePackId genericAPack = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}}); - TypePackId genericR1Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R1"}}); - TypePackId genericR2Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R2"}}); - - TypeId genericE = arena.addType(GenericTypeVar{"E"}); - - TypeId xpcallFArg = arena.addType(FunctionTypeVar{genericAPack, genericR1Pack}); - TypeId xpcallErrArg = arena.addType(FunctionTypeVar{arena.addTypePack({genericE}), genericR2Pack}); - - TypePackId xpcallArgsPack = arena.addTypePack({{xpcallFArg, xpcallErrArg}, genericAPack}); - TypePackId xpcallRetPack = arena.addTypePack({{booleanType}, genericR1Pack}); // FIXME - - addGlobalBinding(typeChecker, "xpcall", - arena.addType(FunctionTypeVar{{genericE}, {genericAPack, genericR1Pack, genericR2Pack}, xpcallArgsPack, xpcallRetPack}), "@luau"); - - addGlobalBinding(typeChecker, "unpack", unpackFunc, "@luau"); - - TypePackId selectArgsTypePack = arena.addTypePack(TypePack{ - {stringOrNumber}, - anyTypePack // FIXME? select() is tricky. - }); - - addGlobalBinding(typeChecker, "select", arena.addType(FunctionTypeVar{selectArgsTypePack, anyTypePack}), "@luau"); - - // TODO: not completely correct. loadstring's return type should be a function or (nil, string) - TypeId loadstringFunc = arena.addType(FunctionTypeVar{anyTypePack, oneAnyPack}); - - addGlobalBinding(typeChecker, "loadstring", - makeFunction(arena, std::nullopt, {stringType, optionalString}, - { - makeOption(typeChecker, arena, loadstringFunc), - makeOption(typeChecker, arena, stringType), - }), - "@luau"); - - // a userdata object is "roughly" the same as a sealed empty table - // except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too. - // another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT - // setmetatable. - // TODO: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`. - TypeId sealedTable = arena.addType(TableTypeVar(TableState::Sealed, typeChecker.globalScope->level)); - addGlobalBinding(typeChecker, "newproxy", makeFunction(arena, std::nullopt, {optionalBoolean}, {sealedTable}), "@luau"); + typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); } - // next(t: Table, i: K | nil) -> (K, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); - addGlobalBinding(typeChecker, "next", - arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); - - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - - TypeId pairsNext = (FFlag::LuauRankNTypes ? arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}) - : getGlobalBinding(typeChecker, "next")); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - - // NOTE we are missing 'i: K | nil' argument in the first return types' argument. - // pairs(t: Table) -> ((Table) -> (K, V), Table, nil) - addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); - TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); TableTypeVar tab{TableState::Generic, typeChecker.globalScope->level}; @@ -572,14 +331,14 @@ void registerBuiltinTypes(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - // setmetatable({ @metatable MT }, MT) -> { @metatable MT } // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } addGlobalBinding(typeChecker, "setmetatable", arena.addType( FunctionTypeVar{ {genericMT}, {}, - arena.addTypePack(TypePack{{tableMetaMT, genericMT}}), + arena.addTypePack(TypePack{{FFlag::LuauUnknownAndNeverType ? tabTy : tableMetaMT, genericMT}}), arena.addTypePack(TypePack{{tableMetaMT}}) } ), "@luau" @@ -591,26 +350,159 @@ void registerBuiltinTypes(TypeChecker& typeChecker) persist(pair.second.typeId); if (TableTypeVar* ttv = getMutable(pair.second.typeId)) - ttv->name = toString(pair.first); + { + if (!ttv->name) + { + if (FFlag::LuauNewLibraryTypeNames) + ttv->name = "typeof(" + toString(pair.first) + ")"; + else + ttv->name = toString(pair.first); + } + } } attachMagicFunction(getGlobalBinding(typeChecker, "assert"), magicFunctionAssert); attachMagicFunction(getGlobalBinding(typeChecker, "setmetatable"), magicFunctionSetMetaTable); attachMagicFunction(getGlobalBinding(typeChecker, "select"), magicFunctionSelect); + attachDcrMagicFunction(getGlobalBinding(typeChecker, "select"), dcrMagicFunctionSelect); - auto tableLib = getMutable(getGlobalBinding(typeChecker, "table")); - attachMagicFunction(tableLib->props["pack"].type, magicFunctionPack); + if (TableTypeVar* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + { + // tabTy is a generic table type which we can't express via declaration syntax yet + ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); + ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); - auto stringLib = getMutable(getGlobalBinding(typeChecker, "string")); - attachMagicFunction(stringLib->props["format"].type, magicFunctionFormat); + attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); + attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); + } attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); + attachDcrMagicFunction(getGlobalBinding(typeChecker, "require"), dcrMagicFunctionRequire); } -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +void registerBuiltinGlobals(Frontend& frontend) { - auto [paramPack, _predicates] = exprResult; + LUAU_ASSERT(!frontend.globalTypes.typeVars.isFrozen()); + LUAU_ASSERT(!frontend.globalTypes.typePacks.isFrozen()); + + if (FFlag::LuauReportShadowedTypeAlias) + registerBuiltinTypes(frontend); + + TypeArena& arena = frontend.globalTypes; + NotNull singletonTypes = frontend.singletonTypes; + + LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau"); + LUAU_ASSERT(loadResult.success); + + TypeId genericK = arena.addType(GenericTypeVar{"K"}); + TypeId genericV = arena.addType(GenericTypeVar{"V"}); + TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), frontend.getGlobalScope()->level, TableState::Generic}); + + std::optional stringMetatableTy = getMetatable(singletonTypes->stringType, singletonTypes); + LUAU_ASSERT(stringMetatableTy); + const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); + LUAU_ASSERT(stringMetatableTable); + + auto it = stringMetatableTable->props.find("__index"); + LUAU_ASSERT(it != stringMetatableTable->props.end()); + + addGlobalBinding(frontend, "string", it->second.type, "@luau"); + + if (FFlag::LuauOptionalNextKey) + { + // next(t: Table, i: K?) -> (K?, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(frontend, arena, genericK), genericV}}); + addGlobalBinding(frontend, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); + + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, nextRetsTypePack}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); + + // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) + addGlobalBinding( + frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } + else + { + // next(t: Table, i: K?) -> (K, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); + addGlobalBinding(frontend, "next", + arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); + + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding( + frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } + + TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); + + TableTypeVar tab{TableState::Generic, frontend.getGlobalScope()->level}; + TypeId tabTy = arena.addType(tab); + + TypeId tableMetaMT = arena.addType(MetatableTypeVar{tabTy, genericMT}); + + addGlobalBinding(frontend, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); + + // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } + addGlobalBinding(frontend, "setmetatable", + arena.addType( + FunctionTypeVar{ + {genericMT}, + {}, + arena.addTypePack(TypePack{{FFlag::LuauUnknownAndNeverType ? tabTy : tableMetaMT, genericMT}}), + arena.addTypePack(TypePack{{tableMetaMT}}) + } + ), "@luau" + ); + // clang-format on + + for (const auto& pair : frontend.getGlobalScope()->bindings) + { + persist(pair.second.typeId); + + if (TableTypeVar* ttv = getMutable(pair.second.typeId)) + { + if (!ttv->name) + { + if (FFlag::LuauNewLibraryTypeNames) + ttv->name = "typeof(" + toString(pair.first) + ")"; + else + ttv->name = toString(pair.first); + } + } + } + + attachMagicFunction(getGlobalBinding(frontend, "assert"), magicFunctionAssert); + attachDcrMagicRefinement(getGlobalBinding(frontend, "assert"), dcrMagicRefinementAssert); + attachMagicFunction(getGlobalBinding(frontend, "setmetatable"), magicFunctionSetMetaTable); + attachMagicFunction(getGlobalBinding(frontend, "select"), magicFunctionSelect); + attachDcrMagicFunction(getGlobalBinding(frontend, "select"), dcrMagicFunctionSelect); + + if (TableTypeVar* ttv = getMutable(getGlobalBinding(frontend, "table"))) + { + // tabTy is a generic table type which we can't express via declaration syntax yet + ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); + ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + + attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); + } + + attachMagicFunction(getGlobalBinding(frontend, "require"), magicFunctionRequire); + attachDcrMagicFunction(getGlobalBinding(frontend, "require"), dcrMagicFunctionRequire); +} + +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; (void)scope; @@ -631,10 +523,10 @@ static std::optional> magicFunctionSelect( if (size_t(offset) < v.size()) { std::vector result(v.begin() + offset, v.end()); - return ExprResult{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})}; + return WithPredicate{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})}; } else if (tail) - return ExprResult{*tail}; + return WithPredicate{*tail}; } typechecker.reportError(TypeError{arg1->location, GenericError{"bad argument #1 to select (index out of range)"}}); @@ -642,16 +534,67 @@ static std::optional> magicFunctionSelect( else if (AstExprConstantString* str = arg1->as()) { if (str->value.size == 1 && str->value.data[0] == '#') - return ExprResult{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})}; + return WithPredicate{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})}; } return std::nullopt; } -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) { - auto [paramPack, _predicates] = exprResult; + if (context.callSite->args.size <= 0) + { + context.solver->reportError(TypeError{context.callSite->location, GenericError{"select should take 1 or more arguments"}}); + return false; + } + + AstExpr* arg1 = context.callSite->args.data[0]; + + if (AstExprConstantNumber* num = arg1->as()) + { + const auto& [v, tail] = flatten(context.arguments); + + int offset = int(num->value); + if (offset > 0) + { + if (size_t(offset) < v.size()) + { + std::vector res(v.begin() + offset, v.end()); + TypePackId resTypePack = context.solver->arena->addTypePack({std::move(res), tail}); + asMutable(context.result)->ty.emplace(resTypePack); + } + else if (tail) + asMutable(context.result)->ty.emplace(*tail); + + return true; + } + + return false; + } + + if (AstExprConstantString* str = arg1->as()) + { + if (str->value.size == 1 && str->value.data[0] == '#') + { + TypePackId numberTypePack = context.solver->arena->addTypePack({context.solver->singletonTypes->numberType}); + asMutable(context.result)->ty.emplace(numberTypePack); + return true; + } + } + + return false; +} + +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + + if (FFlag::LuauUnknownAndNeverType) + { + if (size(paramPack) < 2 && finite(paramPack)) + return std::nullopt; + } TypeArena& arena = typechecker.currentModule->internalTypes; @@ -660,6 +603,12 @@ static std::optional> magicFunctionSetMetaTable( TypeId target = follow(expectedArgs[0]); TypeId mt = follow(expectedArgs[1]); + if (FFlag::LuauUnknownAndNeverType) + { + typechecker.tablify(target); + typechecker.tablify(mt); + } + if (const auto& tab = get(target)) { if (target->persistent) @@ -668,7 +617,8 @@ static std::optional> magicFunctionSetMetaTable( } else { - typechecker.tablify(mt); + if (!FFlag::LuauUnknownAndNeverType) + typechecker.tablify(mt); const TableTypeVar* mtTtv = get(mt); MetatableTypeVar mtv{target, mt}; @@ -679,20 +629,31 @@ static std::optional> magicFunctionSetMetaTable( if (tableName == metatableName) mtv.syntheticName = tableName; - else + else if (!FFlag::LuauBuiltInMetatableNoBadSynthetic) mtv.syntheticName = "{ @metatable: " + metatableName + ", " + tableName + " }"; } TypeId mtTy = arena.addType(mtv); - AstExpr* targetExpr = expr.args.data[0]; - if (AstExprLocal* targetLocal = targetExpr->as()) + if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1) { - const Name targetName(targetLocal->local->name.value); - scope->bindings[targetLocal->local] = Binding{mtTy, expr.location}; + if (FFlag::LuauUnknownAndNeverType) + return std::nullopt; + else + return WithPredicate{}; } - return ExprResult{arena.addTypePack({mtTy})}; + if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self) + { + AstExpr* targetExpr = expr.args.data[0]; + if (AstExprLocal* targetLocal = targetExpr->as()) + { + const Name targetName(targetLocal->local->name.value); + scope->bindings[targetLocal->local] = Binding{mtTy, expr.location}; + } + } + + return WithPredicate{arena.addTypePack({mtTy})}; } } else if (get(target) || get(target) || isTableIntersection(target)) @@ -703,26 +664,62 @@ static std::optional> magicFunctionSetMetaTable( typechecker.reportError(TypeError{expr.location, GenericError{"setmetatable should take a table"}}); } - return ExprResult{arena.addTypePack({target})}; + return WithPredicate{arena.addTypePack({target})}; } -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, predicates] = exprResult; + auto [paramPack, predicates] = withPredicate; - if (expr.args.size < 1) - return ExprResult{paramPack}; + TypeArena& arena = typechecker.currentModule->internalTypes; - typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); + auto [head, tail] = flatten(paramPack); + if (head.empty() && tail) + { + std::optional fst = first(*tail); + if (!fst) + return WithPredicate{paramPack}; + head.push_back(*fst); + } - return ExprResult{paramPack}; + typechecker.resolve(predicates, scope, true); + + if (head.size() > 0) + { + auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true, typechecker.singletonTypes->nilType); + if (FFlag::LuauUnknownAndNeverType) + { + if (get(*ty)) + head = {*ty}; + else + head[0] = *ty; + } + else + { + if (!ty) + head = {typechecker.nilType}; + else + head[0] = *ty; + } + } + + return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; } -static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::vector dcrMagicRefinementAssert(const MagicRefinementContext& ctx) { - auto [paramPack, _predicates] = exprResult; + if (ctx.argumentConnectives.empty()) + return {}; + + ctx.cgb->applyRefinements(ctx.scope, ctx.callSite->location, ctx.argumentConnectives[0]); + return {}; +} + +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -739,7 +736,7 @@ static std::optional> magicFunctionPack( options.push_back(vtp->ty); } - options = typechecker.reduceUnion(options); + options = reduceUnion(options); // table.pack() -> {| n: number, [number]: nil |} // table.pack(1) -> {| n: number, [number]: number |} @@ -755,7 +752,47 @@ static std::optional> magicFunctionPack( TypeId packedTable = arena.addType( TableTypeVar{{{"n", {typechecker.numberType}}}, TableIndexer(typechecker.numberType, result), scope->level, TableState::Sealed}); - return ExprResult{arena.addTypePack({packedTable})}; + return WithPredicate{arena.addTypePack({packedTable})}; +} + +static bool dcrMagicFunctionPack(MagicFunctionCallContext context) +{ + + TypeArena* arena = context.solver->arena; + + const auto& [paramTypes, paramTail] = flatten(context.arguments); + + std::vector options; + options.reserve(paramTypes.size()); + for (auto type : paramTypes) + options.push_back(type); + + if (paramTail) + { + if (const VariadicTypePack* vtp = get(*paramTail)) + options.push_back(vtp->ty); + } + + options = reduceUnion(options); + + // table.pack() -> {| n: number, [number]: nil |} + // table.pack(1) -> {| n: number, [number]: number |} + // table.pack(1, "foo") -> {| n: number, [number]: number | string |} + TypeId result = nullptr; + if (options.empty()) + result = context.solver->singletonTypes->nilType; + else if (options.size() == 1) + result = options[0]; + else + result = arena->addType(UnionTypeVar{std::move(options)}); + + TypeId numberType = context.solver->singletonTypes->numberType; + TypeId packedTable = arena->addType(TableTypeVar{{{"n", {numberType}}}, TableIndexer(numberType, result), {}, TableState::Sealed}); + + TypePackId tableTypePack = arena->addTypePack({packedTable}); + asMutable(context.result)->ty.emplace(tableTypePack); + + return true; } static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) @@ -780,8 +817,8 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) return good; } -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { TypeArena& arena = typechecker.currentModule->internalTypes; @@ -791,15 +828,58 @@ static std::optional> magicFunctionRequire( return std::nullopt; } - AstExpr* require = expr.args.data[0]; - - if (!checkRequirePath(typechecker, require)) + if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; - if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) - return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; + if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr)) + return WithPredicate{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; return std::nullopt; } +static bool checkRequirePathDcr(NotNull solver, AstExpr* expr) +{ + // require(foo.parent.bar) will technically work, but it depends on legacy goop that + // Luau does not and could not support without a bunch of work. It's deprecated anyway, so + // we'll warn here if we see it. + bool good = true; + AstExprIndexName* indexExpr = expr->as(); + + while (indexExpr) + { + if (indexExpr->index == "parent") + { + solver->reportError(DeprecatedApiUsed{"parent", "Parent"}, indexExpr->indexLocation); + good = false; + } + + indexExpr = indexExpr->expr->as(); + } + + return good; +} + +static bool dcrMagicFunctionRequire(MagicFunctionCallContext context) +{ + if (context.callSite->args.size != 1) + { + context.solver->reportError(GenericError{"require takes 1 argument"}, context.callSite->location); + return false; + } + + if (!checkRequirePathDcr(context.solver, context.callSite->args.data[0])) + return false; + + if (auto moduleInfo = context.solver->moduleResolver->resolveModuleInfo(context.solver->currentModuleName, *context.callSite)) + { + TypeId moduleType = context.solver->resolveModule(*moduleInfo, context.callSite->location); + TypePackId moduleResult = context.solver->arena->addTypePack({moduleType}); + asMutable(context.result)->ty.emplace(moduleResult); + + return true; + } + + return false; +} + } // namespace Luau diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp new file mode 100644 index 00000000..86e1c7fc --- /dev/null +++ b/Analysis/src/Clone.cpp @@ -0,0 +1,504 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Clone.h" + +#include "Luau/RecursionCounter.h" +#include "Luau/TxnLog.h" +#include "Luau/TypePack.h" +#include "Luau/Unifiable.h" + +LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) +LUAU_FASTFLAG(LuauClonePublicInterfaceLess) + +LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) + +namespace Luau +{ + +namespace +{ + +struct TypePackCloner; + +/* + * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. + * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. + */ + +struct TypeCloner +{ + TypeCloner(TypeArena& dest, TypeId typeId, CloneState& cloneState) + : dest(dest) + , typeId(typeId) + , seenTypes(cloneState.seenTypes) + , seenTypePacks(cloneState.seenTypePacks) + , cloneState(cloneState) + { + } + + TypeArena& dest; + TypeId typeId; + SeenTypes& seenTypes; + SeenTypePacks& seenTypePacks; + CloneState& cloneState; + + template + void defaultClone(const T& t); + + void operator()(const Unifiable::Free& t); + void operator()(const Unifiable::Generic& t); + void operator()(const Unifiable::Bound& t); + void operator()(const Unifiable::Error& t); + void operator()(const BlockedTypeVar& t); + void operator()(const PendingExpansionTypeVar& t); + void operator()(const PrimitiveTypeVar& t); + void operator()(const SingletonTypeVar& t); + void operator()(const FunctionTypeVar& t); + void operator()(const TableTypeVar& t); + void operator()(const MetatableTypeVar& t); + void operator()(const ClassTypeVar& t); + void operator()(const AnyTypeVar& t); + void operator()(const UnionTypeVar& t); + void operator()(const IntersectionTypeVar& t); + void operator()(const LazyTypeVar& t); + void operator()(const UnknownTypeVar& t); + void operator()(const NeverTypeVar& t); + void operator()(const NegationTypeVar& t); +}; + +struct TypePackCloner +{ + TypeArena& dest; + TypePackId typePackId; + SeenTypes& seenTypes; + SeenTypePacks& seenTypePacks; + CloneState& cloneState; + + TypePackCloner(TypeArena& dest, TypePackId typePackId, CloneState& cloneState) + : dest(dest) + , typePackId(typePackId) + , seenTypes(cloneState.seenTypes) + , seenTypePacks(cloneState.seenTypePacks) + , cloneState(cloneState) + { + } + + template + void defaultClone(const T& t) + { + TypePackId cloned = dest.addTypePack(TypePackVar{t}); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const Unifiable::Free& t) + { + defaultClone(t); + } + void operator()(const Unifiable::Generic& t) + { + defaultClone(t); + } + void operator()(const Unifiable::Error& t) + { + defaultClone(t); + } + + void operator()(const BlockedTypePack& t) + { + defaultClone(t); + } + + // While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. + // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. + void operator()(const Unifiable::Bound& t) + { + TypePackId cloned = clone(t.boundTo, dest, cloneState); + if (FFlag::DebugLuauCopyBeforeNormalizing) + cloned = dest.addTypePack(TypePackVar{BoundTypePack{cloned}}); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const VariadicTypePack& t) + { + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, cloneState), /*hidden*/ t.hidden}}); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const TypePack& t) + { + TypePackId cloned = dest.addTypePack(TypePack{}); + TypePack* destTp = getMutable(cloned); + LUAU_ASSERT(destTp != nullptr); + seenTypePacks[typePackId] = cloned; + + for (TypeId ty : t.head) + destTp->head.push_back(clone(ty, dest, cloneState)); + + if (t.tail) + destTp->tail = clone(*t.tail, dest, cloneState); + } +}; + +template +void TypeCloner::defaultClone(const T& t) +{ + TypeId cloned = dest.addType(t); + seenTypes[typeId] = cloned; +} + +void TypeCloner::operator()(const Unifiable::Free& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const Unifiable::Generic& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const Unifiable::Bound& t) +{ + TypeId boundTo = clone(t.boundTo, dest, cloneState); + if (FFlag::DebugLuauCopyBeforeNormalizing) + boundTo = dest.addType(BoundTypeVar{boundTo}); + seenTypes[typeId] = boundTo; +} + +void TypeCloner::operator()(const Unifiable::Error& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const BlockedTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const PendingExpansionTypeVar& t) +{ + TypeId res = dest.addType(PendingExpansionTypeVar{t.prefix, t.name, t.typeArguments, t.packArguments}); + PendingExpansionTypeVar* petv = getMutable(res); + LUAU_ASSERT(petv); + + seenTypes[typeId] = res; + + std::vector typeArguments; + for (TypeId arg : t.typeArguments) + typeArguments.push_back(clone(arg, dest, cloneState)); + + std::vector packArguments; + for (TypePackId arg : t.packArguments) + packArguments.push_back(clone(arg, dest, cloneState)); + + petv->typeArguments = std::move(typeArguments); + petv->packArguments = std::move(packArguments); +} + +void TypeCloner::operator()(const PrimitiveTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const SingletonTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const FunctionTypeVar& t) +{ + // FISHY: We always erase the scope when we clone things. clone() was + // originally written so that we could copy a module's type surface into an + // export arena. This probably dates to that. + TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); + FunctionTypeVar* ftv = getMutable(result); + LUAU_ASSERT(ftv != nullptr); + + seenTypes[typeId] = result; + + for (TypeId generic : t.generics) + ftv->generics.push_back(clone(generic, dest, cloneState)); + + for (TypePackId genericPack : t.genericPacks) + ftv->genericPacks.push_back(clone(genericPack, dest, cloneState)); + + ftv->tags = t.tags; + ftv->argTypes = clone(t.argTypes, dest, cloneState); + ftv->argNames = t.argNames; + ftv->retTypes = clone(t.retTypes, dest, cloneState); + ftv->hasNoGenerics = t.hasNoGenerics; +} + +void TypeCloner::operator()(const TableTypeVar& t) +{ + // If table is now bound to another one, we ignore the content of the original + if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) + { + TypeId boundTo = clone(*t.boundTo, dest, cloneState); + seenTypes[typeId] = boundTo; + return; + } + + TypeId result = dest.addType(TableTypeVar{}); + TableTypeVar* ttv = getMutable(result); + LUAU_ASSERT(ttv != nullptr); + + *ttv = t; + + seenTypes[typeId] = result; + + ttv->level = TypeLevel{0, 0}; + + if (FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) + ttv->boundTo = clone(*t.boundTo, dest, cloneState); + + for (const auto& [name, prop] : t.props) + ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + + if (t.indexer) + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; + + for (TypeId& arg : ttv->instantiatedTypeParams) + arg = clone(arg, dest, cloneState); + + for (TypePackId& arg : ttv->instantiatedTypePackParams) + arg = clone(arg, dest, cloneState); + + ttv->definitionModuleName = t.definitionModuleName; + ttv->tags = t.tags; +} + +void TypeCloner::operator()(const MetatableTypeVar& t) +{ + TypeId result = dest.addType(MetatableTypeVar{}); + MetatableTypeVar* mtv = getMutable(result); + seenTypes[typeId] = result; + + mtv->table = clone(t.table, dest, cloneState); + mtv->metatable = clone(t.metatable, dest, cloneState); +} + +void TypeCloner::operator()(const ClassTypeVar& t) +{ + TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData, t.definitionModuleName}); + ClassTypeVar* ctv = getMutable(result); + + seenTypes[typeId] = result; + + for (const auto& [name, prop] : t.props) + ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + + if (t.parent) + ctv->parent = clone(*t.parent, dest, cloneState); + + if (t.metatable) + ctv->metatable = clone(*t.metatable, dest, cloneState); +} + +void TypeCloner::operator()(const AnyTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const UnionTypeVar& t) +{ + std::vector options; + options.reserve(t.options.size()); + + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, cloneState)); + + TypeId result = dest.addType(UnionTypeVar{std::move(options)}); + seenTypes[typeId] = result; +} + +void TypeCloner::operator()(const IntersectionTypeVar& t) +{ + TypeId result = dest.addType(IntersectionTypeVar{}); + seenTypes[typeId] = result; + + IntersectionTypeVar* option = getMutable(result); + LUAU_ASSERT(option != nullptr); + + for (TypeId ty : t.parts) + option->parts.push_back(clone(ty, dest, cloneState)); +} + +void TypeCloner::operator()(const LazyTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const UnknownTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const NeverTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const NegationTypeVar& t) +{ + TypeId result = dest.addType(AnyTypeVar{}); + seenTypes[typeId] = result; + + TypeId ty = clone(t.ty, dest, cloneState); + asMutable(result)->ty = NegationTypeVar{ty}; +} + +} // anonymous namespace + +TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) +{ + if (tp->persistent) + return tp; + + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + + TypePackId& res = cloneState.seenTypePacks[tp]; + + if (res == nullptr) + { + TypePackCloner cloner{dest, tp, cloneState}; + Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. + } + + return res; +} + +TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) +{ + if (typeId->persistent) + return typeId; + + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + + TypeId& res = cloneState.seenTypes[typeId]; + + if (res == nullptr) + { + TypeCloner cloner{dest, typeId, cloneState}; + Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. + + // Persistent types are not being cloned and we get the original type back which might be read-only + if (!res->persistent) + { + asMutable(res)->documentationSymbol = typeId->documentationSymbol; + } + } + + return res; +} + +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) +{ + TypeFun result; + + for (auto param : typeFun.typeParams) + { + TypeId ty = clone(param.ty, dest, cloneState); + std::optional defaultValue; + + if (param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, cloneState); + + result.typeParams.push_back({ty, defaultValue}); + } + + for (auto param : typeFun.typePackParams) + { + TypePackId tp = clone(param.tp, dest, cloneState); + std::optional defaultValue; + + if (param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, cloneState); + + result.typePackParams.push_back({tp, defaultValue}); + } + + result.type = clone(typeFun.type, dest, cloneState); + + return result; +} + +TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone) +{ + ty = log->follow(ty); + + TypeId result = ty; + + if (auto pty = log->pending(ty)) + ty = &pty->pending; + + if (const FunctionTypeVar* ftv = get(ty)) + { + FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; + clone.generics = ftv->generics; + clone.genericPacks = ftv->genericPacks; + clone.magicFunction = ftv->magicFunction; + clone.dcrMagicFunction = ftv->dcrMagicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + result = dest.addType(std::move(clone)); + } + else if (const TableTypeVar* ttv = get(ty)) + { + LUAU_ASSERT(!ttv->boundTo); + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state}; + clone.definitionModuleName = ttv->definitionModuleName; + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.instantiatedTypeParams = ttv->instantiatedTypeParams; + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; + clone.tags = ttv->tags; + result = dest.addType(std::move(clone)); + } + else if (const MetatableTypeVar* mtv = get(ty)) + { + MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; + clone.syntheticName = mtv->syntheticName; + result = dest.addType(std::move(clone)); + } + else if (const UnionTypeVar* utv = get(ty)) + { + UnionTypeVar clone; + clone.options = utv->options; + result = dest.addType(std::move(clone)); + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + IntersectionTypeVar clone; + clone.parts = itv->parts; + result = dest.addType(std::move(clone)); + } + else if (const PendingExpansionTypeVar* petv = get(ty)) + { + PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments}; + result = dest.addType(std::move(clone)); + } + else if (const ClassTypeVar* ctv = get(ty); FFlag::LuauClonePublicInterfaceLess && ctv && alwaysClone) + { + ClassTypeVar clone{ctv->name, ctv->props, ctv->parent, ctv->metatable, ctv->tags, ctv->userData, ctv->definitionModuleName}; + result = dest.addType(std::move(clone)); + } + else if (FFlag::LuauClonePublicInterfaceLess && alwaysClone) + { + result = dest.addType(*ty); + } + else if (const NegationTypeVar* ntv = get(ty)) + { + result = dest.addType(NegationTypeVar{ntv->ty}); + } + else + return result; + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; +} + +TypeId shallowClone(TypeId ty, NotNull dest) +{ + return shallowClone(ty, *dest, TxnLog::empty()); +} + +} // namespace Luau diff --git a/Analysis/src/Config.cpp b/Analysis/src/Config.cpp index d9fc44f8..00ca7b16 100644 --- a/Analysis/src/Config.cpp +++ b/Analysis/src/Config.cpp @@ -1,18 +1,21 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Config.h" -#include "Luau/Parser.h" +#include "Luau/Lexer.h" #include "Luau/StringUtils.h" -namespace +LUAU_FASTFLAGVARIABLE(LuauEnableNonstrictByDefaultForLuauConfig, false) + +namespace Luau { using Error = std::optional; -} - -namespace Luau +Config::Config() + : mode(FFlag::LuauEnableNonstrictByDefaultForLuauConfig ? Mode::Nonstrict : Mode::NoCheck) { + enabledLint.setDefaults(); +} static Error parseBoolean(bool& result, const std::string& value) { diff --git a/Analysis/src/Connective.cpp b/Analysis/src/Connective.cpp new file mode 100644 index 00000000..114b5f2f --- /dev/null +++ b/Analysis/src/Connective.cpp @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Connective.h" + +namespace Luau +{ + +ConnectiveId ConnectiveArena::negation(ConnectiveId connective) +{ + return NotNull{allocator.allocate(Negation{connective})}; +} + +ConnectiveId ConnectiveArena::conjunction(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Conjunction{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::disjunction(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Disjunction{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::equivalence(ConnectiveId lhs, ConnectiveId rhs) +{ + return NotNull{allocator.allocate(Equivalence{lhs, rhs})}; +} + +ConnectiveId ConnectiveArena::proposition(DefId def, TypeId discriminantTy) +{ + return NotNull{allocator.allocate(Proposition{def, discriminantTy})}; +} + +} // namespace Luau diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp new file mode 100644 index 00000000..3a6417dc --- /dev/null +++ b/Analysis/src/Constraint.cpp @@ -0,0 +1,15 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Constraint.h" + +namespace Luau +{ + +Constraint::Constraint(NotNull scope, const Location& location, ConstraintV&& c) + : scope(scope) + , location(location) + , c(std::move(c)) +{ +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp new file mode 100644 index 00000000..f0bd958c --- /dev/null +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -0,0 +1,2329 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ConstraintGraphBuilder.h" + +#include "Luau/Ast.h" +#include "Luau/Common.h" +#include "Luau/Constraint.h" +#include "Luau/DcrLogger.h" +#include "Luau/ModuleResolver.h" +#include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" +#include "Luau/TypeUtils.h" +#include "Luau/TypeVar.h" + +LUAU_FASTINT(LuauCheckRecursionLimit); +LUAU_FASTFLAG(DebugLuauLogSolverToJson); +LUAU_FASTFLAG(DebugLuauMagicTypes); + +namespace Luau +{ + +const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp + +static std::optional matchRequire(const AstExprCall& call) +{ + const char* require = "require"; + + if (call.args.size != 1) + return std::nullopt; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != require) + return std::nullopt; + + if (call.args.size != 1) + return std::nullopt; + + return call.args.data[0]; +} + +static bool matchSetmetatable(const AstExprCall& call) +{ + const char* smt = "setmetatable"; + + if (call.args.size != 2) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != smt) + return false; + + return true; +} + +struct TypeGuard +{ + bool isTypeof; + AstExpr* target; + std::string type; +}; + +static std::optional matchTypeGuard(const AstExprBinary* binary) +{ + if (binary->op != AstExprBinary::CompareEq && binary->op != AstExprBinary::CompareNe) + return std::nullopt; + + AstExpr* left = binary->left; + AstExpr* right = binary->right; + if (right->is()) + std::swap(left, right); + + if (!right->is()) + return std::nullopt; + + AstExprCall* call = left->as(); + AstExprConstantString* string = right->as(); + if (!call || !string) + return std::nullopt; + + AstExprGlobal* callee = call->func->as(); + if (!callee) + return std::nullopt; + + if (callee->name != "type" && callee->name != "typeof") + return std::nullopt; + + if (call->args.size != 1) + return std::nullopt; + + return TypeGuard{ + /*isTypeof*/ callee->name == "typeof", + /*target*/ call->args.data[0], + /*type*/ std::string(string->value.data, string->value.size), + }; +} + +namespace +{ + +struct Checkpoint +{ + size_t offset; +}; + +Checkpoint checkpoint(const ConstraintGraphBuilder* cgb) +{ + return Checkpoint{cgb->constraints.size()}; +} + +template +void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const ConstraintGraphBuilder* cgb, F f) +{ + for (size_t i = start.offset; i < end.offset; ++i) + f(cgb->constraints[i]); +} + +} // namespace + +ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, + NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, + DcrLogger* logger, NotNull dfg) + : moduleName(moduleName) + , module(module) + , singletonTypes(singletonTypes) + , arena(arena) + , rootScope(nullptr) + , dfg(dfg) + , moduleResolver(moduleResolver) + , ice(ice) + , globalScope(globalScope) + , logger(logger) +{ + if (FFlag::DebugLuauLogSolverToJson) + LUAU_ASSERT(logger); + + LUAU_ASSERT(module); +} + +TypeId ConstraintGraphBuilder::freshType(const ScopePtr& scope) +{ + return arena->addType(FreeTypeVar{scope.get()}); +} + +TypePackId ConstraintGraphBuilder::freshTypePack(const ScopePtr& scope) +{ + FreeTypePack f{scope.get()}; + return arena->addTypePack(TypePackVar{std::move(f)}); +} + +ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& parent) +{ + auto scope = std::make_shared(parent); + scopes.emplace_back(node->location, scope); + + scope->returnType = parent->returnType; + scope->varargPack = parent->varargPack; + + parent->children.push_back(NotNull{scope.get()}); + module->astScopes[node] = scope.get(); + + return scope; +} + +NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) +{ + return NotNull{constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; +} + +NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) +{ + return NotNull{constraints.emplace_back(std::move(c)).get()}; +} + +static void unionRefinements(const std::unordered_map& lhs, const std::unordered_map& rhs, + std::unordered_map& dest, NotNull arena) +{ + for (auto [def, ty] : lhs) + { + auto rhsIt = rhs.find(def); + if (rhsIt == rhs.end()) + continue; + + std::vector discriminants{{ty, rhsIt->second}}; + + if (auto destIt = dest.find(def); destIt != dest.end()) + discriminants.push_back(destIt->second); + + dest[def] = arena->addType(UnionTypeVar{std::move(discriminants)}); + } +} + +static void computeRefinement(const ScopePtr& scope, ConnectiveId connective, std::unordered_map* refis, bool sense, + NotNull arena, bool eq, std::vector* constraints) +{ + using RefinementMap = std::unordered_map; + + if (!connective) + return; + else if (auto negation = get(connective)) + return computeRefinement(scope, negation->connective, refis, !sense, arena, eq, constraints); + else if (auto conjunction = get(connective)) + { + RefinementMap lhsRefis; + RefinementMap rhsRefis; + + computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints); + computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints); + + if (!sense) + unionRefinements(lhsRefis, rhsRefis, *refis, arena); + } + else if (auto disjunction = get(connective)) + { + RefinementMap lhsRefis; + RefinementMap rhsRefis; + + computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints); + computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints); + + if (sense) + unionRefinements(lhsRefis, rhsRefis, *refis, arena); + } + else if (auto equivalence = get(connective)) + { + computeRefinement(scope, equivalence->lhs, refis, sense, arena, true, constraints); + computeRefinement(scope, equivalence->rhs, refis, sense, arena, true, constraints); + } + else if (auto proposition = get(connective)) + { + TypeId discriminantTy = proposition->discriminantTy; + if (!sense && !eq) + discriminantTy = arena->addType(NegationTypeVar{proposition->discriminantTy}); + else if (eq) + { + discriminantTy = arena->addType(BlockedTypeVar{}); + constraints->push_back(SingletonOrTopTypeConstraint{discriminantTy, proposition->discriminantTy, !sense}); + } + + if (auto it = refis->find(proposition->def); it != refis->end()) + (*refis)[proposition->def] = arena->addType(IntersectionTypeVar{{discriminantTy, it->second}}); + else + (*refis)[proposition->def] = discriminantTy; + } +} + +static std::pair computeDiscriminantType(NotNull arena, const ScopePtr& scope, DefId def, TypeId discriminantTy) +{ + LUAU_ASSERT(get(def)); + + while (const Cell* current = get(def)) + { + if (!current->field) + break; + + TableTypeVar::Props props{{current->field->propName, Property{discriminantTy}}}; + discriminantTy = arena->addType(TableTypeVar{std::move(props), std::nullopt, TypeLevel{}, scope.get(), TableState::Sealed}); + + def = current->field->parent; + current = get(def); + } + + return {def, discriminantTy}; +} + +void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective) +{ + if (!connective) + return; + + std::unordered_map refinements; + std::vector constraints; + computeRefinement(scope, connective, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); + + for (auto [def, discriminantTy] : refinements) + { + auto [def2, discriminantTy2] = computeDiscriminantType(arena, scope, def, discriminantTy); + std::optional defTy = scope->lookup(def2); + if (!defTy) + ice->ice("Every DefId must map to a type!"); + + TypeId resultTy = arena->addType(IntersectionTypeVar{{*defTy, discriminantTy2}}); + scope->dcrRefinements[def2] = resultTy; + } + + for (auto& c : constraints) + addConstraint(scope, location, c); +} + +void ConstraintGraphBuilder::visit(AstStatBlock* block) +{ + LUAU_ASSERT(scopes.empty()); + LUAU_ASSERT(rootScope == nullptr); + ScopePtr scope = std::make_shared(globalScope); + rootScope = scope.get(); + scopes.emplace_back(block->location, scope); + module->astScopes[block] = NotNull{scope.get()}; + + rootScope->returnType = freshTypePack(scope); + + prepopulateGlobalScope(scope, block); + + visitBlockWithoutChildScope(scope, block); +} + +void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(block->location); + return; + } + + std::unordered_map aliasDefinitionLocations; + + // In order to enable mutually-recursive type aliases, we need to + // populate the type bindings before we actually check any of the + // alias statements. Since we're not ready to actually resolve + // any of the annotations, we just use a fresh type for now. + for (AstStat* stat : block->body) + { + if (auto alias = stat->as()) + { + if (scope->privateTypeBindings.count(alias->name.value) != 0) + { + auto it = aliasDefinitionLocations.find(alias->name.value); + LUAU_ASSERT(it != aliasDefinitionLocations.end()); + reportError(alias->location, DuplicateTypeDefinition{alias->name.value, it->second}); + continue; + } + + bool hasGenerics = alias->generics.size > 0 || alias->genericPacks.size > 0; + + ScopePtr defnScope = scope; + if (hasGenerics) + { + defnScope = childScope(alias, scope); + } + + TypeId initialType = freshType(scope); + TypeFun initialFun = TypeFun{initialType}; + + for (const auto& [name, gen] : createGenerics(defnScope, alias->generics)) + { + initialFun.typeParams.push_back(gen); + defnScope->privateTypeBindings[name] = TypeFun{gen.ty}; + } + + for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks)) + { + initialFun.typePackParams.push_back(genPack); + defnScope->privateTypePackBindings[name] = genPack.tp; + } + + scope->privateTypeBindings[alias->name.value] = std::move(initialFun); + astTypeAliasDefiningScopes[alias] = defnScope; + aliasDefinitionLocations[alias->name.value] = alias->location; + } + } + + for (AstStat* stat : block->body) + visit(scope, stat); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) +{ + RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit}; + + if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto f = stat->as()) + visit(scope, f); + else if (auto f = stat->as()) + visit(scope, f); + else if (auto r = stat->as()) + visit(scope, r); + else if (auto a = stat->as()) + visit(scope, a); + else if (auto a = stat->as()) + visit(scope, a); + else if (auto e = stat->as()) + checkPack(scope, e->expr); + else if (auto i = stat->as()) + visit(scope, i); + else if (auto a = stat->as()) + visit(scope, a); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else if (auto s = stat->as()) + visit(scope, s); + else + LUAU_ASSERT(0); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) +{ + std::vector varTypes; + varTypes.reserve(local->vars.size); + + for (AstLocal* local : local->vars) + { + TypeId ty = nullptr; + + if (local->annotation) + ty = resolveType(scope, local->annotation, /* topLevel */ true); + + varTypes.push_back(ty); + } + + for (size_t i = 0; i < local->values.size; ++i) + { + AstExpr* value = local->values.data[i]; + const bool hasAnnotation = i < local->vars.size && nullptr != local->vars.data[i]->annotation; + + if (value->is()) + { + // HACK: we leave nil-initialized things floating under the + // assumption that they will later be populated. + // + // See the test TypeInfer/infer_locals_with_nil_value. Better flow + // awareness should make this obsolete. + + if (!varTypes[i]) + varTypes[i] = freshType(scope); + } + // Only function calls and vararg expressions can produce packs. All + // other expressions produce exactly one value. + else if (i != local->values.size - 1 || (!value->is() && !value->is())) + { + std::optional expectedType; + if (hasAnnotation) + expectedType = varTypes.at(i); + + TypeId exprType = check(scope, value, expectedType).ty; + if (i < varTypes.size()) + { + if (varTypes[i]) + addConstraint(scope, local->location, SubtypeConstraint{exprType, varTypes[i]}); + else + varTypes[i] = exprType; + } + } + else + { + std::vector expectedTypes; + if (hasAnnotation) + expectedTypes.insert(begin(expectedTypes), begin(varTypes) + i, end(varTypes)); + + TypePackId exprPack = checkPack(scope, value, expectedTypes).tp; + + if (i < local->vars.size) + { + TypePack packTypes = extendTypePack(*arena, singletonTypes, exprPack, varTypes.size() - i); + + // fill out missing values in varTypes with values from exprPack + for (size_t j = i; j < varTypes.size(); ++j) + { + if (!varTypes[j]) + { + if (j - i < packTypes.head.size()) + varTypes[j] = packTypes.head[j - i]; + else + varTypes[j] = freshType(scope); + } + } + + std::vector tailValues{varTypes.begin() + i, varTypes.end()}; + TypePackId tailPack = arena->addTypePack(std::move(tailValues)); + addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); + } + } + } + + for (size_t i = 0; i < local->vars.size; ++i) + { + AstLocal* l = local->vars.data[i]; + Location location = l->location; + + if (!varTypes[i]) + varTypes[i] = freshType(scope); + + scope->bindings[l] = Binding{varTypes[i], location}; + + // HACK: In the greedy solver, we say the type state of a variable is the type annotation itself, but + // the actual type state is the corresponding initializer expression (if it exists) or nil otherwise. + if (auto def = dfg->getDef(l)) + scope->dcrRefinements[*def] = varTypes[i]; + } + + if (local->values.size > 0) + { + // To correctly handle 'require', we need to import the exported type bindings into the variable 'namespace'. + for (size_t i = 0; i < local->values.size && i < local->vars.size; ++i) + { + const AstExprCall* call = local->values.data[i]->as(); + if (!call) + continue; + + if (auto maybeRequire = matchRequire(*call)) + { + AstExpr* require = *maybeRequire; + + if (auto moduleInfo = moduleResolver->resolveModuleInfo(moduleName, *require)) + { + const Name name{local->vars.data[i]->name.value}; + + if (ModulePtr module = moduleResolver->getModule(moduleInfo->name)) + scope->importedTypeBindings[name] = module->getModuleScope()->exportedTypeBindings; + } + } + } + } +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) +{ + auto checkNumber = [&](AstExpr* expr) { + if (!expr) + return; + + TypeId t = check(scope, expr).ty; + addConstraint(scope, expr->location, SubtypeConstraint{t, singletonTypes->numberType}); + }; + + checkNumber(for_->from); + checkNumber(for_->to); + checkNumber(for_->step); + + ScopePtr forScope = childScope(for_, scope); + forScope->bindings[for_->var] = Binding{singletonTypes->numberType, for_->var->location}; + + visit(forScope, for_->body); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) +{ + ScopePtr loopScope = childScope(forIn, scope); + + TypePackId iterator = checkPack(scope, forIn->values).tp; + + std::vector variableTypes; + variableTypes.reserve(forIn->vars.size); + for (AstLocal* var : forIn->vars) + { + TypeId ty = freshType(loopScope); + loopScope->bindings[var] = Binding{ty, var->location}; + variableTypes.push_back(ty); + + if (auto def = dfg->getDef(var)) + loopScope->dcrRefinements[*def] = ty; + } + + // It is always ok to provide too few variables, so we give this pack a free tail. + TypePackId variablePack = arena->addTypePack(std::move(variableTypes), arena->addTypePack(FreeTypePack{loopScope.get()})); + + addConstraint(loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack}); + + visit(loopScope, forIn->body); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_) +{ + check(scope, while_->condition); + + ScopePtr whileScope = childScope(while_, scope); + + visit(whileScope, while_->body); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) +{ + ScopePtr repeatScope = childScope(repeat, scope); + + visit(repeatScope, repeat->body); + + // The condition does indeed have access to bindings from within the body of + // the loop. + check(repeatScope, repeat->condition); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function) +{ + // Local + // Global + // Dotted path + // Self? + + TypeId functionType = nullptr; + auto ty = scope->lookup(function->name); + LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. + + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[function->name] = Binding{functionType, function->name->location}; + + FunctionSignature sig = checkFunctionSignature(scope, function->func); + sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; + + Checkpoint start = checkpoint(this); + checkFunctionBody(sig.bodyScope, function->func); + Checkpoint end = checkpoint(this); + + NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; + std::unique_ptr c = + std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); + + forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { + c->dependencies.push_back(NotNull{constraint.get()}); + }); + + addConstraint(scope, std::move(c)); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* function) +{ + // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. + // With or without self + + TypeId generalizedType = arena->addType(BlockedTypeVar{}); + + FunctionSignature sig = checkFunctionSignature(scope, function->func); + + if (AstExprLocal* localName = function->name->as()) + { + std::optional existingFunctionTy = scope->lookup(localName->local); + if (existingFunctionTy) + { + addConstraint(scope, function->name->location, SubtypeConstraint{generalizedType, *existingFunctionTy}); + + Symbol sym{localName->local}; + std::optional def = dfg->getDef(sym); + LUAU_ASSERT(def); + scope->bindings[sym].typeId = generalizedType; + scope->dcrRefinements[*def] = generalizedType; + } + else + scope->bindings[localName->local] = Binding{generalizedType, localName->location}; + + sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location}; + } + else if (AstExprGlobal* globalName = function->name->as()) + { + std::optional existingFunctionTy = scope->lookup(globalName->name); + if (!existingFunctionTy) + ice->ice("prepopulateGlobalScope did not populate a global name", globalName->location); + + generalizedType = *existingFunctionTy; + + sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; + } + else if (AstExprIndexName* indexName = function->name->as()) + { + TypeId containingTableType = check(scope, indexName->expr).ty; + + // TODO look into stack utilization. This is probably ok because it scales with AST depth. + TypeId prospectiveTableType = arena->addType(TableTypeVar{TableState::Unsealed, TypeLevel{}, scope.get()}); + + NotNull prospectiveTable{getMutable(prospectiveTableType)}; + + Property& prop = prospectiveTable->props[indexName->index.value]; + prop.type = generalizedType; + prop.location = function->name->location; + + addConstraint(scope, indexName->location, SubtypeConstraint{containingTableType, prospectiveTableType}); + } + else if (AstExprError* err = function->name->as()) + { + generalizedType = singletonTypes->errorRecoveryType(); + } + + if (generalizedType == nullptr) + ice->ice("generalizedType == nullptr", function->location); + + Checkpoint start = checkpoint(this); + checkFunctionBody(sig.bodyScope, function->func); + Checkpoint end = checkpoint(this); + + NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; + std::unique_ptr c = + std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); + + forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { + c->dependencies.push_back(NotNull{constraint.get()}); + }); + + addConstraint(scope, std::move(c)); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) +{ + // At this point, the only way scope->returnType should have anything + // interesting in it is if the function has an explicit return annotation. + // If this is the case, then we can expect that the return expression + // conforms to that. + std::vector expectedTypes; + for (TypeId ty : scope->returnType) + expectedTypes.push_back(ty); + + TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp; + addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) +{ + ScopePtr innerScope = childScope(block, scope); + + visitBlockWithoutChildScope(innerScope, block); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) +{ + TypePackId varPackId = checkLValues(scope, assign->vars); + + TypePack expectedTypes = extendTypePack(*arena, singletonTypes, varPackId, assign->values.size); + TypePackId valuePack = checkPack(scope, assign->values, expectedTypes.head).tp; + + addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) +{ + // Synthesize A = A op B from A op= B and then build constraints for that instead. + + AstExprBinary exprBinary{assign->location, assign->op, assign->var, assign->value}; + AstExpr* exprBinaryPtr = &exprBinary; + + AstArray vars{&assign->var, 1}; + AstArray values{&exprBinaryPtr, 1}; + AstStatAssign syntheticAssign{assign->location, vars, values}; + + visit(scope, &syntheticAssign); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) +{ + ScopePtr condScope = childScope(ifStatement->condition, scope); + auto [_, connective] = check(condScope, ifStatement->condition, std::nullopt); + + ScopePtr thenScope = childScope(ifStatement->thenbody, scope); + applyRefinements(thenScope, Location{}, connective); + visit(thenScope, ifStatement->thenbody); + + if (ifStatement->elsebody) + { + ScopePtr elseScope = childScope(ifStatement->elsebody, scope); + applyRefinements(elseScope, Location{}, connectiveArena.negation(connective)); + visit(elseScope, ifStatement->elsebody); + } +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alias) +{ + auto bindingIt = scope->privateTypeBindings.find(alias->name.value); + ScopePtr* defnIt = astTypeAliasDefiningScopes.find(alias); + // These will be undefined if the alias was a duplicate definition, in which + // case we just skip over it. + if (bindingIt == scope->privateTypeBindings.end() || defnIt == nullptr) + { + return; + } + + ScopePtr resolvingScope = *defnIt; + TypeId ty = resolveType(resolvingScope, alias->type, /* topLevel */ true); + + if (alias->exported) + { + Name typeName(alias->name.value); + scope->exportedTypeBindings[typeName] = TypeFun{ty}; + } + + LUAU_ASSERT(get(bindingIt->second.type)); + + // Rather than using a subtype constraint, we instead directly bind + // the free type we generated in the first pass to the resolved type. + // This prevents a case where you could cause another constraint to + // bind the free alias type to an unrelated type, causing havoc. + asMutable(bindingIt->second.type)->ty.emplace(ty); + + addConstraint(scope, alias->location, NameConstraint{ty, alias->name.value}); +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) +{ + LUAU_ASSERT(global->type); + + TypeId globalTy = resolveType(scope, global->type); + Name globalName(global->name.value); + + module->declaredGlobals[globalName] = globalTy; + scope->bindings[global->name] = Binding{globalTy, global->location}; +} + +static bool isMetamethod(const Name& name) +{ + return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || + name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) +{ + std::optional superTy = std::nullopt; + if (declaredClass->superName) + { + Name superName = Name(declaredClass->superName->value); + std::optional lookupType = scope->lookupType(superName); + + if (!lookupType) + { + reportError(declaredClass->location, UnknownSymbol{superName, UnknownSymbol::Type}); + return; + } + + // We don't have generic classes, so this assertion _should_ never be hit. + LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); + superTy = lookupType->type; + + if (!get(follow(*superTy))) + { + reportError(declaredClass->location, + GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)}); + + return; + } + } + + Name className(declaredClass->name.value); + + TypeId classTy = arena->addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, moduleName)); + ClassTypeVar* ctv = getMutable(classTy); + + TypeId metaTy = arena->addType(TableTypeVar{TableState::Sealed, scope->level, scope.get()}); + TableTypeVar* metatable = getMutable(metaTy); + + ctv->metatable = metaTy; + + scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + + for (const AstDeclaredClassProp& prop : declaredClass->props) + { + Name propName(prop.name.value); + TypeId propTy = resolveType(scope, prop.ty); + + bool assignToMetatable = isMetamethod(propName); + + // Function types always take 'self', but this isn't reflected in the + // parsed annotation. Add it here. + if (prop.isMethod) + { + if (FunctionTypeVar* ftv = getMutable(propTy)) + { + ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); + ftv->argTypes = arena->addTypePack(TypePack{{classTy}, ftv->argTypes}); + + ftv->hasSelf = true; + } + } + + if (ctv->props.count(propName) == 0) + { + if (assignToMetatable) + metatable->props[propName] = {propTy}; + else + ctv->props[propName] = {propTy}; + } + else + { + TypeId currentTy = assignToMetatable ? metatable->props[propName].type : ctv->props[propName].type; + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionTypeVar* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = arena->addType(IntersectionTypeVar{std::move(options)}); + + if (assignToMetatable) + metatable->props[propName] = {newItv}; + else + ctv->props[propName] = {newItv}; + } + else if (get(currentTy)) + { + TypeId intersection = arena->addType(IntersectionTypeVar{{currentTy, propTy}}); + + if (assignToMetatable) + metatable->props[propName] = {intersection}; + else + ctv->props[propName] = {intersection}; + } + else + { + reportError(declaredClass->location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } + } + } +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) +{ + std::vector> generics = createGenerics(scope, global->generics); + std::vector> genericPacks = createGenericPacks(scope, global->genericPacks); + + std::vector genericTys; + genericTys.reserve(generics.size()); + for (auto& [name, generic] : generics) + { + genericTys.push_back(generic.ty); + scope->privateTypeBindings[name] = TypeFun{generic.ty}; + } + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + for (auto& [name, generic] : genericPacks) + { + genericTps.push_back(generic.tp); + scope->privateTypePackBindings[name] = generic.tp; + } + + ScopePtr funScope = scope; + if (!generics.empty() || !genericPacks.empty()) + funScope = childScope(global, scope); + + TypePackId paramPack = resolveTypePack(funScope, global->params); + TypePackId retPack = resolveTypePack(funScope, global->retTypes); + TypeId fnType = arena->addType(FunctionTypeVar{TypeLevel{}, funScope.get(), std::move(genericTys), std::move(genericTps), paramPack, retPack}); + FunctionTypeVar* ftv = getMutable(fnType); + + ftv->argNames.reserve(global->paramNames.size); + for (const auto& el : global->paramNames) + ftv->argNames.push_back(FunctionArgument{el.first.value, el.second}); + + Name fnName(global->name.value); + + module->declaredGlobals[fnName] = fnType; + scope->bindings[global->name] = Binding{fnType, global->location}; +} + +void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) +{ + for (AstStat* stat : error->statements) + visit(scope, stat); + for (AstExpr* expr : error->expressions) + check(scope, expr); +} + +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray exprs, const std::vector& expectedTypes) +{ + std::vector head; + std::optional tail; + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* expr = exprs.data[i]; + if (i < exprs.size - 1) + { + std::optional expectedType; + if (i < expectedTypes.size()) + expectedType = expectedTypes[i]; + head.push_back(check(scope, expr).ty); + } + else + { + std::vector expectedTailTypes; + if (i < expectedTypes.size()) + expectedTailTypes.assign(begin(expectedTypes) + i, end(expectedTypes)); + tail = checkPack(scope, expr, expectedTailTypes).tp; + } + } + + if (head.empty() && tail) + return InferencePack{*tail}; + else + return InferencePack{arena->addTypePack(TypePack{std::move(head), tail})}; +} + +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector& expectedTypes) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return InferencePack{singletonTypes->errorRecoveryTypePack()}; + } + + InferencePack result; + + if (AstExprCall* call = expr->as()) + result = checkPack(scope, call, expectedTypes); + else if (AstExprVarargs* varargs = expr->as()) + { + if (scope->varargPack) + result = InferencePack{*scope->varargPack}; + else + result = InferencePack{singletonTypes->errorRecoveryTypePack()}; + } + else + { + std::optional expectedType; + if (!expectedTypes.empty()) + expectedType = expectedTypes[0]; + TypeId t = check(scope, expr, expectedType).ty; + result = InferencePack{arena->addTypePack({t})}; + } + + LUAU_ASSERT(result.tp); + astTypePacks[expr] = result.tp; + return result; +} + +InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes) +{ + std::vector exprArgs; + if (call->self) + { + AstExprIndexName* indexExpr = call->func->as(); + if (!indexExpr) + ice->ice("method call expression has no 'self'"); + + exprArgs.push_back(indexExpr->expr); + } + exprArgs.insert(exprArgs.end(), call->args.begin(), call->args.end()); + + Checkpoint startCheckpoint = checkpoint(this); + TypeId fnType = check(scope, call->func).ty; + Checkpoint fnEndCheckpoint = checkpoint(this); + + TypePackId expectedArgPack = arena->freshTypePack(scope.get()); + TypePackId expectedRetPack = arena->freshTypePack(scope.get()); + TypeId expectedFunctionType = arena->addType(FunctionTypeVar{expectedArgPack, expectedRetPack}); + + TypeId instantiatedFnType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, call->location, InstantiationConstraint{instantiatedFnType, fnType}); + + NotNull extractArgsConstraint = addConstraint(scope, call->location, SubtypeConstraint{instantiatedFnType, expectedFunctionType}); + + // Fully solve fnType, then extract its argument list as expectedArgPack. + forEachConstraint(startCheckpoint, fnEndCheckpoint, this, [extractArgsConstraint](const ConstraintPtr& constraint) { + extractArgsConstraint->dependencies.emplace_back(constraint.get()); + }); + + const AstExpr* lastArg = exprArgs.size() ? exprArgs[exprArgs.size() - 1] : nullptr; + const bool needTail = lastArg && (lastArg->is() || lastArg->is()); + + TypePack expectedArgs; + + if (!needTail) + expectedArgs = extendTypePack(*arena, singletonTypes, expectedArgPack, exprArgs.size()); + else + expectedArgs = extendTypePack(*arena, singletonTypes, expectedArgPack, exprArgs.size() - 1); + + std::vector args; + std::optional argTail; + std::vector argumentConnectives; + + Checkpoint argCheckpoint = checkpoint(this); + + for (size_t i = 0; i < exprArgs.size(); ++i) + { + AstExpr* arg = exprArgs[i]; + std::optional expectedType; + if (i < expectedArgs.head.size()) + expectedType = expectedArgs.head[i]; + + if (i == 0 && call->self) + { + // The self type has already been computed as a side effect of + // computing fnType. If computing that did not cause us to exceed a + // recursion limit, we can fetch it from astTypes rather than + // recomputing it. + TypeId* selfTy = astTypes.find(exprArgs[0]); + if (selfTy) + args.push_back(*selfTy); + else + args.push_back(arena->freshType(scope.get())); + } + else if (i < exprArgs.size() - 1 || !(arg->is() || arg->is())) + { + auto [ty, connective] = check(scope, arg, expectedType); + args.push_back(ty); + argumentConnectives.push_back(connective); + } + else + argTail = checkPack(scope, arg, {}).tp; // FIXME? not sure about expectedTypes here + } + + Checkpoint argEndCheckpoint = checkpoint(this); + + // Do not solve argument constraints until after we have extracted the + // expected types from the callable. + forEachConstraint(argCheckpoint, argEndCheckpoint, this, [extractArgsConstraint](const ConstraintPtr& constraint) { + constraint->dependencies.push_back(extractArgsConstraint); + }); + + std::vector returnConnectives; + if (auto ftv = get(follow(fnType)); ftv && ftv->dcrMagicRefinement) + { + MagicRefinementContext ctx{scope, NotNull{this}, dfg, NotNull{&connectiveArena}, std::move(argumentConnectives), call}; + returnConnectives = ftv->dcrMagicRefinement(ctx); + } + + if (matchSetmetatable(*call)) + { + TypePack argTailPack; + if (argTail && args.size() < 2) + argTailPack = extendTypePack(*arena, singletonTypes, *argTail, 2 - args.size()); + + LUAU_ASSERT(args.size() + argTailPack.head.size() == 2); + + TypeId target = args.size() > 0 ? args[0] : argTailPack.head[0]; + TypeId mt = args.size() > 1 ? args[1] : argTailPack.head[args.size() == 0 ? 1 : 0]; + + AstExpr* targetExpr = call->args.data[0]; + + MetatableTypeVar mtv{target, mt}; + TypeId resultTy = arena->addType(mtv); + + if (AstExprLocal* targetLocal = targetExpr->as()) + scope->bindings[targetLocal->local].typeId = resultTy; + + return InferencePack{arena->addTypePack({resultTy}), std::move(returnConnectives)}; + } + else + { + astOriginalCallTypes[call->func] = fnType; + + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); + // TODO: How do expectedTypes play into this? Do they? + TypePackId rets = arena->addTypePack(BlockedTypePack{}); + TypePackId argPack = arena->addTypePack(TypePack{args, argTail}); + FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); + TypeId inferredFnType = arena->addType(ftv); + + unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); + NotNull ic(unqueuedConstraints.back().get()); + + unqueuedConstraints.push_back( + std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{instantiatedType, inferredFnType})); + NotNull sc(unqueuedConstraints.back().get()); + + NotNull fcc = addConstraint(scope, call->func->location, + FunctionCallConstraint{ + {ic, sc}, + fnType, + argPack, + rets, + call, + }); + + // We force constraints produced by checking function arguments to wait + // until after we have resolved the constraint on the function itself. + // This ensures, for instance, that we start inferring the contents of + // lambdas under the assumption that their arguments and return types + // will be compatible with the enclosing function call. + forEachConstraint(fnEndCheckpoint, argEndCheckpoint, this, [fcc](const ConstraintPtr& constraint) { + fcc->dependencies.emplace_back(constraint.get()); + }); + + return InferencePack{rets, std::move(returnConnectives)}; + } +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return Inference{singletonTypes->errorRecoveryType()}; + } + + Inference result; + + if (auto group = expr->as()) + result = check(scope, group->expr, expectedType, forceSingleton); + else if (auto stringExpr = expr->as()) + result = check(scope, stringExpr, expectedType, forceSingleton); + else if (expr->is()) + result = Inference{singletonTypes->numberType}; + else if (auto boolExpr = expr->as()) + result = check(scope, boolExpr, expectedType, forceSingleton); + else if (expr->is()) + result = Inference{singletonTypes->nilType}; + else if (auto local = expr->as()) + result = check(scope, local); + else if (auto global = expr->as()) + result = check(scope, global); + else if (expr->is()) + result = flattenPack(scope, expr->location, checkPack(scope, expr)); + else if (auto call = expr->as()) + { + std::vector expectedTypes; + if (expectedType) + expectedTypes.push_back(*expectedType); + result = flattenPack(scope, expr->location, checkPack(scope, call, expectedTypes)); // TODO: needs predicates too + } + else if (auto a = expr->as()) + { + Checkpoint startCheckpoint = checkpoint(this); + FunctionSignature sig = checkFunctionSignature(scope, a, expectedType); + checkFunctionBody(sig.bodyScope, a); + Checkpoint endCheckpoint = checkpoint(this); + + TypeId generalizedTy = arena->addType(BlockedTypeVar{}); + NotNull gc = addConstraint(scope, expr->location, GeneralizationConstraint{generalizedTy, sig.signature}); + + forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) { + gc->dependencies.emplace_back(constraint.get()); + }); + + return Inference{generalizedTy}; + } + else if (auto indexName = expr->as()) + result = check(scope, indexName); + else if (auto indexExpr = expr->as()) + result = check(scope, indexExpr); + else if (auto table = expr->as()) + result = check(scope, table, expectedType); + else if (auto unary = expr->as()) + result = check(scope, unary); + else if (auto binary = expr->as()) + result = check(scope, binary, expectedType); + else if (auto ifElse = expr->as()) + result = check(scope, ifElse, expectedType); + else if (auto typeAssert = expr->as()) + result = check(scope, typeAssert); + else if (auto err = expr->as()) + { + // Open question: Should we traverse into this? + for (AstExpr* subExpr : err->expressions) + check(scope, subExpr); + + result = Inference{singletonTypes->errorRecoveryType()}; + } + else + { + LUAU_ASSERT(0); + result = Inference{freshType(scope)}; + } + + LUAU_ASSERT(result.ty); + astTypes[expr] = result.ty; + return result; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton) +{ + if (forceSingleton) + return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})}; + + if (expectedType) + { + const TypeId expectedTy = follow(*expectedType); + if (get(expectedTy) || get(expectedTy)) + { + TypeId ty = arena->addType(BlockedTypeVar{}); + TypeId singletonType = arena->addType(SingletonTypeVar(StringSingleton{std::string(string->value.data, string->value.size)})); + addConstraint(scope, string->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->stringType}); + return Inference{ty}; + } + else if (maybeSingleton(expectedTy)) + return Inference{arena->addType(SingletonTypeVar{StringSingleton{std::string{string->value.data, string->value.size}}})}; + + return Inference{singletonTypes->stringType}; + } + + return Inference{singletonTypes->stringType}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBool* boolExpr, std::optional expectedType, bool forceSingleton) +{ + const TypeId singletonType = boolExpr->value ? singletonTypes->trueType : singletonTypes->falseType; + if (forceSingleton) + return Inference{singletonType}; + + if (expectedType) + { + const TypeId expectedTy = follow(*expectedType); + + if (get(expectedTy) || get(expectedTy)) + { + TypeId ty = arena->addType(BlockedTypeVar{}); + addConstraint(scope, boolExpr->location, PrimitiveTypeConstraint{ty, expectedTy, singletonType, singletonTypes->booleanType}); + return Inference{ty}; + } + else if (maybeSingleton(expectedTy)) + return Inference{singletonType}; + + return Inference{singletonTypes->booleanType}; + } + + return Inference{singletonTypes->booleanType}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) +{ + std::optional resultTy; + auto def = dfg->getDef(local); + if (def) + resultTy = scope->lookup(*def); + + if (!resultTy) + { + if (auto ty = scope->lookup(local->local)) + resultTy = *ty; + } + + if (!resultTy) + return Inference{singletonTypes->errorRecoveryType()}; // TODO: replace with ice, locals should never exist before its definition. + + if (def) + return Inference{*resultTy, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + else + return Inference{*resultTy}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* global) +{ + if (std::optional ty = scope->lookup(global->name)) + return Inference{*ty}; + + /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any + * global that is not already in-scope is definitely an unknown symbol. + */ + reportError(global->location, UnknownSymbol{global->name.value}); + return Inference{singletonTypes->errorRecoveryType()}; +} + +static std::optional lookupProp(TypeId ty, const std::string& propName, NotNull arena) +{ + ty = follow(ty); + + if (auto ctv = get(ty)) + { + if (auto prop = lookupClassProp(ctv, propName)) + return prop->type; + } + else if (auto ttv = get(ty)) + { + if (auto it = ttv->props.find(propName); it != ttv->props.end()) + return it->second.type; + } + else if (auto utv = get(ty)) + { + std::vector types; + + for (TypeId ty : utv) + { + if (auto prop = lookupProp(ty, propName, arena)) + { + if (std::find(begin(types), end(types), *prop) == end(types)) + types.push_back(*prop); + } + else + return std::nullopt; + } + + if (types.size() == 1) + return types[0]; + else + return arena->addType(IntersectionTypeVar{std::move(types)}); + } + else if (auto utv = get(ty)) + { + std::vector types; + + for (TypeId ty : utv) + { + if (auto prop = lookupProp(ty, propName, arena)) + { + if (std::find(begin(types), end(types), *prop) == end(types)) + types.push_back(*prop); + } + else + return std::nullopt; + } + + if (types.size() == 1) + return types[0]; + else + return arena->addType(UnionTypeVar{std::move(types)}); + } + + return std::nullopt; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) +{ + TypeId obj = check(scope, indexName->expr).ty; + + // HACK: We need to return the actual type for type refinements so that it can invoke the dcrMagicRefinement function. + TypeId result; + if (auto prop = lookupProp(obj, indexName->index.value, arena)) + result = *prop; + else + result = freshType(scope); + + std::optional def = dfg->getDef(indexName); + if (def) + { + if (auto ty = scope->lookup(*def)) + return Inference{*ty, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + else + scope->dcrRefinements[*def] = result; + } + + TableTypeVar::Props props{{indexName->index.value, Property{result}}}; + const std::optional indexer; + TableTypeVar ttv{std::move(props), indexer, TypeLevel{}, scope.get(), TableState::Free}; + + TypeId expectedTableType = arena->addType(std::move(ttv)); + + addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType}); + + if (def) + return Inference{result, connectiveArena.proposition(*def, singletonTypes->truthyType)}; + else + return Inference{result}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr) +{ + TypeId obj = check(scope, indexExpr->expr).ty; + TypeId indexType = check(scope, indexExpr->index).ty; + + TypeId result = freshType(scope); + + TableIndexer indexer{indexType, result}; + TypeId tableType = + arena->addType(TableTypeVar{TableTypeVar::Props{}, TableIndexer{indexType, result}, TypeLevel{}, scope.get(), TableState::Free}); + + addConstraint(scope, indexExpr->expr->location, SubtypeConstraint{obj, tableType}); + + return Inference{result}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) +{ + auto [operandType, connective] = check(scope, unary->expr); + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, unary->location, UnaryConstraint{unary->op, operandType, resultType}); + + if (unary->op == AstExprUnary::Not) + return Inference{resultType, connectiveArena.negation(connective)}; + else + return Inference{resultType}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) +{ + auto [leftType, rightType, connective] = checkBinary(scope, binary, expectedType); + + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType, binary, &astOriginalCallTypes, &astOverloadResolvedTypes}); + return Inference{resultType, std::move(connective)}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType) +{ + ScopePtr condScope = childScope(ifElse->condition, scope); + auto [_, connective] = check(scope, ifElse->condition); + + ScopePtr thenScope = childScope(ifElse->trueExpr, scope); + applyRefinements(thenScope, ifElse->trueExpr->location, connective); + TypeId thenType = check(thenScope, ifElse->trueExpr, expectedType).ty; + + ScopePtr elseScope = childScope(ifElse->falseExpr, scope); + applyRefinements(elseScope, ifElse->falseExpr->location, connectiveArena.negation(connective)); + TypeId elseType = check(elseScope, ifElse->falseExpr, expectedType).ty; + + if (ifElse->hasElse) + { + TypeId resultType = expectedType ? *expectedType : freshType(scope); + addConstraint(scope, ifElse->trueExpr->location, SubtypeConstraint{thenType, resultType}); + addConstraint(scope, ifElse->falseExpr->location, SubtypeConstraint{elseType, resultType}); + return Inference{resultType}; + } + + return Inference{thenType}; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) +{ + check(scope, typeAssert->expr, std::nullopt); + return Inference{resolveType(scope, typeAssert->annotation)}; +} + +std::tuple ConstraintGraphBuilder::checkBinary( + const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) +{ + if (binary->op == AstExprBinary::And) + { + auto [leftType, leftConnective] = check(scope, binary->left, expectedType); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, leftConnective); + auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, connectiveArena.conjunction(leftConnective, rightConnective)}; + } + else if (binary->op == AstExprBinary::Or) + { + auto [leftType, leftConnective] = check(scope, binary->left, expectedType); + + ScopePtr rightScope = childScope(binary->right, scope); + applyRefinements(rightScope, binary->right->location, connectiveArena.negation(leftConnective)); + auto [rightType, rightConnective] = check(rightScope, binary->right, expectedType); + + return {leftType, rightType, connectiveArena.disjunction(leftConnective, rightConnective)}; + } + else if (auto typeguard = matchTypeGuard(binary)) + { + TypeId leftType = check(scope, binary->left).ty; + TypeId rightType = check(scope, binary->right).ty; + + std::optional def = dfg->getDef(typeguard->target); + if (!def) + return {leftType, rightType, nullptr}; + + TypeId discriminantTy = singletonTypes->neverType; + if (typeguard->type == "nil") + discriminantTy = singletonTypes->nilType; + else if (typeguard->type == "string") + discriminantTy = singletonTypes->stringType; + else if (typeguard->type == "number") + discriminantTy = singletonTypes->numberType; + else if (typeguard->type == "boolean") + discriminantTy = singletonTypes->threadType; + else if (typeguard->type == "table") + discriminantTy = singletonTypes->neverType; // TODO: replace with top table type + else if (typeguard->type == "function") + discriminantTy = singletonTypes->functionType; + else if (typeguard->type == "userdata") + { + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof + discriminantTy = singletonTypes->neverType; // TODO: replace with top class type + } + else if (!typeguard->isTypeof && typeguard->type == "vector") + discriminantTy = singletonTypes->neverType; // TODO: figure out a way to deal with this quirky type + else if (!typeguard->isTypeof) + discriminantTy = singletonTypes->neverType; + else if (auto typeFun = globalScope->lookupType(typeguard->type); typeFun && typeFun->typeParams.empty() && typeFun->typePackParams.empty()) + { + TypeId ty = follow(typeFun->type); + + // We're only interested in the root class of any classes. + if (auto ctv = get(ty); !ctv || !ctv->parent) + discriminantTy = ty; + } + + ConnectiveId proposition = connectiveArena.proposition(*def, discriminantTy); + if (binary->op == AstExprBinary::CompareEq) + return {leftType, rightType, proposition}; + else if (binary->op == AstExprBinary::CompareNe) + return {leftType, rightType, connectiveArena.negation(proposition)}; + else + ice->ice("matchTypeGuard should only return a Some under `==` or `~=`!"); + } + else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) + { + TypeId leftType = check(scope, binary->left, expectedType, true).ty; + TypeId rightType = check(scope, binary->right, expectedType, true).ty; + + ConnectiveId leftConnective = nullptr; + if (auto def = dfg->getDef(binary->left)) + leftConnective = connectiveArena.proposition(*def, rightType); + + ConnectiveId rightConnective = nullptr; + if (auto def = dfg->getDef(binary->right)) + rightConnective = connectiveArena.proposition(*def, leftType); + + if (binary->op == AstExprBinary::CompareNe) + { + leftConnective = connectiveArena.negation(leftConnective); + rightConnective = connectiveArena.negation(rightConnective); + } + + return {leftType, rightType, connectiveArena.equivalence(leftConnective, rightConnective)}; + } + else + { + TypeId leftType = check(scope, binary->left, expectedType).ty; + TypeId rightType = check(scope, binary->right, expectedType).ty; + return {leftType, rightType, nullptr}; + } +} + +TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) +{ + std::vector types; + types.reserve(exprs.size); + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* const expr = exprs.data[i]; + types.push_back(checkLValue(scope, expr)); + } + + return arena->addTypePack(std::move(types)); +} + +/** + * If the expr is a dotted set of names, and if the root symbol refers to an + * unsealed table, return that table type, plus the indeces that follow as a + * vector. + */ +static std::optional>> extractDottedName(AstExpr* expr) +{ + std::vector names; + + while (expr) + { + if (auto global = expr->as()) + { + std::reverse(begin(names), end(names)); + return std::pair{global->name, std::move(names)}; + } + else if (auto local = expr->as()) + { + std::reverse(begin(names), end(names)); + return std::pair{local->local, std::move(names)}; + } + else if (auto indexName = expr->as()) + { + names.push_back(indexName->index.value); + expr = indexName->expr; + } + else + return std::nullopt; + } + + return std::nullopt; +} + +/** + * This function is mostly about identifying properties that are being inserted into unsealed tables. + * + * If expr has the form name.a.b.c + */ +TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) +{ + if (auto indexExpr = expr->as()) + { + if (auto constantString = indexExpr->index->as()) + { + AstName syntheticIndex{constantString->value.data}; + AstExprIndexName synthetic{ + indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'}; + return checkLValue(scope, &synthetic); + } + } + else if (!expr->is()) + return check(scope, expr).ty; + + auto dottedPath = extractDottedName(expr); + if (!dottedPath) + return check(scope, expr).ty; + const auto [sym, segments] = std::move(*dottedPath); + + LUAU_ASSERT(!segments.empty()); + + auto lookupResult = scope->lookupEx(sym); + if (!lookupResult) + return check(scope, expr).ty; + const auto [subjectType, symbolScope] = std::move(*lookupResult); + + TypeId propTy = freshType(scope); + + std::vector segmentStrings(begin(segments), end(segments)); + + TypeId updatedType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), propTy}); + + std::optional def = dfg->getDef(sym); + LUAU_ASSERT(def); + symbolScope->bindings[sym].typeId = updatedType; + symbolScope->dcrRefinements[*def] = updatedType; + + astTypes[expr] = propTy; + + return propTy; +} + +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType) +{ + TypeId ty = arena->addType(TableTypeVar{}); + TableTypeVar* ttv = getMutable(ty); + LUAU_ASSERT(ttv); + + ttv->state = TableState::Unsealed; + ttv->scope = scope.get(); + + auto createIndexer = [this, scope, ttv](const Location& location, TypeId currentIndexType, TypeId currentResultType) { + if (!ttv->indexer) + { + TypeId indexType = this->freshType(scope); + TypeId resultType = this->freshType(scope); + ttv->indexer = TableIndexer{indexType, resultType}; + } + + addConstraint(scope, location, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}); + addConstraint(scope, location, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}); + }; + + for (const AstExprTable::Item& item : expr->items) + { + std::optional expectedValueType; + + if (item.key && expectedType) + { + if (auto stringKey = item.key->as()) + { + ErrorVec errorVec; + std::optional propTy = + findTablePropertyRespectingMeta(singletonTypes, errorVec, follow(*expectedType), stringKey->value.data, item.value->location); + if (propTy) + expectedValueType = propTy; + else + { + expectedValueType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, item.value->location, HasPropConstraint{*expectedValueType, *expectedType, stringKey->value.data}); + } + } + } + + TypeId itemTy = check(scope, item.value, expectedValueType).ty; + + if (item.key) + { + // Even though we don't need to use the type of the item's key if + // it's a string constant, we still want to check it to populate + // astTypes. + TypeId keyTy = check(scope, item.key).ty; + + if (AstExprConstantString* key = item.key->as()) + { + ttv->props[key->value.begin()] = {itemTy}; + } + else + { + createIndexer(item.key->location, keyTy, itemTy); + } + } + else + { + TypeId numberType = singletonTypes->numberType; + // FIXME? The location isn't quite right here. Not sure what is + // right. + createIndexer(item.value->location, numberType, itemTy); + } + } + + return Inference{ty}; +} + +ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature( + const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType) +{ + ScopePtr signatureScope = nullptr; + ScopePtr bodyScope = nullptr; + TypePackId returnType = nullptr; + + std::vector genericTypes; + std::vector genericTypePacks; + + if (expectedType) + expectedType = follow(*expectedType); + + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + + signatureScope = childScope(fn, parent); + + // We need to assign returnType before creating bodyScope so that the + // return type gets propogated to bodyScope. + returnType = freshTypePack(signatureScope); + signatureScope->returnType = returnType; + + bodyScope = childScope(fn->body, signatureScope); + + if (hasGenerics) + { + std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); + + // We do not support default values on function generics, so we only + // care about the types involved. + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + signatureScope->privateTypeBindings[name] = TypeFun{g.ty}; + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + signatureScope->privateTypePackBindings[name] = g.tp; + } + + // Local variable works around an odd gcc 11.3 warning: may be used uninitialized + std::optional none = std::nullopt; + expectedType = none; + } + + std::vector argTypes; + TypePack expectedArgPack; + + const FunctionTypeVar* expectedFunction = expectedType ? get(*expectedType) : nullptr; + + if (expectedFunction) + { + expectedArgPack = extendTypePack(*arena, singletonTypes, expectedFunction->argTypes, fn->args.size); + + genericTypes = expectedFunction->generics; + genericTypePacks = expectedFunction->genericPacks; + } + + for (size_t i = 0; i < fn->args.size; ++i) + { + AstLocal* local = fn->args.data[i]; + + TypeId t = freshType(signatureScope); + argTypes.push_back(t); + signatureScope->bindings[local] = Binding{t, local->location}; + + TypeId annotationTy = t; + + if (local->annotation) + { + annotationTy = resolveType(signatureScope, local->annotation, /* topLevel */ true); + addConstraint(signatureScope, local->annotation->location, SubtypeConstraint{t, annotationTy}); + } + else if (i < expectedArgPack.head.size()) + { + addConstraint(signatureScope, local->location, SubtypeConstraint{t, expectedArgPack.head[i]}); + } + + // HACK: This is the one case where the type of the definition will diverge from the type of the binding. + // We need to do this because there are cases where type refinements needs to have the information available + // at constraint generation time. + if (auto def = dfg->getDef(local)) + signatureScope->dcrRefinements[*def] = annotationTy; + } + + TypePackId varargPack = nullptr; + + if (fn->vararg) + { + if (fn->varargAnnotation) + { + TypePackId annotationType = resolveTypePack(signatureScope, fn->varargAnnotation); + varargPack = annotationType; + } + else if (expectedArgPack.tail && get(*expectedArgPack.tail)) + varargPack = *expectedArgPack.tail; + else + varargPack = singletonTypes->anyTypePack; + + signatureScope->varargPack = varargPack; + bodyScope->varargPack = varargPack; + } + else + { + varargPack = arena->addTypePack(VariadicTypePack{singletonTypes->anyType, /*hidden*/ true}); + // We do not add to signatureScope->varargPack because ... is not valid + // in functions without an explicit ellipsis. + + signatureScope->varargPack = std::nullopt; + bodyScope->varargPack = std::nullopt; + } + + LUAU_ASSERT(nullptr != varargPack); + + // If there is both an annotation and an expected type, the annotation wins. + // Type checking will sort out any discrepancies later. + if (fn->returnAnnotation) + { + TypePackId annotatedRetType = resolveTypePack(signatureScope, *fn->returnAnnotation); + + // We bind the annotated type directly here so that, when we need to + // generate constraints for return types, we have a guarantee that we + // know the annotated return type already, if one was provided. + LUAU_ASSERT(get(returnType)); + asMutable(returnType)->ty.emplace(annotatedRetType); + } + else if (expectedFunction) + { + asMutable(returnType)->ty.emplace(expectedFunction->retTypes); + } + + // TODO: Preserve argument names in the function's type. + + FunctionTypeVar actualFunction{TypeLevel{}, parent.get(), arena->addTypePack(argTypes, varargPack), returnType}; + actualFunction.hasNoGenerics = !hasGenerics; + actualFunction.generics = std::move(genericTypes); + actualFunction.genericPacks = std::move(genericTypePacks); + + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); + LUAU_ASSERT(actualFunctionType); + astTypes[fn] = actualFunctionType; + + if (expectedType && get(*expectedType)) + { + asMutable(*expectedType)->ty.emplace(actualFunctionType); + } + + return { + /* signature */ actualFunctionType, + /* signatureScope */ signatureScope, + /* bodyScope */ bodyScope, + }; +} + +void ConstraintGraphBuilder::checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn) +{ + visitBlockWithoutChildScope(scope, fn->body); + + // If it is possible for execution to reach the end of the function, the return type must be compatible with () + + if (nullptr != getFallthrough(fn->body)) + { + TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever + addConstraint(scope, fn->location, PackSubtypeConstraint{scope->returnType, empty}); + } +} + +TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, bool topLevel) +{ + TypeId result = nullptr; + + if (auto ref = ty->as()) + { + if (FFlag::DebugLuauMagicTypes) + { + if (ref->name == "_luau_ice") + ice->ice("_luau_ice encountered", ty->location); + else if (ref->name == "_luau_print") + { + if (ref->parameters.size != 1 || !ref->parameters.data[0].type) + { + reportError(ty->location, GenericError{"_luau_print requires one generic parameter"}); + return singletonTypes->errorRecoveryType(); + } + else + return resolveType(scope, ref->parameters.data[0].type, topLevel); + } + } + + std::optional alias; + + if (ref->prefix.has_value()) + { + alias = scope->lookupImportedType(ref->prefix->value, ref->name.value); + } + else + { + alias = scope->lookupType(ref->name.value); + } + + if (alias.has_value()) + { + // If the alias is not generic, we don't need to set up a blocked + // type and an instantiation constraint. + if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty()) + { + result = alias->type; + } + else + { + std::vector parameters; + std::vector packParameters; + + for (const AstTypeOrPack& p : ref->parameters) + { + // We do not enforce the ordering of types vs. type packs here; + // that is done in the parser. + if (p.type) + { + parameters.push_back(resolveType(scope, p.type)); + } + else if (p.typePack) + { + packParameters.push_back(resolveTypePack(scope, p.typePack)); + } + else + { + // This indicates a parser bug: one of these two pointers + // should be set. + LUAU_ASSERT(false); + } + } + + result = arena->addType(PendingExpansionTypeVar{ref->prefix, ref->name, parameters, packParameters}); + + if (topLevel) + { + addConstraint(scope, ty->location, TypeAliasExpansionConstraint{/* target */ result}); + } + } + } + else + { + std::string typeName; + if (ref->prefix) + typeName = std::string(ref->prefix->value) + "."; + typeName += ref->name.value; + + result = singletonTypes->errorRecoveryType(); + } + } + else if (auto tab = ty->as()) + { + TableTypeVar::Props props; + std::optional indexer; + + for (const AstTableProp& prop : tab->props) + { + std::string name = prop.name.value; + // TODO: Recursion limit. + TypeId propTy = resolveType(scope, prop.type); + // TODO: Fill in location. + props[name] = {propTy}; + } + + if (tab->indexer) + { + // TODO: Recursion limit. + indexer = TableIndexer{ + resolveType(scope, tab->indexer->indexType), + resolveType(scope, tab->indexer->resultType), + }; + } + + result = arena->addType(TableTypeVar{props, indexer, scope->level, scope.get(), TableState::Sealed}); + } + else if (auto fn = ty->as()) + { + // TODO: Recursion limit. + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + ScopePtr signatureScope = nullptr; + + std::vector genericTypes; + std::vector genericTypePacks; + + // If we don't have generics, we do not need to generate a child scope + // for the generic bindings to live on. + if (hasGenerics) + { + signatureScope = childScope(fn, scope); + + std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); + + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + signatureScope->privateTypeBindings[name] = TypeFun{g.ty}; + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + signatureScope->privateTypePackBindings[name] = g.tp; + } + } + else + { + // To eliminate the need to branch on hasGenerics below, we say that + // the signature scope is the parent scope if we don't have + // generics. + signatureScope = scope; + } + + TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes); + TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes); + + // TODO: FunctionTypeVar needs a pointer to the scope so that we know + // how to quantify/instantiate it. + FunctionTypeVar ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; + + // This replicates the behavior of the appropriate FunctionTypeVar + // constructors. + ftv.hasNoGenerics = !hasGenerics; + ftv.generics = std::move(genericTypes); + ftv.genericPacks = std::move(genericTypePacks); + + ftv.argNames.reserve(fn->argNames.size); + for (const auto& el : fn->argNames) + { + if (el) + { + const auto& [name, location] = *el; + ftv.argNames.push_back(FunctionArgument{name.value, location}); + } + else + { + ftv.argNames.push_back(std::nullopt); + } + } + + result = arena->addType(std::move(ftv)); + } + else if (auto tof = ty->as()) + { + // TODO: Recursion limit. + TypeId exprType = check(scope, tof->expr).ty; + result = exprType; + } + else if (auto unionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : unionAnnotation->types) + { + // TODO: Recursion limit. + parts.push_back(resolveType(scope, part, topLevel)); + } + + result = arena->addType(UnionTypeVar{parts}); + } + else if (auto intersectionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : intersectionAnnotation->types) + { + // TODO: Recursion limit. + parts.push_back(resolveType(scope, part, topLevel)); + } + + result = arena->addType(IntersectionTypeVar{parts}); + } + else if (auto boolAnnotation = ty->as()) + { + result = arena->addType(SingletonTypeVar(BooleanSingleton{boolAnnotation->value})); + } + else if (auto stringAnnotation = ty->as()) + { + result = arena->addType(SingletonTypeVar(StringSingleton{std::string(stringAnnotation->value.data, stringAnnotation->value.size)})); + } + else if (ty->is()) + { + result = singletonTypes->errorRecoveryType(); + } + else + { + LUAU_ASSERT(0); + result = singletonTypes->errorRecoveryType(); + } + + astResolvedTypes[ty] = result; + return result; +} + +TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTypePack* tp) +{ + TypePackId result; + if (auto expl = tp->as()) + { + result = resolveTypePack(scope, expl->typeList); + } + else if (auto var = tp->as()) + { + TypeId ty = resolveType(scope, var->variadicType); + result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); + } + else if (auto gen = tp->as()) + { + if (std::optional lookup = scope->lookupPack(gen->genericName.value)) + { + result = *lookup; + } + else + { + reportError(tp->location, UnknownSymbol{gen->genericName.value, UnknownSymbol::Context::Type}); + result = singletonTypes->errorRecoveryTypePack(); + } + } + else + { + LUAU_ASSERT(0); + result = singletonTypes->errorRecoveryTypePack(); + } + + astResolvedTypePacks[tp] = result; + return result; +} + +TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const AstTypeList& list) +{ + std::vector head; + + for (AstType* headTy : list.types) + { + head.push_back(resolveType(scope, headTy)); + } + + std::optional tail = std::nullopt; + if (list.tailType) + { + tail = resolveTypePack(scope, list.tailType); + } + + return arena->addTypePack(TypePack{head, tail}); +} + +std::vector> ConstraintGraphBuilder::createGenerics(const ScopePtr& scope, AstArray generics) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypeId genericTy = arena->addType(GenericTypeVar{scope.get(), generic.name.value}); + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveType(scope, generic.defaultValue); + + result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); + } + + return result; +} + +std::vector> ConstraintGraphBuilder::createGenericPacks( + const ScopePtr& scope, AstArray generics) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypePackId genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic.name.value}}); + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveTypePack(scope, generic.defaultValue); + + result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); + } + + return result; +} + +Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, InferencePack pack) +{ + const auto& [tp, connectives] = pack; + ConnectiveId connective = nullptr; + if (!connectives.empty()) + connective = connectives[0]; + + if (auto f = first(tp)) + return Inference{*f, connective}; + + TypeId typeResult = freshType(scope); + TypePack onePack{{typeResult}, freshTypePack(scope)}; + TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); + + addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); + + return Inference{typeResult, connective}; +} + +void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) +{ + errors.push_back(TypeError{location, moduleName, std::move(err)}); + + if (FFlag::DebugLuauLogSolverToJson) + logger->captureGenerationError(errors.back()); +} + +void ConstraintGraphBuilder::reportCodeTooComplex(Location location) +{ + errors.push_back(TypeError{location, moduleName, CodeTooComplex{}}); + + if (FFlag::DebugLuauLogSolverToJson) + logger->captureGenerationError(errors.back()); +} + +struct GlobalPrepopulator : AstVisitor +{ + const NotNull globalScope; + const NotNull arena; + + GlobalPrepopulator(NotNull globalScope, NotNull arena) + : globalScope(globalScope) + , arena(arena) + { + } + + bool visit(AstStatFunction* function) override + { + if (AstExprGlobal* g = function->name->as()) + globalScope->bindings[g->name] = Binding{arena->addType(BlockedTypeVar{})}; + + return true; + } +}; + +void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) +{ + GlobalPrepopulator gp{NotNull{globalScope.get()}, arena}; + + program->visit(&gp); +} + +std::vector> borrowConstraints(const std::vector& constraints) +{ + std::vector> result; + result.reserve(constraints.size()); + + for (const auto& c : constraints) + result.emplace_back(c.get()); + + return result; +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp new file mode 100644 index 00000000..d73c14a6 --- /dev/null +++ b/Analysis/src/ConstraintSolver.cpp @@ -0,0 +1,1970 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Anyification.h" +#include "Luau/ApplyTypeFunction.h" +#include "Luau/Clone.h" +#include "Luau/ConstraintSolver.h" +#include "Luau/DcrLogger.h" +#include "Luau/Instantiation.h" +#include "Luau/Location.h" +#include "Luau/Metamethods.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Quantify.h" +#include "Luau/ToString.h" +#include "Luau/TypeUtils.h" +#include "Luau/TypeVar.h" +#include "Luau/Unifier.h" +#include "Luau/VisitTypeVar.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); + +namespace Luau +{ + +[[maybe_unused]] static void dumpBindings(NotNull scope, ToStringOptions& opts) +{ + for (const auto& [k, v] : scope->bindings) + { + auto d = toString(v.typeId, opts); + printf("\t%s : %s\n", k.c_str(), d.c_str()); + } + + for (NotNull child : scope->children) + dumpBindings(child, opts); +} + +static std::pair, std::vector> saturateArguments(TypeArena* arena, NotNull singletonTypes, + const TypeFun& fn, const std::vector& rawTypeArguments, const std::vector& rawPackArguments) +{ + std::vector saturatedTypeArguments; + std::vector extraTypes; + std::vector saturatedPackArguments; + + for (size_t i = 0; i < rawTypeArguments.size(); ++i) + { + TypeId ty = rawTypeArguments[i]; + + if (i < fn.typeParams.size()) + saturatedTypeArguments.push_back(ty); + else + extraTypes.push_back(ty); + } + + // If we collected extra types, put them in a type pack now. This case is + // mutually exclusive with the type pack -> type conversion we do below: + // extraTypes will only have elements in it if we have more types than we + // have parameter slots for them to go into. + if (!extraTypes.empty()) + { + saturatedPackArguments.push_back(arena->addTypePack(extraTypes)); + } + + for (size_t i = 0; i < rawPackArguments.size(); ++i) + { + TypePackId tp = rawPackArguments[i]; + + // If we are short on regular type saturatedTypeArguments and we have a single + // element type pack, we can decompose that to the type it contains and + // use that as a type parameter. + if (saturatedTypeArguments.size() < fn.typeParams.size() && size(tp) == 1 && finite(tp) && first(tp) && saturatedPackArguments.empty()) + { + saturatedTypeArguments.push_back(*first(tp)); + } + else + { + saturatedPackArguments.push_back(tp); + } + } + + size_t typesProvided = saturatedTypeArguments.size(); + size_t typesRequired = fn.typeParams.size(); + + size_t packsProvided = saturatedPackArguments.size(); + size_t packsRequired = fn.typePackParams.size(); + + // Extra types should be accumulated in extraTypes, not saturatedTypeArguments. Extra + // packs will be accumulated in saturatedPackArguments, so we don't have an + // assertion for that. + LUAU_ASSERT(typesProvided <= typesRequired); + + // If we didn't provide enough types, but we did provide a type pack, we + // don't want to use defaults. The rationale for this is that if the user + // provides a pack but doesn't provide enough types, we want to report an + // error, rather than simply using the default saturatedTypeArguments, if they exist. If + // they did provide enough types, but not enough packs, we of course want to + // use the default packs. + bool needsDefaults = (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired); + + if (needsDefaults) + { + // Default types can reference earlier types. It's legal to write + // something like + // type T = (A, B) -> number + // and we need to respect that. We use an ApplyTypeFunction for this. + ApplyTypeFunction atf{arena}; + + for (size_t i = 0; i < typesProvided; ++i) + atf.typeArguments[fn.typeParams[i].ty] = saturatedTypeArguments[i]; + + for (size_t i = typesProvided; i < typesRequired; ++i) + { + TypeId defaultTy = fn.typeParams[i].defaultValue.value_or(nullptr); + + // We will fill this in with the error type later. + if (!defaultTy) + break; + + TypeId instantiatedDefault = atf.substitute(defaultTy).value_or(singletonTypes->errorRecoveryType()); + atf.typeArguments[fn.typeParams[i].ty] = instantiatedDefault; + saturatedTypeArguments.push_back(instantiatedDefault); + } + + for (size_t i = 0; i < packsProvided; ++i) + { + atf.typePackArguments[fn.typePackParams[i].tp] = saturatedPackArguments[i]; + } + + for (size_t i = packsProvided; i < packsRequired; ++i) + { + TypePackId defaultTp = fn.typePackParams[i].defaultValue.value_or(nullptr); + + // We will fill this in with the error type pack later. + if (!defaultTp) + break; + + TypePackId instantiatedDefault = atf.substitute(defaultTp).value_or(singletonTypes->errorRecoveryTypePack()); + atf.typePackArguments[fn.typePackParams[i].tp] = instantiatedDefault; + saturatedPackArguments.push_back(instantiatedDefault); + } + } + + // If we didn't create an extra type pack from overflowing parameter packs, + // and we're still missing a type pack, plug in an empty type pack as the + // value of the empty packs. + if (extraTypes.empty() && saturatedPackArguments.size() + 1 == fn.typePackParams.size()) + { + saturatedPackArguments.push_back(arena->addTypePack({})); + } + + // We need to have _something_ when we substitute the generic saturatedTypeArguments, + // even if they're missing, so we use the error type as a filler. + for (size_t i = saturatedTypeArguments.size(); i < typesRequired; ++i) + { + saturatedTypeArguments.push_back(singletonTypes->errorRecoveryType()); + } + + for (size_t i = saturatedPackArguments.size(); i < packsRequired; ++i) + { + saturatedPackArguments.push_back(singletonTypes->errorRecoveryTypePack()); + } + + // At this point, these two conditions should be true. If they aren't we + // will run into access violations. + LUAU_ASSERT(saturatedTypeArguments.size() == fn.typeParams.size()); + LUAU_ASSERT(saturatedPackArguments.size() == fn.typePackParams.size()); + + return {saturatedTypeArguments, saturatedPackArguments}; +} + +bool InstantiationSignature::operator==(const InstantiationSignature& rhs) const +{ + return fn == rhs.fn && arguments == rhs.arguments && packArguments == rhs.packArguments; +} + +size_t HashInstantiationSignature::operator()(const InstantiationSignature& signature) const +{ + size_t hash = std::hash{}(signature.fn.type); + for (const GenericTypeDefinition& p : signature.fn.typeParams) + { + hash ^= (std::hash{}(p.ty) << 1); + } + + for (const GenericTypePackDefinition& p : signature.fn.typePackParams) + { + hash ^= (std::hash{}(p.tp) << 1); + } + + for (const TypeId a : signature.arguments) + { + hash ^= (std::hash{}(a) << 1); + } + + for (const TypePackId a : signature.packArguments) + { + hash ^= (std::hash{}(a) << 1); + } + + return hash; +} + +void dump(ConstraintSolver* cs, ToStringOptions& opts) +{ + printf("constraints:\n"); + for (NotNull c : cs->unsolvedConstraints) + { + auto it = cs->blockedConstraints.find(c); + int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); + printf("\t%d\t%s\n", blockCount, toString(*c, opts).c_str()); + + for (NotNull dep : c->dependencies) + { + auto unsolvedIter = std::find(begin(cs->unsolvedConstraints), end(cs->unsolvedConstraints), dep); + if (unsolvedIter == cs->unsolvedConstraints.end()) + continue; + + auto it = cs->blockedConstraints.find(dep); + int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); + printf("\t%d\t\t%s\n", blockCount, toString(*dep, opts).c_str()); + } + + if (auto fcc = get(*c)) + { + for (NotNull inner : fcc->innerConstraints) + printf("\t ->\t\t%s\n", toString(*inner, opts).c_str()); + } + } +} + +ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) + : arena(normalizer->arena) + , singletonTypes(normalizer->singletonTypes) + , normalizer(normalizer) + , constraints(std::move(constraints)) + , rootScope(rootScope) + , currentModuleName(std::move(moduleName)) + , moduleResolver(moduleResolver) + , requireCycles(requireCycles) + , logger(logger) +{ + opts.exhaustive = true; + + for (NotNull c : this->constraints) + { + unsolvedConstraints.push_back(c); + + for (NotNull dep : c->dependencies) + { + block(dep, c); + } + } + + if (FFlag::DebugLuauLogSolverToJson) + LUAU_ASSERT(logger); +} + +void ConstraintSolver::randomize(unsigned seed) +{ + if (unsolvedConstraints.empty()) + return; + + unsigned int rng = seed; + + for (size_t i = unsolvedConstraints.size() - 1; i > 0; --i) + { + // Fisher-Yates shuffle + size_t j = rng % (i + 1); + + std::swap(unsolvedConstraints[i], unsolvedConstraints[j]); + + // LCG RNG, constants from Numerical Recipes + // This may occasionally result in skewed shuffles due to distribution properties, but this is a debugging tool so it should be good enough + rng = rng * 1664525 + 1013904223; + } +} + +void ConstraintSolver::run() +{ + if (isDone()) + return; + + if (FFlag::DebugLuauLogSolver) + { + printf("Starting solver\n"); + dump(this, opts); + printf("Bindings:\n"); + dumpBindings(rootScope, opts); + } + + if (FFlag::DebugLuauLogSolverToJson) + { + logger->captureInitialSolverState(rootScope, unsolvedConstraints); + } + + auto runSolverPass = [&](bool force) { + bool progress = false; + + size_t i = 0; + while (i < unsolvedConstraints.size()) + { + NotNull c = unsolvedConstraints[i]; + if (!force && isBlocked(c)) + { + ++i; + continue; + } + + std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{}; + StepSnapshot snapshot; + + if (FFlag::DebugLuauLogSolverToJson) + { + snapshot = logger->prepareStepSnapshot(rootScope, c, force, unsolvedConstraints); + } + + bool success = tryDispatch(c, force); + + progress |= success; + + if (success) + { + unblock(c); + unsolvedConstraints.erase(unsolvedConstraints.begin() + i); + + if (FFlag::DebugLuauLogSolverToJson) + { + logger->commitStepSnapshot(snapshot); + } + + if (FFlag::DebugLuauLogSolver) + { + if (force) + printf("Force "); + printf("Dispatched\n\t%s\n", saveMe.c_str()); + dump(this, opts); + } + } + else + ++i; + + if (force && success) + return true; + } + + return progress; + }; + + bool progress = false; + do + { + progress = runSolverPass(false); + if (!progress) + progress |= runSolverPass(true); + } while (progress); + + finalizeModule(); + + if (FFlag::DebugLuauLogSolver) + { + dumpBindings(rootScope, opts); + } + + if (FFlag::DebugLuauLogSolverToJson) + { + logger->captureFinalSolverState(rootScope, unsolvedConstraints); + } +} + +bool ConstraintSolver::isDone() +{ + return unsolvedConstraints.empty(); +} + +void ConstraintSolver::finalizeModule() +{ + Anyification a{arena, rootScope, singletonTypes, &iceReporter, singletonTypes->anyType, singletonTypes->anyTypePack}; + std::optional returnType = a.substitute(rootScope->returnType); + if (!returnType) + { + reportError(CodeTooComplex{}, Location{}); + rootScope->returnType = singletonTypes->errorTypePack; + } + else + rootScope->returnType = *returnType; +} + +bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) +{ + if (!force && isBlocked(constraint)) + return false; + + bool success = false; + + if (auto sc = get(*constraint)) + success = tryDispatch(*sc, constraint, force); + else if (auto psc = get(*constraint)) + success = tryDispatch(*psc, constraint, force); + else if (auto gc = get(*constraint)) + success = tryDispatch(*gc, constraint, force); + else if (auto ic = get(*constraint)) + success = tryDispatch(*ic, constraint, force); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint, force); + else if (auto bc = get(*constraint)) + success = tryDispatch(*bc, constraint, force); + else if (auto ic = get(*constraint)) + success = tryDispatch(*ic, constraint, force); + else if (auto nc = get(*constraint)) + success = tryDispatch(*nc, constraint); + else if (auto taec = get(*constraint)) + success = tryDispatch(*taec, constraint); + else if (auto fcc = get(*constraint)) + success = tryDispatch(*fcc, constraint); + else if (auto fcc = get(*constraint)) + success = tryDispatch(*fcc, constraint); + else if (auto hpc = get(*constraint)) + success = tryDispatch(*hpc, constraint); + else if (auto spc = get(*constraint)) + success = tryDispatch(*spc, constraint); + else if (auto sottc = get(*constraint)) + success = tryDispatch(*sottc, constraint); + else + LUAU_ASSERT(false); + + if (success) + { + unblock(constraint); + } + + return success; +} + +bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) +{ + if (!recursiveBlock(c.subType, constraint)) + return false; + if (!recursiveBlock(c.superType, constraint)) + return false; + + if (isBlocked(c.subType)) + return block(c.subType, constraint); + else if (isBlocked(c.superType)) + return block(c.superType, constraint); + + unify(c.subType, c.superType, constraint->scope); + + return true; +} + +bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) +{ + if (!recursiveBlock(c.subPack, constraint) || !recursiveBlock(c.superPack, constraint)) + return false; + + if (isBlocked(c.subPack)) + return block(c.subPack, constraint); + else if (isBlocked(c.superPack)) + return block(c.superPack, constraint); + + unify(c.subPack, c.superPack, constraint->scope); + + return true; +} + +bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force) +{ + if (isBlocked(c.sourceType)) + return block(c.sourceType, constraint); + + TypeId generalized = quantify(arena, c.sourceType, constraint->scope); + + if (isBlocked(c.generalizedType)) + asMutable(c.generalizedType)->ty.emplace(generalized); + else + unify(c.generalizedType, generalized, constraint->scope); + + unblock(c.generalizedType); + unblock(c.sourceType); + + return true; +} + +bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force) +{ + if (isBlocked(c.superType)) + return block(c.superType, constraint); + + Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); + + std::optional instantiated = inst.substitute(c.superType); + LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS + + if (isBlocked(c.subType)) + asMutable(c.subType)->ty.emplace(*instantiated); + else + unify(c.subType, *instantiated, constraint->scope); + + unblock(c.subType); + + return true; +} + +bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force) +{ + TypeId operandType = follow(c.operandType); + + if (isBlocked(operandType)) + return block(operandType, constraint); + + if (get(operandType)) + return block(operandType, constraint); + + LUAU_ASSERT(get(c.resultType)); + + switch (c.op) + { + case AstExprUnary::Not: + { + asMutable(c.resultType)->ty.emplace(singletonTypes->booleanType); + return true; + } + case AstExprUnary::Len: + { + // __len must return a number. + asMutable(c.resultType)->ty.emplace(singletonTypes->numberType); + return true; + } + case AstExprUnary::Minus: + { + if (isNumber(operandType) || get(operandType) || get(operandType)) + { + asMutable(c.resultType)->ty.emplace(c.operandType); + } + else if (std::optional mm = findMetatableEntry(singletonTypes, errors, operandType, "__unm", constraint->location)) + { + const FunctionTypeVar* ftv = get(follow(*mm)); + + if (!ftv) + { + if (std::optional callMm = findMetatableEntry(singletonTypes, errors, follow(*mm), "__call", constraint->location)) + { + ftv = get(follow(*callMm)); + } + } + + if (!ftv) + { + asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + return true; + } + + TypePackId argsPack = arena->addTypePack({operandType}); + unify(ftv->argTypes, argsPack, constraint->scope); + + TypeId result = singletonTypes->errorRecoveryType(); + if (ftv) + { + result = first(ftv->retTypes).value_or(singletonTypes->errorRecoveryType()); + } + + asMutable(c.resultType)->ty.emplace(result); + } + else + { + asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + } + + return true; + } + } + + LUAU_ASSERT(false); + return false; +} + +bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force) +{ + TypeId leftType = follow(c.leftType); + TypeId rightType = follow(c.rightType); + TypeId resultType = follow(c.resultType); + + bool isLogical = c.op == AstExprBinary::Op::And || c.op == AstExprBinary::Op::Or; + + /* Compound assignments create constraints of the form + * + * A <: Binary + * + * This constraint is the one that is meant to unblock A, so it doesn't + * make any sense to stop and wait for someone else to do it. + */ + + if (isBlocked(leftType) && leftType != resultType) + return block(c.leftType, constraint); + + if (isBlocked(rightType) && rightType != resultType) + return block(c.rightType, constraint); + + if (!force) + { + // Logical expressions may proceed if the LHS is free. + if (get(leftType) && !isLogical) + return block(leftType, constraint); + } + + // Logical expressions may proceed if the LHS is free. + if (isBlocked(leftType) || (get(leftType) && !isLogical)) + { + asMutable(resultType)->ty.emplace(errorRecoveryType()); + unblock(resultType); + return true; + } + + // Metatables go first, even if there is primitive behavior. + if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end()) + { + // Metatables are not the same. The metamethod will not be invoked. + if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) && + getMetatable(leftType, singletonTypes) != getMetatable(rightType, singletonTypes)) + { + // TODO: Boolean singleton false? The result is _always_ boolean false. + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); + return true; + } + + std::optional mm; + + // The LHS metatable takes priority over the RHS metatable, where + // present. + if (std::optional leftMm = findMetatableEntry(singletonTypes, errors, leftType, it->second, constraint->location)) + mm = leftMm; + else if (std::optional rightMm = findMetatableEntry(singletonTypes, errors, rightType, it->second, constraint->location)) + mm = rightMm; + + if (mm) + { + Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, constraint->scope}; + std::optional instantiatedMm = instantiation.substitute(*mm); + if (!instantiatedMm) + { + reportError(CodeTooComplex{}, constraint->location); + return true; + } + + // TODO: Is a table with __call legal here? + // TODO: Overloads + if (const FunctionTypeVar* ftv = get(follow(*instantiatedMm))) + { + TypePackId inferredArgs; + // For >= and > we invoke __lt and __le respectively with + // swapped argument ordering. + if (c.op == AstExprBinary::Op::CompareGe || c.op == AstExprBinary::Op::CompareGt) + { + inferredArgs = arena->addTypePack({rightType, leftType}); + } + else + { + inferredArgs = arena->addTypePack({leftType, rightType}); + } + + unify(inferredArgs, ftv->argTypes, constraint->scope); + + TypeId mmResult; + + // Comparison operations always evaluate to a boolean, + // regardless of what the metamethod returns. + switch (c.op) + { + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + mmResult = singletonTypes->booleanType; + break; + default: + mmResult = first(ftv->retTypes).value_or(errorRecoveryType()); + } + + asMutable(resultType)->ty.emplace(mmResult); + unblock(resultType); + + (*c.astOriginalCallTypes)[c.expr] = *mm; + (*c.astOverloadResolvedTypes)[c.expr] = *instantiatedMm; + return true; + } + } + + // If there's no metamethod available, fall back to primitive behavior. + } + + // If any is present, the expression must evaluate to any as well. + bool leftAny = get(leftType) || get(leftType); + bool rightAny = get(rightType) || get(rightType); + bool anyPresent = leftAny || rightAny; + + switch (c.op) + { + // For arithmetic operators, if the LHS is a number, the RHS must be a + // number as well. The result will also be a number. + case AstExprBinary::Op::Add: + case AstExprBinary::Op::Sub: + case AstExprBinary::Op::Mul: + case AstExprBinary::Op::Div: + case AstExprBinary::Op::Pow: + case AstExprBinary::Op::Mod: + if (isNumber(leftType)) + { + unify(leftType, rightType, constraint->scope); + asMutable(resultType)->ty.emplace(anyPresent ? singletonTypes->anyType : leftType); + unblock(resultType); + return true; + } + + break; + // For concatenation, if the LHS is a string, the RHS must be a string as + // well. The result will also be a string. + case AstExprBinary::Op::Concat: + if (isString(leftType)) + { + unify(leftType, rightType, constraint->scope); + asMutable(resultType)->ty.emplace(anyPresent ? singletonTypes->anyType : leftType); + unblock(resultType); + return true; + } + + break; + // Inexact comparisons require that the types be both numbers or both + // strings, and evaluate to a boolean. + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + if ((isNumber(leftType) && isNumber(rightType)) || (isString(leftType) && isString(rightType))) + { + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); + return true; + } + + break; + // == and ~= always evaluate to a boolean, and impose no other constraints + // on their parameters. + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + asMutable(resultType)->ty.emplace(singletonTypes->booleanType); + unblock(resultType); + return true; + // And evalutes to a boolean if the LHS is falsey, and the RHS type if LHS is + // truthy. + case AstExprBinary::Op::And: + { + TypeId leftFilteredTy = arena->addType(IntersectionTypeVar{{singletonTypes->falsyType, leftType}}); + + asMutable(resultType)->ty.emplace(arena->addType(UnionTypeVar{{leftFilteredTy, rightType}})); + unblock(resultType); + return true; + } + // Or evaluates to the LHS type if the LHS is truthy, and the RHS type if + // LHS is falsey. + case AstExprBinary::Op::Or: + { + TypeId leftFilteredTy = arena->addType(IntersectionTypeVar{{singletonTypes->truthyType, leftType}}); + + asMutable(resultType)->ty.emplace(arena->addType(UnionTypeVar{{leftFilteredTy, rightType}})); + unblock(resultType); + return true; + } + default: + iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location); + break; + } + + // We failed to either evaluate a metamethod or invoke primitive behavior. + unify(leftType, errorRecoveryType(), constraint->scope); + unify(rightType, errorRecoveryType(), constraint->scope); + asMutable(resultType)->ty.emplace(errorRecoveryType()); + unblock(resultType); + + return true; +} + +bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull constraint, bool force) +{ + /* + * for .. in loops can play out in a bunch of different ways depending on + * the shape of iteratee. + * + * iteratee might be: + * * (nextFn) + * * (nextFn, table) + * * (nextFn, table, firstIndex) + * * table with a metatable and __index + * * table with a metatable and __call but no __index (if the metatable has + * both, __index takes precedence) + * * table with an indexer but no __index or __call (or no metatable) + * + * To dispatch this constraint, we need first to know enough about iteratee + * to figure out which of the above shapes we are actually working with. + * + * If `force` is true and we still do not know, we must flag a warning. Type + * families are the fix for this. + * + * Since we need to know all of this stuff about the types of the iteratee, + * we have no choice but for ConstraintSolver to also be the thing that + * applies constraints to the types of the iterators. + */ + + auto block_ = [&](auto&& t) { + if (force) + { + // If we haven't figured out the type of the iteratee by now, + // there's nothing we can do. + return true; + } + + block(t, constraint); + return false; + }; + + auto [iteratorTypes, iteratorTail] = flatten(c.iterator); + if (iteratorTail) + return block_(*iteratorTail); + + { + bool blocked = false; + for (TypeId t : iteratorTypes) + { + if (isBlocked(t)) + { + block(t, constraint); + blocked = true; + } + } + + if (blocked) + return false; + } + + if (0 == iteratorTypes.size()) + { + Anyification anyify{arena, constraint->scope, singletonTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; + std::optional anyified = anyify.substitute(c.variables); + LUAU_ASSERT(anyified); + unify(*anyified, c.variables, constraint->scope); + + return true; + } + + TypeId nextTy = follow(iteratorTypes[0]); + if (get(nextTy)) + return block_(nextTy); + + if (get(nextTy)) + { + TypeId tableTy = singletonTypes->nilType; + if (iteratorTypes.size() >= 2) + tableTy = iteratorTypes[1]; + + TypeId firstIndexTy = singletonTypes->nilType; + if (iteratorTypes.size() >= 3) + firstIndexTy = iteratorTypes[2]; + + return tryDispatchIterableFunction(nextTy, tableTy, firstIndexTy, c, constraint, force); + } + + else + return tryDispatchIterableTable(iteratorTypes[0], c, constraint, force); + + return true; +} + +bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull constraint) +{ + if (isBlocked(c.namedType)) + return block(c.namedType, constraint); + + TypeId target = follow(c.namedType); + + if (target->persistent || target->owningArena != arena) + return true; + + if (TableTypeVar* ttv = getMutable(target)) + ttv->name = c.name; + else if (MetatableTypeVar* mtv = getMutable(target)) + mtv->syntheticName = c.name; + else if (get(target) || get(target)) + { + // nothing (yet) + } + else + return block(c.namedType, constraint); + + return true; +} + +struct InfiniteTypeFinder : TypeVarOnceVisitor +{ + ConstraintSolver* solver; + const InstantiationSignature& signature; + NotNull scope; + bool foundInfiniteType = false; + + explicit InfiniteTypeFinder(ConstraintSolver* solver, const InstantiationSignature& signature, NotNull scope) + : solver(solver) + , signature(signature) + , scope(scope) + { + } + + bool visit(TypeId ty, const PendingExpansionTypeVar& petv) override + { + std::optional tf = + (petv.prefix) ? scope->lookupImportedType(petv.prefix->value, petv.name.value) : scope->lookupType(petv.name.value); + + if (!tf.has_value()) + return true; + + auto [typeArguments, packArguments] = saturateArguments(solver->arena, solver->singletonTypes, *tf, petv.typeArguments, petv.packArguments); + + if (follow(tf->type) == follow(signature.fn.type) && (signature.arguments != typeArguments || signature.packArguments != packArguments)) + { + foundInfiniteType = true; + return false; + } + + return true; + } +}; + +struct InstantiationQueuer : TypeVarOnceVisitor +{ + ConstraintSolver* solver; + const InstantiationSignature& signature; + NotNull scope; + Location location; + + explicit InstantiationQueuer(NotNull scope, const Location& location, ConstraintSolver* solver, const InstantiationSignature& signature) + : solver(solver) + , signature(signature) + , scope(scope) + , location(location) + { + } + + bool visit(TypeId ty, const PendingExpansionTypeVar& petv) override + { + solver->pushConstraint(scope, location, TypeAliasExpansionConstraint{ty}); + return false; + } +}; + +bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint) +{ + const PendingExpansionTypeVar* petv = get(follow(c.target)); + if (!petv) + { + unblock(c.target); + return true; + } + + auto bindResult = [this, &c](TypeId result) { + asMutable(c.target)->ty.emplace(result); + unblock(c.target); + }; + + std::optional tf = (petv->prefix) ? constraint->scope->lookupImportedType(petv->prefix->value, petv->name.value) + : constraint->scope->lookupType(petv->name.value); + + if (!tf.has_value()) + { + reportError(UnknownSymbol{petv->name.value, UnknownSymbol::Context::Type}, constraint->location); + bindResult(errorRecoveryType()); + return true; + } + + // If there are no parameters to the type function we can just use the type + // directly. + if (tf->typeParams.empty() && tf->typePackParams.empty()) + { + bindResult(tf->type); + return true; + } + + auto [typeArguments, packArguments] = saturateArguments(arena, singletonTypes, *tf, petv->typeArguments, petv->packArguments); + + bool sameTypes = std::equal(typeArguments.begin(), typeArguments.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& p) { + return itp == p.ty; + }); + + bool samePacks = + std::equal(packArguments.begin(), packArguments.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itp, auto&& p) { + return itp == p.tp; + }); + + // If we're instantiating the type with its generic saturatedTypeArguments we are + // performing the identity substitution. We can just short-circuit and bind + // to the TypeFun's type. + if (sameTypes && samePacks) + { + bindResult(tf->type); + return true; + } + + InstantiationSignature signature{ + *tf, + typeArguments, + packArguments, + }; + + // If we use the same signature, we don't need to bother trying to + // instantiate the alias again, since the instantiation should be + // deterministic. + if (TypeId* cached = instantiatedAliases.find(signature)) + { + bindResult(*cached); + return true; + } + + // In order to prevent infinite types from being expanded and causing us to + // cycle infinitely, we need to scan the type function for cases where we + // expand the same alias with different type saturatedTypeArguments. See + // https://github.com/Roblox/luau/pull/68 for the RFC responsible for this. + // This is a little nicer than using a recursion limit because we can catch + // the infinite expansion before actually trying to expand it. + InfiniteTypeFinder itf{this, signature, constraint->scope}; + itf.traverse(tf->type); + + if (itf.foundInfiniteType) + { + // TODO (CLI-56761): Report an error. + bindResult(errorRecoveryType()); + return true; + } + + ApplyTypeFunction applyTypeFunction{arena}; + for (size_t i = 0; i < typeArguments.size(); ++i) + { + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeArguments[i]; + } + + for (size_t i = 0; i < packArguments.size(); ++i) + { + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = packArguments[i]; + } + + std::optional maybeInstantiated = applyTypeFunction.substitute(tf->type); + // Note that ApplyTypeFunction::encounteredForwardedType is never set in + // DCR, because we do not use free types for forward-declared generic + // aliases. + + if (!maybeInstantiated.has_value()) + { + // TODO (CLI-56761): Report an error. + bindResult(errorRecoveryType()); + return true; + } + + TypeId instantiated = *maybeInstantiated; + TypeId target = follow(instantiated); + + if (target->persistent) + return true; + + // Type function application will happily give us the exact same type if + // there are e.g. generic saturatedTypeArguments that go unused. + bool needsClone = follow(tf->type) == target; + // Only tables have the properties we're trying to set. + TableTypeVar* ttv = getMutableTableType(target); + + if (ttv) + { + if (needsClone) + { + // Substitution::clone is a shallow clone. If this is a + // metatable type, we want to mutate its table, so we need to + // explicitly clone that table as well. If we don't, we will + // mutate another module's type surface and cause a + // use-after-free. + if (get(target)) + { + instantiated = applyTypeFunction.clone(target); + MetatableTypeVar* mtv = getMutable(instantiated); + mtv->table = applyTypeFunction.clone(mtv->table); + ttv = getMutable(mtv->table); + } + else if (get(target)) + { + instantiated = applyTypeFunction.clone(target); + ttv = getMutable(instantiated); + } + + target = follow(instantiated); + } + + ttv->instantiatedTypeParams = typeArguments; + ttv->instantiatedTypePackParams = packArguments; + // TODO: Fill in definitionModuleName. + } + + bindResult(target); + + // The application is not recursive, so we need to queue up application of + // any child type function instantiations within the result in order for it + // to be complete. + InstantiationQueuer queuer{constraint->scope, constraint->location, this, signature}; + queuer.traverse(target); + + instantiatedAliases[signature] = target; + + return true; +} + +bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull constraint) +{ + TypeId fn = follow(c.fn); + TypePackId result = follow(c.result); + + if (isBlocked(c.fn)) + { + return block(c.fn, constraint); + } + + // We don't support magic __call metamethods. + if (std::optional callMm = findMetatableEntry(singletonTypes, errors, fn, "__call", constraint->location)) + { + std::vector args{fn}; + + for (TypeId arg : c.argsPack) + args.push_back(arg); + + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); + TypeId inferredFnType = + arena->addType(FunctionTypeVar(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result)); + + // Alter the inner constraints. + LUAU_ASSERT(c.innerConstraints.size() == 2); + + // Anything that is blocked on this constraint must also be blocked on our inner constraints + auto blockedIt = blocked.find(constraint.get()); + if (blockedIt != blocked.end()) + { + for (const auto& ic : c.innerConstraints) + { + for (const auto& blockedConstraint : blockedIt->second) + block(ic, blockedConstraint); + } + } + + asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm}; + asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType}; + + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); + + asMutable(c.result)->ty.emplace(constraint->scope); + unblock(c.result); + return true; + } + + const FunctionTypeVar* ftv = get(fn); + bool usedMagic = false; + + if (ftv && ftv->dcrMagicFunction != nullptr) + { + usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull(this), c.callSite, c.argsPack, result}); + } + + if (usedMagic) + { + // There are constraints that are blocked on these constraints. If we + // are never going to even examine them, then we should not block + // anything else on them. + // + // TODO CLI-58842 +#if 0 + for (auto& c: c.innerConstraints) + unblock(c); +#endif + } + else + { + // Anything that is blocked on this constraint must also be blocked on our inner constraints + auto blockedIt = blocked.find(constraint.get()); + if (blockedIt != blocked.end()) + { + for (const auto& ic : c.innerConstraints) + { + for (const auto& blockedConstraint : blockedIt->second) + block(ic, blockedConstraint); + } + } + + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); + asMutable(c.result)->ty.emplace(constraint->scope); + } + + unblock(c.result); + + return true; +} + +bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint) +{ + TypeId expectedType = follow(c.expectedType); + if (isBlocked(expectedType) || get(expectedType)) + return block(expectedType, constraint); + + TypeId bindTo = maybeSingleton(expectedType) ? c.singletonType : c.multitonType; + asMutable(c.resultType)->ty.emplace(bindTo); + + return true; +} + +bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull constraint) +{ + TypeId subjectType = follow(c.subjectType); + + if (isBlocked(subjectType) || get(subjectType)) + return block(subjectType, constraint); + + if (get(subjectType)) + { + TableTypeVar& ttv = asMutable(subjectType)->ty.emplace(TableState::Free, TypeLevel{}, constraint->scope); + ttv.props[c.prop] = Property{c.resultType}; + asMutable(c.resultType)->ty.emplace(constraint->scope); + unblock(c.resultType); + return true; + } + + std::optional resultType = lookupTableProp(subjectType, c.prop); + if (!resultType) + { + asMutable(c.resultType)->ty.emplace(singletonTypes->errorRecoveryType()); + unblock(c.resultType); + return true; + } + + if (isBlocked(*resultType)) + { + block(*resultType, constraint); + return false; + } + + asMutable(c.resultType)->ty.emplace(*resultType); + return true; +} + +static bool isUnsealedTable(TypeId ty) +{ + ty = follow(ty); + const TableTypeVar* ttv = get(ty); + return ttv && ttv->state == TableState::Unsealed; +} + +/** + * Create a shallow copy of `ty` and its properties along `path`. Insert a new + * property (the last segment of `path`) into the tail table with the value `t`. + * + * On success, returns the new outermost table type. If the root table or any + * of its subkeys are not unsealed tables, the function fails and returns + * std::nullopt. + * + * TODO: Prove that we completely give up in the face of indexers and + * metatables. + */ +static std::optional updateTheTableType(NotNull arena, TypeId ty, const std::vector& path, TypeId replaceTy) +{ + if (path.empty()) + return std::nullopt; + + // First walk the path and ensure that it's unsealed tables all the way + // to the end. + { + TypeId t = ty; + for (size_t i = 0; i < path.size() - 1; ++i) + { + if (!isUnsealedTable(t)) + return std::nullopt; + + const TableTypeVar* tbl = get(t); + auto it = tbl->props.find(path[i]); + if (it == tbl->props.end()) + return std::nullopt; + + t = it->second.type; + } + + // The last path segment should not be a property of the table at all. + // We are not changing property types. We are only admitting this one + // new property to be appended. + if (!isUnsealedTable(t)) + return std::nullopt; + const TableTypeVar* tbl = get(t); + if (0 != tbl->props.count(path.back())) + return std::nullopt; + } + + const TypeId res = shallowClone(ty, arena); + TypeId t = res; + + for (size_t i = 0; i < path.size() - 1; ++i) + { + const std::string segment = path[i]; + + TableTypeVar* ttv = getMutable(t); + LUAU_ASSERT(ttv); + + auto propIt = ttv->props.find(segment); + if (propIt != ttv->props.end()) + { + LUAU_ASSERT(isUnsealedTable(propIt->second.type)); + t = shallowClone(follow(propIt->second.type), arena); + ttv->props[segment].type = t; + } + else + return std::nullopt; + } + + TableTypeVar* ttv = getMutable(t); + LUAU_ASSERT(ttv); + + const std::string lastSegment = path.back(); + LUAU_ASSERT(0 == ttv->props.count(lastSegment)); + ttv->props[lastSegment] = Property{replaceTy}; + return res; +} + +bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint) +{ + TypeId subjectType = follow(c.subjectType); + + if (isBlocked(subjectType)) + return block(subjectType, constraint); + + std::optional existingPropType = subjectType; + for (const std::string& segment : c.path) + { + ErrorVec e; + std::optional propTy = lookupTableProp(*existingPropType, segment); + if (!propTy) + { + existingPropType = std::nullopt; + break; + } + else if (isBlocked(*propTy)) + return block(*propTy, constraint); + else + existingPropType = follow(*propTy); + } + + auto bind = [](TypeId a, TypeId b) { + asMutable(a)->ty.emplace(b); + }; + + if (existingPropType) + { + unify(c.propType, *existingPropType, constraint->scope); + bind(c.resultType, c.subjectType); + return true; + } + + if (get(subjectType)) + { + TypeId ty = arena->freshType(constraint->scope); + + // Mint a chain of free tables per c.path + for (auto it = rbegin(c.path); it != rend(c.path); ++it) + { + TableTypeVar t{TableState::Free, TypeLevel{}, constraint->scope}; + t.props[*it] = {ty}; + + ty = arena->addType(std::move(t)); + } + + LUAU_ASSERT(ty); + + bind(subjectType, ty); + bind(c.resultType, ty); + return true; + } + else if (auto ttv = getMutable(subjectType)) + { + if (ttv->state == TableState::Free) + { + ttv->props[c.path[0]] = Property{c.propType}; + bind(c.resultType, c.subjectType); + return true; + } + else if (ttv->state == TableState::Unsealed) + { + std::optional augmented = updateTheTableType(NotNull{arena}, subjectType, c.path, c.propType); + bind(c.resultType, augmented.value_or(subjectType)); + return true; + } + else + { + bind(c.resultType, subjectType); + return true; + } + } + else if (get(subjectType) || get(subjectType)) + { + bind(c.resultType, subjectType); + return true; + } + + LUAU_ASSERT(0); + return true; +} + +bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint) +{ + if (isBlocked(c.discriminantType)) + return false; + + TypeId followed = follow(c.discriminantType); + + // `nil` is a singleton type too! There's only one value of type `nil`. + if (c.negated && (get(followed) || isNil(followed))) + *asMutable(c.resultType) = NegationTypeVar{c.discriminantType}; + else if (!c.negated && get(followed)) + *asMutable(c.resultType) = BoundTypeVar{c.discriminantType}; + else + *asMutable(c.resultType) = BoundTypeVar{singletonTypes->unknownType}; + + return true; +} + +bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) +{ + auto block_ = [&](auto&& t) { + if (force) + { + // TODO: I believe it is the case that, if we are asked to force + // this constraint, then we can do nothing but fail. I'd like to + // find a code sample that gets here. + LUAU_ASSERT(false); + } + else + block(t, constraint); + return false; + }; + + // We may have to block here if we don't know what the iteratee type is, + // if it's a free table, if we don't know it has a metatable, and so on. + iteratorTy = follow(iteratorTy); + if (get(iteratorTy)) + return block_(iteratorTy); + + auto anyify = [&](auto ty) { + Anyification anyify{arena, constraint->scope, singletonTypes, &iceReporter, singletonTypes->anyType, singletonTypes->anyTypePack}; + std::optional anyified = anyify.substitute(ty); + if (!anyified) + reportError(CodeTooComplex{}, constraint->location); + else + unify(*anyified, ty, constraint->scope); + }; + + auto errorify = [&](auto ty) { + Anyification anyify{arena, constraint->scope, singletonTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; + std::optional errorified = anyify.substitute(ty); + if (!errorified) + reportError(CodeTooComplex{}, constraint->location); + else + unify(*errorified, ty, constraint->scope); + }; + + if (get(iteratorTy)) + { + anyify(c.variables); + return true; + } + + if (get(iteratorTy)) + { + errorify(c.variables); + return true; + } + + // Irksome: I don't think we have any way to guarantee that this table + // type never has a metatable. + + if (auto iteratorTable = get(iteratorTy)) + { + if (iteratorTable->state == TableState::Free) + return block_(iteratorTy); + + if (iteratorTable->indexer) + { + TypePackId expectedVariablePack = arena->addTypePack({iteratorTable->indexer->indexType, iteratorTable->indexer->indexResultType}); + unify(c.variables, expectedVariablePack, constraint->scope); + } + else + errorify(c.variables); + } + else if (std::optional iterFn = findMetatableEntry(singletonTypes, errors, iteratorTy, "__iter", Location{})) + { + if (isBlocked(*iterFn)) + { + return block(*iterFn, constraint); + } + + Instantiation instantiation(TxnLog::empty(), arena, TypeLevel{}, constraint->scope); + + if (std::optional instantiatedIterFn = instantiation.substitute(*iterFn)) + { + if (auto iterFtv = get(*instantiatedIterFn)) + { + TypePackId expectedIterArgs = arena->addTypePack({iteratorTy}); + unify(iterFtv->argTypes, expectedIterArgs, constraint->scope); + + TypePack iterRets = extendTypePack(*arena, singletonTypes, iterFtv->retTypes, 2); + + if (iterRets.head.size() < 1) + { + // We've done what we can; this will get reported as an + // error by the type checker. + return true; + } + + TypeId nextFn = iterRets.head[0]; + TypeId table = iterRets.head.size() == 2 ? iterRets.head[1] : arena->freshType(constraint->scope); + + if (std::optional instantiatedNextFn = instantiation.substitute(nextFn)) + { + const TypeId firstIndex = arena->freshType(constraint->scope); + + // nextTy : (iteratorTy, indexTy?) -> (indexTy, valueTailTy...) + const TypePackId nextArgPack = arena->addTypePack({table, arena->addType(UnionTypeVar{{firstIndex, singletonTypes->nilType}})}); + const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); + const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); + + const TypeId expectedNextTy = arena->addType(FunctionTypeVar{nextArgPack, nextRetPack}); + unify(*instantiatedNextFn, expectedNextTy, constraint->scope); + + pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, nextRetPack}); + } + else + { + reportError(UnificationTooComplex{}, constraint->location); + } + } + else + { + // TODO: Support __call and function overloads (what does an overload even mean for this?) + } + } + else + { + reportError(UnificationTooComplex{}, constraint->location); + } + } + else if (auto iteratorMetatable = get(iteratorTy)) + { + TypeId metaTy = follow(iteratorMetatable->metatable); + if (get(metaTy)) + return block_(metaTy); + + LUAU_ASSERT(false); + } + else + errorify(c.variables); + + return true; +} + +bool ConstraintSolver::tryDispatchIterableFunction( + TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force) +{ + // We need to know whether or not this type is nil or not. + // If we don't know, block and reschedule ourselves. + firstIndexTy = follow(firstIndexTy); + if (get(firstIndexTy)) + { + if (force) + LUAU_ASSERT(false); + else + block(firstIndexTy, constraint); + return false; + } + + const TypeId firstIndex = isNil(firstIndexTy) ? arena->freshType(constraint->scope) // FIXME: Surely this should be a union (free | nil) + : firstIndexTy; + + // nextTy : (tableTy, indexTy?) -> (indexTy, valueTailTy...) + const TypePackId nextArgPack = arena->addTypePack({tableTy, arena->addType(UnionTypeVar{{firstIndex, singletonTypes->nilType}})}); + const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); + const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); + + const TypeId expectedNextTy = arena->addType(FunctionTypeVar{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); + unify(nextTy, expectedNextTy, constraint->scope); + + pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, nextRetPack}); + + return true; +} + +std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) +{ + auto collectParts = [&](auto&& unionOrIntersection) -> std::pair, std::vector> { + std::optional blocked; + + std::vector parts; + for (TypeId expectedPart : unionOrIntersection) + { + expectedPart = follow(expectedPart); + if (isBlocked(expectedPart) || get(expectedPart)) + blocked = expectedPart; + else if (const TableTypeVar* ttv = get(follow(expectedPart))) + { + if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) + parts.push_back(prop->second.type); + else if (ttv->indexer && maybeString(ttv->indexer->indexType)) + parts.push_back(ttv->indexer->indexResultType); + } + } + + return {blocked, parts}; + }; + + std::optional resultType; + + if (auto ttv = get(subjectType)) + { + if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) + resultType = prop->second.type; + else if (ttv->indexer && maybeString(ttv->indexer->indexType)) + resultType = ttv->indexer->indexResultType; + } + else if (auto utv = get(subjectType)) + { + auto [blocked, parts] = collectParts(utv); + + if (blocked) + resultType = *blocked; + else if (parts.size() == 1) + resultType = parts[0]; + else if (parts.size() > 1) + resultType = arena->addType(UnionTypeVar{std::move(parts)}); + else + LUAU_ASSERT(false); // parts.size() == 0 + } + else if (auto itv = get(subjectType)) + { + auto [blocked, parts] = collectParts(itv); + + if (blocked) + resultType = *blocked; + else if (parts.size() == 1) + resultType = parts[0]; + else if (parts.size() > 1) + resultType = arena->addType(IntersectionTypeVar{std::move(parts)}); + else + LUAU_ASSERT(false); // parts.size() == 0 + } + + return resultType; +} + +void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) +{ + blocked[target].push_back(constraint); + + auto& count = blockedConstraints[constraint]; + count += 1; +} + +void ConstraintSolver::block(NotNull target, NotNull constraint) +{ + if (FFlag::DebugLuauLogSolverToJson) + logger->pushBlock(constraint, target); + + if (FFlag::DebugLuauLogSolver) + printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str()); + + block_(target, constraint); +} + +bool ConstraintSolver::block(TypeId target, NotNull constraint) +{ + if (FFlag::DebugLuauLogSolverToJson) + logger->pushBlock(constraint, target); + + if (FFlag::DebugLuauLogSolver) + printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + + block_(target, constraint); + return false; +} + +bool ConstraintSolver::block(TypePackId target, NotNull constraint) +{ + if (FFlag::DebugLuauLogSolverToJson) + logger->pushBlock(constraint, target); + + if (FFlag::DebugLuauLogSolver) + printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); + + block_(target, constraint); + return false; +} + +struct Blocker : TypeVarOnceVisitor +{ + NotNull solver; + NotNull constraint; + + bool blocked = false; + + explicit Blocker(NotNull solver, NotNull constraint) + : solver(solver) + , constraint(constraint) + { + } + + bool visit(TypeId ty, const BlockedTypeVar&) + { + blocked = true; + solver->block(ty, constraint); + return false; + } + + bool visit(TypeId ty, const PendingExpansionTypeVar&) + { + blocked = true; + solver->block(ty, constraint); + return false; + } +}; + +bool ConstraintSolver::recursiveBlock(TypeId target, NotNull constraint) +{ + Blocker blocker{NotNull{this}, constraint}; + blocker.traverse(target); + return !blocker.blocked; +} + +bool ConstraintSolver::recursiveBlock(TypePackId pack, NotNull constraint) +{ + Blocker blocker{NotNull{this}, constraint}; + blocker.traverse(pack); + return !blocker.blocked; +} + +void ConstraintSolver::unblock_(BlockedConstraintId progressed) +{ + auto it = blocked.find(progressed); + if (it == blocked.end()) + return; + + // unblocked should contain a value always, because of the above check + for (NotNull unblockedConstraint : it->second) + { + auto& count = blockedConstraints[unblockedConstraint]; + if (FFlag::DebugLuauLogSolver) + printf("Unblocking count=%d\t%s\n", int(count), toString(*unblockedConstraint, opts).c_str()); + + // This assertion being hit indicates that `blocked` and + // `blockedConstraints` desynchronized at some point. This is problematic + // because we rely on this count being correct to skip over blocked + // constraints. + LUAU_ASSERT(count > 0); + count -= 1; + } + + blocked.erase(it); +} + +void ConstraintSolver::unblock(NotNull progressed) +{ + if (FFlag::DebugLuauLogSolverToJson) + logger->popBlock(progressed); + + return unblock_(progressed); +} + +void ConstraintSolver::unblock(TypeId progressed) +{ + if (FFlag::DebugLuauLogSolverToJson) + logger->popBlock(progressed); + + return unblock_(progressed); +} + +void ConstraintSolver::unblock(TypePackId progressed) +{ + if (FFlag::DebugLuauLogSolverToJson) + logger->popBlock(progressed); + + return unblock_(progressed); +} + +void ConstraintSolver::unblock(const std::vector& types) +{ + for (TypeId t : types) + unblock(t); +} + +void ConstraintSolver::unblock(const std::vector& packs) +{ + for (TypePackId t : packs) + unblock(t); +} + +bool ConstraintSolver::isBlocked(TypeId ty) +{ + return nullptr != get(follow(ty)) || nullptr != get(follow(ty)); +} + +bool ConstraintSolver::isBlocked(TypePackId tp) +{ + return nullptr != get(follow(tp)); +} + +bool ConstraintSolver::isBlocked(NotNull constraint) +{ + auto blockedIt = blockedConstraints.find(constraint); + return blockedIt != blockedConstraints.end() && blockedIt->second > 0; +} + +void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) +{ + Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(subType, superType); + + if (!u.errors.empty()) + { + TypeId errorType = errorRecoveryType(); + u.tryUnify(subType, errorType); + u.tryUnify(superType, errorType); + } + + const auto [changedTypes, changedPacks] = u.log.getChanges(); + + u.log.commit(); + + unblock(changedTypes); + unblock(changedPacks); +} + +void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) +{ + UnifierSharedState sharedState{&iceReporter}; + Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(subPack, superPack); + + const auto [changedTypes, changedPacks] = u.log.getChanges(); + + u.log.commit(); + + unblock(changedTypes); + unblock(changedPacks); +} + +void ConstraintSolver::pushConstraint(NotNull scope, const Location& location, ConstraintV cv) +{ + std::unique_ptr c = std::make_unique(scope, location, std::move(cv)); + NotNull borrow = NotNull(c.get()); + solverConstraints.push_back(std::move(c)); + unsolvedConstraints.push_back(borrow); +} + +TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& location) +{ + if (info.name.empty()) + { + reportError(UnknownRequire{}, location); + return errorRecoveryType(); + } + + std::string humanReadableName = moduleResolver->getHumanReadableModuleName(info.name); + + for (const auto& [location, path] : requireCycles) + { + if (!path.empty() && path.front() == humanReadableName) + return singletonTypes->anyType; + } + + ModulePtr module = moduleResolver->getModule(info.name); + if (!module) + { + if (!moduleResolver->moduleExists(info.name) && !info.optional) + reportError(UnknownRequire{humanReadableName}, location); + + return errorRecoveryType(); + } + + if (module->type != SourceCode::Type::Module) + { + reportError(IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}, location); + return errorRecoveryType(); + } + + TypePackId modulePack = module->getModuleScope()->returnType; + if (get(modulePack)) + return errorRecoveryType(); + + std::optional moduleType = first(modulePack); + if (!moduleType) + { + reportError(IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}, location); + return errorRecoveryType(); + } + + return *moduleType; +} + +void ConstraintSolver::reportError(TypeErrorData&& data, const Location& location) +{ + errors.emplace_back(location, std::move(data)); + errors.back().moduleName = currentModuleName; +} + +void ConstraintSolver::reportError(TypeError e) +{ + errors.emplace_back(std::move(e)); + errors.back().moduleName = currentModuleName; +} + +TypeId ConstraintSolver::errorRecoveryType() const +{ + return singletonTypes->errorRecoveryType(); +} + +TypePackId ConstraintSolver::errorRecoveryTypePack() const +{ + return singletonTypes->errorRecoveryTypePack(); +} + +TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes) +{ + a = follow(a); + b = follow(b); + + if (unifyFreeTypes && (get(a) || get(b))) + { + Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; + u.useScopes = true; + u.tryUnify(b, a); + + if (u.errors.empty()) + { + u.log.commit(); + return a; + } + else + { + return singletonTypes->errorRecoveryType(singletonTypes->anyType); + } + } + + if (*a == *b) + return a; + + std::vector types = reduceUnion({a, b}); + if (types.empty()) + return singletonTypes->neverType; + + if (types.size() == 1) + return types[0]; + + return arena->addType(UnionTypeVar{types}); +} + +} // namespace Luau diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp new file mode 100644 index 00000000..cffd00c9 --- /dev/null +++ b/Analysis/src/DataFlowGraph.cpp @@ -0,0 +1,451 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/DataFlowGraph.h" + +#include "Luau/Error.h" + +LUAU_FASTFLAG(DebugLuauFreezeArena) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + +namespace Luau +{ + +std::optional DataFlowGraph::getDef(const AstExpr* expr) const +{ + // We need to skip through AstExprGroup because DFG doesn't try its best to transitively + while (auto group = expr->as()) + expr = group->expr; + if (auto def = astDefs.find(expr)) + return NotNull{*def}; + return std::nullopt; +} + +std::optional DataFlowGraph::getDef(const AstLocal* local) const +{ + if (auto def = localDefs.find(local)) + return NotNull{*def}; + return std::nullopt; +} + +std::optional DataFlowGraph::getDef(const Symbol& symbol) const +{ + if (symbol.local) + return getDef(symbol.local); + else + return std::nullopt; +} + +DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) +{ + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + + DataFlowGraphBuilder builder; + builder.handle = handle; + builder.visit(nullptr, block); // nullptr is the root DFG scope. + if (FFlag::DebugLuauFreezeArena) + builder.arena->allocator.freeze(); + return std::move(builder.graph); +} + +DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope) +{ + return scopes.emplace_back(new DfgScope{scope}).get(); +} + +std::optional DataFlowGraphBuilder::use(DfgScope* scope, Symbol symbol, AstExpr* e) +{ + for (DfgScope* current = scope; current; current = current->parent) + { + if (auto def = current->bindings.find(symbol)) + { + graph.astDefs[e] = *def; + return NotNull{*def}; + } + } + + return std::nullopt; +} + +DefId DataFlowGraphBuilder::use(DefId def, AstExprIndexName* e) +{ + auto& propertyDef = props[def][e->index.value]; + if (!propertyDef) + propertyDef = arena->freshCell(def, e->index.value); + graph.astDefs[e] = propertyDef; + return NotNull{propertyDef}; +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) +{ + DfgScope* child = childScope(scope); + return visitBlockWithoutChildScope(child, b); +} + +void DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b) +{ + for (AstStat* s : b->body) + visit(scope, s); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) +{ + if (auto b = s->as()) + return visit(scope, b); + else if (auto i = s->as()) + return visit(scope, i); + else if (auto w = s->as()) + return visit(scope, w); + else if (auto r = s->as()) + return visit(scope, r); + else if (auto b = s->as()) + return visit(scope, b); + else if (auto c = s->as()) + return visit(scope, c); + else if (auto r = s->as()) + return visit(scope, r); + else if (auto e = s->as()) + return visit(scope, e); + else if (auto l = s->as()) + return visit(scope, l); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto a = s->as()) + return visit(scope, a); + else if (auto c = s->as()) + return visit(scope, c); + else if (auto f = s->as()) + return visit(scope, f); + else if (auto l = s->as()) + return visit(scope, l); + else if (auto t = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto d = s->as()) + return; // ok + else if (auto _ = s->as()) + return; // ok + else + handle->ice("Unknown AstStat in DataFlowGraphBuilder"); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) +{ + DfgScope* condScope = childScope(scope); + visitExpr(condScope, i->condition); + visit(condScope, i->thenbody); + + if (i->elsebody) + visit(scope, i->elsebody); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w) +{ + // TODO(controlflow): entry point has a back edge from exit point + DfgScope* whileScope = childScope(scope); + visitExpr(whileScope, w->condition); + visit(whileScope, w->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r) +{ + // TODO(controlflow): entry point has a back edge from exit point + DfgScope* repeatScope = childScope(scope); // TODO: loop scope. + visitBlockWithoutChildScope(repeatScope, r->body); + visitExpr(repeatScope, r->condition); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b) +{ + // TODO: Control flow analysis + return; // ok +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c) +{ + // TODO: Control flow analysis + return; // ok +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r) +{ + // TODO: Control flow analysis + for (AstExpr* e : r->list) + visitExpr(scope, e); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e) +{ + visitExpr(scope, e->expr); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) +{ + // TODO: alias tracking + for (AstExpr* e : l->values) + visitExpr(scope, e); + + for (AstLocal* local : l->vars) + { + DefId def = arena->freshCell(); + graph.localDefs[local] = def; + scope->bindings[local] = def; + } +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) +{ + DfgScope* forScope = childScope(scope); // TODO: loop scope. + DefId def = arena->freshCell(); + graph.localDefs[f->var] = def; + scope->bindings[f->var] = def; + + // TODO(controlflow): entry point has a back edge from exit point + visit(forScope, f->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) +{ + DfgScope* forScope = childScope(scope); // TODO: loop scope. + + for (AstLocal* local : f->vars) + { + DefId def = arena->freshCell(); + graph.localDefs[local] = def; + forScope->bindings[local] = def; + } + + // TODO(controlflow): entry point has a back edge from exit point + for (AstExpr* e : f->values) + visitExpr(forScope, e); + + visit(forScope, f->body); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) +{ + for (AstExpr* r : a->values) + visitExpr(scope, r); + + for (AstExpr* l : a->vars) + { + AstExpr* root = l; + + bool isUpdatable = true; + while (true) + { + if (root->is() || root->is()) + break; + + AstExprIndexName* indexName = root->as(); + if (!indexName) + { + isUpdatable = false; + break; + } + + root = indexName->expr; + } + + if (isUpdatable) + { + // TODO global? + if (auto exprLocal = root->as()) + { + DefId def = arena->freshCell(); + graph.astDefs[exprLocal] = def; + + // Update the def in the scope that introduced the local. Not + // the current scope. + AstLocal* local = exprLocal->local; + DfgScope* s = scope; + while (s && !s->bindings.find(local)) + s = s->parent; + LUAU_ASSERT(s && s->bindings.find(local)); + s->bindings[local] = def; + } + } + + visitExpr(scope, l); // TODO: they point to a new def!! + } +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c) +{ + // TODO(typestates): The lhs is being read and written to. This might or might not be annoying. + visitExpr(scope, c->value); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) +{ + visitExpr(scope, f->name); + visitExpr(scope, f->func); +} + +void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) +{ + DefId def = arena->freshCell(); + graph.localDefs[l->name] = def; + scope->bindings[l->name] = def; + + visitExpr(scope, l->func); +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) +{ + if (auto g = e->as()) + return visitExpr(scope, g->expr); + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto c = e->as()) + return {}; // ok + else if (auto l = e->as()) + return visitExpr(scope, l); + else if (auto g = e->as()) + return visitExpr(scope, g); + else if (auto v = e->as()) + return {}; // ok + else if (auto c = e->as()) + return visitExpr(scope, c); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto f = e->as()) + return visitExpr(scope, f); + else if (auto t = e->as()) + return visitExpr(scope, t); + else if (auto u = e->as()) + return visitExpr(scope, u); + else if (auto b = e->as()) + return visitExpr(scope, b); + else if (auto t = e->as()) + return visitExpr(scope, t); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto i = e->as()) + return visitExpr(scope, i); + else if (auto _ = e->as()) + return {}; // ok + else + handle->ice("Unknown AstExpr in DataFlowGraphBuilder"); +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) +{ + return {use(scope, l->local, l)}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) +{ + return {use(scope, g->name, g)}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) +{ + visitExpr(scope, c->func); + + for (AstExpr* arg : c->args) + visitExpr(scope, arg); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) +{ + std::optional def = visitExpr(scope, i->expr).def; + if (!def) + return {}; + + return {use(*def, i)}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) +{ + visitExpr(scope, i->expr); + visitExpr(scope, i->expr); + + if (i->index->as()) + { + // TODO: properties for the def + } + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) +{ + if (AstLocal* self = f->self) + { + DefId def = arena->freshCell(); + graph.localDefs[self] = def; + scope->bindings[self] = def; + } + + for (AstLocal* param : f->args) + { + DefId def = arena->freshCell(); + graph.localDefs[param] = def; + scope->bindings[param] = def; + } + + visit(scope, f->body); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) +{ + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) +{ + visitExpr(scope, u->expr); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) +{ + visitExpr(scope, b->left); + visitExpr(scope, b->right); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) +{ + ExpressionFlowGraph result = visitExpr(scope, t->expr); + // TODO: visit type + return result; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) +{ + DfgScope* condScope = childScope(scope); + visitExpr(condScope, i->condition); + visitExpr(condScope, i->trueExpr); + + visitExpr(scope, i->falseExpr); + + return {}; +} + +ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) +{ + for (AstExpr* e : i->expressions) + visitExpr(scope, e); + return {}; +} + +} // namespace Luau diff --git a/Analysis/src/DcrLogger.cpp b/Analysis/src/DcrLogger.cpp new file mode 100644 index 00000000..ef33aa60 --- /dev/null +++ b/Analysis/src/DcrLogger.cpp @@ -0,0 +1,396 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/DcrLogger.h" + +#include + +#include "Luau/JsonEmitter.h" + +namespace Luau +{ + +namespace Json +{ + +void write(JsonEmitter& emitter, const Location& location) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("beginLine", location.begin.line); + o.writePair("beginColumn", location.begin.column); + o.writePair("endLine", location.end.line); + o.writePair("endColumn", location.end.column); + o.finish(); +} + +void write(JsonEmitter& emitter, const ErrorSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("message", snapshot.message); + o.writePair("location", snapshot.location); + o.finish(); +} + +void write(JsonEmitter& emitter, const BindingSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("typeId", snapshot.typeId); + o.writePair("typeString", snapshot.typeString); + o.writePair("location", snapshot.location); + o.finish(); +} + +void write(JsonEmitter& emitter, const TypeBindingSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("typeId", snapshot.typeId); + o.writePair("typeString", snapshot.typeString); + o.finish(); +} + +void write(JsonEmitter& emitter, const ConstraintGenerationLog& log) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("source", log.source); + + emitter.writeComma(); + write(emitter, "constraintLocations"); + emitter.writeRaw(":"); + + ObjectEmitter locationEmitter = emitter.writeObject(); + + for (const auto& [id, location] : log.constraintLocations) + { + locationEmitter.writePair(id, location); + } + + locationEmitter.finish(); + o.writePair("errors", log.errors); + o.finish(); +} + +void write(JsonEmitter& emitter, const ScopeSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("bindings", snapshot.bindings); + o.writePair("typeBindings", snapshot.typeBindings); + o.writePair("typePackBindings", snapshot.typePackBindings); + o.writePair("children", snapshot.children); + o.finish(); +} + +void write(JsonEmitter& emitter, const ConstraintBlockKind& kind) +{ + switch (kind) + { + case ConstraintBlockKind::TypeId: + return write(emitter, "type"); + case ConstraintBlockKind::TypePackId: + return write(emitter, "typePack"); + case ConstraintBlockKind::ConstraintId: + return write(emitter, "constraint"); + default: + LUAU_ASSERT(0); + } +} + +void write(JsonEmitter& emitter, const ConstraintBlock& block) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("kind", block.kind); + o.writePair("stringification", block.stringification); + o.finish(); +} + +void write(JsonEmitter& emitter, const ConstraintSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("stringification", snapshot.stringification); + o.writePair("blocks", snapshot.blocks); + o.finish(); +} + +void write(JsonEmitter& emitter, const BoundarySnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("rootScope", snapshot.rootScope); + o.writePair("constraints", snapshot.constraints); + o.finish(); +} + +void write(JsonEmitter& emitter, const StepSnapshot& snapshot) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("currentConstraint", snapshot.currentConstraint); + o.writePair("forced", snapshot.forced); + o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints); + o.writePair("rootScope", snapshot.rootScope); + o.finish(); +} + +void write(JsonEmitter& emitter, const TypeSolveLog& log) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("initialState", log.initialState); + o.writePair("stepStates", log.stepStates); + o.writePair("finalState", log.finalState); + o.finish(); +} + +void write(JsonEmitter& emitter, const TypeCheckLog& log) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("errors", log.errors); + o.finish(); +} + +} // namespace Json + +static std::string toPointerId(NotNull ptr) +{ + return std::to_string(reinterpret_cast(ptr.get())); +} + +static ScopeSnapshot snapshotScope(const Scope* scope, ToStringOptions& opts) +{ + std::unordered_map bindings; + std::unordered_map typeBindings; + std::unordered_map typePackBindings; + std::vector children; + + for (const auto& [name, binding] : scope->bindings) + { + std::string id = std::to_string(reinterpret_cast(binding.typeId)); + ToStringResult result = toStringDetailed(binding.typeId, opts); + + bindings[name.c_str()] = BindingSnapshot{ + id, + result.name, + binding.location, + }; + } + + for (const auto& [name, tf] : scope->exportedTypeBindings) + { + std::string id = std::to_string(reinterpret_cast(tf.type)); + + typeBindings[name] = TypeBindingSnapshot{ + id, + toString(tf.type, opts), + }; + } + + for (const auto& [name, tf] : scope->privateTypeBindings) + { + std::string id = std::to_string(reinterpret_cast(tf.type)); + + typeBindings[name] = TypeBindingSnapshot{ + id, + toString(tf.type, opts), + }; + } + + for (const auto& [name, tp] : scope->privateTypePackBindings) + { + std::string id = std::to_string(reinterpret_cast(tp)); + + typePackBindings[name] = TypeBindingSnapshot{ + id, + toString(tp, opts), + }; + } + + for (const auto& child : scope->children) + { + children.push_back(snapshotScope(child.get(), opts)); + } + + return ScopeSnapshot{ + bindings, + typeBindings, + typePackBindings, + children, + }; +} + +std::string DcrLogger::compileOutput() +{ + Json::JsonEmitter emitter; + Json::ObjectEmitter o = emitter.writeObject(); + o.writePair("generation", generationLog); + o.writePair("solve", solveLog); + o.writePair("check", checkLog); + o.finish(); + + return emitter.str(); +} + +void DcrLogger::captureSource(std::string source) +{ + generationLog.source = std::move(source); +} + +void DcrLogger::captureGenerationError(const TypeError& error) +{ + std::string stringifiedError = toString(error); + generationLog.errors.push_back(ErrorSnapshot{ + /* message */ stringifiedError, + /* location */ error.location, + }); +} + +void DcrLogger::captureConstraintLocation(NotNull constraint, Location location) +{ + std::string id = toPointerId(constraint); + generationLog.constraintLocations[id] = location; +} + +void DcrLogger::pushBlock(NotNull constraint, TypeId block) +{ + constraintBlocks[constraint].push_back(block); +} + +void DcrLogger::pushBlock(NotNull constraint, TypePackId block) +{ + constraintBlocks[constraint].push_back(block); +} + +void DcrLogger::pushBlock(NotNull constraint, NotNull block) +{ + constraintBlocks[constraint].push_back(block); +} + +void DcrLogger::popBlock(TypeId block) +{ + for (auto& [_, list] : constraintBlocks) + { + list.erase(std::remove(list.begin(), list.end(), block), list.end()); + } +} + +void DcrLogger::popBlock(TypePackId block) +{ + for (auto& [_, list] : constraintBlocks) + { + list.erase(std::remove(list.begin(), list.end(), block), list.end()); + } +} + +void DcrLogger::popBlock(NotNull block) +{ + for (auto& [_, list] : constraintBlocks) + { + list.erase(std::remove(list.begin(), list.end(), block), list.end()); + } +} + +void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) +{ + solveLog.initialState.rootScope = snapshotScope(rootScope, opts); + solveLog.initialState.constraints.clear(); + + for (NotNull c : unsolvedConstraints) + { + std::string id = toPointerId(c); + solveLog.initialState.constraints[id] = { + toString(*c.get(), opts), + snapshotBlocks(c), + }; + } +} + +StepSnapshot DcrLogger::prepareStepSnapshot( + const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints) +{ + ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts); + std::string currentId = toPointerId(current); + std::unordered_map constraints; + + for (NotNull c : unsolvedConstraints) + { + std::string id = toPointerId(c); + constraints[id] = { + toString(*c.get(), opts), + snapshotBlocks(c), + }; + } + + return StepSnapshot{ + currentId, + force, + constraints, + scopeSnapshot, + }; +} + +void DcrLogger::commitStepSnapshot(StepSnapshot snapshot) +{ + solveLog.stepStates.push_back(std::move(snapshot)); +} + +void DcrLogger::captureFinalSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) +{ + solveLog.finalState.rootScope = snapshotScope(rootScope, opts); + solveLog.finalState.constraints.clear(); + + for (NotNull c : unsolvedConstraints) + { + std::string id = toPointerId(c); + solveLog.finalState.constraints[id] = { + toString(*c.get(), opts), + snapshotBlocks(c), + }; + } +} + +void DcrLogger::captureTypeCheckError(const TypeError& error) +{ + std::string stringifiedError = toString(error); + checkLog.errors.push_back(ErrorSnapshot{ + /* message */ stringifiedError, + /* location */ error.location, + }); +} + +std::vector DcrLogger::snapshotBlocks(NotNull c) +{ + auto it = constraintBlocks.find(c); + if (it == constraintBlocks.end()) + { + return {}; + } + + std::vector snapshot; + + for (const ConstraintBlockTarget& target : it->second) + { + if (const TypeId* ty = get_if(&target)) + { + snapshot.push_back({ + ConstraintBlockKind::TypeId, + toString(*ty, opts), + }); + } + else if (const TypePackId* tp = get_if(&target)) + { + snapshot.push_back({ + ConstraintBlockKind::TypePackId, + toString(*tp, opts), + }); + } + else if (const NotNull* c = get_if>(&target)) + { + snapshot.push_back({ + ConstraintBlockKind::ConstraintId, + toString(*(c->get()), opts), + }); + } + else + { + LUAU_ASSERT(0); + } + } + + return snapshot; +} + +} // namespace Luau diff --git a/Analysis/src/Def.cpp b/Analysis/src/Def.cpp new file mode 100644 index 00000000..8ce1129c --- /dev/null +++ b/Analysis/src/Def.cpp @@ -0,0 +1,17 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Def.h" + +namespace Luau +{ + +DefId DefArena::freshCell() +{ + return NotNull{allocator.allocate(Def{Cell{std::nullopt}})}; +} + +DefId DefArena::freshCell(DefId parent, const std::string& prop) +{ + return NotNull{allocator.allocate(Def{Cell{FieldMetadata{parent, prop}}})}; +} + +} // namespace Luau diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 61a63f06..b0f21737 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,8 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(LuauParseGenericFunctions) -LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(LuauOptionalNextKey) namespace Luau { @@ -10,59 +10,65 @@ namespace Luau static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC( declare bit32: { - -- band, bor, bxor, and btest are declared in C++ - rrotate: (number, number) -> number, - lrotate: (number, number) -> number, - lshift: (number, number) -> number, - arshift: (number, number) -> number, - rshift: (number, number) -> number, - bnot: (number) -> number, - extract: (number, number, number?) -> number, - replace: (number, number, number, number?) -> number, + band: (...number) -> number, + bor: (...number) -> number, + bxor: (...number) -> number, + btest: (number, ...number) -> boolean, + rrotate: (x: number, disp: number) -> number, + lrotate: (x: number, disp: number) -> number, + lshift: (x: number, disp: number) -> number, + arshift: (x: number, disp: number) -> number, + rshift: (x: number, disp: number) -> number, + bnot: (x: number) -> number, + extract: (n: number, field: number, width: number?) -> number, + replace: (n: number, v: number, field: number, width: number?) -> number, + countlz: (n: number) -> number, + countrz: (n: number) -> number, } declare math: { - frexp: (number) -> (number, number), - ldexp: (number, number) -> number, - fmod: (number, number) -> number, - modf: (number) -> (number, number), - pow: (number, number) -> number, - exp: (number) -> number, + frexp: (n: number) -> (number, number), + ldexp: (s: number, e: number) -> number, + fmod: (x: number, y: number) -> number, + modf: (n: number) -> (number, number), + pow: (x: number, y: number) -> number, + exp: (n: number) -> number, - ceil: (number) -> number, - floor: (number) -> number, - abs: (number) -> number, - sqrt: (number) -> number, + ceil: (n: number) -> number, + floor: (n: number) -> number, + abs: (n: number) -> number, + sqrt: (n: number) -> number, - log: (number, number?) -> number, - log10: (number) -> number, + log: (n: number, base: number?) -> number, + log10: (n: number) -> number, - rad: (number) -> number, - deg: (number) -> number, + rad: (n: number) -> number, + deg: (n: number) -> number, - sin: (number) -> number, - cos: (number) -> number, - tan: (number) -> number, - sinh: (number) -> number, - cosh: (number) -> number, - tanh: (number) -> number, - atan: (number) -> number, - acos: (number) -> number, - asin: (number) -> number, - atan2: (number, number) -> number, + sin: (n: number) -> number, + cos: (n: number) -> number, + tan: (n: number) -> number, + sinh: (n: number) -> number, + cosh: (n: number) -> number, + tanh: (n: number) -> number, + atan: (n: number) -> number, + acos: (n: number) -> number, + asin: (n: number) -> number, + atan2: (y: number, x: number) -> number, - -- min and max are declared in C++. + min: (number, ...number) -> number, + max: (number, ...number) -> number, pi: number, huge: number, - randomseed: (number) -> (), + randomseed: (seed: number) -> (), random: (number?, number?) -> number, - sign: (number) -> number, - clamp: (number, number, number) -> number, - noise: (number, number?, number?) -> number, - round: (number) -> number, + sign: (n: number) -> number, + clamp: (n: number, min: number, max: number) -> number, + noise: (x: number, y: number?, z: number?) -> number, + round: (n: number) -> number, } type DateTypeArg = { @@ -88,151 +94,126 @@ type DateTypeResult = { } declare os: { - time: (DateTypeArg?) -> number, - date: (string?, number?) -> DateTypeResult | string, - difftime: (DateTypeResult | number, DateTypeResult | number) -> number, + time: (time: DateTypeArg?) -> number, + date: (formatString: string?, time: number?) -> DateTypeResult | string, + difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, clock: () -> number, } declare function require(target: any): any -declare function getfenv(target: any?): { [string]: any } +declare function getfenv(target: any): { [string]: any } declare _G: any declare _VERSION: string declare function gcinfo(): number +declare function print(...: T...) + +declare function type(value: T): string +declare function typeof(value: T): string + +-- `assert` has a magic function attached that will give more detailed type information +declare function assert(value: T, errorMessage: string?): T + +declare function tostring(value: T): string +declare function tonumber(value: T, radix: number?): number? + +declare function rawequal(a: T1, b: T2): boolean +declare function rawget(tab: {[K]: V}, k: K): V +declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} +declare function rawlen(obj: {[K]: V} | string): number + +declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? + +-- TODO: place ipairs definition here with removal of FFlagLuauOptionalNextKey + +declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) + +-- FIXME: The actual type of `xpcall` is: +-- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) +-- Since we can't represent the return value, we use (boolean, R1...). +declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) + +-- `select` has a magic function attached to provide more detailed type information +declare function select(i: string | number, ...: A...): ...any + +-- FIXME: This type is not entirely correct - `loadstring` returns a function or +-- (nil, string). +declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) + +declare function newproxy(mt: boolean?): any + +declare coroutine: { + create: (f: (A...) -> R...) -> thread, + resume: (co: thread, A...) -> (boolean, R...), + running: () -> thread, + status: (co: thread) -> "dead" | "running" | "normal" | "suspended", + -- FIXME: This technically returns a function, but we can't represent this yet. + wrap: (f: (A...) -> R...) -> any, + yield: (A...) -> R..., + isyieldable: () -> boolean, + close: (co: thread) -> (boolean, any) +} + +declare table: { + concat: (t: {V}, sep: string?, i: number?, j: number?) -> string, + insert: ((t: {V}, value: V) -> ()) & ((t: {V}, pos: number, value: V) -> ()), + maxn: (t: {V}) -> number, + remove: (t: {V}, number?) -> V?, + sort: (t: {V}, comp: ((V, V) -> boolean)?) -> (), + create: (count: number, value: V?) -> {V}, + find: (haystack: {V}, needle: V, init: number?) -> number?, + + unpack: (list: {V}, i: number?, j: number?) -> ...V, + pack: (...V) -> { n: number, [number]: V }, + + getn: (t: {V}) -> number, + foreach: (t: {[K]: V}, f: (K, V) -> ()) -> (), + foreachi: ({V}, (number, V) -> ()) -> (), + + move: (src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V}, + clear: (table: {[K]: V}) -> (), + + isfrozen: (t: {[K]: V}) -> boolean, +} + +declare debug: { + info: ((thread: thread, level: number, options: string) -> R...) & ((level: number, options: string) -> R...) & ((func: (A...) -> R1..., options: string) -> R2...), + traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), +} + +declare utf8: { + char: (...number) -> string, + charpattern: string, + codes: (str: string) -> ((string, number) -> (number, number), string, number), + codepoint: (str: string, i: number?, j: number?) -> ...number, + len: (s: string, i: number?, j: number?) -> (number?, number?), + offset: (s: string, n: number?, i: number?) -> number, +} + +-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. +declare function unpack(tab: {V}, i: number?, j: number?): ...V + )BUILTIN_SRC"; std::string getBuiltinDefinitionSource() { - std::string src = kBuiltinDefinitionLuaSrc; - if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) - { - src += R"( - declare function print(...: T...) + std::string result = kBuiltinDefinitionLuaSrc; - declare function type(value: T): string - declare function typeof(value: T): string - - -- `assert` has a magic function attached that will give more detailed type information - declare function assert(value: T, errorMessage: string?): T + if (FFlag::LuauUnknownAndNeverType) + result += "declare function error(message: T, level: number?): never\n"; + else + result += "declare function error(message: T, level: number?)\n"; - declare function error(message: T, level: number?) + if (FFlag::LuauOptionalNextKey) + result += "declare function ipairs(tab: {V}): (({V}, number) -> (number?, V), {V}, number)\n"; + else + result += "declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number)\n"; - declare function tostring(value: T): string - declare function tonumber(value: T, radix: number?): number - - declare function rawequal(a: T1, b: T2): boolean - declare function rawget(tab: {[K]: V}, k: K): V - declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} - - declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? - - declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number) - - declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) - - -- FIXME: The actual type of `xpcall` is: - -- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) - -- Since we can't represent the return value, we use (boolean, R1...). - declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) - - -- `select` has a magic function attached to provide more detailed type information - declare function select(i: string | number, ...: A...): ...any - - -- FIXME: This type is not entirely correct - `loadstring` returns a function or - -- (nil, string). - declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) - - -- a userdata object is "roughly" the same as a sealed empty table - -- except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too. - -- another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT - -- setmetatable. - -- FIXME: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`. - declare function newproxy(mt: boolean?): {} - - declare coroutine: { - create: ((A...) -> R...) -> thread, - resume: (thread, A...) -> (boolean, R...), - running: () -> thread, - status: (thread) -> string, - -- FIXME: This technically returns a function, but we can't represent this yet. - wrap: ((A...) -> R...) -> any, - yield: (A...) -> R..., - isyieldable: () -> boolean, - } - - declare table: { - concat: ({V}, string?, number?, number?) -> string, - insert: (({V}, V) -> ()) & (({V}, number, V) -> ()), - maxn: ({V}) -> number, - remove: ({V}, number?) -> V?, - sort: ({V}, ((V, V) -> boolean)?) -> (), - create: (number, V?) -> {V}, - find: ({V}, V, number?) -> number?, - - unpack: ({V}, number?, number?) -> ...V, - pack: (...V) -> { n: number, [number]: V }, - - getn: ({V}) -> number, - foreach: ({[K]: V}, (K, V) -> ()) -> (), - foreachi: ({V}, (number, V) -> ()) -> (), - - move: ({V}, number, number, number, {V}?) -> (), - clear: ({[K]: V}) -> (), - - freeze: ({[K]: V}) -> {[K]: V}, - isfrozen: ({[K]: V}) -> boolean, - } - - declare debug: { - info: ((thread, number, string) -> R...) & ((number, string) -> R...) & (((A...) -> R1..., string) -> R2...), - traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), - } - - declare utf8: { - char: (number, ...number) -> string, - charpattern: string, - codes: (string) -> ((string, number) -> (number, number), string, number), - -- FIXME - codepoint: (string, number?, number?) -> (number, ...number), - len: (string, number?, number?) -> (number?, number?), - offset: (string, number?, number?) -> number, - nfdnormalize: (string) -> string, - nfcnormalize: (string) -> string, - graphemes: (string, number?, number?) -> (() -> (number, number)), - } - - declare string: { - byte: (string, number?, number?) -> ...number, - char: (number, ...number) -> string, - find: (string, string, number?, boolean?) -> (number?, number?), - -- `string.format` has a magic function attached that will provide more type information for literal format strings. - format: (string, A...) -> string, - gmatch: (string, string) -> () -> (...string), - -- gsub is defined in C++ because we don't have syntax for describing a generic table. - len: (string) -> number, - lower: (string) -> string, - match: (string, string, number?) -> string?, - rep: (string, number) -> string, - reverse: (string) -> string, - sub: (string, number, number?) -> string, - upper: (string) -> string, - split: (string, string, string?) -> {string}, - pack: (string, A...) -> string, - packsize: (string) -> number, - unpack: (string, string, number?) -> R..., - } - - -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. - declare function unpack(tab: {V}, i: number?, j: number?): ...V - )"; - } - - return src; + return result; } } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 680bcf3f..748cf20f 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -1,23 +1,35 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Error.h" -#include "Luau/Module.h" +#include "Luau/Clone.h" +#include "Luau/Common.h" +#include "Luau/FileResolver.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" #include +#include -LUAU_FASTFLAG(LuauFasterStringifier) +LUAU_FASTFLAGVARIABLE(LuauTypeMismatchInvarianceInError, false) -static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) +static std::string wrongNumberOfArgsString( + size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { - std::string s = "expects " + std::to_string(expectedCount) + " "; + std::string s = "expects "; - if (isTypeArgs) - s += "type "; + if (isVariadic) + s += "at least "; + + s += std::to_string(expectedCount) + " "; + + if (maximumCount && expectedCount != *maximumCount) + s += "to " + std::to_string(*maximumCount) + " "; + + if (argPrefix) + s += std::string(argPrefix) + " "; s += "argument"; - if (expectedCount != 1) + if ((maximumCount ? *maximumCount : expectedCount) != 1) s += "s"; s += ", but "; @@ -46,10 +58,60 @@ namespace Luau struct ErrorConverter { + FileResolver* fileResolver = nullptr; + std::string operator()(const Luau::TypeMismatch& tm) const { - ToStringOptions opts; - return "Type '" + Luau::toString(tm.givenType, opts) + "' could not be converted into '" + Luau::toString(tm.wantedType, opts) + "'"; + std::string givenTypeName = Luau::toString(tm.givenType); + std::string wantedTypeName = Luau::toString(tm.wantedType); + + std::string result; + + if (givenTypeName == wantedTypeName) + { + if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) + { + if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) + { + if (fileResolver != nullptr) + { + std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); + std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); + result = "Type '" + givenTypeName + "' from '" + givenModuleName + "' could not be converted into '" + wantedTypeName + + "' from '" + wantedModuleName + "'"; + } + else + { + result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + + "' from '" + *wantedDefinitionModule + "'"; + } + } + } + } + + if (result.empty()) + result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; + + + if (tm.error) + { + result += "\ncaused by:\n "; + + if (!tm.reason.empty()) + result += tm.reason + " "; + + result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); + } + else if (!tm.reason.empty()) + { + result += "; " + tm.reason; + } + else if (FFlag::LuauTypeMismatchInvarianceInError && tm.context == TypeMismatch::InvariantContext) + { + result += " in an invariant context"; + } + + return result; } std::string operator()(const Luau::UnknownSymbol& e) const @@ -60,8 +122,6 @@ struct ErrorConverter return "Unknown global '" + e.name + "'"; case UnknownSymbol::Type: return "Unknown type '" + e.name + "'"; - case UnknownSymbol::Generic: - return "Unknown generic '" + e.name + "'"; } LUAU_ASSERT(!"Unexpected context for UnknownSymbol"); @@ -107,28 +167,38 @@ struct ErrorConverter std::string operator()(const Luau::DuplicateTypeDefinition& e) const { - return "Redefinition of type '" + e.name + "', previously defined at line " + std::to_string(e.previousLocation.begin.line + 1); + std::string s = "Redefinition of type '" + e.name + "'"; + if (e.previousLocation) + s += ", previously defined at line " + std::to_string(e.previousLocation->begin.line + 1); + return s; } std::string operator()(const Luau::CountMismatch& e) const { + const std::string expectedS = e.expected == 1 ? "" : "s"; + const std::string actualS = e.actual == 1 ? "" : "s"; + const std::string actualVerb = e.actual == 1 ? "is" : "are"; + switch (e.context) { case CountMismatch::Return: - { - const std::string expectedS = e.expected == 1 ? "" : "s"; - const std::string actualS = e.actual == 1 ? "is" : "are"; - return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + actualS + - " returned here"; - } - case CountMismatch::Result: - if (e.expected > e.actual) - return "Function returns " + std::to_string(e.expected) + " values but there are only " + std::to_string(e.expected) + - " values to unpack them into."; - else - return "Function only returns " + std::to_string(e.expected) + " values. " + std::to_string(e.actual) + " are required here"; + return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + + actualVerb + " returned here"; + case CountMismatch::FunctionResult: + // It is alright if right hand side produces more values than the + // left hand side accepts. In this context consider only the opposite case. + return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + + actualVerb + " required here"; + case CountMismatch::ExprListResult: + return "Expression list has " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + + actualVerb + " required here"; case CountMismatch::Arg: - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + if (!e.function.empty()) + return "Argument count mismatch. Function '" + e.function + "' " + + wrongNumberOfArgsString(e.expected, e.maximum, e.actual, /*argPrefix*/ nullptr, e.isVariadic); + else + return "Argument count mismatch. Function " + + wrongNumberOfArgsString(e.expected, e.maximum, e.actual, /*argPrefix*/ nullptr, e.isVariadic); } LUAU_ASSERT(!"Unknown context"); @@ -142,15 +212,7 @@ struct ErrorConverter std::string operator()(const Luau::FunctionRequiresSelf& e) const { - if (e.requiredExtraNils) - { - const char* plural = e.requiredExtraNils == 1 ? "" : "s"; - return format("This function was declared to accept self, but you did not pass enough arguments. Use a colon instead of a dot or " - "pass %i extra nil%s to suppress this warning", - e.requiredExtraNils, plural); - } - else - return "This function must be called with self. Did you mean to use a colon instead of a dot?"; + return "This function must be called with self. Did you mean to use a colon instead of a dot?"; } std::string operator()(const Luau::OccursCheckFailed&) const @@ -160,34 +222,53 @@ struct ErrorConverter std::string operator()(const Luau::UnknownRequire& e) const { - return "Unknown require: " + e.modulePath; + if (e.modulePath.empty()) + return "Unknown require: unsupported path"; + else + return "Unknown require: " + e.modulePath; } std::string operator()(const Luau::IncorrectGenericParameterCount& e) const { std::string name = e.name; - if (!e.typeFun.typeParams.empty()) + if (!e.typeFun.typeParams.empty() || !e.typeFun.typePackParams.empty()) { name += "<"; bool first = true; - for (TypeId t : e.typeFun.typeParams) + for (auto param : e.typeFun.typeParams) { if (first) first = false; else name += ", "; - name += toString(t); + name += toString(param.ty); } + + for (auto param : e.typeFun.typePackParams) + { + if (first) + first = false; + else + name += ", "; + + name += toString(param.tp); + } + name += ">"; } - return "Generic type '" + name + "' " + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); + if (e.typeFun.typeParams.size() != e.actualParameters) + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typeParams.size(), std::nullopt, e.actualParameters, "type", !e.typeFun.typePackParams.empty()); + + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typePackParams.size(), std::nullopt, e.actualPackParameters, "type pack", /*isVariadic*/ false); } std::string operator()(const Luau::SyntaxError& e) const { - return "Syntax error: " + e.message; + return e.message; } std::string operator()(const Luau::CodeTooComplex&) const @@ -234,6 +315,11 @@ struct ErrorConverter return e.message; } + std::string operator()(const Luau::InternalError& e) const + { + return e.message; + } + std::string operator()(const Luau::CannotCallNonFunction& e) const { return "Cannot call non-function " + toString(e.ty); @@ -374,6 +460,26 @@ struct ErrorConverter return ss + " in the type '" + toString(e.type) + "'"; } + + std::string operator()(const TypesAreUnrelated& e) const + { + return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated"; + } + + std::string operator()(const NormalizationTooComplex&) const + { + return "Code is too complex to typecheck! Consider simplifying the code around this area"; + } + + std::string operator()(const TypePackMismatch& e) const + { + return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'"; + } + + std::string operator()(const DynamicPropertyLookupOnClassesUnsafe& e) const + { + return "Attempting a dynamic property access on type '" + Luau::toString(e.ty) + "' is unsafe and may cause exceptions at runtime"; + } }; struct InvalidNameChecker @@ -400,9 +506,60 @@ struct InvalidNameChecker } }; +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType) + : wantedType(wantedType) + , givenType(givenType) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason) + : wantedType(wantedType) + , givenType(givenType) + , reason(reason) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional error) + : wantedType(wantedType) + , givenType(givenType) + , reason(reason) + , error(error ? std::make_shared(std::move(*error)) : nullptr) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, TypeMismatch::Context context) + : wantedType(wantedType) + , givenType(givenType) + , context(context) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeMismatch::Context context) + : wantedType(wantedType) + , givenType(givenType) + , context(context) + , reason(reason) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional error, TypeMismatch::Context context) + : wantedType(wantedType) + , givenType(givenType) + , context(context) + , reason(reason) + , error(error ? std::make_shared(std::move(*error)) : nullptr) +{ +} + bool TypeMismatch::operator==(const TypeMismatch& rhs) const { - return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType; + if (!!error != !!rhs.error) + return false; + + if (error && !(*error == *rhs.error)) + return false; + + return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType && reason == rhs.reason && context == rhs.context; } bool UnknownSymbol::operator==(const UnknownSymbol& rhs) const @@ -437,7 +594,7 @@ bool DuplicateTypeDefinition::operator==(const DuplicateTypeDefinition& rhs) con bool CountMismatch::operator==(const CountMismatch& rhs) const { - return expected == rhs.expected && actual == rhs.actual && context == rhs.context; + return expected == rhs.expected && maximum == rhs.maximum && actual == rhs.actual && context == rhs.context && function == rhs.function; } bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const @@ -447,7 +604,7 @@ bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const bool FunctionRequiresSelf::operator==(const FunctionRequiresSelf& e) const { - return requiredExtraNils == e.requiredExtraNils; + return true; } bool OccursCheckFailed::operator==(const OccursCheckFailed&) const @@ -471,9 +628,20 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size()) return false; + if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size()) + return false; + for (size_t i = 0; i < typeFun.typeParams.size(); ++i) - if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) + { + if (typeFun.typeParams[i].ty != rhs.typeFun.typeParams[i].ty) return false; + } + + for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) + { + if (typeFun.typePackParams[i].tp != rhs.typeFun.typePackParams[i].tp) + return false; + } return true; } @@ -504,6 +672,11 @@ bool GenericError::operator==(const GenericError& rhs) const return message == rhs.message; } +bool InternalError::operator==(const InternalError& rhs) const +{ + return message == rhs.message; +} + bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const { return ty == rhs.ty; @@ -526,7 +699,17 @@ bool FunctionExitsWithoutReturning::operator==(const FunctionExitsWithoutReturni int TypeError::code() const { - return 1000 + int(data.index()); + return minCode() + int(data.index()); +} + +int TypeError::minCode() +{ + return 1000; +} + +TypeErrorSummary TypeError::summary() const +{ + return TypeErrorSummary{location, moduleName, code()}; } bool TypeError::operator==(const TypeError& rhs) const @@ -584,9 +767,29 @@ bool MissingUnionProperty::operator==(const MissingUnionProperty& rhs) const return *type == *rhs.type && key == rhs.key; } +bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const +{ + return left == rhs.left && right == rhs.right; +} + +bool TypePackMismatch::operator==(const TypePackMismatch& rhs) const +{ + return *wantedTp == *rhs.wantedTp && *givenTp == *rhs.givenTp; +} + +bool DynamicPropertyLookupOnClassesUnsafe::operator==(const DynamicPropertyLookupOnClassesUnsafe& rhs) const +{ + return ty == rhs.ty; +} + std::string toString(const TypeError& error) { - ErrorConverter converter; + return toString(error, TypeErrorToStringOptions{}); +} + +std::string toString(const TypeError& error, TypeErrorToStringOptions options) +{ + ErrorConverter converter{options.fileResolver}; return Luau::visit(converter, error.data); } @@ -595,130 +798,158 @@ bool containsParseErrorName(const TypeError& error) return Luau::visit(InvalidNameChecker{}, error.data); } -void copyErrors(ErrorVec& errors, struct TypeArena& destArena) +template +void copyError(T& e, TypeArena& destArena, CloneState cloneState) { - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - auto clone = [&](auto&& ty) { - return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks); + return ::Luau::clone(ty, destArena, cloneState); }; auto visitErrorData = [&](auto&& e) { - using T = std::decay_t; + copyError(e, destArena, cloneState); + }; - if constexpr (false) - { - } - else if constexpr (std::is_same_v) - { - e.wantedType = clone(e.wantedType); - e.givenType = clone(e.givenType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.table = clone(e.table); - } - else if constexpr (std::is_same_v) - { - e.ty = clone(e.ty); - } - else if constexpr (std::is_same_v) - { - e.tableType = clone(e.tableType); - } - else if constexpr (std::is_same_v) - { - e.tableType = clone(e.tableType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.typeFun = clone(e.typeFun); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.table = clone(e.table); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.ty = clone(e.ty); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.expectedReturnType = clone(e.expectedReturnType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.superType = clone(e.superType); - e.subType = clone(e.subType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.optional = clone(e.optional); - } - else if constexpr (std::is_same_v) - { - e.type = clone(e.type); + if constexpr (false) + { + } + else if constexpr (std::is_same_v) + { + e.wantedType = clone(e.wantedType); + e.givenType = clone(e.givenType); - for (auto& ty : e.missing) - ty = clone(ty); - } - else - static_assert(always_false_v, "Non-exhaustive type switch"); + if (e.error) + visit(visitErrorData, e.error->data); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.table = clone(e.table); + } + else if constexpr (std::is_same_v) + { + e.ty = clone(e.ty); + } + else if constexpr (std::is_same_v) + { + e.tableType = clone(e.tableType); + } + else if constexpr (std::is_same_v) + { + e.tableType = clone(e.tableType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.typeFun = clone(e.typeFun); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.table = clone(e.table); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.ty = clone(e.ty); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.expectedReturnType = clone(e.expectedReturnType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.superType = clone(e.superType); + e.subType = clone(e.subType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.optional = clone(e.optional); + } + else if constexpr (std::is_same_v) + { + e.type = clone(e.type); + + for (auto& ty : e.missing) + ty = clone(ty); + } + else if constexpr (std::is_same_v) + { + e.left = clone(e.left); + e.right = clone(e.right); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.wantedTp = clone(e.wantedTp); + e.givenTp = clone(e.givenTp); + } + else if constexpr (std::is_same_v) + e.ty = clone(e.ty); + else + static_assert(always_false_v, "Non-exhaustive type switch"); +} + +void copyErrors(ErrorVec& errors, TypeArena& destArena) +{ + CloneState cloneState; + + auto visitErrorData = [&](auto&& e) { + copyError(e, destArena, cloneState); }; LUAU_ASSERT(!destArena.typeVars.isFrozen()); @@ -730,7 +961,7 @@ void copyErrors(ErrorVec& errors, struct TypeArena& destArena) void InternalErrorReporter::ice(const std::string& message, const Location& location) { - std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message); + InternalCompilerError error(message, moduleName, location); if (onInternalError) onInternalError(error.what()); @@ -740,7 +971,7 @@ void InternalErrorReporter::ice(const std::string& message, const Location& loca void InternalErrorReporter::ice(const std::string& message) { - std::runtime_error error("Internal error in " + moduleName + ": " + message); + InternalCompilerError error(message, moduleName); if (onInternalError) onInternalError(error.what()); @@ -748,4 +979,9 @@ void InternalErrorReporter::ice(const std::string& message) throw error; } +const char* InternalCompilerError::what() const throw() +{ + return this->message.data(); +} + } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 4d385ec1..e21e42e1 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,39 +1,53 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/Config.h" +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintSolver.h" +#include "Luau/DataFlowGraph.h" +#include "Luau/DcrLogger.h" #include "Luau/FileResolver.h" +#include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/StringUtils.h" +#include "Luau/TimeTrace.h" +#include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" -#include "Luau/Common.h" #include #include #include +#include +LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) -LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false) -LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) -LUAU_FASTFLAG(LuauTraceRequireLookupChild) -LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) +LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) +LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) +LUAU_FASTFLAG(DebugLuauLogSolverToJson); namespace Luau { -std::optional parseMode(const std::vector& hotcomments) +std::optional parseMode(const std::vector& hotcomments) { - for (const std::string& hc : hotcomments) + for (const HotComment& hc : hotcomments) { - if (hc == "nocheck") + if (!hc.header) + continue; + + if (hc.content == "nocheck") return Mode::NoCheck; - if (hc == "nonstrict") + if (hc.content == "nonstrict") return Mode::Nonstrict; - if (hc == "strict") + if (hc.content == "strict") return Mode::Strict; } @@ -67,8 +81,70 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) } } +LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName) +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, source, packageName); + + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + + ParseOptions options; + options.allowDeclarationSyntax = true; + + Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); + + if (parseResult.errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, nullptr}; + + Luau::SourceModule module; + module.root = parseResult.root; + module.mode = Mode::Definition; + + ModulePtr checkedModule = check(module, Mode::Definition, globalScope, {}); + + if (checkedModule->errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, checkedModule}; + + CloneState cloneState; + + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); + + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + typesToPersist.push_back(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + globalScope->exportedTypeBindings[name] = globalTy; + + typesToPersist.push_back(globalTy.type); + } + + for (TypeId ty : typesToPersist) + { + persist(ty); + } + + return LoadDefinitionFileResult{true, parseResult, checkedModule}; +} + LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName) { + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); + Luau::Allocator allocator; Luau::AstNameTable names(allocator); @@ -89,29 +165,34 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, checkedModule}; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; + CloneState cloneState; + + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size()); for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - if (FFlag::LuauPersistDefinitionFileTypes) - persist(globalTy); + typesToPersist.push_back(globalTy); } for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; - if (FFlag::LuauPersistDefinitionFileTypes) - persist(globalTy.type); + typesToPersist.push_back(globalTy.type); + } + + for (TypeId ty : typesToPersist) + { + persist(ty); } return LoadDefinitionFileResult{true, parseResult, checkedModule}; @@ -208,7 +289,7 @@ ErrorVec accumulateErrors( continue; const SourceNode& sourceNode = it->second; - queue.insert(queue.end(), sourceNode.requires.begin(), sourceNode.requires.end()); + queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end()); // FIXME: If a module has a syntax error, we won't be able to re-report it here. // The solution is probably to move errors from Module to SourceNode @@ -231,18 +312,12 @@ ErrorVec accumulateErrors( return result; } -struct RequireCycle -{ - Location location; - std::vector path; // one of the paths for a require() to go all the way back to the originating module -}; - // Given a source node (start), find all requires that start a transitive dependency path that ends back at start // For each such path, record the full path and the location of the require in the starting module. // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) // However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true) std::vector getRequireCycles( - const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) + const FileResolver* resolver, const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) { std::vector result; @@ -276,9 +351,9 @@ std::vector getRequireCycles( if (top == start) { for (const SourceNode* node : path) - cycle.push_back(node->name); + cycle.push_back(resolver->getHumanReadableModuleName(node->name)); - cycle.push_back(top->name); + cycle.push_back(resolver->getHumanReadableModuleName(top->name)); break; } } @@ -333,13 +408,15 @@ double getTimestamp() } // namespace Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options) - : fileResolver(fileResolver) + : singletonTypes(NotNull{&singletonTypes_}) + , fileResolver(fileResolver) , moduleResolver(this) , moduleResolverForAutocomplete(this) - , typeChecker(&moduleResolver, &iceHandler) - , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, &iceHandler) + , typeChecker(&moduleResolver, singletonTypes, &iceHandler) + , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, singletonTypes, &iceHandler) , configResolver(configResolver) , options(options) + , globalScope(typeChecker.globalScope) { } @@ -348,33 +425,49 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) { } -CheckResult Frontend::check(const ModuleName& name) +CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) { + LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + + FrontendOptions frontendOptions = optionOverride.value_or(options); CheckResult checkResult; auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.dirty) + if (it != sourceNodes.end() && !it->second.hasDirtyModule(frontendOptions.forAutocomplete)) { // No recheck required. - auto it2 = moduleResolver.modules.find(name); - if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); + if (frontendOptions.forAutocomplete) + { + auto it2 = moduleResolverForAutocomplete.modules.find(name); + if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) + throw InternalCompilerError("Frontend::modules does not have data for " + name, name); + } + else + { + auto it2 = moduleResolver.modules.find(name); + if (it2 == moduleResolver.modules.end() || it2->second == nullptr) + throw InternalCompilerError("Frontend::modules does not have data for " + name, name); + } - return CheckResult{accumulateErrors(sourceNodes, moduleResolver.modules, name)}; + return CheckResult{ + accumulateErrors(sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; } std::vector buildQueue; - bool cycleDetected = parseGraph(buildQueue, checkResult, name); + bool cycleDetected = parseGraph(buildQueue, checkResult, name, frontendOptions.forAutocomplete); // Keep track of which AST nodes we've reported cycles in std::unordered_set reportedCycles; + double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; + for (const ModuleName& moduleName : buildQueue) { LUAU_ASSERT(sourceNodes.count(moduleName)); SourceNode& sourceNode = sourceNodes[moduleName]; - if (!sourceNode.dirty) + if (!sourceNode.hasDirtyModule(frontendOptions.forAutocomplete)) continue; LUAU_ASSERT(sourceModules.count(moduleName)); @@ -384,7 +477,7 @@ CheckResult Frontend::check(const ModuleName& name) Mode mode = sourceModule.mode.value_or(config.mode); - ScopePtr environmentScope = getModuleEnvironment(sourceModule, config); + ScopePtr environmentScope = getModuleEnvironment(sourceModule, config, frontendOptions.forAutocomplete); double timestamp = getTimestamp(); @@ -395,49 +488,75 @@ CheckResult Frontend::check(const ModuleName& name) // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // all correct programs must be acyclic so this code triggers rarely if (cycleDetected) - requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck); + requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck); // This is used by the type checker to replace the resulting type of cyclic modules with any sourceModule.cyclic = !requireCycles.empty(); - ModulePtr module = typeChecker.check(sourceModule, mode, environmentScope); - - // If we're typechecking twice, we do so. - // The second typecheck is always in strict mode with DM awareness - // to provide better typen information for IDE features. - if (options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel) + if (frontendOptions.forAutocomplete) { - ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); + // The autocomplete typecheck is always in strict mode with DM awareness + // to provide better type information for IDE features + typeCheckerForAutocomplete.requireCycles = requireCycles; + + if (autocompleteTimeLimit != 0.0) + typeCheckerForAutocomplete.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; + else + typeCheckerForAutocomplete.finishTime = std::nullopt; + + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (FInt::LuauTarjanChildLimit > 0) + typeCheckerForAutocomplete.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckerForAutocomplete.unifierIterationLimit = + std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; + + ModulePtr moduleForAutocomplete = FFlag::DebugLuauDeferredConstraintResolution + ? check(sourceModule, mode, environmentScope, requireCycles, /*forAutocomplete*/ true) + : typeCheckerForAutocomplete.check(sourceModule, Mode::Strict, environmentScope); + moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; + + double duration = getTimestamp() - timestamp; + + if (moduleForAutocomplete->timeout) + { + checkResult.timeoutHits.push_back(moduleName); + + sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + } + else if (duration < autocompleteTimeLimit / 2.0) + { + sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); + } + + stats.timeCheck += duration; + stats.filesStrict += 1; + + sourceNode.dirtyModuleForAutocomplete = false; + continue; } - else if (options.retainFullTypeGraphs && options.typecheckTwice && mode != Mode::Strict) - { - ModulePtr strictModule = typeChecker.check(sourceModule, Mode::Strict, environmentScope); - module->astTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astExpectedTypes.clear(); - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; + typeChecker.requireCycles = requireCycles; - for (const auto& [expr, strictTy] : strictModule->astTypes) - module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); - - for (const auto& [expr, strictTy] : strictModule->astOriginalCallTypes) - module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); - - for (const auto& [expr, strictTy] : strictModule->astExpectedTypes) - module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); - } + ModulePtr module = FFlag::DebugLuauDeferredConstraintResolution ? check(sourceModule, mode, environmentScope, requireCycles) + : typeChecker.check(sourceModule, mode, environmentScope); stats.timeCheck += getTimestamp() - timestamp; stats.filesStrict += mode == Mode::Strict; stats.filesNonstrict += mode == Mode::Nonstrict; if (module == nullptr) - throw std::runtime_error("Frontend::check produced a nullptr module for " + moduleName); + throw InternalCompilerError("Frontend::check produced a nullptr module for " + moduleName, moduleName); - if (!options.retainFullTypeGraphs) + if (!frontendOptions.retainFullTypeGraphs) { // copyErrors needs to allocate into interfaceTypes as it copies // types out of internalTypes, so we unfreeze it here. @@ -449,6 +568,9 @@ CheckResult Frontend::check(const ModuleName& name) module->astTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); + module->astResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); + module->scopes.resize(1); } if (mode != Mode::NoCheck) @@ -471,14 +593,17 @@ CheckResult Frontend::check(const ModuleName& name) checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end()); moduleResolver.modules[moduleName] = std::move(module); - sourceNode.dirty = false; + sourceNode.dirtyModule = false; } return checkResult; } -bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root) +bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete) { + LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); + // https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search enum Mark { @@ -536,7 +661,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec path.push_back(top); // push children - for (const ModuleName& dep : top->requires) + for (const ModuleName& dep : top->requireSet) { auto it = sourceNodes.find(dep); if (it != sourceNodes.end()) @@ -545,7 +670,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec // this relies on the fact that markDirty marks reverse-dependencies dirty as well // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // to be built, *and* can't form a cycle with any nodes we did process. - if (!it->second.dirty) + if (!it->second.hasDirtyModule(forAutocomplete)) continue; // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization @@ -572,9 +697,13 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec return cyclic; } -ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config) +ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) { - ScopePtr result = typeChecker.globalScope; + ScopePtr result; + if (forAutocomplete) + result = typeCheckerForAutocomplete.globalScope; + else + result = typeChecker.globalScope; if (module.environmentName) result = getEnvironmentScope(*module.environmentName); @@ -597,6 +726,9 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config LintResult Frontend::lint(const ModuleName& name, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + CheckResult checkResult; auto [_sourceNode, sourceModule] = getSourceNode(checkResult, name); @@ -606,52 +738,17 @@ LintResult Frontend::lint(const ModuleName& name, std::optional Frontend::lintFragment(std::string_view source, std::optional enabledLintWarnings) -{ - const Config& config = configResolver->getConfig(""); - - SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions); - - Luau::LintOptions lintOptions = enabledLintWarnings.value_or(config.enabledLint); - lintOptions.warningMask &= sourceModule.ignoreLints; - - double timestamp = getTimestamp(); - - std::vector warnings = - Luau::lint(sourceModule.root, *sourceModule.names.get(), typeChecker.globalScope, nullptr, enabledLintWarnings.value_or(config.enabledLint)); - - stats.timeLint += getTimestamp() - timestamp; - - return {std::move(sourceModule), classifyLints(warnings, config)}; -} - -CheckResult Frontend::check(const SourceModule& module) -{ - const Config& config = configResolver->getConfig(module.name); - - Mode mode = module.mode.value_or(config.mode); - - double timestamp = getTimestamp(); - - ModulePtr checkedModule = typeChecker.check(module, mode); - - stats.timeCheck += getTimestamp() - timestamp; - stats.filesStrict += mode == Mode::Strict; - stats.filesNonstrict += mode == Mode::Nonstrict; - - if (checkedModule == nullptr) - throw std::runtime_error("Frontend::check produced a nullptr module for module " + module.name); - moduleResolver.modules[module.name] = checkedModule; - - return CheckResult{checkedModule->errors}; -} - LintResult Frontend::lint(const SourceModule& module, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + const Config& config = configResolver->getConfig(module.name); + uint64_t ignoreLints = LintWarning::parseMask(module.hotcomments); + LintOptions options = enabledLintWarnings.value_or(config.enabledLint); - options.warningMask &= ~module.ignoreLints; + options.warningMask &= ~ignoreLints; Mode mode = module.mode.value_or(config.mode); if (mode != Mode::NoCheck) @@ -670,17 +767,17 @@ LintResult Frontend::lint(const SourceModule& module, std::optional warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), options); + std::vector warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), module.hotcomments, options); stats.timeLint += getTimestamp() - timestamp; return classifyLints(warnings, config); } -bool Frontend::isDirty(const ModuleName& name) const +bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const { auto it = sourceNodes.find(name); - return it == sourceNodes.end() || it->second.dirty; + return it == sourceNodes.end() || it->second.hasDirtyModule(forAutocomplete); } /* @@ -691,13 +788,13 @@ bool Frontend::isDirty(const ModuleName& name) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { - if (!moduleResolver.modules.count(name)) + if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name)) return; std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) { - for (const auto& dep : module.second.requires) + for (const auto& dep : module.second.requireSet) reverseDeps[dep].push_back(module.first); } @@ -714,17 +811,19 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked if (markedDirty) markedDirty->push_back(next); - if (sourceNode.dirty) + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) continue; - sourceNode.dirty = true; + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; - if (0 == reverseDeps.count(name)) + if (0 == reverseDeps.count(next)) continue; - sourceModules.erase(name); + sourceModules.erase(next); - const std::vector& dependents = reverseDeps[name]; + const std::vector& dependents = reverseDeps[next]; queue.insert(queue.end(), dependents.begin(), dependents.end()); } } @@ -743,11 +842,95 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons return const_cast(this)->getSourceModule(moduleName); } +ScopePtr Frontend::getGlobalScope() +{ + if (!globalScope) + { + globalScope = typeChecker.globalScope; + } + + return globalScope; +} + +ModulePtr Frontend::check( + const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope, std::vector requireCycles, bool forAutocomplete) +{ + ModulePtr result = std::make_shared(); + + std::unique_ptr logger; + if (FFlag::DebugLuauLogSolverToJson) + { + logger = std::make_unique(); + std::optional source = fileResolver->readSource(sourceModule.name); + if (source) + { + logger->captureSource(source->source); + } + } + + DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&iceHandler}); + + const NotNull mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}; + const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope}; + + Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}}; + + ConstraintGraphBuilder cgb{ + sourceModule.name, + result, + &result->internalTypes, + mr, + singletonTypes, + NotNull(&iceHandler), + globalScope, + logger.get(), + NotNull{&dfg}, + }; + + cgb.visit(sourceModule.root); + result->errors = std::move(cgb.errors); + + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, NotNull(&moduleResolver), + requireCycles, logger.get()}; + + if (options.randomizeConstraintResolutionSeed) + cs.randomize(*options.randomizeConstraintResolutionSeed); + + cs.run(); + + for (TypeError& e : cs.errors) + result->errors.emplace_back(std::move(e)); + + result->scopes = std::move(cgb.scopes); + result->astTypes = std::move(cgb.astTypes); + result->astTypePacks = std::move(cgb.astTypePacks); + result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes); + result->astOverloadResolvedTypes = std::move(cgb.astOverloadResolvedTypes); + result->astResolvedTypes = std::move(cgb.astResolvedTypes); + result->astResolvedTypePacks = std::move(cgb.astResolvedTypePacks); + result->type = sourceModule.type; + + Luau::check(singletonTypes, logger.get(), sourceModule, result.get()); + + if (FFlag::DebugLuauLogSolverToJson) + { + std::string output = logger->compileOutput(); + printf("%s\n", output.c_str()); + } + + result->clonePublicInterface(singletonTypes, iceHandler); + + return result; +} + // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) { + LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.dirty) + if (it != sourceNodes.end() && !it->second.hasDirtySourceModule()) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) @@ -778,8 +961,8 @@ std::pair Frontend::getSourceNode(CheckResult& check SourceModule result = parse(name, source->source, opts); result.type = source->type; - RequireTraceResult& requireTrace = requires[name]; - requireTrace = traceRequires(fileResolver, result.root, name); + RequireTraceResult& require = requireTrace[name]; + require = traceRequires(fileResolver, result.root, name); SourceNode& sourceNode = sourceNodes[name]; SourceModule& sourceModule = sourceModules[name]; @@ -788,14 +971,20 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceModule.environmentName = environmentName; sourceNode.name = name; - sourceNode.requires.clear(); + sourceNode.requireSet.clear(); sourceNode.requireLocations.clear(); - sourceNode.dirty = true; + sourceNode.dirtySourceModule = false; - for (const auto& [moduleName, location] : requireTrace.requires) - sourceNode.requires.insert(moduleName); + if (it == sourceNodes.end()) + { + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; + } - sourceNode.requireLocations = requireTrace.requires; + for (const auto& [moduleName, location] : require.requireList) + sourceNode.requireSet.insert(moduleName); + + sourceNode.requireLocations = require.requireList; return {&sourceNode, &sourceModule}; } @@ -815,15 +1004,18 @@ std::pair Frontend::getSourceNode(CheckResult& check */ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions) { + LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + SourceModule sourceModule; double timestamp = getTimestamp(); - auto parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); + Luau::ParseResult parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); stats.timeParse += getTimestamp() - timestamp; stats.files++; - stats.lines += std::count(src.begin(), src.end(), '\n') + (src.size() && src.back() != '\n'); + stats.lines += parseResult.lines; if (!parseResult.errors.empty()) sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end()); @@ -832,7 +1024,6 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const { sourceModule.root = parseResult.root; sourceModule.mode = parseMode(parseResult.hotcomments); - sourceModule.ignoreLints = LintWarning::parseMask(parseResult.hotcomments); } else { @@ -841,43 +1032,36 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const } sourceModule.name = name; + if (parseOptions.captureComments) + { sourceModule.commentLocations = std::move(parseResult.commentLocations); + sourceModule.hotcomments = std::move(parseResult.hotcomments); + } + return sourceModule; } std::optional FrontendModuleResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) { // FIXME I think this can be pushed into the FileResolver. - auto it = frontend->requires.find(currentModuleName); - if (it == frontend->requires.end()) + auto it = frontend->requireTrace.find(currentModuleName); + if (it == frontend->requireTrace.end()) { // CLI-43699 // If we can't find the current module name, that's because we bypassed the frontend's initializer - // and called typeChecker.check directly. (This is done by autocompleteSource, for example). + // and called typeChecker.check directly. // In that case, requires will always fail. - if (FFlag::LuauResolveModuleNameWithoutACurrentModule) - return std::nullopt; - else - throw std::runtime_error("Frontend::resolveModuleName: Unknown currentModuleName '" + currentModuleName + "'"); + return std::nullopt; } const auto& exprs = it->second.exprs; - const ModuleName* relativeName = exprs.find(&pathExpr); - if (!relativeName || relativeName->empty()) + const ModuleInfo* info = exprs.find(&pathExpr); + if (!info) return std::nullopt; - if (FFlag::LuauTraceRequireLookupChild) - { - const bool* optional = it->second.optional.find(&pathExpr); - - return {{*relativeName, optional ? *optional : false}}; - } - else - { - return {{*relativeName, false}}; - } + return *info; } const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const @@ -891,12 +1075,12 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const { - return frontend->fileResolver->moduleExists(moduleName); + return frontend->sourceNodes.count(moduleName) != 0; } std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const { - return frontend->fileResolver->getHumanReadableModuleName_(moduleName).value_or(moduleName); + return frontend->fileResolver->getHumanReadableModuleName(moduleName); } ScopePtr Frontend::addEnvironment(const std::string& environmentName) @@ -961,7 +1145,7 @@ void Frontend::clear() sourceModules.clear(); moduleResolver.modules.clear(); moduleResolverForAutocomplete.modules.clear(); - requires.clear(); + requireTrace.clear(); } } // namespace Luau diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp new file mode 100644 index 00000000..3d0cd0d1 --- /dev/null +++ b/Analysis/src/Instantiation.cpp @@ -0,0 +1,131 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" +#include "Luau/Instantiation.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeArena.h" + +LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) + +namespace Luau +{ + +bool Instantiation::isDirty(TypeId ty) +{ + if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (ftv->hasNoGenerics) + return false; + + return true; + } + else + { + return false; + } +} + +bool Instantiation::isDirty(TypePackId tp) +{ + return false; +} + +bool Instantiation::ignoreChildren(TypeId ty) +{ + if (log->getMutable(ty)) + return true; + else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + return true; + else + return false; +} + +TypeId Instantiation::clean(TypeId ty) +{ + const FunctionTypeVar* ftv = log->getMutable(ty); + LUAU_ASSERT(ftv); + + FunctionTypeVar clone = FunctionTypeVar{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; + clone.magicFunction = ftv->magicFunction; + clone.dcrMagicFunction = ftv->dcrMagicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + TypeId result = addType(std::move(clone)); + + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + ReplaceGenerics replaceGenerics{log, arena, level, scope, ftv->generics, ftv->genericPacks}; + + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = replaceGenerics.substitute(result).value_or(result); + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; +} + +TypePackId Instantiation::clean(TypePackId tp) +{ + LUAU_ASSERT(false); + return tp; +} + +bool ReplaceGenerics::ignoreChildren(TypeId ty) +{ + if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (ftv->hasNoGenerics) + return true; + + // We aren't recursing in the case of a generic function which + // binds the same generics. This can happen if, for example, there's recursive types. + // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. + // It's OK to use vector equality here, since we always generate fresh generics + // whenever we quantify, so the vectors overlap if and only if they are equal. + return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); + } + else if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + return true; + else + { + return false; + } +} + +bool ReplaceGenerics::isDirty(TypeId ty) +{ + if (const TableTypeVar* ttv = log->getMutable(ty)) + return ttv->state == TableState::Generic; + else if (log->getMutable(ty)) + return std::find(generics.begin(), generics.end(), ty) != generics.end(); + else + return false; +} + +bool ReplaceGenerics::isDirty(TypePackId tp) +{ + if (log->getMutable(tp)) + return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); + else + return false; +} + +TypeId ReplaceGenerics::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + if (const TableTypeVar* ttv = log->getMutable(ty)) + { + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, scope, TableState::Free}; + clone.definitionModuleName = ttv->definitionModuleName; + return addType(std::move(clone)); + } + else + return addType(FreeTypeVar{scope, level}); +} + +TypePackId ReplaceGenerics::clean(TypePackId tp) +{ + LUAU_ASSERT(isDirty(tp)); + return addTypePack(TypePackVar(FreeTypePack{level})); +} + +} // namespace Luau diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 84e9b77f..cd59fdfb 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -23,232 +23,191 @@ std::ostream& operator<<(std::ostream& stream, const AstName& name) return stream << ""; } -std::ostream& operator<<(std::ostream& stream, const TypeMismatch& tm) +template +static void errorToString(std::ostream& stream, const T& err) { - return stream << "TypeMismatch { " << toString(tm.wantedType) << ", " << toString(tm.givenType) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const TypeError& error) -{ - return stream << "TypeError { \"" << error.moduleName << "\", " << error.location << ", " << error.data << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownSymbol& error) -{ - return stream << "UnknownSymbol { " << error.name << " , context " << error.context << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownProperty& error) -{ - return stream << "UnknownProperty { " << toString(error.table) << ", key = " << error.key << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const NotATable& ge) -{ - return stream << "NotATable { " << toString(ge.ty) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotExtendTable& error) -{ - return stream << "CannotExtendTable { " << toString(error.tableType) << ", context " << error.context << ", prop \"" << error.prop << "\" }"; -} - -std::ostream& operator<<(std::ostream& stream, const OnlyTablesCanHaveMethods& error) -{ - return stream << "OnlyTablesCanHaveMethods { " << toString(error.tableType) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const DuplicateTypeDefinition& error) -{ - return stream << "DuplicateTypeDefinition { " << error.name << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CountMismatch& error) -{ - return stream << "CountMismatch { expected " << error.expected << ", got " << error.actual << ", context " << error.context << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionDoesNotTakeSelf&) -{ - return stream << "FunctionDoesNotTakeSelf { }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionRequiresSelf& error) -{ - return stream << "FunctionRequiresSelf { extraNils " << error.requiredExtraNils << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const OccursCheckFailed&) -{ - return stream << "OccursCheckFailed { }"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownRequire& error) -{ - return stream << "UnknownRequire { " << error.modulePath << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCount& error) -{ - stream << "IncorrectGenericParameterCount { name = " << error.name; - - if (!error.typeFun.typeParams.empty()) + if constexpr (false) { - stream << "<"; + } + else if constexpr (std::is_same_v) + stream << "TypeMismatch { " << toString(err.wantedType) << ", " << toString(err.givenType) << " }"; + else if constexpr (std::is_same_v) + stream << "UnknownSymbol { " << err.name << " , context " << err.context << " }"; + else if constexpr (std::is_same_v) + stream << "UnknownProperty { " << toString(err.table) << ", key = " << err.key << " }"; + else if constexpr (std::is_same_v) + stream << "NotATable { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "CannotExtendTable { " << toString(err.tableType) << ", context " << err.context << ", prop \"" << err.prop << "\" }"; + else if constexpr (std::is_same_v) + stream << "OnlyTablesCanHaveMethods { " << toString(err.tableType) << " }"; + else if constexpr (std::is_same_v) + stream << "DuplicateTypeDefinition { " << err.name << " }"; + else if constexpr (std::is_same_v) + stream << "CountMismatch { expected " << err.expected << ", got " << err.actual << ", context " << err.context << " }"; + else if constexpr (std::is_same_v) + stream << "FunctionDoesNotTakeSelf { }"; + else if constexpr (std::is_same_v) + stream << "FunctionRequiresSelf { }"; + else if constexpr (std::is_same_v) + stream << "OccursCheckFailed { }"; + else if constexpr (std::is_same_v) + stream << "UnknownRequire { " << err.modulePath << " }"; + else if constexpr (std::is_same_v) + { + stream << "IncorrectGenericParameterCount { name = " << err.name; + + if (!err.typeFun.typeParams.empty() || !err.typeFun.typePackParams.empty()) + { + stream << "<"; + bool first = true; + for (auto param : err.typeFun.typeParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(param.ty); + } + + for (auto param : err.typeFun.typePackParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(param.tp); + } + + stream << ">"; + } + + stream << ", typeFun = " << toString(err.typeFun.type) << ", actualCount = " << err.actualParameters << " }"; + } + else if constexpr (std::is_same_v) + stream << "SyntaxError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "CodeTooComplex {}"; + else if constexpr (std::is_same_v) + stream << "UnificationTooComplex {}"; + else if constexpr (std::is_same_v) + { + stream << "UnknownPropButFoundLikeProp { key = '" << err.key << "', suggested = { "; + bool first = true; - for (TypeId t : error.typeFun.typeParams) + for (Name name : err.candidates) { if (first) first = false; else stream << ", "; - stream << toString(t); + stream << "'" << name << "'"; } - stream << ">"; - } - stream << ", typeFun = " << toString(error.typeFun.type) << ", actualCount = " << error.actualParameters << " }"; + stream << " }, table = " << toString(err.table) << " } "; + } + else if constexpr (std::is_same_v) + stream << "GenericError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "InternalError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "CannotCallNonFunction { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "ExtraInformation { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "DeprecatedApiUsed { " << err.symbol << ", useInstead = " << err.useInstead << " }"; + else if constexpr (std::is_same_v) + { + stream << "ModuleHasCyclicDependency {"; + + bool first = true; + for (const ModuleName& name : err.cycle) + { + if (first) + first = false; + else + stream << ", "; + + stream << name; + } + + stream << "}"; + } + else if constexpr (std::is_same_v) + stream << "IllegalRequire { " << err.moduleName << ", reason = " << err.reason << " }"; + else if constexpr (std::is_same_v) + stream << "FunctionExitsWithoutReturning {" << toString(err.expectedReturnType) << "}"; + else if constexpr (std::is_same_v) + stream << "DuplicateGenericParameter { " + err.parameterName + " }"; + else if constexpr (std::is_same_v) + stream << "CannotInferBinaryOperation { op = " + toString(err.op) + ", suggested = '" + + (err.suggestedToAnnotate ? *err.suggestedToAnnotate : "") + "', kind " + << err.kind << "}"; + else if constexpr (std::is_same_v) + { + stream << "MissingProperties { superType = '" << toString(err.superType) << "', subType = '" << toString(err.subType) << "', properties = { "; + + bool first = true; + for (Name name : err.properties) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << name << "'"; + } + + stream << " }, context " << err.context << " } "; + } + else if constexpr (std::is_same_v) + stream << "SwappedGenericTypeParameter { name = '" + err.name + "', kind = " + std::to_string(err.kind) + " }"; + else if constexpr (std::is_same_v) + stream << "OptionalValueAccess { optional = '" + toString(err.optional) + "' }"; + else if constexpr (std::is_same_v) + { + stream << "MissingUnionProperty { type = '" + toString(err.type) + "', missing = { "; + + bool first = true; + for (auto ty : err.missing) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << toString(ty) << "'"; + } + + stream << " }, key = '" + err.key + "' }"; + } + else if constexpr (std::is_same_v) + stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; + else if constexpr (std::is_same_v) + stream << "NormalizationTooComplex { }"; + else if constexpr (std::is_same_v) + stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }"; + else if constexpr (std::is_same_v) + stream << "DynamicPropertyLookupOnClassesUnsafe { " << toString(err.ty) << " }"; + else + static_assert(always_false_v, "Non-exhaustive type switch"); +} + +std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data) +{ + auto cb = [&](const auto& e) { + return errorToString(stream, e); + }; + visit(cb, data); return stream; } -std::ostream& operator<<(std::ostream& stream, const SyntaxError& ge) +std::ostream& operator<<(std::ostream& stream, const TypeError& error) { - return stream << "SyntaxError { " << ge.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CodeTooComplex&) -{ - return stream << "CodeTooComplex {}"; -} - -std::ostream& operator<<(std::ostream& stream, const UnificationTooComplex&) -{ - return stream << "UnificationTooComplex {}"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownPropButFoundLikeProp& e) -{ - stream << "UnknownPropButFoundLikeProp { key = '" << e.key << "', suggested = { "; - - bool first = true; - for (Name name : e.candidates) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << name << "'"; - } - - return stream << " }, table = " << toString(e.table) << " } "; -} - -std::ostream& operator<<(std::ostream& stream, const GenericError& ge) -{ - return stream << "GenericError { " << ge.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotCallNonFunction& e) -{ - return stream << "CannotCallNonFunction { " << toString(e.ty) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionExitsWithoutReturning& error) -{ - return stream << "FunctionExitsWithoutReturning {" << toString(error.expectedReturnType) << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const ExtraInformation& e) -{ - return stream << "ExtraInformation { " << e.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const DeprecatedApiUsed& e) -{ - return stream << "DeprecatedApiUsed { " << e.symbol << ", useInstead = " << e.useInstead << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const ModuleHasCyclicDependency& e) -{ - stream << "ModuleHasCyclicDependency {"; - - bool first = true; - for (const ModuleName& name : e.cycle) - { - if (first) - first = false; - else - stream << ", "; - - stream << name; - } - - return stream << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const IllegalRequire& e) -{ - return stream << "IllegalRequire { " << e.moduleName << ", reason = " << e.reason << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const MissingProperties& e) -{ - stream << "MissingProperties { superType = '" << toString(e.superType) << "', subType = '" << toString(e.subType) << "', properties = { "; - - bool first = true; - for (Name name : e.properties) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << name << "'"; - } - - return stream << " }, context " << e.context << " } "; -} - -std::ostream& operator<<(std::ostream& stream, const DuplicateGenericParameter& error) -{ - return stream << "DuplicateGenericParameter { " + error.parameterName + " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotInferBinaryOperation& error) -{ - return stream << "CannotInferBinaryOperation { op = " + toString(error.op) + ", suggested = '" + - (error.suggestedToAnnotate ? *error.suggestedToAnnotate : "") + "', kind " - << error.kind << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const SwappedGenericTypeParameter& error) -{ - return stream << "SwappedGenericTypeParameter { name = '" + error.name + "', kind = " + std::to_string(error.kind) + " }"; -} - -std::ostream& operator<<(std::ostream& stream, const OptionalValueAccess& error) -{ - return stream << "OptionalValueAccess { optional = '" + toString(error.optional) + "' }"; -} - -std::ostream& operator<<(std::ostream& stream, const MissingUnionProperty& error) -{ - stream << "MissingUnionProperty { type = '" + toString(error.type) + "', missing = { "; - - bool first = true; - for (auto ty : error.missing) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << toString(ty) << "'"; - } - - return stream << " }, key = '" + error.key + "' }"; + return stream << "TypeError { \"" << error.moduleName << "\", " << error.location << ", " << error.data << " }"; } std::ostream& operator<<(std::ostream& stream, const TableState& tv) @@ -266,15 +225,4 @@ std::ostream& operator<<(std::ostream& stream, const TypePackVar& tv) return stream << toString(tv); } -std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted) -{ - Luau::visit( - [&](const auto& a) { - lhs << a; - }, - ted); - - return lhs; -} - } // namespace Luau diff --git a/Analysis/src/JsonEmitter.cpp b/Analysis/src/JsonEmitter.cpp new file mode 100644 index 00000000..9c8a7af9 --- /dev/null +++ b/Analysis/src/JsonEmitter.cpp @@ -0,0 +1,222 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/JsonEmitter.h" + +#include "Luau/StringUtils.h" + +#include + +namespace Luau::Json +{ + +static constexpr int CHUNK_SIZE = 1024; + +ObjectEmitter::ObjectEmitter(NotNull emitter) + : emitter(emitter) + , finished(false) +{ + comma = emitter->pushComma(); + emitter->writeRaw('{'); +} + +ObjectEmitter::~ObjectEmitter() +{ + finish(); +} + +void ObjectEmitter::finish() +{ + if (finished) + return; + + emitter->writeRaw('}'); + emitter->popComma(comma); + finished = true; +} + +ArrayEmitter::ArrayEmitter(NotNull emitter) + : emitter(emitter) + , finished(false) +{ + comma = emitter->pushComma(); + emitter->writeRaw('['); +} + +ArrayEmitter::~ArrayEmitter() +{ + finish(); +} + +void ArrayEmitter::finish() +{ + if (finished) + return; + + emitter->writeRaw(']'); + emitter->popComma(comma); + finished = true; +} + +JsonEmitter::JsonEmitter() +{ + newChunk(); +} + +std::string JsonEmitter::str() +{ + return join(chunks, ""); +} + +bool JsonEmitter::pushComma() +{ + bool current = comma; + comma = false; + return current; +} + +void JsonEmitter::popComma(bool c) +{ + comma = c; +} + +void JsonEmitter::writeRaw(std::string_view sv) +{ + if (sv.size() > CHUNK_SIZE) + { + chunks.emplace_back(sv); + newChunk(); + return; + } + + auto& chunk = chunks.back(); + if (chunk.size() + sv.size() < CHUNK_SIZE) + { + chunk.append(sv.data(), sv.size()); + return; + } + + size_t prefix = CHUNK_SIZE - chunk.size(); + chunk.append(sv.data(), prefix); + newChunk(); + + chunks.back().append(sv.data() + prefix, sv.size() - prefix); +} + +void JsonEmitter::writeRaw(char c) +{ + writeRaw(std::string_view{&c, 1}); +} + +void write(JsonEmitter& emitter, bool b) +{ + if (b) + emitter.writeRaw("true"); + else + emitter.writeRaw("false"); +} + +void write(JsonEmitter& emitter, double d) +{ + emitter.writeRaw(std::to_string(d)); +} + +void write(JsonEmitter& emitter, int i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, long i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, long long i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, unsigned int i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, unsigned long i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, unsigned long long i) +{ + emitter.writeRaw(std::to_string(i)); +} + +void write(JsonEmitter& emitter, std::string_view sv) +{ + emitter.writeRaw('\"'); + + for (char c : sv) + { + if (c == '"') + emitter.writeRaw("\\\""); + else if (c == '\\') + emitter.writeRaw("\\\\"); + else if (c == '\n') + emitter.writeRaw("\\n"); + else if (c < ' ') + emitter.writeRaw(format("\\u%04x", c)); + else + emitter.writeRaw(c); + } + + emitter.writeRaw('\"'); +} + +void write(JsonEmitter& emitter, char c) +{ + write(emitter, std::string_view{&c, 1}); +} + +void write(JsonEmitter& emitter, const char* str) +{ + write(emitter, std::string_view{str, strlen(str)}); +} + +void write(JsonEmitter& emitter, const std::string& str) +{ + write(emitter, std::string_view{str}); +} + +void write(JsonEmitter& emitter, std::nullptr_t) +{ + emitter.writeRaw("null"); +} + +void write(JsonEmitter& emitter, std::nullopt_t) +{ + emitter.writeRaw("null"); +} + +void JsonEmitter::writeComma() +{ + if (comma) + writeRaw(','); + else + comma = true; +} + +ObjectEmitter JsonEmitter::writeObject() +{ + return ObjectEmitter{NotNull(this)}; +} + +ArrayEmitter JsonEmitter::writeArray() +{ + return ArrayEmitter{NotNull(this)}; +} + +void JsonEmitter::newChunk() +{ + chunks.emplace_back(); + chunks.back().reserve(CHUNK_SIZE); +} + +} // namespace Luau::Json diff --git a/Analysis/src/LValue.cpp b/Analysis/src/LValue.cpp new file mode 100644 index 00000000..38dfe1ae --- /dev/null +++ b/Analysis/src/LValue.cpp @@ -0,0 +1,107 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/LValue.h" + +#include "Luau/Ast.h" + +#include + +namespace Luau +{ + +bool Field::operator==(const Field& rhs) const +{ + LUAU_ASSERT(parent && rhs.parent); + return key == rhs.key && (parent == rhs.parent || *parent == *rhs.parent); +} + +bool Field::operator!=(const Field& rhs) const +{ + return !(*this == rhs); +} + +size_t LValueHasher::operator()(const LValue& lvalue) const +{ + // Most likely doesn't produce high quality hashes, but we're probably ok enough with it. + // When an evidence is shown that operator==(LValue) is used more often than it should, we can have a look at improving the hash quality. + size_t acc = 0; + size_t offset = 0; + + const LValue* current = &lvalue; + while (current) + { + if (auto field = get(*current)) + acc ^= (std::hash{}(field->key) << 1) >> ++offset; + else if (auto symbol = get(*current)) + acc ^= std::hash{}(*symbol) << 1; + else + LUAU_ASSERT(!"Hash not accumulated for this new LValue alternative."); + + current = baseof(*current); + } + + return acc; +} + +const LValue* baseof(const LValue& lvalue) +{ + if (auto field = get(lvalue)) + return field->parent.get(); + + auto symbol = get(lvalue); + LUAU_ASSERT(symbol); + return nullptr; // Base of root is null. +} + +std::optional tryGetLValue(const AstExpr& node) +{ + const AstExpr* expr = &node; + while (auto e = expr->as()) + expr = e->expr; + + if (auto local = expr->as()) + return Symbol{local->local}; + else if (auto global = expr->as()) + return Symbol{global->name}; + else if (auto indexname = expr->as()) + { + if (auto lvalue = tryGetLValue(*indexname->expr)) + return Field{std::make_shared(*lvalue), indexname->index.value}; + } + else if (auto indexexpr = expr->as()) + { + if (auto lvalue = tryGetLValue(*indexexpr->expr)) + if (auto string = indexexpr->index->as()) + return Field{std::make_shared(*lvalue), std::string(string->value.data, string->value.size)}; + } + + return std::nullopt; +} + +Symbol getBaseSymbol(const LValue& lvalue) +{ + const LValue* current = &lvalue; + while (auto field = get(*current)) + current = baseof(*current); + + const Symbol* symbol = get(*current); + LUAU_ASSERT(symbol); + return *symbol; +} + +void merge(RefinementMap& l, const RefinementMap& r, std::function f) +{ + for (const auto& [k, a] : r) + { + if (auto it = l.find(k); it != l.end()) + l[k] = f(it->second, a); + else + l[k] = a; + } +} + +void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty) +{ + refis[lvalue] = ty; +} + +} // namespace Luau diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index f97f6a4a..d5578446 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -3,6 +3,7 @@ #include "Luau/AstQuery.h" #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/StringUtils.h" #include "Luau/Common.h" @@ -11,7 +12,8 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false) +LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) +LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) namespace Luau { @@ -44,6 +46,10 @@ static const char* kWarningNames[] = { "DeprecatedApi", "TableOperations", "DuplicateCondition", + "MisleadingAndOr", + "CommentDirective", + "IntegerParsing", + "ComparisonPrecedence", }; // clang-format on @@ -85,10 +91,10 @@ struct LintContext return std::nullopt; auto it = module->astTypes.find(expr); - if (it == module->astTypes.end()) + if (!it) return std::nullopt; - return it->second; + return *it; } }; @@ -198,6 +204,26 @@ static bool similar(AstExpr* lhs, AstExpr* rhs) return true; } + CASE(AstExprIfElse) return similar(le->condition, re->condition) && similar(le->trueExpr, re->trueExpr) && similar(le->falseExpr, re->falseExpr); + CASE(AstExprInterpString) + { + if (le->strings.size != re->strings.size) + return false; + + if (le->expressions.size != re->expressions.size) + return false; + + for (size_t i = 0; i < le->strings.size; ++i) + if (le->strings.data[i].size != re->strings.data[i].size || + memcmp(le->strings.data[i].data, re->strings.data[i].data, le->strings.data[i].size) != 0) + return false; + + for (size_t i = 0; i < le->expressions.size; ++i) + if (!similar(le->expressions.data[i], re->expressions.data[i])) + return false; + + return true; + } else { LUAU_ASSERT(!"Unknown expression type"); @@ -229,6 +255,20 @@ public: } private: + struct FunctionInfo + { + explicit FunctionInfo(AstExprFunction* ast) + : ast(ast) + , dominatedGlobals({}) + , conditionalExecution(false) + { + } + + AstExprFunction* ast; + DenseHashSet dominatedGlobals; + bool conditionalExecution; + }; + struct Global { AstExprGlobal* firstRef = nullptr; @@ -237,6 +277,9 @@ private: bool assigned = false; bool builtin = false; + bool definedInModuleScope = false; + bool definedAsFunction = false; + bool readBeforeWritten = false; std::optional deprecated; }; @@ -244,7 +287,8 @@ private: DenseHashMap globals; std::vector globalRefs; - std::vector functionStack; + std::vector functionStack; + LintGlobalLocal() : globals(AstName()) @@ -262,9 +306,9 @@ private: emitWarning(*context, LintWarning::Code_UnknownGlobal, gv->location, "Unknown global '%s'", gv->name.value); else if (g->deprecated) { - if (*g->deprecated) + if (const char* replacement = *g->deprecated; replacement && strlen(replacement)) emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", - gv->name.value, *g->deprecated); + gv->name.value, replacement); else emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); } @@ -287,12 +331,18 @@ private: "Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local", g.firstRef->name.value, top->location.begin.line + 1); } + else if (FFlag::LuauLintGlobalNeverReadBeforeWritten && g.assigned && !g.readBeforeWritten && !g.definedInModuleScope && + g.firstRef->name != context->placeholder) + { + emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, + "Global '%s' is never read before being written. Consider changing it to local", g.firstRef->name.value); + } } } bool visit(AstExprFunction* node) override { - functionStack.push_back(node); + functionStack.emplace_back(node); node->body->visit(this); @@ -303,6 +353,11 @@ private: bool visit(AstExprGlobal* node) override { + if (FFlag::LuauLintGlobalNeverReadBeforeWritten && !functionStack.empty() && !functionStack.back().dominatedGlobals.contains(node->name)) + { + Global& g = globals[node->name]; + g.readBeforeWritten = true; + } trackGlobalRef(node); if (node->name == context->placeholder) @@ -331,6 +386,21 @@ private: { Global& g = globals[gv->name]; + if (FFlag::LuauLintGlobalNeverReadBeforeWritten) + { + if (functionStack.empty()) + { + g.definedInModuleScope = true; + } + else + { + if (!functionStack.back().conditionalExecution) + { + functionStack.back().dominatedGlobals.insert(gv->name); + } + } + } + if (g.builtin) emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); @@ -365,7 +435,14 @@ private: emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); else + { g.assigned = true; + if (FFlag::LuauLintGlobalNeverReadBeforeWritten) + { + g.definedAsFunction = true; + g.definedInModuleScope = functionStack.empty(); + } + } trackGlobalRef(gv); } @@ -373,6 +450,98 @@ private: return true; } + class HoldConditionalExecution + { + public: + HoldConditionalExecution(LintGlobalLocal& p) + : p(p) + { + if (!p.functionStack.empty() && !p.functionStack.back().conditionalExecution) + { + resetToFalse = true; + p.functionStack.back().conditionalExecution = true; + } + } + ~HoldConditionalExecution() + { + if (resetToFalse) + p.functionStack.back().conditionalExecution = false; + } + + private: + bool resetToFalse = false; + LintGlobalLocal& p; + }; + + bool visit(AstStatIf* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->condition->visit(this); + node->thenbody->visit(this); + if (node->elsebody) + node->elsebody->visit(this); + + return false; + } + + bool visit(AstStatWhile* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->condition->visit(this); + node->body->visit(this); + + return false; + } + + bool visit(AstStatRepeat* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->condition->visit(this); + node->body->visit(this); + + return false; + } + + bool visit(AstStatFor* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->from->visit(this); + node->to->visit(this); + + if (node->step) + node->step->visit(this); + + node->body->visit(this); + + return false; + } + + bool visit(AstStatForIn* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + for (AstExpr* expr : node->values) + expr->visit(this); + + node->body->visit(this); + + return false; + } + void trackGlobalRef(AstExprGlobal* node) { Global& g = globals[node->name]; @@ -386,7 +555,12 @@ private: // to reduce the cost of tracking we only track this for user globals if (!g.builtin) { - g.functionRef = functionStack; + g.functionRef.clear(); + g.functionRef.reserve(functionStack.size()); + for (const FunctionInfo& entry : functionStack) + { + g.functionRef.push_back(entry.ast); + } } } else @@ -397,7 +571,7 @@ private: // we need to find a common prefix between all uses of a global size_t prefix = 0; - while (prefix < g.functionRef.size() && prefix < functionStack.size() && g.functionRef[prefix] == functionStack[prefix]) + while (prefix < g.functionRef.size() && prefix < functionStack.size() && g.functionRef[prefix] == functionStack[prefix].ast) prefix++; g.functionRef.resize(prefix); @@ -732,13 +906,13 @@ private: bool visit(AstTypeReference* node) override { - if (!node->hasPrefix) + if (!node->prefix) return true; - if (!imports.contains(node->prefix)) + if (!imports.contains(*node->prefix)) return true; - AstLocal* astLocal = imports[node->prefix]; + AstLocal* astLocal = imports[*node->prefix]; Local& local = locals[astLocal]; LUAU_ASSERT(local.import); local.used = true; @@ -982,25 +1156,12 @@ private: enum TypeKind { - Kind_Invalid, + Kind_Unknown, Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata. - Kind_Vector, // For 'vector' but only used when type is used - Kind_Userdata, // custom userdata type - Vector3/etc. - Kind_Class, // custom userdata type that reflects Roblox Instance-derived hierarchy - Part/etc. - Kind_Enum, // custom userdata type referring to an enum item of enum classes, e.g. Enum.NormalId.Back/Enum.Axis.X/etc. + Kind_Vector, // 'vector' but only used when type is used + Kind_Userdata, // custom userdata type }; - bool containsPropName(TypeId ty, const std::string& propName) - { - if (auto ctv = get(ty)) - return lookupClassProp(ctv, propName) != nullptr; - - if (auto ttv = get(ty)) - return ttv->props.find(propName) != ttv->props.end(); - - return false; - } - TypeKind getTypeKind(const std::string& name) { if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" || @@ -1011,12 +1172,9 @@ private: return Kind_Vector; if (std::optional maybeTy = context->scope->lookupType(name)) - // Kind_Userdata is probably not 100% precise but is close enough - return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata; - else if (std::optional maybeTy = context->scope->lookupImportedType("Enum", name)) - return Kind_Enum; + return Kind_Userdata; - return Kind_Invalid; + return Kind_Unknown; } void validateType(AstExprConstantString* expr, std::initializer_list expected, const char* expectedString) @@ -1024,7 +1182,7 @@ private: std::string name(expr->value.data, expr->value.size); TypeKind kind = getTypeKind(name); - if (kind == Kind_Invalid) + if (kind == Kind_Unknown) { emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s'", name.c_str()); return; @@ -1034,61 +1192,11 @@ private: { if (kind == ek) return; - - // as a special case, Instance and EnumItem are both a userdata type (as returned by typeof) and a class type - if (ek == Kind_Userdata && (name == "Instance" || name == "EnumItem")) - return; } emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s' (expected %s)", name.c_str(), expectedString); } - bool acceptsClassName(AstName method) - { - return method.value[0] == 'F' && (method == "FindFirstChildOfClass" || method == "FindFirstChildWhichIsA" || - method == "FindFirstAncestorOfClass" || method == "FindFirstAncestorWhichIsA"); - } - - bool visit(AstExprCall* node) override - { - if (AstExprIndexName* index = node->func->as()) - { - AstExprConstantString* arg0 = node->args.size > 0 ? node->args.data[0]->as() : NULL; - - if (arg0) - { - if (node->self && index->index == "IsA" && node->args.size == 1) - { - validateType(arg0, {Kind_Class, Kind_Enum}, "class or enum type"); - } - else if (node->self && (index->index == "GetService" || index->index == "FindService") && node->args.size == 1) - { - AstExprGlobal* g = index->expr->as(); - - if (g && (g->name == "game" || g->name == "Game")) - { - validateType(arg0, {Kind_Class}, "class type"); - } - } - else if (node->self && acceptsClassName(index->index) && node->args.size == 1) - { - validateType(arg0, {Kind_Class}, "class type"); - } - else if (!node->self && index->index == "new" && node->args.size <= 2) - { - AstExprGlobal* g = index->expr->as(); - - if (g && g->name == "Instance") - { - validateType(arg0, {Kind_Class}, "class type"); - } - } - } - } - - return true; - } - bool visit(AstExprBinary* node) override { if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq) @@ -1108,10 +1216,7 @@ private: if (g && g->name == "type") { - if (FFlag::LuauLinterUnknownTypeVectorAware) - validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type"); - else - validateType(arg, {Kind_Primitive}, "primitive type"); + validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type"); } else if (g && g->name == "typeof") { @@ -1349,7 +1454,7 @@ private: const char* checkStringFormat(const char* data, size_t size) { const char* flags = "-+ #0"; - const char* options = "cdiouxXeEfgGqs"; + const char* options = "cdiouxXeEfgGqs*"; for (size_t i = 0; i < size; ++i) { @@ -2044,18 +2149,28 @@ private: const Property* prop = lookupClassProp(cty, node->index.value); if (prop && prop->deprecated) - { - if (!prop->deprecatedSuggestion.empty()) - emitWarning(*context, LintWarning::Code_DeprecatedApi, node->location, "Member '%s.%s' is deprecated, use '%s' instead", - cty->name.c_str(), node->index.value, prop->deprecatedSuggestion.c_str()); - else - emitWarning(*context, LintWarning::Code_DeprecatedApi, node->location, "Member '%s.%s' is deprecated", cty->name.c_str(), - node->index.value); - } + report(node->location, *prop, cty->name.c_str(), node->index.value); + } + else if (const TableTypeVar* tty = get(follow(*ty))) + { + auto prop = tty->props.find(node->index.value); + + if (prop != tty->props.end() && prop->second.deprecated) + report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value); } return true; } + + void report(const Location& location, const Property& prop, const char* container, const char* field) + { + std::string suggestion = prop.deprecatedSuggestion.empty() ? "" : format(", use '%s' instead", prop.deprecatedSuggestion.c_str()); + + if (container) + emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s.%s' is deprecated%s", container, field, suggestion.c_str()); + else + emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s' is deprecated%s", field, suggestion.c_str()); + } }; class LintTableOperations : AstVisitor @@ -2144,6 +2259,32 @@ private: "wrap it in parentheses to silence"); } + if (func->index == "move" && node->args.size >= 4) + { + // table.move(t, 0, _, _) + if (isConstant(args[1], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + + // table.move(t, _, _, 0) + else if (isConstant(args[3], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + } + + if (func->index == "create" && node->args.size == 2) + { + // table.create(n, {...}) + if (args[1]->is()) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + + // table.create(n, {...} :: ?) + if (AstExprTypeAssertion* as = args[1]->as(); as && as->expr->is()) + emitWarning(*context, LintWarning::Code_TableOperations, as->expr->location, + "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + } + return true; } @@ -2162,7 +2303,7 @@ private: size_t getReturnCount(TypeId ty) { if (auto ftv = get(ty)) - return size(ftv->retType); + return size(ftv->retTypes); if (auto itv = get(ty)) { @@ -2171,7 +2312,7 @@ private: for (TypeId part : itv->parts) if (auto ftv = get(follow(part))) - result = std::max(result, size(ftv->retType)); + result = std::max(result, size(ftv->retTypes)); return result; } @@ -2235,6 +2376,39 @@ private: return false; } + bool visit(AstExprIfElse* expr) override + { + if (!expr->falseExpr->is()) + return true; + + // if..elseif chain detected, we need to unroll it + std::vector conditions; + conditions.reserve(2); + + AstExprIfElse* head = expr; + while (head) + { + head->condition->visit(this); + head->trueExpr->visit(this); + + conditions.push_back(head->condition); + + if (head->falseExpr->is()) + { + head = head->falseExpr->as(); + continue; + } + + head->falseExpr->visit(this); + break; + } + + detectDuplicates(conditions); + + // block recursive visits so that we only analyze each chain once + return false; + } + bool visit(AstExprBinary* expr) override { if (expr->op != AstExprBinary::And && expr->op != AstExprBinary::Or) @@ -2396,6 +2570,153 @@ private: } }; +class LintMisleadingAndOr : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintMisleadingAndOr pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + bool visit(AstExprBinary* node) override + { + if (node->op != AstExprBinary::Or) + return true; + + AstExprBinary* and_ = node->left->as(); + if (!and_ || and_->op != AstExprBinary::And) + return true; + + const char* alt = nullptr; + + if (and_->right->is()) + alt = "nil"; + else if (AstExprConstantBool* c = and_->right->as(); c && c->value == false) + alt = "false"; + + if (alt) + emitWarning(*context, LintWarning::Code_MisleadingAndOr, node->location, + "The and-or expression always evaluates to the second alternative because the first alternative is %s; consider using if-then-else " + "expression instead", + alt); + + return true; + } +}; + +class LintIntegerParsing : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintIntegerParsing pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + bool visit(AstExprConstantNumber* node) override + { + switch (node->parseResult) + { + case ConstantNumberParseResult::Ok: + case ConstantNumberParseResult::Malformed: + break; + case ConstantNumberParseResult::BinOverflow: + emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, + "Binary number literal exceeded available precision and has been truncated to 2^64"); + break; + case ConstantNumberParseResult::HexOverflow: + emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, + "Hexadecimal number literal exceeded available precision and has been truncated to 2^64"); + break; + case ConstantNumberParseResult::DoublePrefix: + emitWarning(*context, LintWarning::Code_IntegerParsing, node->location, + "Hexadecimal number literal has a double prefix, which will fail to parse in the future; remove the extra 0x to fix"); + break; + } + + return true; + } +}; + +class LintComparisonPrecedence : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintComparisonPrecedence pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + static bool isEquality(AstExprBinary::Op op) + { + return op == AstExprBinary::CompareNe || op == AstExprBinary::CompareEq; + } + + static bool isComparison(AstExprBinary::Op op) + { + return op == AstExprBinary::CompareNe || op == AstExprBinary::CompareEq || op == AstExprBinary::CompareLt || op == AstExprBinary::CompareLe || + op == AstExprBinary::CompareGt || op == AstExprBinary::CompareGe; + } + + static bool isNot(AstExpr* node) + { + AstExprUnary* expr = node->as(); + + return expr && expr->op == AstExprUnary::Not; + } + + bool visit(AstExprBinary* node) override + { + if (!isComparison(node->op)) + return true; + + // not X == Y; we silence this for not X == not Y as it's likely an intentional boolean comparison + if (isNot(node->left) && !isNot(node->right)) + { + std::string op = toString(node->op); + + if (isEquality(node->op)) + emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, + "not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or add parentheses to silence", op.c_str(), op.c_str(), + node->op == AstExprBinary::CompareEq ? "~=" : "=="); + else + emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, + "not X %s Y is equivalent to (not X) %s Y; add parentheses to silence", op.c_str(), op.c_str()); + } + else if (AstExprBinary* left = node->left->as(); left && isComparison(left->op)) + { + std::string lop = toString(left->op); + std::string rop = toString(node->op); + + if (isEquality(left->op) || isEquality(node->op)) + emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, + "X %s Y %s Z is equivalent to (X %s Y) %s Z; add parentheses to silence", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str()); + else + emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location, + "X %s Y %s Z is equivalent to (X %s Y) %s Z; did you mean X %s Y and Y %s Z?", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str(), + lop.c_str(), rop.c_str()); + } + + return true; + } +}; + static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names, const ScopePtr& env) { ScopePtr current = env; @@ -2421,13 +2742,124 @@ static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names, } } +static const char* fuzzyMatch(std::string_view str, const char** array, size_t size) +{ + if (FInt::LuauSuggestionDistance == 0) + return nullptr; + + size_t bestDistance = FInt::LuauSuggestionDistance; + size_t bestMatch = size; + + for (size_t i = 0; i < size; ++i) + { + size_t ed = editDistance(str, array[i]); + + if (ed <= bestDistance) + { + bestDistance = ed; + bestMatch = i; + } + } + + return bestMatch < size ? array[bestMatch] : nullptr; +} + +static void lintComments(LintContext& context, const std::vector& hotcomments) +{ + bool seenMode = false; + + for (const HotComment& hc : hotcomments) + { + // We reserve --! for various informational (non-directive) comments + if (hc.content.empty() || hc.content[0] == ' ' || hc.content[0] == '\t') + continue; + + if (!hc.header) + { + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "Comment directive is ignored because it is placed after the first non-comment token"); + } + else + { + size_t space = hc.content.find_first_of(" \t"); + std::string_view first = std::string_view(hc.content).substr(0, space); + + if (first == "nolint") + { + size_t notspace = hc.content.find_first_not_of(" \t", space); + + if (space == std::string::npos || notspace == std::string::npos) + { + // disables all lints + } + else if (LintWarning::parseName(hc.content.c_str() + notspace) == LintWarning::Code_Unknown) + { + const char* rule = hc.content.c_str() + notspace; + + // skip Unknown + if (const char* suggestion = fuzzyMatch(rule, kWarningNames + 1, LintWarning::Code__Count - 1)) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "nolint directive refers to unknown lint rule '%s'; did you mean '%s'?", rule, suggestion); + else + emitWarning( + context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule); + } + } + else if (first == "nocheck" || first == "nonstrict" || first == "strict") + { + if (space != std::string::npos) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "Comment directive with the type checking mode has extra symbols at the end of the line"); + else if (seenMode) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "Comment directive with the type checking mode has already been used"); + else + seenMode = true; + } + else if (first == "optimize") + { + size_t notspace = hc.content.find_first_not_of(" \t", space); + + if (space == std::string::npos || notspace == std::string::npos) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "optimize directive requires an optimization level"); + else + { + const char* level = hc.content.c_str() + notspace; + + if (strcmp(level, "0") && strcmp(level, "1") && strcmp(level, "2")) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "optimize directive uses unknown optimization level '%s', 0..2 expected", level); + } + } + else + { + static const char* kHotComments[] = { + "nolint", + "nocheck", + "nonstrict", + "strict", + "optimize", + }; + + if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments))) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'; did you mean '%s'?", + int(first.size()), first.data(), suggestion); + else + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()), + first.data()); + } + } + } +} + void LintOptions::setDefaults() { // By default, we enable all warnings warningMask = ~0ull; } -std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const LintOptions& options) +std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, + const std::vector& hotcomments, const LintOptions& options) { LintContext context; @@ -2500,6 +2932,18 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc if (context.warningEnabled(LintWarning::Code_DuplicateLocal)) LintDuplicateLocal::process(context); + if (context.warningEnabled(LintWarning::Code_MisleadingAndOr)) + LintMisleadingAndOr::process(context); + + if (context.warningEnabled(LintWarning::Code_CommentDirective)) + lintComments(context, hotcomments); + + if (context.warningEnabled(LintWarning::Code_IntegerParsing)) + LintIntegerParsing::process(context); + + if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence)) + LintComparisonPrecedence::process(context); + std::sort(context.result.begin(), context.result.end(), WarningComparator()); return context.result; @@ -2521,23 +2965,30 @@ LintWarning::Code LintWarning::parseName(const char* name) return Code_Unknown; } -uint64_t LintWarning::parseMask(const std::vector& hotcomments) +uint64_t LintWarning::parseMask(const std::vector& hotcomments) { uint64_t result = 0; - for (const std::string& hc : hotcomments) + for (const HotComment& hc : hotcomments) { - if (hc.compare(0, 6, "nolint") != 0) + if (!hc.header) continue; - std::string::size_type name = hc.find_first_not_of(" \t", 6); + if (hc.content.compare(0, 6, "nolint") != 0) + continue; + + size_t name = hc.content.find_first_not_of(" \t", 6); // --!nolint disables everything if (name == std::string::npos) return ~0ull; + // --!nolint needs to be followed by a whitespace character + if (name == 6) + continue; + // --!nolint name disables the specific lint - LintWarning::Code code = LintWarning::parseName(hc.c_str() + name); + LintWarning::Code code = LintWarning::parseName(hc.content.c_str() + name); if (code != LintWarning::Code_Unknown) result |= 1ull << int(code); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index f1d975fe..62674aa8 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,18 +1,24 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Clone.h" +#include "Luau/Common.h" +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/Normalize.h" +#include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" -#include "Luau/Common.h" #include -LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) -LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false); +LUAU_FASTFLAG(LuauSubstitutionReentrant); +LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); +LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); namespace Luau { @@ -21,7 +27,7 @@ static bool contains(Position pos, Comment comment) { if (comment.location.contains(pos)) return true; - else if (FFlag::LuauCaptureBrokenCommentSpans && comment.type == Lexeme::BrokenComment && + else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end return true; else if (comment.type == Lexeme::Comment && comment.location.end == pos) @@ -52,411 +58,178 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) return contains(pos, *iter); } -void TypeArena::clear() +struct ClonePublicInterface : Substitution { - typeVars.clear(); - typePacks.clear(); -} + NotNull singletonTypes; + NotNull module; -TypeId TypeArena::addTV(TypeVar&& tv) -{ - TypeId allocated = typeVars.allocate(std::move(tv)); - - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypeId TypeArena::freshType(TypeLevel level) -{ - TypeId allocated = typeVars.allocate(FreeTypeVar{level}); - - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(std::initializer_list types) -{ - TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(std::vector types) -{ - TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(TypePack tp) -{ - TypePackId allocated = typePacks.allocate(std::move(tp)); - - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(TypePackVar tp) -{ - TypePackId allocated = typePacks.allocate(std::move(tp)); - - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = this; - - return allocated; -} - -using SeenTypes = std::unordered_map; -using SeenTypePacks = std::unordered_map; - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType); - -namespace -{ - -struct TypePackCloner; - -/* - * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. - * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. - */ - -struct TypeCloner -{ - TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) - : dest(dest) - , typeId(typeId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) + ClonePublicInterface(const TxnLog* log, NotNull singletonTypes, Module* module) + : Substitution(log, &module->interfaceTypes) + , singletonTypes(singletonTypes) + , module(module) { + LUAU_ASSERT(module); } - TypeArena& dest; - TypeId typeId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - - bool* encounteredFreeType = nullptr; - - template - void defaultClone(const T& t); - - void operator()(const Unifiable::Free& t); - void operator()(const Unifiable::Generic& t); - void operator()(const Unifiable::Bound& t); - void operator()(const Unifiable::Error& t); - void operator()(const PrimitiveTypeVar& t); - void operator()(const FunctionTypeVar& t); - void operator()(const TableTypeVar& t); - void operator()(const MetatableTypeVar& t); - void operator()(const ClassTypeVar& t); - void operator()(const AnyTypeVar& t); - void operator()(const UnionTypeVar& t); - void operator()(const IntersectionTypeVar& t); - void operator()(const LazyTypeVar& t); -}; - -struct TypePackCloner -{ - TypeArena& dest; - TypePackId typePackId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - bool* encounteredFreeType = nullptr; - - TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) - : dest(dest) - , typePackId(typePackId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) + bool isDirty(TypeId ty) override { + if (ty->owningArena == &module->internalTypes) + return true; + + if (const FunctionTypeVar* ftv = get(ty)) + return ftv->level.level != 0; + if (const TableTypeVar* ttv = get(ty)) + return ttv->level.level != 0; + return false; } - template - void defaultClone(const T& t) + bool isDirty(TypePackId tp) override { - TypePackId cloned = dest.typePacks.allocate(t); - seenTypePacks[typePackId] = cloned; + return tp->owningArena == &module->internalTypes; } - void operator()(const Unifiable::Free& t) + TypeId clean(TypeId ty) override { - if (encounteredFreeType) - *encounteredFreeType = true; + TypeId result = clone(ty); - seenTypePacks[typePackId] = dest.typePacks.allocate(TypePackVar{Unifiable::Error{}}); + if (FunctionTypeVar* ftv = getMutable(result)) + ftv->level = TypeLevel{0, 0}; + else if (TableTypeVar* ttv = getMutable(result)) + ttv->level = TypeLevel{0, 0}; + + return result; } - void operator()(const Unifiable::Generic& t) + TypePackId clean(TypePackId tp) override { - defaultClone(t); - } - void operator()(const Unifiable::Error& t) - { - defaultClone(t); + return clone(tp); } - // While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. - // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. - void operator()(const Unifiable::Bound& t) + TypeId cloneType(TypeId ty) { - TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); - seenTypePacks[typePackId] = cloned; - } + LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); - void operator()(const VariadicTypePack& t) - { - TypePackId cloned = dest.typePacks.allocate(VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const TypePack& t) - { - TypePackId cloned = dest.typePacks.allocate(TypePack{}); - TypePack* destTp = getMutable(cloned); - LUAU_ASSERT(destTp != nullptr); - seenTypePacks[typePackId] = cloned; - - for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); - - if (t.tail) - destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, encounteredFreeType); - } -}; - -template -void TypeCloner::defaultClone(const T& t) -{ - TypeId cloned = dest.typeVars.allocate(t); - seenTypes[typeId] = cloned; -} - -void TypeCloner::operator()(const Unifiable::Free& t) -{ - if (encounteredFreeType) - *encounteredFreeType = true; - - seenTypes[typeId] = dest.typeVars.allocate(ErrorTypeVar{}); -} - -void TypeCloner::operator()(const Unifiable::Generic& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const Unifiable::Bound& t) -{ - TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); - seenTypes[typeId] = boundTo; -} - -void TypeCloner::operator()(const Unifiable::Error& t) -{ - defaultClone(t); -} -void TypeCloner::operator()(const PrimitiveTypeVar& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const FunctionTypeVar& t) -{ - TypeId result = dest.typeVars.allocate(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); - FunctionTypeVar* ftv = getMutable(result); - LUAU_ASSERT(ftv != nullptr); - - seenTypes[typeId] = result; - - for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, encounteredFreeType)); - - for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, encounteredFreeType)); - - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ftv->tags = t.tags; - - ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, encounteredFreeType); - ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, encounteredFreeType); -} - -void TypeCloner::operator()(const TableTypeVar& t) -{ - TypeId result = dest.typeVars.allocate(TableTypeVar{}); - TableTypeVar* ttv = getMutable(result); - LUAU_ASSERT(ttv != nullptr); - - *ttv = t; - - seenTypes[typeId] = result; - - ttv->level = TypeLevel{0, 0}; - - for (const auto& [name, prop] : t.props) - { - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; - else - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location}; - } - - if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), - clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)}; - - if (t.boundTo) - ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); - - for (TypeId& arg : ttv->instantiatedTypeParams) - arg = (clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType)); - - if (ttv->state == TableState::Free) - { - if (!t.boundTo) + std::optional result = substitute(ty); + if (result) { - if (encounteredFreeType) - *encounteredFreeType = true; + return *result; + } + else + { + module->errors.push_back(TypeError{module->scopes[0].first, UnificationTooComplex{}}); + return singletonTypes->errorRecoveryType(); + } + } + + TypePackId cloneTypePack(TypePackId tp) + { + LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); + + std::optional result = substitute(tp); + if (result) + { + return *result; + } + else + { + module->errors.push_back(TypeError{module->scopes[0].first, UnificationTooComplex{}}); + return singletonTypes->errorRecoveryTypePack(); + } + } + + TypeFun cloneTypeFun(const TypeFun& tf) + { + LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields); + + std::vector typeParams; + std::vector typePackParams; + + for (GenericTypeDefinition typeParam : tf.typeParams) + { + TypeId ty = cloneType(typeParam.ty); + std::optional defaultValue; + + if (typeParam.defaultValue) + defaultValue = cloneType(*typeParam.defaultValue); + + typeParams.push_back(GenericTypeDefinition{ty, defaultValue}); } - ttv->state = TableState::Sealed; + for (GenericTypePackDefinition typePackParam : tf.typePackParams) + { + TypePackId tp = cloneTypePack(typePackParam.tp); + std::optional defaultValue; + + if (typePackParam.defaultValue) + defaultValue = cloneTypePack(*typePackParam.defaultValue); + + typePackParams.push_back(GenericTypePackDefinition{tp, defaultValue}); + } + + TypeId type = cloneType(tf.type); + + return TypeFun{typeParams, typePackParams, type}; } +}; - ttv->definitionModuleName = t.definitionModuleName; - ttv->methodDefinitionLocations = t.methodDefinitionLocations; - ttv->tags = t.tags; +Module::~Module() +{ + unfreeze(interfaceTypes); + unfreeze(internalTypes); } -void TypeCloner::operator()(const MetatableTypeVar& t) +void Module::clonePublicInterface(NotNull singletonTypes, InternalErrorReporter& ice) { - TypeId result = dest.typeVars.allocate(MetatableTypeVar{}); - MetatableTypeVar* mtv = getMutable(result); - seenTypes[typeId] = result; + LUAU_ASSERT(interfaceTypes.typeVars.empty()); + LUAU_ASSERT(interfaceTypes.typePacks.empty()); - mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, encounteredFreeType); - mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType); -} + CloneState cloneState; -void TypeCloner::operator()(const ClassTypeVar& t) -{ - TypeId result = dest.typeVars.allocate(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); - ClassTypeVar* ctv = getMutable(result); + ScopePtr moduleScope = getModuleScope(); - seenTypes[typeId] = result; + TypePackId returnType = moduleScope->returnType; + std::optional varargPack = FFlag::DebugLuauDeferredConstraintResolution ? std::nullopt : moduleScope->varargPack; + std::unordered_map* exportedTypeBindings = &moduleScope->exportedTypeBindings; - for (const auto& [name, prop] : t.props) - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; + TxnLog log; + ClonePublicInterface clonePublicInterface{&log, singletonTypes, this}; + + if (FFlag::LuauClonePublicInterfaceLess) + returnType = clonePublicInterface.cloneTypePack(returnType); + else + returnType = clone(returnType, interfaceTypes, cloneState); + + moduleScope->returnType = returnType; + if (varargPack) + { + if (FFlag::LuauClonePublicInterfaceLess) + varargPack = clonePublicInterface.cloneTypePack(*varargPack); else - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location}; - - if (t.parent) - ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, encounteredFreeType); - - if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType); -} - -void TypeCloner::operator()(const AnyTypeVar& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const UnionTypeVar& t) -{ - TypeId result = dest.typeVars.allocate(UnionTypeVar{}); - seenTypes[typeId] = result; - - UnionTypeVar* option = getMutable(result); - LUAU_ASSERT(option != nullptr); - - for (TypeId ty : t.options) - option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); -} - -void TypeCloner::operator()(const IntersectionTypeVar& t) -{ - TypeId result = dest.typeVars.allocate(IntersectionTypeVar{}); - seenTypes[typeId] = result; - - IntersectionTypeVar* option = getMutable(result); - LUAU_ASSERT(option != nullptr); - - for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); -} - -void TypeCloner::operator()(const LazyTypeVar& t) -{ - defaultClone(t); -} - -} // anonymous namespace - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) -{ - if (tp->persistent) - return tp; - - TypePackId& res = seenTypePacks[tp]; - - if (res == nullptr) - { - TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks}; - cloner.encounteredFreeType = encounteredFreeType; - Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. + varargPack = clone(*varargPack, interfaceTypes, cloneState); + moduleScope->varargPack = varargPack; } - if (FFlag::DebugLuauTrackOwningArena) - asMutable(res)->owningArena = &dest; - - return res; -} - -TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) -{ - if (typeId->persistent) - return typeId; - - TypeId& res = seenTypes[typeId]; - - if (res == nullptr) + if (exportedTypeBindings) { - TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks}; - cloner.encounteredFreeType = encounteredFreeType; - Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - asMutable(res)->documentationSymbol = typeId->documentationSymbol; + for (auto& [name, tf] : *exportedTypeBindings) + { + if (FFlag::LuauClonePublicInterfaceLess) + tf = clonePublicInterface.cloneTypeFun(tf); + else + tf = clone(tf, interfaceTypes, cloneState); + } } - if (FFlag::DebugLuauTrackOwningArena) - asMutable(res)->owningArena = &dest; + for (auto& [name, ty] : declaredGlobals) + { + if (FFlag::LuauClonePublicInterfaceLess) + ty = clonePublicInterface.cloneType(ty); + else + ty = clone(ty, interfaceTypes, cloneState); + } - return res; -} - -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) -{ - TypeFun result; - for (TypeId param : typeFun.typeParams) - result.typeParams.push_back(clone(param, dest, seenTypes, seenTypePacks, encounteredFreeType)); - - result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType); - - return result; + freeze(internalTypes); + freeze(interfaceTypes); } ScopePtr Module::getModuleScope() const @@ -465,57 +238,4 @@ ScopePtr Module::getModuleScope() const return scopes.front().second; } -void freeze(TypeArena& arena) -{ - if (!FFlag::DebugLuauFreezeArena) - return; - - arena.typeVars.freeze(); - arena.typePacks.freeze(); -} - -void unfreeze(TypeArena& arena) -{ - if (!FFlag::DebugLuauFreezeArena) - return; - - arena.typeVars.unfreeze(); - arena.typePacks.unfreeze(); -} - -Module::~Module() -{ - unfreeze(interfaceTypes); - unfreeze(internalTypes); -} - -bool Module::clonePublicInterface() -{ - LUAU_ASSERT(interfaceTypes.typeVars.empty()); - LUAU_ASSERT(interfaceTypes.typePacks.empty()); - - bool encounteredFreeType = false; - - SeenTypePacks seenTypePacks; - SeenTypes seenTypes; - - ScopePtr moduleScope = getModuleScope(); - - moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); - if (moduleScope->varargPack) - moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); - - for (auto& pair : moduleScope->exportedTypeBindings) - pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); - - for (TypeId ty : moduleScope->returnType) - if (get(follow(ty))) - *asMutable(ty) = AnyTypeVar{}; - - freeze(internalTypes); - freeze(interfaceTypes); - - return encounteredFreeType; -} - } // namespace Luau diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp new file mode 100644 index 00000000..09f0595d --- /dev/null +++ b/Analysis/src/Normalize.cpp @@ -0,0 +1,2194 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Normalize.h" +#include "Luau/ToString.h" + +#include + +#include "Luau/Clone.h" +#include "Luau/Common.h" +#include "Luau/RecursionCounter.h" +#include "Luau/TypeVar.h" +#include "Luau/Unifier.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) +LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) + +// This could theoretically be 2000 on amd64, but x86 requires this. +LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); +LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); +LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); +LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauOverloadedFunctionSubtypingPerf); +LUAU_FASTFLAG(LuauUninhabitedSubAnything2) + +namespace Luau +{ +void TypeIds::insert(TypeId ty) +{ + ty = follow(ty); + auto [_, fresh] = types.insert(ty); + if (fresh) + { + order.push_back(ty); + hash ^= std::hash{}(ty); + } +} + +void TypeIds::clear() +{ + order.clear(); + types.clear(); + hash = 0; +} + +TypeId TypeIds::front() const +{ + return order.at(0); +} + +TypeIds::iterator TypeIds::begin() +{ + return order.begin(); +} + +TypeIds::iterator TypeIds::end() +{ + return order.end(); +} + +TypeIds::const_iterator TypeIds::begin() const +{ + return order.begin(); +} + +TypeIds::const_iterator TypeIds::end() const +{ + return order.end(); +} + +TypeIds::iterator TypeIds::erase(TypeIds::const_iterator it) +{ + TypeId ty = *it; + types.erase(ty); + hash ^= std::hash{}(ty); + return order.erase(it); +} + +size_t TypeIds::size() const +{ + return types.size(); +} + +bool TypeIds::empty() const +{ + return types.empty(); +} + +size_t TypeIds::count(TypeId ty) const +{ + ty = follow(ty); + return types.count(ty); +} + +void TypeIds::retain(const TypeIds& there) +{ + for (auto it = begin(); it != end();) + { + if (there.count(*it)) + it++; + else + it = erase(it); + } +} + +size_t TypeIds::getHash() const +{ + return hash; +} + +bool TypeIds::operator==(const TypeIds& there) const +{ + return hash == there.hash && types == there.types; +} + +NormalizedStringType::NormalizedStringType() +{} + +NormalizedStringType::NormalizedStringType(bool isCofinite, std::map singletons) + : isCofinite(isCofinite) + , singletons(std::move(singletons)) +{ +} + +void NormalizedStringType::resetToString() +{ + isCofinite = true; + singletons.clear(); +} + +void NormalizedStringType::resetToNever() +{ + isCofinite = false; + singletons.clear(); +} + +bool NormalizedStringType::isNever() const +{ + return !isCofinite && singletons.empty(); +} + +bool NormalizedStringType::isString() const +{ + return isCofinite && singletons.empty(); +} + +bool NormalizedStringType::isUnion() const +{ + return !isCofinite; +} + +bool NormalizedStringType::isIntersection() const +{ + return isCofinite; +} + +bool NormalizedStringType::includes(const std::string& str) const +{ + if (isString()) + return true; + else if (isUnion() && singletons.count(str)) + return true; + else if (isIntersection() && !singletons.count(str)) + return true; + else + return false; +} + +const NormalizedStringType NormalizedStringType::never; + +bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr) +{ + if (subStr.isUnion() && superStr.isUnion()) + { + for (auto [name, ty] : subStr.singletons) + { + if (!superStr.singletons.count(name)) + return false; + } + } + else if (subStr.isString() && superStr.isUnion()) + return false; + + return true; +} + +NormalizedFunctionType::NormalizedFunctionType() + : parts(FFlag::LuauNegatedFunctionTypes ? std::optional{TypeIds{}} : std::nullopt) +{ +} + +void NormalizedFunctionType::resetToTop() +{ + isTop = true; + parts.emplace(); +} + +void NormalizedFunctionType::resetToNever() +{ + isTop = false; + parts.emplace(); +} + +bool NormalizedFunctionType::isNever() const +{ + return !isTop && (!parts || parts->empty()); +} + +NormalizedType::NormalizedType(NotNull singletonTypes) + : tops(singletonTypes->neverType) + , booleans(singletonTypes->neverType) + , errors(singletonTypes->neverType) + , nils(singletonTypes->neverType) + , numbers(singletonTypes->neverType) + , strings{NormalizedStringType::never} + , threads(singletonTypes->neverType) +{ +} + +static bool isShallowInhabited(const NormalizedType& norm) +{ + // This test is just a shallow check, for example it returns `true` for `{ p : never }` + return !get(norm.tops) || !get(norm.booleans) || !norm.classes.empty() || !get(norm.errors) || + !get(norm.nils) || !get(norm.numbers) || !norm.strings.isNever() || !get(norm.threads) || + !norm.functions.isNever() || !norm.tables.empty() || !norm.tyvars.empty(); +} + +bool isInhabited_DEPRECATED(const NormalizedType& norm) +{ + LUAU_ASSERT(!FFlag::LuauUninhabitedSubAnything2); + return isShallowInhabited(norm); +} + +bool Normalizer::isInhabited(const NormalizedType* norm, std::unordered_set seen) +{ + // If normalization failed, the type is complex, and so is more likely than not to be inhabited. + if (!norm) + return true; + + if (!get(norm->tops) || !get(norm->booleans) || !get(norm->errors) || + !get(norm->nils) || !get(norm->numbers) || !get(norm->threads) || + !norm->classes.empty() || !norm->strings.isNever() || !norm->functions.isNever()) + return true; + + for (const auto& [_, intersect] : norm->tyvars) + { + if (isInhabited(intersect.get(), seen)) + return true; + } + + for (TypeId table : norm->tables) + { + if (isInhabited(table, seen)) + return true; + } + + return false; +} + +bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) +{ + // TODO: use log.follow(ty), CLI-64291 + ty = follow(ty); + + if (get(ty)) + return false; + + if (!get(ty) && !get(ty) && !get(ty) && !get(ty)) + return true; + + if (seen.count(ty)) + return true; + + seen.insert(ty); + + if (const TableTypeVar* ttv = get(ty)) + { + for (const auto& [_, prop] : ttv->props) + { + if (!isInhabited(prop.type, seen)) + return false; + } + return true; + } + + if (const MetatableTypeVar* mtv = get(ty)) + return isInhabited(mtv->table, seen) && isInhabited(mtv->metatable, seen); + + const NormalizedType* norm = normalize(ty); + return isInhabited(norm, seen); +} + +static int tyvarIndex(TypeId ty) +{ + if (const GenericTypeVar* gtv = get(ty)) + return gtv->index; + else if (const FreeTypeVar* ftv = get(ty)) + return ftv->index; + else + return 0; +} + +#ifdef LUAU_ASSERTENABLED + +static bool isNormalizedTop(TypeId ty) +{ + return get(ty) || get(ty) || get(ty); +} + +static bool isNormalizedBoolean(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::Boolean; + else if (const SingletonTypeVar* stv = get(ty)) + return get(stv); + else + return false; +} + +static bool isNormalizedError(TypeId ty) +{ + if (get(ty) || get(ty)) + return true; + else + return false; +} + +static bool isNormalizedNil(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::NilType; + else + return false; +} + +static bool isNormalizedNumber(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::Number; + else + return false; +} + +static bool isNormalizedString(const NormalizedStringType& ty) +{ + if (ty.isString()) + return true; + + for (auto& [str, ty] : ty.singletons) + { + if (const SingletonTypeVar* stv = get(ty)) + { + if (const StringSingleton* sstv = get(stv)) + { + if (sstv->value != str) + return false; + } + else + return false; + } + else + return false; + } + + return true; +} + +static bool isNormalizedThread(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::Thread; + else + return false; +} + +static bool areNormalizedFunctions(const NormalizedFunctionType& tys) +{ + if (tys.parts) + { + for (TypeId ty : *tys.parts) + { + if (!get(ty) && !get(ty)) + return false; + } + } + return true; +} + +static bool areNormalizedTables(const TypeIds& tys) +{ + for (TypeId ty : tys) + if (!get(ty) && !get(ty)) + return false; + return true; +} + +static bool areNormalizedClasses(const TypeIds& tys) +{ + for (TypeId ty : tys) + if (!get(ty)) + return false; + return true; +} + +static bool isPlainTyvar(TypeId ty) +{ + return (get(ty) || get(ty)); +} + +static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) +{ + for (auto& [tyvar, intersect] : tyvars) + { + if (!isPlainTyvar(tyvar)) + return false; + if (!isShallowInhabited(*intersect)) + return false; + for (auto& [other, _] : intersect->tyvars) + if (tyvarIndex(other) <= tyvarIndex(tyvar)) + return false; + } + return true; +} + +#endif // LUAU_ASSERTENABLED + +static void assertInvariant(const NormalizedType& norm) +{ +#ifdef LUAU_ASSERTENABLED + if (!FFlag::DebugLuauCheckNormalizeInvariant) + return; + + LUAU_ASSERT(isNormalizedTop(norm.tops)); + LUAU_ASSERT(isNormalizedBoolean(norm.booleans)); + LUAU_ASSERT(areNormalizedClasses(norm.classes)); + LUAU_ASSERT(isNormalizedError(norm.errors)); + LUAU_ASSERT(isNormalizedNil(norm.nils)); + LUAU_ASSERT(isNormalizedNumber(norm.numbers)); + LUAU_ASSERT(isNormalizedString(norm.strings)); + LUAU_ASSERT(isNormalizedThread(norm.threads)); + LUAU_ASSERT(areNormalizedFunctions(norm.functions)); + LUAU_ASSERT(areNormalizedTables(norm.tables)); + LUAU_ASSERT(isNormalizedTyvar(norm.tyvars)); + for (auto& [_, child] : norm.tyvars) + assertInvariant(*child); +#endif +} + +Normalizer::Normalizer(TypeArena* arena, NotNull singletonTypes, NotNull sharedState) + : arena(arena) + , singletonTypes(singletonTypes) + , sharedState(sharedState) +{ +} + +const NormalizedType* Normalizer::normalize(TypeId ty) +{ + if (!arena) + sharedState->iceHandler->ice("Normalizing types outside a module"); + + auto found = cachedNormals.find(ty); + if (found != cachedNormals.end()) + return found->second.get(); + + NormalizedType norm{singletonTypes}; + if (!unionNormalWithTy(norm, ty)) + return nullptr; + std::unique_ptr uniq = std::make_unique(std::move(norm)); + const NormalizedType* result = uniq.get(); + cachedNormals[ty] = std::move(uniq); + return result; +} + +void Normalizer::clearNormal(NormalizedType& norm) +{ + norm.tops = singletonTypes->neverType; + norm.booleans = singletonTypes->neverType; + norm.classes.clear(); + norm.errors = singletonTypes->neverType; + norm.nils = singletonTypes->neverType; + norm.numbers = singletonTypes->neverType; + norm.strings.resetToNever(); + norm.threads = singletonTypes->neverType; + norm.tables.clear(); + norm.functions.resetToNever(); + norm.tyvars.clear(); +} + +// ------- Cached TypeIds +const TypeIds* Normalizer::cacheTypeIds(TypeIds tys) +{ + auto found = cachedTypeIds.find(&tys); + if (found != cachedTypeIds.end()) + return found->first; + + std::unique_ptr uniq = std::make_unique(std::move(tys)); + const TypeIds* result = uniq.get(); + cachedTypeIds[result] = std::move(uniq); + return result; +} + +TypeId Normalizer::unionType(TypeId here, TypeId there) +{ + here = follow(here); + there = follow(there); + + if (here == there) + return here; + if (get(here) || get(there)) + return there; + if (get(there) || get(here)) + return here; + + TypeIds tmps; + + if (const UnionTypeVar* utv = get(here)) + { + TypeIds heres; + heres.insert(begin(utv), end(utv)); + tmps.insert(heres.begin(), heres.end()); + cachedUnions[cacheTypeIds(std::move(heres))] = here; + } + else + tmps.insert(here); + + if (const UnionTypeVar* utv = get(there)) + { + TypeIds theres; + theres.insert(begin(utv), end(utv)); + tmps.insert(theres.begin(), theres.end()); + cachedUnions[cacheTypeIds(std::move(theres))] = there; + } + else + tmps.insert(there); + + auto cacheHit = cachedUnions.find(&tmps); + if (cacheHit != cachedUnions.end()) + return cacheHit->second; + + std::vector parts; + parts.insert(parts.end(), tmps.begin(), tmps.end()); + TypeId result = arena->addType(UnionTypeVar{std::move(parts)}); + cachedUnions[cacheTypeIds(std::move(tmps))] = result; + + return result; +} + +TypeId Normalizer::intersectionType(TypeId here, TypeId there) +{ + here = follow(here); + there = follow(there); + + if (here == there) + return here; + if (get(here) || get(there)) + return here; + if (get(there) || get(here)) + return there; + + TypeIds tmps; + + if (const IntersectionTypeVar* utv = get(here)) + { + TypeIds heres; + heres.insert(begin(utv), end(utv)); + tmps.insert(heres.begin(), heres.end()); + cachedIntersections[cacheTypeIds(std::move(heres))] = here; + } + else + tmps.insert(here); + + if (const IntersectionTypeVar* utv = get(there)) + { + TypeIds theres; + theres.insert(begin(utv), end(utv)); + tmps.insert(theres.begin(), theres.end()); + cachedIntersections[cacheTypeIds(std::move(theres))] = there; + } + else + tmps.insert(there); + + if (tmps.size() == 1) + return *tmps.begin(); + + auto cacheHit = cachedIntersections.find(&tmps); + if (cacheHit != cachedIntersections.end()) + return cacheHit->second; + + std::vector parts; + parts.insert(parts.end(), tmps.begin(), tmps.end()); + TypeId result = arena->addType(IntersectionTypeVar{std::move(parts)}); + cachedIntersections[cacheTypeIds(std::move(tmps))] = result; + + return result; +} + +void Normalizer::clearCaches() +{ + cachedNormals.clear(); + cachedIntersections.clear(); + cachedUnions.clear(); + cachedTypeIds.clear(); +} + +// ------- Normalizing unions +TypeId Normalizer::unionOfTops(TypeId here, TypeId there) +{ + if (get(here) || get(there)) + return there; + else + return here; +} + +TypeId Normalizer::unionOfBools(TypeId here, TypeId there) +{ + if (get(here)) + return there; + if (get(there)) + return here; + if (const BooleanSingleton* hbool = get(get(here))) + if (const BooleanSingleton* tbool = get(get(there))) + if (hbool->value == tbool->value) + return here; + return singletonTypes->booleanType; +} + +void Normalizer::unionClassesWithClass(TypeIds& heres, TypeId there) +{ + if (heres.count(there)) + return; + + const ClassTypeVar* tctv = get(there); + + for (auto it = heres.begin(); it != heres.end();) + { + TypeId here = *it; + const ClassTypeVar* hctv = get(here); + if (isSubclass(tctv, hctv)) + return; + else if (isSubclass(hctv, tctv)) + it = heres.erase(it); + else + it++; + } + + heres.insert(there); +} + +void Normalizer::unionClasses(TypeIds& heres, const TypeIds& theres) +{ + for (TypeId there : theres) + unionClassesWithClass(heres, there); +} + +void Normalizer::unionStrings(NormalizedStringType& here, const NormalizedStringType& there) +{ + if (there.isString()) + here.resetToString(); + else if (here.isUnion() && there.isUnion()) + here.singletons.insert(there.singletons.begin(), there.singletons.end()); + else if (here.isUnion() && there.isIntersection()) + { + here.isCofinite = true; + for (const auto& pair : there.singletons) + { + auto it = here.singletons.find(pair.first); + if (it != end(here.singletons)) + here.singletons.erase(it); + else + here.singletons.insert(pair); + } + } + else if (here.isIntersection() && there.isUnion()) + { + for (const auto& [name, ty] : there.singletons) + here.singletons.erase(name); + } + else if (here.isIntersection() && there.isIntersection()) + { + auto iter = begin(here.singletons); + auto endIter = end(here.singletons); + + while (iter != endIter) + { + if (!there.singletons.count(iter->first)) + { + auto eraseIt = iter; + ++iter; + here.singletons.erase(eraseIt); + } + else + ++iter; + } + } + else + LUAU_ASSERT(!"Unreachable"); +} + +std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePackId there) +{ + if (here == there) + return here; + + std::vector head; + std::optional tail; + + bool hereSubThere = true; + bool thereSubHere = true; + + TypePackIterator ith = begin(here); + TypePackIterator itt = begin(there); + + while (ith != end(here) && itt != end(there)) + { + TypeId hty = *ith; + TypeId tty = *itt; + TypeId ty = unionType(hty, tty); + if (ty != hty) + thereSubHere = false; + if (ty != tty) + hereSubThere = false; + head.push_back(ty); + ith++; + itt++; + } + + auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, + bool& thereSubHere) { + if (ith != end(here)) + { + TypeId tty = singletonTypes->nilType; + if (std::optional ttail = itt.tail()) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + tty = tvtp->ty; + else + // Luau doesn't have unions of type pack variables + return false; + } + else + // Type packs of different arities are incomparable + return false; + + while (ith != end(here)) + { + TypeId hty = *ith; + TypeId ty = unionType(hty, tty); + if (ty != hty) + thereSubHere = false; + if (ty != tty) + hereSubThere = false; + head.push_back(ty); + ith++; + } + } + return true; + }; + + if (!dealWithDifferentArities(ith, itt, here, there, hereSubThere, thereSubHere)) + return std::nullopt; + + if (!dealWithDifferentArities(itt, ith, there, here, thereSubHere, hereSubThere)) + return std::nullopt; + + if (std::optional htail = ith.tail()) + { + if (std::optional ttail = itt.tail()) + { + if (*htail == *ttail) + tail = htail; + else if (const VariadicTypePack* hvtp = get(*htail)) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + { + TypeId ty = unionType(hvtp->ty, tvtp->ty); + if (ty != hvtp->ty) + thereSubHere = false; + if (ty != tvtp->ty) + hereSubThere = false; + bool hidden = hvtp->hidden & tvtp->hidden; + tail = arena->addTypePack(VariadicTypePack{ty, hidden}); + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (get(*htail)) + { + hereSubThere = false; + tail = htail; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (std::optional ttail = itt.tail()) + { + if (get(*ttail)) + { + thereSubHere = false; + tail = htail; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + + if (hereSubThere) + return there; + else if (thereSubHere) + return here; + if (!head.empty()) + return arena->addTypePack(TypePack{head, tail}); + else if (tail) + return *tail; + else + // TODO: Add an emptyPack to singleton types + return arena->addTypePack({}); +} + +std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) +{ + if (get(here)) + return here; + + if (get(there)) + return there; + + const FunctionTypeVar* hftv = get(here); + LUAU_ASSERT(hftv); + const FunctionTypeVar* tftv = get(there); + LUAU_ASSERT(tftv); + + if (hftv->generics != tftv->generics) + return std::nullopt; + if (hftv->genericPacks != tftv->genericPacks) + return std::nullopt; + + std::optional argTypes = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypes) + return std::nullopt; + + std::optional retTypes = unionOfTypePacks(hftv->retTypes, tftv->retTypes); + if (!retTypes) + return std::nullopt; + + if (*argTypes == hftv->argTypes && *retTypes == hftv->retTypes) + return here; + if (*argTypes == tftv->argTypes && *retTypes == tftv->retTypes) + return there; + + FunctionTypeVar result{*argTypes, *retTypes}; + result.generics = hftv->generics; + result.genericPacks = hftv->genericPacks; + return arena->addType(std::move(result)); +} + +void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) +{ + if (FFlag::LuauNegatedFunctionTypes) + { + if (heres.isTop) + return; + if (theres.isTop) + heres.resetToTop(); + } + + if (theres.isNever()) + return; + + TypeIds tmps; + + if (heres.isNever()) + { + tmps.insert(theres.parts->begin(), theres.parts->end()); + heres.parts = std::move(tmps); + return; + } + + for (TypeId here : *heres.parts) + for (TypeId there : *theres.parts) + { + if (std::optional fun = unionOfFunctions(here, there)) + tmps.insert(*fun); + else + tmps.insert(singletonTypes->errorRecoveryType(there)); + } + + heres.parts = std::move(tmps); +} + +void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) +{ + if (heres.isNever()) + { + TypeIds tmps; + tmps.insert(there); + heres.parts = std::move(tmps); + return; + } + + TypeIds tmps; + for (TypeId here : *heres.parts) + { + if (std::optional fun = unionOfFunctions(here, there)) + tmps.insert(*fun); + else + tmps.insert(singletonTypes->errorRecoveryType(there)); + } + heres.parts = std::move(tmps); +} + +void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) +{ + // TODO: remove unions of tables where possible + heres.insert(there); +} + +void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) +{ + for (TypeId there : theres) + unionTablesWithTable(heres, there); +} + +// So why `ignoreSmallerTyvars`? +// +// First up, what it does... Every tyvar has an index, and this parameter says to ignore +// any tyvars in `there` if their index is less than or equal to the parameter. +// The parameter is always greater than any tyvars mentioned in here, so the result is +// a lower bound on any tyvars in `here.tyvars`. +// +// This is used to maintain in invariant, which is that in any tyvar `X&T`, any any tyvar +// `Y&U` in `T`, the index of `X` is less than the index of `Y`. This is an implementation +// of *ordered decision diagrams* (https://en.wikipedia.org/wiki/Binary_decision_diagram#Variable_ordering) +// which are a compression technique used to save memory usage when representing boolean formulae. +// +// The idea is that if you have an out-of-order decision diagram +// like `Z&(X|Y)`, to re-order it in this case to `(X&Z)|(Y&Z)`. +// The hope is that by imposing a global order, there's a higher chance of sharing opportunities, +// and hence reduced memory. +// +// And yes, this is essentially a SAT solver hidden inside a typechecker. +// That's what you get for having a type system with generics, intersection and union types. +bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) +{ + TypeId tops = unionOfTops(here.tops, there.tops); + if (!get(tops)) + { + clearNormal(here); + here.tops = tops; + return true; + } + + for (auto it = there.tyvars.begin(); it != there.tyvars.end(); it++) + { + TypeId tyvar = it->first; + const NormalizedType& inter = *it->second; + int index = tyvarIndex(tyvar); + if (index <= ignoreSmallerTyvars) + continue; + auto [emplaced, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{singletonTypes})); + if (fresh) + if (!unionNormals(*emplaced->second, here, index)) + return false; + if (!unionNormals(*emplaced->second, inter, index)) + return false; + } + + here.booleans = unionOfBools(here.booleans, there.booleans); + unionClasses(here.classes, there.classes); + here.errors = (get(there.errors) ? here.errors : there.errors); + here.nils = (get(there.nils) ? here.nils : there.nils); + here.numbers = (get(there.numbers) ? here.numbers : there.numbers); + unionStrings(here.strings, there.strings); + here.threads = (get(there.threads) ? here.threads : there.threads); + unionFunctions(here.functions, there.functions); + unionTables(here.tables, there.tables); + return true; +} + +bool Normalizer::withinResourceLimits() +{ + // If cache is too large, clear it + if (FInt::LuauNormalizeCacheLimit > 0) + { + size_t cacheUsage = cachedNormals.size() + cachedIntersections.size() + cachedUnions.size() + cachedTypeIds.size(); + if (cacheUsage > size_t(FInt::LuauNormalizeCacheLimit)) + { + clearCaches(); + return false; + } + } + + // Check the recursion count + if (sharedState->counters.recursionLimit > 0) + if (sharedState->counters.recursionLimit < sharedState->counters.recursionCount) + return false; + + return true; +} + +// See above for an explaination of `ignoreSmallerTyvars`. +bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars) +{ + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return false; + + there = follow(there); + if (get(there) || get(there)) + { + TypeId tops = unionOfTops(here.tops, there); + clearNormal(here); + here.tops = tops; + return true; + } + else if (get(there) || !get(here.tops)) + return true; + else if (const UnionTypeVar* utv = get(there)) + { + for (UnionTypeVarIterator it = begin(utv); it != end(utv); ++it) + if (!unionNormalWithTy(here, *it)) + return false; + return true; + } + else if (const IntersectionTypeVar* itv = get(there)) + { + NormalizedType norm{singletonTypes}; + norm.tops = singletonTypes->anyType; + for (IntersectionTypeVarIterator it = begin(itv); it != end(itv); ++it) + if (!intersectNormalWithTy(norm, *it)) + return false; + return unionNormals(here, norm); + } + else if (get(there) || get(there)) + { + if (tyvarIndex(there) <= ignoreSmallerTyvars) + return true; + NormalizedType inter{singletonTypes}; + inter.tops = singletonTypes->unknownType; + here.tyvars.insert_or_assign(there, std::make_unique(std::move(inter))); + } + else if (get(there)) + unionFunctionsWithFunction(here.functions, there); + else if (get(there) || get(there)) + unionTablesWithTable(here.tables, there); + else if (get(there)) + unionClassesWithClass(here.classes, there); + else if (get(there)) + here.errors = there; + else if (const PrimitiveTypeVar* ptv = get(there)) + { + if (ptv->type == PrimitiveTypeVar::Boolean) + here.booleans = there; + else if (ptv->type == PrimitiveTypeVar::NilType) + here.nils = there; + else if (ptv->type == PrimitiveTypeVar::Number) + here.numbers = there; + else if (ptv->type == PrimitiveTypeVar::String) + here.strings.resetToString(); + else if (ptv->type == PrimitiveTypeVar::Thread) + here.threads = there; + else if (ptv->type == PrimitiveTypeVar::Function) + { + LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); + here.functions.resetToTop(); + } + else + LUAU_ASSERT(!"Unreachable"); + } + else if (const SingletonTypeVar* stv = get(there)) + { + if (get(stv)) + here.booleans = unionOfBools(here.booleans, there); + else if (const StringSingleton* sstv = get(stv)) + { + if (here.strings.isCofinite) + { + auto it = here.strings.singletons.find(sstv->value); + if (it != here.strings.singletons.end()) + here.strings.singletons.erase(it); + } + else + here.strings.singletons.insert({sstv->value, there}); + } + else + LUAU_ASSERT(!"Unreachable"); + } + else if (const NegationTypeVar* ntv = get(there)) + { + const NormalizedType* thereNormal = normalize(ntv->ty); + std::optional tn = negateNormal(*thereNormal); + if (!tn) + return false; + + if (!unionNormals(here, *tn)) + return false; + } + else + LUAU_ASSERT(!"Unreachable"); + + for (auto& [tyvar, intersect] : here.tyvars) + if (!unionNormalWithTy(*intersect, there, tyvarIndex(tyvar))) + return false; + + assertInvariant(here); + return true; +} + +// ------- Negations + +std::optional Normalizer::negateNormal(const NormalizedType& here) +{ + NormalizedType result{singletonTypes}; + if (!get(here.tops)) + { + // The negation of unknown or any is never. Easy. + return result; + } + + if (!get(here.errors)) + { + // Negating an error yields the same error. + result.errors = here.errors; + return result; + } + + if (get(here.booleans)) + result.booleans = singletonTypes->booleanType; + else if (get(here.booleans)) + result.booleans = singletonTypes->neverType; + else if (auto stv = get(here.booleans)) + { + auto boolean = get(stv); + LUAU_ASSERT(boolean != nullptr); + if (boolean->value) + result.booleans = singletonTypes->falseType; + else + result.booleans = singletonTypes->trueType; + } + + result.classes = negateAll(here.classes); + result.nils = get(here.nils) ? singletonTypes->nilType : singletonTypes->neverType; + result.numbers = get(here.numbers) ? singletonTypes->numberType : singletonTypes->neverType; + + result.strings = here.strings; + result.strings.isCofinite = !result.strings.isCofinite; + + result.threads = get(here.threads) ? singletonTypes->threadType : singletonTypes->neverType; + + /* + * Things get weird and so, so complicated if we allow negations of + * arbitrary function types. Ordinary code can never form these kinds of + * types, so we decline to negate them. + */ + if (FFlag::LuauNegatedFunctionTypes) + { + if (here.functions.isNever()) + result.functions.resetToTop(); + else if (here.functions.isTop) + result.functions.resetToNever(); + else + return std::nullopt; + } + + // TODO: negating tables + // TODO: negating tyvars? + + return result; +} + +TypeIds Normalizer::negateAll(const TypeIds& theres) +{ + TypeIds tys; + for (TypeId there : theres) + tys.insert(negate(there)); + return tys; +} + +TypeId Normalizer::negate(TypeId there) +{ + there = follow(there); + if (get(there)) + return there; + else if (get(there)) + return singletonTypes->neverType; + else if (get(there)) + return singletonTypes->unknownType; + else if (auto ntv = get(there)) + return ntv->ty; // TODO: do we want to normalize this? + else if (auto utv = get(there)) + { + std::vector parts; + for (TypeId option : utv) + parts.push_back(negate(option)); + return arena->addType(IntersectionTypeVar{std::move(parts)}); + } + else if (auto itv = get(there)) + { + std::vector options; + for (TypeId part : itv) + options.push_back(negate(part)); + return arena->addType(UnionTypeVar{std::move(options)}); + } + else + return there; +} + +void Normalizer::subtractPrimitive(NormalizedType& here, TypeId ty) +{ + const PrimitiveTypeVar* ptv = get(follow(ty)); + LUAU_ASSERT(ptv); + switch (ptv->type) + { + case PrimitiveTypeVar::NilType: + here.nils = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Boolean: + here.booleans = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Number: + here.numbers = singletonTypes->neverType; + break; + case PrimitiveTypeVar::String: + here.strings.resetToNever(); + break; + case PrimitiveTypeVar::Thread: + here.threads = singletonTypes->neverType; + break; + case PrimitiveTypeVar::Function: + here.functions.resetToNever(); + break; + } +} + +void Normalizer::subtractSingleton(NormalizedType& here, TypeId ty) +{ + const SingletonTypeVar* stv = get(ty); + LUAU_ASSERT(stv); + + if (const StringSingleton* ss = get(stv)) + { + if (here.strings.isCofinite) + here.strings.singletons.insert({ss->value, ty}); + else + { + auto it = here.strings.singletons.find(ss->value); + if (it != here.strings.singletons.end()) + here.strings.singletons.erase(it); + } + } + else if (const BooleanSingleton* bs = get(stv)) + { + if (get(here.booleans)) + { + // Nothing + } + else if (get(here.booleans)) + here.booleans = bs->value ? singletonTypes->falseType : singletonTypes->trueType; + else if (auto hereSingleton = get(here.booleans)) + { + const BooleanSingleton* hereBooleanSingleton = get(hereSingleton); + LUAU_ASSERT(hereBooleanSingleton); + + // Crucial subtlety: ty (and thus bs) are the value that is being + // negated out. We therefore reduce to never when the values match, + // rather than when they differ. + if (bs->value == hereBooleanSingleton->value) + here.booleans = singletonTypes->neverType; + } + else + LUAU_ASSERT(!"Unreachable"); + } + else + LUAU_ASSERT(!"Unreachable"); +} + +// ------- Normalizing intersections +TypeId Normalizer::intersectionOfTops(TypeId here, TypeId there) +{ + if (get(here) || get(there)) + return here; + else + return there; +} + +TypeId Normalizer::intersectionOfBools(TypeId here, TypeId there) +{ + if (get(here)) + return here; + if (get(there)) + return there; + if (const BooleanSingleton* hbool = get(get(here))) + if (const BooleanSingleton* tbool = get(get(there))) + return (hbool->value == tbool->value ? here : singletonTypes->neverType); + else + return here; + else + return there; +} + +void Normalizer::intersectClasses(TypeIds& heres, const TypeIds& theres) +{ + TypeIds tmp; + for (auto it = heres.begin(); it != heres.end();) + { + const ClassTypeVar* hctv = get(*it); + LUAU_ASSERT(hctv); + bool keep = false; + for (TypeId there : theres) + { + const ClassTypeVar* tctv = get(there); + LUAU_ASSERT(tctv); + if (isSubclass(hctv, tctv)) + { + keep = true; + break; + } + else if (isSubclass(tctv, hctv)) + { + keep = false; + tmp.insert(there); + break; + } + } + if (keep) + it++; + else + it = heres.erase(it); + } + heres.insert(tmp.begin(), tmp.end()); +} + +void Normalizer::intersectClassesWithClass(TypeIds& heres, TypeId there) +{ + bool foundSuper = false; + const ClassTypeVar* tctv = get(there); + LUAU_ASSERT(tctv); + for (auto it = heres.begin(); it != heres.end();) + { + const ClassTypeVar* hctv = get(*it); + LUAU_ASSERT(hctv); + if (isSubclass(hctv, tctv)) + it++; + else if (isSubclass(tctv, hctv)) + { + foundSuper = true; + break; + } + else + it = heres.erase(it); + } + if (foundSuper) + { + heres.clear(); + heres.insert(there); + } +} + +void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) +{ + if (there.isString()) + return; + if (here.isString()) + here.resetToNever(); + + for (auto it = here.singletons.begin(); it != here.singletons.end();) + { + if (there.singletons.count(it->first)) + it++; + else + it = here.singletons.erase(it); + } +} + +std::optional Normalizer::intersectionOfTypePacks(TypePackId here, TypePackId there) +{ + if (here == there) + return here; + + std::vector head; + std::optional tail; + + bool hereSubThere = true; + bool thereSubHere = true; + + TypePackIterator ith = begin(here); + TypePackIterator itt = begin(there); + + while (ith != end(here) && itt != end(there)) + { + TypeId hty = *ith; + TypeId tty = *itt; + TypeId ty = intersectionType(hty, tty); + if (ty != hty) + hereSubThere = false; + if (ty != tty) + thereSubHere = false; + head.push_back(ty); + ith++; + itt++; + } + + auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, + bool& thereSubHere) { + if (ith != end(here)) + { + TypeId tty = singletonTypes->nilType; + if (std::optional ttail = itt.tail()) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + tty = tvtp->ty; + else + // Luau doesn't have intersections of type pack variables + return false; + } + else + // Type packs of different arities are incomparable + return false; + + while (ith != end(here)) + { + TypeId hty = *ith; + TypeId ty = intersectionType(hty, tty); + if (ty != hty) + hereSubThere = false; + if (ty != tty) + thereSubHere = false; + head.push_back(ty); + ith++; + } + } + return true; + }; + + if (!dealWithDifferentArities(ith, itt, here, there, hereSubThere, thereSubHere)) + return std::nullopt; + + if (!dealWithDifferentArities(itt, ith, there, here, thereSubHere, hereSubThere)) + return std::nullopt; + + if (std::optional htail = ith.tail()) + { + if (std::optional ttail = itt.tail()) + { + if (*htail == *ttail) + tail = htail; + else if (const VariadicTypePack* hvtp = get(*htail)) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + { + TypeId ty = intersectionType(hvtp->ty, tvtp->ty); + if (ty != hvtp->ty) + thereSubHere = false; + if (ty != tvtp->ty) + hereSubThere = false; + bool hidden = hvtp->hidden & tvtp->hidden; + tail = arena->addTypePack(VariadicTypePack{ty, hidden}); + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (get(*htail)) + hereSubThere = false; + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (std::optional ttail = itt.tail()) + { + if (get(*ttail)) + thereSubHere = false; + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + + if (hereSubThere) + return here; + else if (thereSubHere) + return there; + if (!head.empty()) + return arena->addTypePack(TypePack{head, tail}); + else if (tail) + return *tail; + else + // TODO: Add an emptyPack to singleton types + return arena->addTypePack({}); +} + +std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there) +{ + if (here == there) + return here; + + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (sharedState->counters.recursionLimit > 0 && sharedState->counters.recursionLimit < sharedState->counters.recursionCount) + return std::nullopt; + + TypeId htable = here; + TypeId hmtable = nullptr; + if (const MetatableTypeVar* hmtv = get(here)) + { + htable = hmtv->table; + hmtable = hmtv->metatable; + } + TypeId ttable = there; + TypeId tmtable = nullptr; + if (const MetatableTypeVar* tmtv = get(there)) + { + ttable = tmtv->table; + tmtable = tmtv->metatable; + } + + const TableTypeVar* httv = get(htable); + LUAU_ASSERT(httv); + const TableTypeVar* tttv = get(ttable); + LUAU_ASSERT(tttv); + + if (httv->state == TableState::Free || tttv->state == TableState::Free) + return std::nullopt; + if (httv->state == TableState::Generic || tttv->state == TableState::Generic) + return std::nullopt; + + TableState state = httv->state; + if (tttv->state == TableState::Unsealed) + state = tttv->state; + + TypeLevel level = max(httv->level, tttv->level); + TableTypeVar result{state, level}; + + bool hereSubThere = true; + bool thereSubHere = true; + + for (const auto& [name, hprop] : httv->props) + { + Property prop = hprop; + auto tfound = tttv->props.find(name); + if (tfound == tttv->props.end()) + thereSubHere = false; + else + { + const auto& [_name, tprop] = *tfound; + // TODO: variance issues here, which can't be fixed until we have read/write property types + prop.type = intersectionType(hprop.type, tprop.type); + hereSubThere &= (prop.type == hprop.type); + thereSubHere &= (prop.type == tprop.type); + } + // TODO: string indexers + result.props[name] = prop; + } + + for (const auto& [name, tprop] : tttv->props) + { + if (httv->props.count(name) == 0) + { + result.props[name] = tprop; + hereSubThere = false; + } + } + + if (httv->indexer && tttv->indexer) + { + // TODO: What should intersection of indexes be? + TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType); + TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType); + result.indexer = {index, indexResult}; + hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult); + thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult); + } + else if (httv->indexer) + { + result.indexer = httv->indexer; + thereSubHere = false; + } + else if (tttv->indexer) + { + result.indexer = tttv->indexer; + hereSubThere = false; + } + + TypeId table; + if (hereSubThere) + table = htable; + else if (thereSubHere) + table = ttable; + else + table = arena->addType(std::move(result)); + + if (tmtable && hmtable) + { + // NOTE: this assumes metatables are ivariant + if (std::optional mtable = intersectionOfTables(hmtable, tmtable)) + { + if (table == htable && *mtable == hmtable) + return here; + else if (table == ttable && *mtable == tmtable) + return there; + else + return arena->addType(MetatableTypeVar{table, *mtable}); + } + else + return std::nullopt; + } + else if (hmtable) + { + if (table == htable) + return here; + else + return arena->addType(MetatableTypeVar{table, hmtable}); + } + else if (tmtable) + { + if (table == ttable) + return there; + else + return arena->addType(MetatableTypeVar{table, tmtable}); + } + else + return table; +} + +void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there) +{ + TypeIds tmp; + for (TypeId here : heres) + if (std::optional inter = intersectionOfTables(here, there)) + tmp.insert(*inter); + heres.retain(tmp); + heres.insert(tmp.begin(), tmp.end()); +} + +void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres) +{ + TypeIds tmp; + for (TypeId here : heres) + for (TypeId there : theres) + if (std::optional inter = intersectionOfTables(here, there)) + tmp.insert(*inter); + heres.retain(tmp); + heres.insert(tmp.begin(), tmp.end()); +} + +std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId there) +{ + const FunctionTypeVar* hftv = get(here); + LUAU_ASSERT(hftv); + const FunctionTypeVar* tftv = get(there); + LUAU_ASSERT(tftv); + + if (hftv->generics != tftv->generics) + return std::nullopt; + if (hftv->genericPacks != tftv->genericPacks) + return std::nullopt; + + TypePackId argTypes; + TypePackId retTypes; + + if (hftv->retTypes == tftv->retTypes) + { + std::optional argTypesOpt = unionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypesOpt) + return std::nullopt; + argTypes = *argTypesOpt; + retTypes = hftv->retTypes; + } + else if (FFlag::LuauOverloadedFunctionSubtypingPerf && hftv->argTypes == tftv->argTypes) + { + std::optional retTypesOpt = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!retTypesOpt) + return std::nullopt; + argTypes = hftv->argTypes; + retTypes = *retTypesOpt; + } + else + return std::nullopt; + + if (argTypes == hftv->argTypes && retTypes == hftv->retTypes) + return here; + if (argTypes == tftv->argTypes && retTypes == tftv->retTypes) + return there; + + FunctionTypeVar result{argTypes, retTypes}; + result.generics = hftv->generics; + result.genericPacks = hftv->genericPacks; + return arena->addType(std::move(result)); +} + +std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId there) +{ + // Deep breath... + // + // When we come to check overloaded functions for subtyping, + // we have to compare (F1 & ... & FM) <: (G1 & ... G GN) + // where each Fi or Gj is a function type. Now that intersection on the right is no + // problem, since that's true if and only if (F1 & ... & FM) <: Gj for every j. + // But the intersection on the left is annoying, since we might have + // (F1 & ... & FM) <: G but no Fi <: G. For example + // + // ((number? -> number?) & (string? -> string?)) <: (nil -> nil) + // + // So in this case, what we do is define Apply for the result of applying + // a function of type F to an argument of type T, and then F <: (T -> U) + // if and only if Apply <: U. For example: + // + // if f : ((number? -> number?) & (string? -> string?)) + // then f(nil) must be nil, so + // Apply<((number? -> number?) & (string? -> string?)), nil> is nil + // + // So subtyping on overloaded functions "just" boils down to defining Apply. + // + // Now for non-overloaded functions, this is easy! + // Apply<(R -> S), T> is S if T <: R, and an error type otherwise. + // + // But for overloaded functions it's not so simple. We'd like Apply + // to just be Apply & ... & Apply but oh dear + // + // if f : ((number -> number) & (string -> string)) + // and x : (number | string) + // then f(x) : (number | string) + // + // so we want + // + // Apply<((number -> number) & (string -> string)), (number | string)> is (number | string) + // + // but + // + // Apply<(number -> number), (number | string)> is an error + // Apply<(string -> string), (number | string)> is an error + // + // that is Apply should consider all possible combinations of overloads of F, + // not just individual overloads. + // + // For this reason, when we're normalizing function types (in order to check subtyping + // or perform overload resolution) we should first *union-saturate* them. An overloaded + // function is union-saturated whenever: + // + // if (R -> S) is an overload of F + // and (T -> U) is an overload of F + // then ((R | T) -> (S | U)) is a subtype of an overload of F + // + // Any overloaded function can be normalized to a union-saturated one by adding enough extra overloads. + // For example, union-saturating + // + // ((number -> number) & (string -> string)) + // + // is + // + // ((number -> number) & (string -> string) & ((number | string) -> (number | string))) + // + // For union-saturated overloaded functions, the "obvious" algorithm works: + // + // Apply is Apply & ... & Apply + // + // so we can define Apply, so we can perform overloaded function resolution + // and check subtyping on overloaded function types, yay! + // + // This is yet another potential source of exponential blow-up, sigh, since + // the union-saturation of a function with N overloads may have 2^N overloads + // (one for every subset). In practice, that hopefully won't happen that often, + // in particular we only union-saturate overloads with different return types, + // and there are hopefully not very many cases of that. + // + // All of this is mechanically verified in Agda, at https://github.com/luau-lang/agda-typeck + // + // It is essentially the algorithm defined in https://pnwamk.github.io/sst-tutorial/ + // except that we're precomputing the union-saturation rather than converting + // to disjunctive normal form on the fly. + // + // This is all built on semantic subtyping: + // + // Covariance and Contravariance, Giuseppe Castagna, + // Logical Methods in Computer Science 16(1), 2022 + // https://arxiv.org/abs/1809.01427 + // + // A gentle introduction to semantic subtyping, Giuseppe Castagna and Alain Frisch, + // Proc. Principles and practice of declarative programming 2005, pp 198–208 + // https://doi.org/10.1145/1069774.1069793 + + const FunctionTypeVar* hftv = get(here); + if (!hftv) + return std::nullopt; + const FunctionTypeVar* tftv = get(there); + if (!tftv) + return std::nullopt; + + if (hftv->generics != tftv->generics) + return std::nullopt; + if (hftv->genericPacks != tftv->genericPacks) + return std::nullopt; + + std::optional argTypes = unionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypes) + return std::nullopt; + std::optional retTypes = unionOfTypePacks(hftv->retTypes, tftv->retTypes); + if (!retTypes) + return std::nullopt; + + FunctionTypeVar result{*argTypes, *retTypes}; + result.generics = hftv->generics; + result.genericPacks = hftv->genericPacks; + return arena->addType(std::move(result)); +} + +void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) +{ + if (heres.isNever()) + return; + + heres.isTop = false; + + for (auto it = heres.parts->begin(); it != heres.parts->end();) + { + TypeId here = *it; + if (get(here)) + it++; + else if (std::optional tmp = intersectionOfFunctions(here, there)) + { + heres.parts->erase(it); + heres.parts->insert(*tmp); + return; + } + else + it++; + } + + TypeIds tmps; + for (TypeId here : *heres.parts) + { + if (std::optional tmp = unionSaturatedFunctions(here, there)) + tmps.insert(*tmp); + } + heres.parts->insert(there); + heres.parts->insert(tmps.begin(), tmps.end()); +} + +void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) +{ + if (heres.isNever()) + return; + else if (theres.isNever()) + { + heres.resetToNever(); + return; + } + else + { + for (TypeId there : *theres.parts) + intersectFunctionsWithFunction(heres, there); + } +} + +bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there) +{ + for (auto it = here.begin(); it != here.end();) + { + NormalizedType& inter = *it->second; + if (!intersectNormalWithTy(inter, there)) + return false; + if (isShallowInhabited(inter)) + ++it; + else + it = here.erase(it); + } + return true; +} + +// See above for an explaination of `ignoreSmallerTyvars`. +bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) +{ + if (!get(there.tops)) + { + here.tops = intersectionOfTops(here.tops, there.tops); + return true; + } + else if (!get(here.tops)) + { + clearNormal(here); + return unionNormals(here, there, ignoreSmallerTyvars); + } + + here.booleans = intersectionOfBools(here.booleans, there.booleans); + intersectClasses(here.classes, there.classes); + here.errors = (get(there.errors) ? there.errors : here.errors); + here.nils = (get(there.nils) ? there.nils : here.nils); + here.numbers = (get(there.numbers) ? there.numbers : here.numbers); + intersectStrings(here.strings, there.strings); + here.threads = (get(there.threads) ? there.threads : here.threads); + intersectFunctions(here.functions, there.functions); + intersectTables(here.tables, there.tables); + + for (auto& [tyvar, inter] : there.tyvars) + { + int index = tyvarIndex(tyvar); + if (ignoreSmallerTyvars < index) + { + auto [found, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{singletonTypes})); + if (fresh) + { + if (!unionNormals(*found->second, here, index)) + return false; + } + } + } + for (auto it = here.tyvars.begin(); it != here.tyvars.end();) + { + TypeId tyvar = it->first; + NormalizedType& inter = *it->second; + int index = tyvarIndex(tyvar); + LUAU_ASSERT(ignoreSmallerTyvars < index); + auto found = there.tyvars.find(tyvar); + if (found == there.tyvars.end()) + { + if (!intersectNormals(inter, there, index)) + return false; + } + else + { + if (!intersectNormals(inter, *found->second, index)) + return false; + } + if (isShallowInhabited(inter)) + it++; + else + it = here.tyvars.erase(it); + } + return true; +} + +bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) +{ + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return false; + + there = follow(there); + if (get(there) || get(there)) + { + here.tops = intersectionOfTops(here.tops, there); + return true; + } + else if (!get(here.tops)) + { + clearNormal(here); + return unionNormalWithTy(here, there); + } + else if (const UnionTypeVar* utv = get(there)) + { + NormalizedType norm{singletonTypes}; + for (UnionTypeVarIterator it = begin(utv); it != end(utv); ++it) + if (!unionNormalWithTy(norm, *it)) + return false; + return intersectNormals(here, norm); + } + else if (const IntersectionTypeVar* itv = get(there)) + { + for (IntersectionTypeVarIterator it = begin(itv); it != end(itv); ++it) + if (!intersectNormalWithTy(here, *it)) + return false; + return true; + } + else if (get(there) || get(there)) + { + NormalizedType thereNorm{singletonTypes}; + NormalizedType topNorm{singletonTypes}; + topNorm.tops = singletonTypes->unknownType; + thereNorm.tyvars.insert_or_assign(there, std::make_unique(std::move(topNorm))); + return intersectNormals(here, thereNorm); + } + + NormalizedTyvars tyvars = std::move(here.tyvars); + + if (const FunctionTypeVar* utv = get(there)) + { + NormalizedFunctionType functions = std::move(here.functions); + clearNormal(here); + intersectFunctionsWithFunction(functions, there); + here.functions = std::move(functions); + } + else if (get(there) || get(there)) + { + TypeIds tables = std::move(here.tables); + clearNormal(here); + intersectTablesWithTable(tables, there); + here.tables = std::move(tables); + } + else if (get(there)) + { + TypeIds classes = std::move(here.classes); + clearNormal(here); + intersectClassesWithClass(classes, there); + here.classes = std::move(classes); + } + else if (get(there)) + { + TypeId errors = here.errors; + clearNormal(here); + here.errors = errors; + } + else if (const PrimitiveTypeVar* ptv = get(there)) + { + TypeId booleans = here.booleans; + TypeId nils = here.nils; + TypeId numbers = here.numbers; + NormalizedStringType strings = std::move(here.strings); + NormalizedFunctionType functions = std::move(here.functions); + TypeId threads = here.threads; + + clearNormal(here); + + if (ptv->type == PrimitiveTypeVar::Boolean) + here.booleans = booleans; + else if (ptv->type == PrimitiveTypeVar::NilType) + here.nils = nils; + else if (ptv->type == PrimitiveTypeVar::Number) + here.numbers = numbers; + else if (ptv->type == PrimitiveTypeVar::String) + here.strings = std::move(strings); + else if (ptv->type == PrimitiveTypeVar::Thread) + here.threads = threads; + else if (ptv->type == PrimitiveTypeVar::Function) + { + LUAU_ASSERT(FFlag::LuauNegatedFunctionTypes); + here.functions = std::move(functions); + } + else + LUAU_ASSERT(!"Unreachable"); + } + else if (const SingletonTypeVar* stv = get(there)) + { + TypeId booleans = here.booleans; + NormalizedStringType strings = std::move(here.strings); + + clearNormal(here); + + if (get(stv)) + here.booleans = intersectionOfBools(booleans, there); + else if (const StringSingleton* sstv = get(stv)) + { + if (strings.includes(sstv->value)) + here.strings.singletons.insert({sstv->value, there}); + } + else + LUAU_ASSERT(!"Unreachable"); + } + else if (const NegationTypeVar* ntv = get(there)) + { + TypeId t = follow(ntv->ty); + if (const PrimitiveTypeVar* ptv = get(t)) + subtractPrimitive(here, ntv->ty); + else if (const SingletonTypeVar* stv = get(t)) + subtractSingleton(here, follow(ntv->ty)); + else if (const UnionTypeVar* itv = get(t)) + { + for (TypeId part : itv->options) + { + const NormalizedType* normalPart = normalize(part); + std::optional negated = negateNormal(*normalPart); + if (!negated) + return false; + intersectNormals(here, *negated); + } + } + else + { + // TODO negated unions, intersections, table, and function. + // Report a TypeError for other types. + LUAU_ASSERT(!"Unimplemented"); + } + } + else + LUAU_ASSERT(!"Unreachable"); + + if (!intersectTyvarsWithTy(tyvars, there)) + return false; + here.tyvars = std::move(tyvars); + + return true; +} + +// -------- Convert back from a normalized type to a type +TypeId Normalizer::typeFromNormal(const NormalizedType& norm) +{ + assertInvariant(norm); + if (!get(norm.tops)) + return norm.tops; + + std::vector result; + + if (!get(norm.booleans)) + result.push_back(norm.booleans); + result.insert(result.end(), norm.classes.begin(), norm.classes.end()); + if (!get(norm.errors)) + result.push_back(norm.errors); + if (FFlag::LuauNegatedFunctionTypes && norm.functions.isTop) + result.push_back(singletonTypes->functionType); + else if (!norm.functions.isNever()) + { + if (norm.functions.parts->size() == 1) + result.push_back(*norm.functions.parts->begin()); + else + { + std::vector parts; + parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); + result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); + } + } + if (!get(norm.nils)) + result.push_back(norm.nils); + if (!get(norm.numbers)) + result.push_back(norm.numbers); + if (norm.strings.isString()) + result.push_back(singletonTypes->stringType); + else if (norm.strings.isUnion()) + { + for (auto& [_, ty] : norm.strings.singletons) + result.push_back(ty); + } + else if (norm.strings.isIntersection()) + { + std::vector parts; + parts.push_back(singletonTypes->stringType); + for (const auto& [name, ty] : norm.strings.singletons) + parts.push_back(arena->addType(NegationTypeVar{ty})); + + result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); + } + if (!get(norm.threads)) + result.push_back(singletonTypes->threadType); + + result.insert(result.end(), norm.tables.begin(), norm.tables.end()); + for (auto& [tyvar, intersect] : norm.tyvars) + { + if (get(intersect->tops)) + { + TypeId ty = typeFromNormal(*intersect); + result.push_back(arena->addType(IntersectionTypeVar{{tyvar, ty}})); + } + else + result.push_back(tyvar); + } + + if (result.size() == 0) + return singletonTypes->neverType; + else if (result.size() == 1) + return result[0]; + else + return arena->addType(UnionTypeVar{std::move(result)}); +} + +bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) +{ + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + + u.tryUnify(subTy, superTy); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; +} + +bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull singletonTypes, InternalErrorReporter& ice) +{ + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + + u.tryUnify(subPack, superPack); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; +} + +} // namespace Luau diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/Predicate.cpp deleted file mode 100644 index 25e63bff..00000000 --- a/Analysis/src/Predicate.cpp +++ /dev/null @@ -1,93 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Predicate.h" - -#include "Luau/Ast.h" - -LUAU_FASTFLAG(LuauOrPredicate) - -namespace Luau -{ - -std::optional tryGetLValue(const AstExpr& node) -{ - const AstExpr* expr = &node; - while (auto e = expr->as()) - expr = e->expr; - - if (auto local = expr->as()) - return Symbol{local->local}; - else if (auto global = expr->as()) - return Symbol{global->name}; - else if (auto indexname = expr->as()) - { - if (auto lvalue = tryGetLValue(*indexname->expr)) - return Field{std::make_shared(*lvalue), indexname->index.value}; - } - else if (auto indexexpr = expr->as()) - { - if (auto lvalue = tryGetLValue(*indexexpr->expr)) - if (auto string = indexexpr->expr->as()) - return Field{std::make_shared(*lvalue), std::string(string->value.data, string->value.size)}; - } - - return std::nullopt; -} - -std::pair> getFullName(const LValue& lvalue) -{ - const LValue* current = &lvalue; - std::vector keys; - while (auto field = get(*current)) - { - keys.push_back(field->key); - current = field->parent.get(); - if (!current) - LUAU_ASSERT(!"LValue root is a Field?"); - } - - const Symbol* symbol = get(*current); - return {*symbol, std::vector(keys.rbegin(), keys.rend())}; -} - -std::string toString(const LValue& lvalue) -{ - auto [symbol, keys] = getFullName(lvalue); - std::string s = toString(symbol); - for (std::string key : keys) - s += "." + key; - return s; -} - -void merge(RefinementMap& l, const RefinementMap& r, std::function f) -{ - LUAU_ASSERT(FFlag::LuauOrPredicate); - - auto itL = l.begin(); - auto itR = r.begin(); - while (itL != l.end() && itR != r.end()) - { - const auto& [k, a] = *itR; - if (itL->first == k) - { - l[k] = f(itL->second, a); - ++itL; - ++itR; - } - else if (itL->first > k) - { - l[k] = a; - ++itR; - } - else - ++itL; - } - - l.insert(itR, r.end()); -} - -void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty) -{ - refis[toString(lvalue)] = ty; -} - -} // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp new file mode 100644 index 00000000..e9de094b --- /dev/null +++ b/Analysis/src/Quantify.cpp @@ -0,0 +1,259 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Quantify.h" + +#include "Luau/Scope.h" +#include "Luau/Substitution.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +LUAU_FASTFLAG(DebugLuauSharedSelf) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) + +namespace Luau +{ + +struct Quantifier final : TypeVarOnceVisitor +{ + TypeLevel level; + std::vector generics; + std::vector genericPacks; + Scope* scope = nullptr; + bool seenGenericType = false; + bool seenMutableType = false; + + explicit Quantifier(TypeLevel level) + : level(level) + { + LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); + } + + /// @return true if outer encloses inner + bool subsumes(Scope* outer, Scope* inner) + { + while (inner) + { + if (inner == outer) + return true; + inner = inner->parent.get(); + } + + return false; + } + + bool visit(TypeId ty, const FreeTypeVar& ftv) override + { + seenMutableType = true; + + if (!level.subsumes(ftv.level)) + return false; + + *asMutable(ty) = GenericTypeVar{level}; + + generics.push_back(ty); + + return false; + } + + bool visit(TypeId ty, const TableTypeVar&) override + { + LUAU_ASSERT(getMutable(ty)); + TableTypeVar& ttv = *getMutable(ty); + + if (ttv.state == TableState::Generic) + seenGenericType = true; + + if (ttv.state == TableState::Free) + seenMutableType = true; + + if (!level.subsumes(ttv.level)) + { + if (ttv.state == TableState::Unsealed) + seenMutableType = true; + return false; + } + + if (ttv.state == TableState::Free) + { + ttv.state = TableState::Generic; + seenGenericType = true; + } + else if (ttv.state == TableState::Unsealed) + ttv.state = TableState::Sealed; + + ttv.level = level; + + return true; + } + + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + seenMutableType = true; + + if (!level.subsumes(ftp.level)) + return false; + + *asMutable(tp) = GenericTypePack{level}; + genericPacks.push_back(tp); + return true; + } +}; + +void quantify(TypeId ty, TypeLevel level) +{ + if (FFlag::DebugLuauSharedSelf) + { + ty = follow(ty); + + if (auto ttv = getTableType(ty); ttv && ttv->selfTy) + { + Quantifier selfQ{level}; + selfQ.traverse(*ttv->selfTy); + + Quantifier q{level}; + q.traverse(ty); + + for (const auto& [_, prop] : ttv->props) + { + auto ftv = getMutable(follow(prop.type)); + if (!ftv || !ftv->hasSelf) + continue; + + if (Luau::first(ftv->argTypes) == ttv->selfTy) + { + ftv->generics.insert(ftv->generics.end(), selfQ.generics.begin(), selfQ.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), selfQ.genericPacks.begin(), selfQ.genericPacks.end()); + } + } + } + else if (auto ftv = getMutable(ty)) + { + Quantifier q{level}; + q.traverse(ty); + + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + + if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + ftv->hasNoGenerics = true; + } + } + else + { + Quantifier q{level}; + q.traverse(ty); + + FunctionTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + } +} + +struct PureQuantifier : Substitution +{ + Scope* scope; + std::vector insertedGenerics; + std::vector insertedGenericPacks; + + PureQuantifier(TypeArena* arena, Scope* scope) + : Substitution(TxnLog::empty(), arena) + , scope(scope) + { + } + + bool isDirty(TypeId ty) override + { + LUAU_ASSERT(ty == follow(ty)); + + if (auto ftv = get(ty)) + { + return subsumes(scope, ftv->scope); + } + else if (auto ttv = get(ty)) + { + return ttv->state == TableState::Free && subsumes(scope, ttv->scope); + } + + return false; + } + + bool isDirty(TypePackId tp) override + { + if (auto ftp = get(tp)) + { + return subsumes(scope, ftp->scope); + } + + return false; + } + + TypeId clean(TypeId ty) override + { + if (auto ftv = get(ty)) + { + TypeId result = arena->addType(GenericTypeVar{scope}); + insertedGenerics.push_back(result); + return result; + } + else if (auto ttv = get(ty)) + { + TypeId result = arena->addType(TableTypeVar{}); + TableTypeVar* resultTable = getMutable(result); + LUAU_ASSERT(resultTable); + + *resultTable = *ttv; + resultTable->level = TypeLevel{}; + resultTable->scope = scope; + resultTable->state = TableState::Generic; + + return result; + } + + return ty; + } + + TypePackId clean(TypePackId tp) override + { + if (auto ftp = get(tp)) + { + TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{}}); + insertedGenericPacks.push_back(result); + return result; + } + + return tp; + } + + bool ignoreChildren(TypeId ty) override + { + if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + return true; + + return ty->persistent; + } + bool ignoreChildren(TypePackId ty) override + { + return ty->persistent; + } +}; + +TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope) +{ + PureQuantifier quantifier{arena, scope}; + std::optional result = quantifier.substitute(ty); + LUAU_ASSERT(result); + + FunctionTypeVar* ftv = getMutable(*result); + LUAU_ASSERT(ftv); + ftv->scope = scope; + ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); + ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty(); + + return *result; +} + +} // namespace Luau diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 5b3997e2..c036a7a5 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -4,187 +4,163 @@ #include "Luau/Ast.h" #include "Luau/Module.h" -LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) - namespace Luau { -namespace -{ - struct RequireTracer : AstVisitor { - explicit RequireTracer(FileResolver* fileResolver, ModuleName currentModuleName) - : fileResolver(fileResolver) - , currentModuleName(std::move(currentModuleName)) + RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName) + : result(result) + , fileResolver(fileResolver) + , currentModuleName(currentModuleName) + , locals(nullptr) { } - FileResolver* const fileResolver; - ModuleName currentModuleName; - DenseHashMap locals{0}; - RequireTraceResult result; - - std::optional fromAstFragment(AstExpr* expr) - { - if (auto g = expr->as(); g && g->name == "script") - return currentModuleName; - - return fileResolver->fromAstFragment(expr); - } - - bool visit(AstStatLocal* stat) override - { - for (size_t i = 0; i < stat->vars.size; ++i) - { - AstLocal* local = stat->vars.data[i]; - - if (local->annotation) - { - if (AstTypeTypeof* ann = local->annotation->as()) - ann->expr->visit(this); - } - - if (i < stat->values.size) - { - AstExpr* expr = stat->values.data[i]; - expr->visit(this); - - const ModuleName* name = result.exprs.find(expr); - if (name) - locals[local] = *name; - } - } - - return false; - } - - bool visit(AstExprGlobal* global) override - { - std::optional name = fromAstFragment(global); - if (name) - result.exprs[global] = *name; - - return false; - } - - bool visit(AstExprLocal* local) override - { - const ModuleName* name = locals.find(local->local); - if (name) - result.exprs[local] = *name; - - return false; - } - - bool visit(AstExprIndexName* indexName) override - { - indexName->expr->visit(this); - - const ModuleName* name = result.exprs.find(indexName->expr); - if (name) - { - if (indexName->index == "parent" || indexName->index == "Parent") - { - if (auto parent = fileResolver->getParentModuleName(*name)) - result.exprs[indexName] = *parent; - } - else - result.exprs[indexName] = fileResolver->concat(*name, indexName->index.value); - } - - return false; - } - - bool visit(AstExprIndexExpr* indexExpr) override - { - indexExpr->expr->visit(this); - - const ModuleName* name = result.exprs.find(indexExpr->expr); - const AstExprConstantString* str = indexExpr->index->as(); - if (name && str) - { - result.exprs[indexExpr] = fileResolver->concat(*name, std::string_view(str->value.data, str->value.size)); - } - - indexExpr->index->visit(this); - - return false; - } - bool visit(AstExprTypeAssertion* expr) override { + // suppress `require() :: any` return false; } - // If we see game:GetService("StringLiteral") or Game:GetService("StringLiteral"), then rewrite to game.StringLiteral. - // Else traverse arguments and trace requires to them. - bool visit(AstExprCall* call) override + bool visit(AstExprCall* expr) override { - for (AstExpr* arg : call->args) - arg->visit(this); + AstExprGlobal* global = expr->func->as(); - call->func->visit(this); + if (global && global->name == "require" && expr->args.size >= 1) + requireCalls.push_back(expr); - AstExprGlobal* globalName = call->func->as(); - if (globalName && globalName->name == "require" && call->args.size >= 1) - { - if (const ModuleName* moduleName = result.exprs.find(call->args.data[0])) - result.requires.push_back({*moduleName, call->location}); - - return false; - } - - AstExprIndexName* indexName = call->func->as(); - if (!indexName) - return false; - - std::optional rootName = fromAstFragment(indexName->expr); - - if (FFlag::LuauTraceRequireLookupChild && !rootName) - { - if (const ModuleName* moduleName = result.exprs.find(indexName->expr)) - rootName = *moduleName; - } - - if (!rootName) - return false; - - bool supportedLookup = indexName->index == "GetService" || - (FFlag::LuauTraceRequireLookupChild && (indexName->index == "FindFirstChild" || indexName->index == "WaitForChild")); - - if (!supportedLookup) - return false; - - if (call->args.size != 1) - return false; - - AstExprConstantString* name = call->args.data[0]->as(); - if (!name) - return false; - - std::string_view v{name->value.data, name->value.size}; - if (v.end() != std::find(v.begin(), v.end(), '/')) - return false; - - result.exprs[call] = fileResolver->concat(*rootName, v); - - // 'WaitForChild' can be used on modules that are not awailable at the typecheck time, but will be awailable at runtime - // If we fail to find such module, we will not report an UnknownRequire error - if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") - result.optional[call] = true; - - return false; + return true; } + + bool visit(AstStatLocal* stat) override + { + for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i) + { + AstLocal* local = stat->vars.data[i]; + AstExpr* expr = stat->values.data[i]; + + // track initializing expression to be able to trace modules through locals + locals[local] = expr; + } + + return true; + } + + bool visit(AstStatAssign* stat) override + { + for (size_t i = 0; i < stat->vars.size; ++i) + { + // locals that are assigned don't have a known expression + if (AstExprLocal* expr = stat->vars.data[i]->as()) + locals[expr->local] = nullptr; + } + + return true; + } + + bool visit(AstType* node) override + { + // allow resolving require inside `typeof` annotations + return true; + } + + AstExpr* getDependent(AstExpr* node) + { + if (AstExprLocal* expr = node->as()) + return locals[expr->local]; + else if (AstExprIndexName* expr = node->as()) + return expr->expr; + else if (AstExprIndexExpr* expr = node->as()) + return expr->expr; + else if (AstExprCall* expr = node->as(); expr && expr->self) + return expr->func->as()->expr; + else + return nullptr; + } + + void process() + { + ModuleInfo moduleContext{currentModuleName}; + + // seed worklist with require arguments + work.reserve(requireCalls.size()); + + for (AstExprCall* require : requireCalls) + work.push_back(require->args.data[0]); + + // push all dependent expressions to the work stack; note that the vector is modified during traversal + for (size_t i = 0; i < work.size(); ++i) + if (AstExpr* dep = getDependent(work[i])) + work.push_back(dep); + + // resolve all expressions to a module info + for (size_t i = work.size(); i > 0; --i) + { + AstExpr* expr = work[i - 1]; + + // when multiple expressions depend on the same one we push it to work queue multiple times + if (result.exprs.contains(expr)) + continue; + + std::optional info; + + if (AstExpr* dep = getDependent(expr)) + { + const ModuleInfo* context = result.exprs.find(dep); + + // locals just inherit their dependent context, no resolution required + if (expr->is()) + info = context ? std::optional(*context) : std::nullopt; + else + info = fileResolver->resolveModule(context, expr); + } + else + { + info = fileResolver->resolveModule(&moduleContext, expr); + } + + if (info) + result.exprs[expr] = std::move(*info); + } + + // resolve all requires according to their argument + result.requireList.reserve(requireCalls.size()); + + for (AstExprCall* require : requireCalls) + { + AstExpr* arg = require->args.data[0]; + + if (const ModuleInfo* info = result.exprs.find(arg)) + { + result.requireList.push_back({info->name, require->location}); + + ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info! + result.exprs[require] = std::move(infoCopy); + } + else + { + result.exprs[require] = {}; // mark require as unresolved + } + } + } + + RequireTraceResult& result; + FileResolver* fileResolver; + ModuleName currentModuleName; + + DenseHashMap locals; + std::vector work; + std::vector requireCalls; }; -} // anonymous namespace - -RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName) +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) { - RequireTracer tracer{fileResolver, std::move(currentModuleName)}; + RequireTraceResult result; + RequireTracer tracer{result, fileResolver, currentModuleName}; root->visit(&tracer); - return tracer.result; + tracer.process(); + return result; } } // namespace Luau diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp new file mode 100644 index 00000000..84925f79 --- /dev/null +++ b/Analysis/src/Scope.cpp @@ -0,0 +1,170 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Scope.h" + +namespace Luau +{ + +Scope::Scope(TypePackId returnType) + : parent(nullptr) + , returnType(returnType) + , level(TypeLevel()) +{ +} + +Scope::Scope(const ScopePtr& parent, int subLevel) + : parent(parent) + , returnType(parent->returnType) + , level(parent->level.incr()) +{ + level = level.incr(); + level.subLevel = subLevel; +} + +void Scope::addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun) +{ + exportedTypeBindings[name] = tyFun; + builtinTypeNames.insert(name); +} + +std::optional Scope::lookup(Symbol sym) const +{ + auto r = const_cast(this)->lookupEx(sym); + if (r) + return r->first; + else + return std::nullopt; +} + +std::optional> Scope::lookupEx(Symbol sym) +{ + Scope* s = this; + + while (true) + { + auto it = s->bindings.find(sym); + if (it != s->bindings.end()) + return std::pair{it->second.typeId, s}; + + if (s->parent) + s = s->parent.get(); + else + return std::nullopt; + } +} + +// TODO: We might kill Scope::lookup(Symbol) once data flow is fully fleshed out with type states and control flow analysis. +std::optional Scope::lookup(DefId def) const +{ + for (const Scope* current = this; current; current = current->parent.get()) + { + if (auto ty = current->dcrRefinements.find(def)) + return *ty; + } + + return std::nullopt; +} + +std::optional Scope::lookupType(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->exportedTypeBindings.find(name); + if (it != scope->exportedTypeBindings.end()) + return it->second; + + it = scope->privateTypeBindings.find(name); + if (it != scope->privateTypeBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) +{ + const Scope* scope = this; + while (scope) + { + auto it = scope->importedTypeBindings.find(moduleAlias); + if (it == scope->importedTypeBindings.end()) + { + scope = scope->parent.get(); + continue; + } + + auto it2 = it->second.find(name); + if (it2 == it->second.end()) + { + scope = scope->parent.get(); + continue; + } + + return it2->second; + } + + return std::nullopt; +} + +std::optional Scope::lookupPack(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->privateTypePackBindings.find(name); + if (it != scope->privateTypePackBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) const +{ + const Scope* scope = this; + + while (scope) + { + for (const auto& [n, binding] : scope->bindings) + { + if (n.local && n.local->name == name.c_str()) + return binding; + else if (n.global.value && n.global == name.c_str()) + return binding; + } + + scope = scope->parent.get(); + + if (!traverseScopeChain) + break; + } + + return std::nullopt; +} + +bool subsumesStrict(Scope* left, Scope* right) +{ + while (right) + { + if (right->parent.get() == left) + return true; + + right = right->parent.get(); + } + + return false; +} + +bool subsumes(Scope* left, Scope* right) +{ + return left == right || subsumesStrict(left, right); +} + +} // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 7223998a..20ed34f6 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -2,28 +2,44 @@ #include "Luau/Substitution.h" #include "Luau/Common.h" +#include "Luau/Clone.h" +#include "Luau/TxnLog.h" #include #include -LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 0) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) -LUAU_FASTFLAG(LuauRankNTypes) +LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false) +LUAU_FASTFLAG(LuauClonePublicInterfaceLess) +LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) +LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false) +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false) namespace Luau { void Tarjan::visitChildren(TypeId ty, int index) { - ty = follow(ty); + LUAU_ASSERT(ty == log->follow(ty)); - if (FFlag::LuauRankNTypes && ignoreChildren(ty)) + if (ignoreChildren(ty)) return; + if (auto pty = log->pending(ty)) + ty = &pty->pending; + if (const FunctionTypeVar* ftv = get(ty)) { + if (FFlag::LuauSubstitutionFixMissingFields) + { + for (TypeId generic : ftv->generics) + visitChild(generic); + for (TypePackId genericPack : ftv->genericPacks) + visitChild(genericPack); + } + visitChild(ftv->argTypes); - visitChild(ftv->retType); + visitChild(ftv->retTypes); } else if (const TableTypeVar* ttv = get(ty)) { @@ -35,8 +51,12 @@ void Tarjan::visitChildren(TypeId ty, int index) visitChild(ttv->indexer->indexType); visitChild(ttv->indexer->indexResultType); } + for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp); + + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp); } else if (const MetatableTypeVar* mtv = get(ty)) { @@ -53,15 +73,41 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypeId part : itv->parts) visitChild(part); } + else if (const PendingExpansionTypeVar* petv = get(ty)) + { + for (TypeId a : petv->typeArguments) + visitChild(a); + + for (TypePackId a : petv->packArguments) + visitChild(a); + } + else if (const ClassTypeVar* ctv = get(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + { + for (auto [name, prop] : ctv->props) + visitChild(prop.type); + + if (ctv->parent) + visitChild(*ctv->parent); + + if (ctv->metatable) + visitChild(*ctv->metatable); + } + else if (const NegationTypeVar* ntv = get(ty)) + { + visitChild(ntv->ty); + } } void Tarjan::visitChildren(TypePackId tp, int index) { - tp = follow(tp); + LUAU_ASSERT(tp == log->follow(tp)); - if (FFlag::LuauRankNTypes && ignoreChildren(tp)) + if (ignoreChildren(tp)) return; + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; + if (const TypePack* tpp = get(tp)) { for (TypeId tv : tpp->head) @@ -77,7 +123,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) std::pair Tarjan::indexify(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); bool fresh = !typeToIndex.contains(ty); int& index = typeToIndex[ty]; @@ -95,7 +141,7 @@ std::pair Tarjan::indexify(TypeId ty) std::pair Tarjan::indexify(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); bool fresh = !packToIndex.contains(tp); int& index = packToIndex[tp]; @@ -113,7 +159,7 @@ std::pair Tarjan::indexify(TypePackId tp) void Tarjan::visitChild(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); edgesTy.push_back(ty); edgesTp.push_back(nullptr); @@ -121,7 +167,7 @@ void Tarjan::visitChild(TypeId ty) void Tarjan::visitChild(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); edgesTy.push_back(nullptr); edgesTp.push_back(tp); @@ -138,7 +184,7 @@ TarjanResult Tarjan::loop() if (currEdge == -1) { ++childCount; - if (FInt::LuauTarjanChildLimit > 0 && FInt::LuauTarjanChildLimit < childCount) + if (childLimit > 0 && (FFlag::LuauUnknownAndNeverType ? childLimit <= childCount : childLimit < childCount)) return TarjanResult::TooManyChildren; stack.push_back(index); @@ -223,27 +269,14 @@ TarjanResult Tarjan::loop() return TarjanResult::Ok; } -void Tarjan::clear() -{ - typeToIndex.clear(); - indexToType.clear(); - packToIndex.clear(); - indexToPack.clear(); - lowlink.clear(); - stack.clear(); - onStack.clear(); - - edgesTy.clear(); - edgesTp.clear(); - worklist.clear(); -} - TarjanResult Tarjan::visitRoot(TypeId ty) { childCount = 0; - ty = follow(ty); + if (childLimit == 0) + childLimit = FInt::LuauTarjanChildLimit; + + ty = log->follow(ty); - clear(); auto [index, fresh] = indexify(ty); worklist.push_back({index, -1, -1}); return loop(); @@ -252,14 +285,34 @@ TarjanResult Tarjan::visitRoot(TypeId ty) TarjanResult Tarjan::visitRoot(TypePackId tp) { childCount = 0; - tp = follow(tp); + if (childLimit == 0) + childLimit = FInt::LuauTarjanChildLimit; + + tp = log->follow(tp); - clear(); auto [index, fresh] = indexify(tp); worklist.push_back({index, -1, -1}); return loop(); } +void FindDirty::clearTarjan() +{ + dirty.clear(); + + typeToIndex.clear(); + packToIndex.clear(); + indexToType.clear(); + indexToPack.clear(); + + stack.clear(); + onStack.clear(); + lowlink.clear(); + + edgesTy.clear(); + edgesTp.clear(); + worklist.clear(); +} + bool FindDirty::getDirty(int index) { if (dirty.size() <= size_t(index)) @@ -311,107 +364,122 @@ void FindDirty::visitSCC(int index) TarjanResult FindDirty::findDirty(TypeId ty) { - dirty.clear(); return visitRoot(ty); } TarjanResult FindDirty::findDirty(TypePackId tp) { - dirty.clear(); return visitRoot(tp); } std::optional Substitution::substitute(TypeId ty) { - ty = follow(ty); - newTypes.clear(); - newPacks.clear(); + ty = log->follow(ty); + + // clear algorithm state for reentrancy + if (FFlag::LuauSubstitutionReentrant) + clearTarjan(); auto result = findDirty(ty); if (result != TarjanResult::Ok) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - replaceChildren(newTy); + { + if (FFlag::LuauSubstitutionReentrant) + { + if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) + { + replaceChildren(newTy); + replacedTypes.insert(newTy); + } + } + else + { + if (!ignoreChildren(oldTy)) + replaceChildren(newTy); + } + } for (auto [oldTp, newTp] : newPacks) - replaceChildren(newTp); + { + if (FFlag::LuauSubstitutionReentrant) + { + if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) + { + replaceChildren(newTp); + replacedTypePacks.insert(newTp); + } + } + else + { + if (!ignoreChildren(oldTp)) + replaceChildren(newTp); + } + } TypeId newTy = replace(ty); return newTy; } std::optional Substitution::substitute(TypePackId tp) { - tp = follow(tp); - newTypes.clear(); - newPacks.clear(); + tp = log->follow(tp); + + // clear algorithm state for reentrancy + if (FFlag::LuauSubstitutionReentrant) + clearTarjan(); auto result = findDirty(tp); if (result != TarjanResult::Ok) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - replaceChildren(newTy); + { + if (FFlag::LuauSubstitutionReentrant) + { + if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) + { + replaceChildren(newTy); + replacedTypes.insert(newTy); + } + } + else + { + if (!ignoreChildren(oldTy)) + replaceChildren(newTy); + } + } for (auto [oldTp, newTp] : newPacks) - replaceChildren(newTp); + { + if (FFlag::LuauSubstitutionReentrant) + { + if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) + { + replaceChildren(newTp); + replacedTypePacks.insert(newTp); + } + } + else + { + if (!ignoreChildren(oldTp)) + replaceChildren(newTp); + } + } TypePackId newTp = replace(tp); return newTp; } TypeId Substitution::clone(TypeId ty) { - ty = follow(ty); - - TypeId result = ty; - - if (const FunctionTypeVar* ftv = get(ty)) - { - FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; - clone.generics = ftv->generics; - clone.genericPacks = ftv->genericPacks; - clone.magicFunction = ftv->magicFunction; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - result = addType(std::move(clone)); - } - else if (const TableTypeVar* ttv = get(ty)) - { - LUAU_ASSERT(!ttv->boundTo); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.instantiatedTypeParams = ttv->instantiatedTypeParams; - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - clone.tags = ttv->tags; - result = addType(std::move(clone)); - } - else if (const MetatableTypeVar* mtv = get(ty)) - { - MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; - clone.syntheticName = mtv->syntheticName; - result = addType(std::move(clone)); - } - else if (const UnionTypeVar* utv = get(ty)) - { - UnionTypeVar clone; - clone.options = utv->options; - result = addType(std::move(clone)); - } - else if (const IntersectionTypeVar* itv = get(ty)) - { - IntersectionTypeVar clone; - clone.parts = itv->parts; - result = addType(std::move(clone)); - } - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; + return shallowClone(ty, *arena, log, /* alwaysClone */ FFlag::LuauClonePublicInterfaceLess); } TypePackId Substitution::clone(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); + + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; + if (const TypePack* tpp = get(tp)) { TypePack clone; @@ -423,33 +491,48 @@ TypePackId Substitution::clone(TypePackId tp) { VariadicTypePack clone; clone.ty = vtp->ty; + if (FFlag::LuauSubstitutionFixMissingFields) + clone.hidden = vtp->hidden; return addTypePack(std::move(clone)); } + else if (FFlag::LuauClonePublicInterfaceLess) + { + return addTypePack(*tp); + } else return tp; } void Substitution::foundDirty(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); + + if (FFlag::LuauSubstitutionReentrant && newTypes.contains(ty)) + return; + if (isDirty(ty)) - newTypes[ty] = clean(ty); + newTypes[ty] = follow(clean(ty)); else - newTypes[ty] = clone(ty); + newTypes[ty] = follow(clone(ty)); } void Substitution::foundDirty(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); + + if (FFlag::LuauSubstitutionReentrant && newPacks.contains(tp)) + return; + if (isDirty(tp)) - newPacks[tp] = clean(tp); + newPacks[tp] = follow(clean(tp)); else - newPacks[tp] = clone(tp); + newPacks[tp] = follow(clone(tp)); } TypeId Substitution::replace(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); + if (TypeId* prevTy = newTypes.find(ty)) return *prevTy; else @@ -458,7 +541,8 @@ TypeId Substitution::replace(TypeId ty) TypePackId Substitution::replace(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); + if (TypePackId* prevTp = newPacks.find(tp)) return *prevTp; else @@ -467,15 +551,26 @@ TypePackId Substitution::replace(TypePackId tp) void Substitution::replaceChildren(TypeId ty) { - ty = follow(ty); + LUAU_ASSERT(ty == log->follow(ty)); - if (FFlag::LuauRankNTypes && ignoreChildren(ty)) + if (ignoreChildren(ty)) + return; + + if (ty->owningArena != arena) return; if (FunctionTypeVar* ftv = getMutable(ty)) { + if (FFlag::LuauSubstitutionFixMissingFields) + { + for (TypeId& generic : ftv->generics) + generic = replace(generic); + for (TypePackId& genericPack : ftv->genericPacks) + genericPack = replace(genericPack); + } + ftv->argTypes = replace(ftv->argTypes); - ftv->retType = replace(ftv->retType); + ftv->retTypes = replace(ftv->retTypes); } else if (TableTypeVar* ttv = getMutable(ty)) { @@ -487,8 +582,12 @@ void Substitution::replaceChildren(TypeId ty) ttv->indexer->indexType = replace(ttv->indexer->indexType); ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType); } + for (TypeId& itp : ttv->instantiatedTypeParams) itp = replace(itp); + + for (TypePackId& itp : ttv->instantiatedTypePackParams) + itp = replace(itp); } else if (MetatableTypeVar* mtv = getMutable(ty)) { @@ -505,13 +604,39 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& part : itv->parts) part = replace(part); } + else if (PendingExpansionTypeVar* petv = getMutable(ty)) + { + for (TypeId& a : petv->typeArguments) + a = replace(a); + + for (TypePackId& a : petv->packArguments) + a = replace(a); + } + else if (ClassTypeVar* ctv = getMutable(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) + { + for (auto& [name, prop] : ctv->props) + prop.type = replace(prop.type); + + if (ctv->parent) + ctv->parent = replace(*ctv->parent); + + if (ctv->metatable) + ctv->metatable = replace(*ctv->metatable); + } + else if (NegationTypeVar* ntv = getMutable(ty)) + { + ntv->ty = replace(ntv->ty); + } } void Substitution::replaceChildren(TypePackId tp) { - tp = follow(tp); + LUAU_ASSERT(tp == log->follow(tp)); - if (FFlag::LuauRankNTypes && ignoreChildren(tp)) + if (ignoreChildren(tp)) + return; + + if (tp->owningArena != arena) return; if (TypePack* tpp = getMutable(tp)) diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp new file mode 100644 index 00000000..68fa5393 --- /dev/null +++ b/Analysis/src/ToDot.cpp @@ -0,0 +1,400 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ToDot.h" + +#include "Luau/ToString.h" +#include "Luau/TypePack.h" +#include "Luau/TypeVar.h" +#include "Luau/StringUtils.h" + +#include +#include + +namespace Luau +{ + +namespace +{ + +struct StateDot +{ + StateDot(ToDotOptions opts) + : opts(opts) + { + } + + ToDotOptions opts; + + std::unordered_set seenTy; + std::unordered_set seenTp; + std::unordered_map tyToIndex; + std::unordered_map tpToIndex; + int nextIndex = 1; + std::string result; + + bool canDuplicatePrimitive(TypeId ty); + + void visitChildren(TypeId ty, int index); + void visitChildren(TypePackId ty, int index); + + void visitChild(TypeId ty, int parentIndex, const char* linkName = nullptr); + void visitChild(TypePackId tp, int parentIndex, const char* linkName = nullptr); + + void startNode(int index); + void finishNode(); + + void startNodeLabel(); + void finishNodeLabel(TypeId ty); + void finishNodeLabel(TypePackId tp); +}; + +bool StateDot::canDuplicatePrimitive(TypeId ty) +{ + if (get(ty)) + return false; + + return get(ty) || get(ty); +} + +void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) +{ + if (!tyToIndex.count(ty) || (opts.duplicatePrimitives && canDuplicatePrimitive(ty))) + tyToIndex[ty] = nextIndex++; + + int index = tyToIndex[ty]; + + if (parentIndex != 0) + { + if (linkName) + formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, index, linkName); + else + formatAppend(result, "n%d -> n%d;\n", parentIndex, index); + } + + if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) + { + if (get(ty)) + formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str()); + else if (get(ty)) + formatAppend(result, "n%d [label=\"any\"];\n", index); + } + else + { + visitChildren(ty, index); + } +} + +void StateDot::visitChild(TypePackId tp, int parentIndex, const char* linkName) +{ + if (!tpToIndex.count(tp)) + tpToIndex[tp] = nextIndex++; + + if (parentIndex != 0) + { + if (linkName) + formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, tpToIndex[tp], linkName); + else + formatAppend(result, "n%d -> n%d;\n", parentIndex, tpToIndex[tp]); + } + + visitChildren(tp, tpToIndex[tp]); +} + +void StateDot::startNode(int index) +{ + formatAppend(result, "n%d [", index); +} + +void StateDot::finishNode() +{ + formatAppend(result, "];\n"); +} + +void StateDot::startNodeLabel() +{ + formatAppend(result, "label=\""); +} + +void StateDot::finishNodeLabel(TypeId ty) +{ + if (opts.showPointers) + formatAppend(result, "\n0x%p", ty); + // additional common attributes can be added here as well + result += "\""; +} + +void StateDot::finishNodeLabel(TypePackId tp) +{ + if (opts.showPointers) + formatAppend(result, "\n0x%p", tp); + // additional common attributes can be added here as well + result += "\""; +} + +void StateDot::visitChildren(TypeId ty, int index) +{ + if (seenTy.count(ty)) + return; + seenTy.insert(ty); + + startNode(index); + startNodeLabel(); + + if (const BoundTypeVar* btv = get(ty)) + { + formatAppend(result, "BoundTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(btv->boundTo, index); + } + else if (const FunctionTypeVar* ftv = get(ty)) + { + formatAppend(result, "FunctionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(ftv->argTypes, index, "arg"); + visitChild(ftv->retTypes, index, "ret"); + } + else if (const TableTypeVar* ttv = get(ty)) + { + if (ttv->name) + formatAppend(result, "TableTypeVar %s", ttv->name->c_str()); + else if (ttv->syntheticName) + formatAppend(result, "TableTypeVar %s", ttv->syntheticName->c_str()); + else + formatAppend(result, "TableTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + if (ttv->boundTo) + return visitChild(*ttv->boundTo, index, "boundTo"); + + for (const auto& [name, prop] : ttv->props) + visitChild(prop.type, index, name.c_str()); + if (ttv->indexer) + { + visitChild(ttv->indexer->indexType, index, "[index]"); + visitChild(ttv->indexer->indexResultType, index, "[value]"); + } + for (TypeId itp : ttv->instantiatedTypeParams) + visitChild(itp, index, "typeParam"); + + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } + else if (const MetatableTypeVar* mtv = get(ty)) + { + formatAppend(result, "MetatableTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(mtv->table, index, "table"); + visitChild(mtv->metatable, index, "metatable"); + } + else if (const UnionTypeVar* utv = get(ty)) + { + formatAppend(result, "UnionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId opt : utv->options) + visitChild(opt, index); + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + formatAppend(result, "IntersectionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId part : itv->parts) + visitChild(part, index); + } + else if (const GenericTypeVar* gtv = get(ty)) + { + if (gtv->explicitName) + formatAppend(result, "GenericTypeVar %s", gtv->name.c_str()); + else + formatAppend(result, "GenericTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (const FreeTypeVar* ftv = get(ty)) + { + formatAppend(result, "FreeTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "AnyTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "PrimitiveTypeVar %s", toString(ty).c_str()); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "ErrorTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (const ClassTypeVar* ctv = get(ty)) + { + formatAppend(result, "ClassTypeVar %s", ctv->name.c_str()); + finishNodeLabel(ty); + finishNode(); + + for (const auto& [name, prop] : ctv->props) + visitChild(prop.type, index, name.c_str()); + + if (ctv->parent) + visitChild(*ctv->parent, index, "[parent]"); + + if (ctv->metatable) + visitChild(*ctv->metatable, index, "[metatable]"); + } + else if (const SingletonTypeVar* stv = get(ty)) + { + std::string res; + + if (const StringSingleton* ss = get(stv)) + { + // Don't put in quotes anywhere. If it's outside of the call to escape, + // then it's invalid syntax. If it's inside, then escaping is super noisy. + res = "string: " + escape(ss->value); + } + else if (const BooleanSingleton* bs = get(stv)) + { + res = "boolean: "; + res += bs->value ? "true" : "false"; + } + else + LUAU_ASSERT(!"unknown singleton type"); + + formatAppend(result, "SingletonTypeVar %s", res.c_str()); + finishNodeLabel(ty); + finishNode(); + } + else + { + LUAU_ASSERT(!"unknown type kind"); + finishNodeLabel(ty); + finishNode(); + } +} + +void StateDot::visitChildren(TypePackId tp, int index) +{ + if (seenTp.count(tp)) + return; + seenTp.insert(tp); + + startNode(index); + startNodeLabel(); + + if (const BoundTypePack* btp = get(tp)) + { + formatAppend(result, "BoundTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + visitChild(btp->boundTo, index); + } + else if (const TypePack* tpp = get(tp)) + { + formatAppend(result, "TypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + for (TypeId tv : tpp->head) + visitChild(tv, index); + if (tpp->tail) + visitChild(*tpp->tail, index, "tail"); + } + else if (const VariadicTypePack* vtp = get(tp)) + { + formatAppend(result, "VariadicTypePack %s%d", vtp->hidden ? "hidden " : "", index); + finishNodeLabel(tp); + finishNode(); + + visitChild(vtp->ty, index); + } + else if (const FreeTypePack* ftp = get(tp)) + { + formatAppend(result, "FreeTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else if (const GenericTypePack* gtp = get(tp)) + { + if (gtp->explicitName) + formatAppend(result, "GenericTypePack %s", gtp->name.c_str()); + else + formatAppend(result, "GenericTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else if (get(tp)) + { + formatAppend(result, "ErrorTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else + { + LUAU_ASSERT(!"unknown type pack kind"); + finishNodeLabel(tp); + finishNode(); + } +} + +} // namespace + +std::string toDot(TypeId ty, const ToDotOptions& opts) +{ + StateDot state{opts}; + + state.result = "digraph graphname {\n"; + state.visitChild(ty, 0); + state.result += "}"; + + return state.result; +} + +std::string toDot(TypePackId tp, const ToDotOptions& opts) +{ + StateDot state{opts}; + + state.result = "digraph graphname {\n"; + state.visitChild(tp, 0); + state.result += "}"; + + return state.result; +} + +std::string toDot(TypeId ty) +{ + return toDot(ty, {}); +} + +std::string toDot(TypePackId tp) +{ + return toDot(tp, {}); +} + +void dumpDot(TypeId ty) +{ + printf("%s\n", toDot(ty).c_str()); +} + +void dumpDot(TypePackId tp) +{ + printf("%s\n", toDot(tp).c_str()); +} + +} // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 9d2f47ba..01144a3d 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1,6 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ToString.h" +#include "Luau/Constraint.h" +#include "Luau/Location.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -9,10 +12,18 @@ #include #include -LUAU_FASTFLAG(LuauToStringFollowsBoundTo) -LUAU_FASTFLAG(LuauExtraNilRecovery) -LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) -LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauLineBreaksDetermineIndents, false) +LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) +LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) + +/* + * Prefix generic typenames with gen- + * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 + * Fair warning: Setting this will break a lot of Luau unit tests. + */ +LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) namespace Luau { @@ -20,7 +31,7 @@ namespace Luau namespace { -struct FindCyclicTypes +struct FindCyclicTypes final : TypeVarVisitor { FindCyclicTypes() = default; FindCyclicTypes(const FindCyclicTypes&) = delete; @@ -29,28 +40,30 @@ struct FindCyclicTypes bool exhaustive = false; std::unordered_set visited; std::unordered_set visitedPacks; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; - void cycle(TypeId ty) + void cycle(TypeId ty) override { cycles.insert(ty); } - void cycle(TypePackId tp) + void cycle(TypePackId tp) override { cycleTPs.insert(tp); } - template - bool operator()(TypeId ty, const T&) + bool visit(TypeId ty) override { return visited.insert(ty).second; } - bool operator()(TypeId ty, const TableTypeVar& ttv) = delete; + bool visit(TypePackId tp) override + { + return visitedPacks.insert(tp).second; + } - bool operator()(TypeId ty, const TableTypeVar& ttv, std::unordered_set& seen) + bool visit(TypeId ty, const TableTypeVar& ttv) override { if (!visited.insert(ty).second) return false; @@ -58,31 +71,29 @@ struct FindCyclicTypes if (ttv.name || ttv.syntheticName) { for (TypeId itp : ttv.instantiatedTypeParams) - visitTypeVar(itp, *this, seen); + traverse(itp); + + for (TypePackId itp : ttv.instantiatedTypePackParams) + traverse(itp); + return exhaustive; } return true; } - bool operator()(TypeId, const ClassTypeVar&) + bool visit(TypeId ty, const ClassTypeVar&) override { return false; } - - template - bool operator()(TypePackId tp, const T&) - { - return visitedPacks.insert(tp).second; - } }; template -void findCyclicTypes(std::unordered_set& cycles, std::unordered_set& cycleTPs, TID ty, bool exhaustive) +void findCyclicTypes(std::set& cycles, std::set& cycleTPs, TID ty, bool exhaustive) { FindCyclicTypes fct; fct.exhaustive = exhaustive; - visitTypeVar(ty, fct); + fct.traverse(ty); cycles = std::move(fct.cycles); cycleTPs = std::move(fct.cycleTPs); @@ -109,27 +120,25 @@ static std::pair> canUseTypeNameInScope(ScopePtr struct StringifierState { - const ToStringOptions& opts; + ToStringOptions& opts; ToStringResult& result; std::unordered_map cycleNames; std::unordered_map cycleTpNames; std::unordered_set seen; std::unordered_set usedNames; + size_t indentation = 0; bool exhaustive; - StringifierState(const ToStringOptions& opts, ToStringResult& result, const std::optional& nameMap) + StringifierState(ToStringOptions& opts, ToStringResult& result) : opts(opts) , result(result) , exhaustive(opts.exhaustive) { - if (nameMap) - result.nameMap = *nameMap; - - for (const auto& [_, v] : result.nameMap.typeVars) + for (const auto& [_, v] : opts.nameMap.typeVars) usedNames.insert(v); - for (const auto& [_, v] : result.nameMap.typePacks) + for (const auto& [_, v] : opts.nameMap.typePacks) usedNames.insert(v); } @@ -151,19 +160,10 @@ struct StringifierState seen.erase(iter); } - static std::string generateName(size_t i) - { - std::string n; - n = char('a' + i % 26); - if (i >= 26) - n += std::to_string(i / 26); - return n; - } - std::string getName(TypeId ty) { - const size_t s = result.nameMap.typeVars.size(); - std::string& n = result.nameMap.typeVars[ty]; + const size_t s = opts.nameMap.typeVars.size(); + std::string& n = opts.nameMap.typeVars[ty]; if (!n.empty()) return n; @@ -181,18 +181,21 @@ struct StringifierState return generateName(s); } + int previousNameIndex = 0; + std::string getName(TypePackId ty) { - const size_t s = result.nameMap.typePacks.size(); - std::string& n = result.nameMap.typePacks[ty]; + const size_t s = opts.nameMap.typePacks.size(); + std::string& n = opts.nameMap.typePacks[ty]; if (!n.empty()) return n; for (int count = 0; count < 256; ++count) { - std::string candidate = generateName(usedNames.size() + count); + std::string candidate = generateName(previousNameIndex + count); if (!usedNames.count(candidate)) { + previousNameIndex += count; usedNames.insert(candidate); n = candidate; return candidate; @@ -209,6 +212,83 @@ struct StringifierState result.name += s; } + + void emitLevel(Scope* scope) + { + size_t count = 0; + for (Scope* s = scope; s; s = s->parent.get()) + ++count; + + emit(count); + emit("-"); + char buffer[16]; + uint32_t s = uint32_t(intptr_t(scope) & 0xFFFFFF); + snprintf(buffer, sizeof(buffer), "0x%x", s); + emit(buffer); + } + + void emit(TypeLevel level) + { + emit(std::to_string(level.level)); + emit("-"); + emit(std::to_string(level.subLevel)); + } + + void emit(const char* s) + { + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + return; + + result.name += s; + } + + void emit(int i) + { + emit(std::to_string(i).c_str()); + } + + void emit(size_t i) + { + emit(std::to_string(i).c_str()); + } + + void indent() + { + indentation += 4; + } + + void dedent() + { + indentation -= 4; + } + + void newline() + { + if (!opts.useLineBreaks) + return emit(" "); + + emit("\n"); + emitIndentation(); + } + +private: + void emitIndentation() + { + if (!FFlag::LuauLineBreaksDetermineIndents) + { + if (!opts.DEPRECATED_indent) + return; + + emit(std::string(indentation, ' ')); + } + else + { + if (!opts.useLineBreaks) + return; + + emit(std::string(indentation, ' ')); + } + } }; struct TypeVarStringifier @@ -228,7 +308,7 @@ struct TypeVarStringifier if (tv->ty.valueless_by_exception()) { state.result.error = true; - state.emit("< VALUELESS BY EXCEPTION >"); + state.emit("* VALUELESS BY EXCEPTION *"); return; } @@ -239,17 +319,9 @@ struct TypeVarStringifier return; } - if (!FFlag::LuauAddMissingFollow) - { - if (get(tv)) - { - state.emit(state.getName(tv)); - return; - } - } - Luau::visit( - [this, tv](auto&& t) { + [this, tv](auto&& t) + { return (*this)(tv, t); }, tv->ty); @@ -258,34 +330,67 @@ struct TypeVarStringifier void stringify(TypePackId tp); void stringify(TypePackId tpid, const std::vector>& names); - void stringify(const std::vector& types) + void stringify(const std::vector& types, const std::vector& typePacks) { - if (types.size() == 0) + if (types.size() == 0 && typePacks.size() == 0) return; - if (types.size()) + if (types.size() || typePacks.size()) state.emit("<"); - for (size_t i = 0; i < types.size(); ++i) - { - if (i > 0) - state.emit(", "); + bool first = true; - stringify(types[i]); + for (TypeId ty : types) + { + if (!first) + state.emit(", "); + first = false; + + stringify(ty); } - if (types.size()) + bool singleTp = typePacks.size() == 1; + + for (TypePackId tp : typePacks) + { + if (isEmpty(tp) && singleTp) + continue; + + if (!first) + state.emit(", "); + else + first = false; + + bool wrap = !singleTp && get(follow(tp)); + + if (wrap) + state.emit("("); + + stringify(tp); + + if (wrap) + state.emit(")"); + } + + if (types.size() || typePacks.size()) state.emit(">"); } void operator()(TypeId ty, const Unifiable::Free& ftv) { state.result.invalid = true; + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("free-"); + state.emit(state.getName(ty)); - if (FFlag::LuauAddMissingFollow) - state.emit(state.getName(ty)); - else - state.emit(""); + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(ftv.scope); + else + state.emit(ftv.level); + } } void operator()(TypeId, const BoundTypeVar& btv) @@ -293,15 +398,39 @@ struct TypeVarStringifier stringify(btv.boundTo); } - void operator()(TypeId ty, const Unifiable::Generic& gtv) + void operator()(TypeId ty, const GenericTypeVar& gtv) { if (gtv.explicitName) { - state.result.nameMap.typeVars[ty] = gtv.name; + state.usedNames.insert(gtv.name); + state.opts.nameMap.typeVars[ty] = gtv.name; state.emit(gtv.name); } else state.emit(state.getName(ty)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(gtv.scope); + else + state.emit(gtv.level); + } + } + + void operator()(TypeId, const BlockedTypeVar& btv) + { + state.emit("*blocked-"); + state.emit(btv.index); + state.emit("*"); + } + + void operator()(TypeId ty, const PendingExpansionTypeVar& petv) + { + state.emit("*pending-expansion-"); + state.emit(petv.index); + state.emit("*"); } void operator()(TypeId, const PrimitiveTypeVar& ptv) @@ -323,9 +452,29 @@ struct TypeVarStringifier case PrimitiveTypeVar::Thread: state.emit("thread"); return; + case PrimitiveTypeVar::Function: + state.emit("function"); + return; default: LUAU_ASSERT(!"Unknown primitive type"); - throw std::runtime_error("Unknown primitive type " + std::to_string(ptv.type)); + throw InternalCompilerError("Unknown primitive type " + std::to_string(ptv.type)); + } + } + + void operator()(TypeId, const SingletonTypeVar& stv) + { + if (const BooleanSingleton* bs = Luau::get(&stv)) + state.emit(bs->value ? "true" : "false"); + else if (const StringSingleton* ss = Luau::get(&stv)) + { + state.emit("\""); + state.emit(escape(ss->value)); + state.emit("\""); + } + else + { + LUAU_ASSERT(!"Unknown singleton type"); + throw InternalCompilerError("Unknown singleton type"); } } @@ -334,10 +483,11 @@ struct TypeVarStringifier if (state.hasSeen(&ftv)) { state.result.cycle = true; - state.emit(""); + state.emit("*CYCLE*"); return; } + // We should not be respecting opts.hideNamedFunctionTypeParameters here. if (ftv.generics.size() > 0 || ftv.genericPacks.size() > 0) { state.emit("<"); @@ -369,16 +519,20 @@ struct TypeVarStringifier state.emit(") -> "); bool plural = true; - if (auto retPack = get(follow(ftv.retType))) + + auto retBegin = begin(ftv.retTypes); + auto retEnd = end(ftv.retTypes); + if (retBegin != retEnd) { - if (retPack->head.size() == 1 && !retPack->tail) + ++retBegin; + if (retBegin == retEnd && !retBegin.tail()) plural = false; } if (plural) state.emit("("); - stringify(ftv.retType); + stringify(ftv.retTypes); if (plural) state.emit(")"); @@ -388,7 +542,7 @@ struct TypeVarStringifier void operator()(TypeId, const TableTypeVar& ttv) { - if (FFlag::LuauToStringFollowsBoundTo && ttv.boundTo) + if (ttv.boundTo) return stringify(*ttv.boundTo); if (!state.exhaustive) @@ -411,14 +565,14 @@ struct TypeVarStringifier } state.emit(*ttv.name); - stringify(ttv.instantiatedTypeParams); + stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams); return; } if (ttv.syntheticName) { state.result.invalid = true; state.emit(*ttv.syntheticName); - stringify(ttv.instantiatedTypeParams); + stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams); return; } } @@ -426,7 +580,7 @@ struct TypeVarStringifier if (state.hasSeen(&ttv)) { state.result.cycle = true; - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -436,22 +590,22 @@ struct TypeVarStringifier { case TableState::Sealed: state.result.invalid = true; - openbrace = "{| "; - closedbrace = " |}"; + openbrace = "{|"; + closedbrace = "|}"; break; case TableState::Unsealed: - openbrace = "{ "; - closedbrace = " }"; + openbrace = "{"; + closedbrace = "}"; break; case TableState::Free: state.result.invalid = true; - openbrace = "{- "; - closedbrace = " -}"; + openbrace = "{-"; + closedbrace = "-}"; break; case TableState::Generic: state.result.invalid = true; - openbrace = "{+ "; - closedbrace = " +}"; + openbrace = "{+"; + closedbrace = "+}"; break; } @@ -461,14 +615,20 @@ struct TypeVarStringifier state.emit("{"); stringify(ttv.indexer->indexResultType); state.emit("}"); + + if (FFlag::LuauUnseeArrayTtv) + state.unsee(&ttv); + return; } state.emit(openbrace); + state.indent(); bool comma = false; if (ttv.indexer) { + state.newline(); state.emit("["); stringify(ttv.indexer->indexType); state.emit("]: "); @@ -481,7 +641,12 @@ struct TypeVarStringifier for (const auto& [name, prop] : ttv.props) { if (comma) - state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + { + state.emit(","); + state.newline(); + } + else + state.newline(); size_t length = state.result.name.length() - oldLength; @@ -493,13 +658,25 @@ struct TypeVarStringifier break; } - state.emit(name); + if (isIdentifier(name)) + state.emit(name); + else + { + state.emit("[\""); + state.emit(escape(name)); + state.emit("\"]"); + } state.emit(": "); stringify(prop.type); comma = true; ++index; } + state.dedent(); + if (comma) + state.newline(); + else + state.emit(" "); state.emit(closedbrace); state.unsee(&ttv); @@ -508,9 +685,16 @@ struct TypeVarStringifier void operator()(TypeId, const MetatableTypeVar& mtv) { state.result.invalid = true; + if (!state.exhaustive && mtv.syntheticName) + { + state.emit(*mtv.syntheticName); + return; + } + state.emit("{ @metatable "); stringify(mtv.metatable); - state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + state.emit(","); + state.newline(); stringify(mtv.table); state.emit(" }"); } @@ -530,7 +714,7 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - state.emit(""); + state.emit("*CYCLE*"); return; } @@ -539,8 +723,7 @@ struct TypeVarStringifier std::vector results = {}; for (auto el : &uv) { - if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) - el = follow(el); + el = follow(el); if (isNil(el)) { @@ -550,9 +733,7 @@ struct TypeVarStringifier std::string saved = std::move(state.result.name); - bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions - ? !state.cycleNames.count(el) && (get(el) || get(el)) - : get(el) || get(el); + bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); if (needParens) state.emit("("); @@ -577,7 +758,10 @@ struct TypeVarStringifier for (std::string& ss : results) { if (!first) - state.emit(" | "); + { + state.newline(); + state.emit("| "); + } state.emit(ss); first = false; } @@ -597,21 +781,18 @@ struct TypeVarStringifier if (state.hasSeen(&uv)) { state.result.cycle = true; - state.emit(""); + state.emit("*CYCLE*"); return; } std::vector results = {}; for (auto el : uv.parts) { - if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) - el = follow(el); + el = follow(el); std::string saved = std::move(state.result.name); - bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions - ? !state.cycleNames.count(el) && (get(el) || get(el)) - : get(el) || get(el); + bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); if (needParens) state.emit("("); @@ -633,7 +814,10 @@ struct TypeVarStringifier for (std::string& ss : results) { if (!first) - state.emit(" & "); + { + state.newline(); + state.emit("& "); + } state.emit(ss); first = false; } @@ -642,7 +826,7 @@ struct TypeVarStringifier void operator()(TypeId, const ErrorTypeVar& tv) { state.result.error = true; - state.emit("*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); } void operator()(TypeId, const LazyTypeVar& ltv) @@ -651,7 +835,33 @@ struct TypeVarStringifier state.emit("lazy?"); } -}; // namespace + void operator()(TypeId, const UnknownTypeVar& ttv) + { + state.emit("unknown"); + } + + void operator()(TypeId, const NeverTypeVar& ttv) + { + state.emit("never"); + } + + void operator()(TypeId, const NegationTypeVar& ntv) + { + state.emit("~"); + + // The precedence of `~` should be less than `|` and `&`. + TypeId followed = follow(ntv.ty); + bool parens = get(followed) || get(followed); + + if (parens) + state.emit("("); + + stringify(ntv.ty); + + if (parens) + state.emit(")"); + } +}; struct TypePackStringifier { @@ -687,20 +897,10 @@ struct TypePackStringifier if (tp->ty.valueless_by_exception()) { state.result.error = true; - state.emit("< VALUELESS TP BY EXCEPTION >"); + state.emit("* VALUELESS TP BY EXCEPTION *"); return; } - if (!FFlag::LuauAddMissingFollow) - { - if (get(tp)) - { - state.emit(state.getName(tp)); - state.emit("..."); - return; - } - } - auto it = state.cycleTpNames.find(tp); if (it != state.cycleTpNames.end()) { @@ -709,7 +909,8 @@ struct TypePackStringifier } Luau::visit( - [this, tp](auto&& t) { + [this, tp](auto&& t) + { return (*this)(tp, t); }, tp->ty); @@ -720,7 +921,7 @@ struct TypePackStringifier if (state.hasSeen(&tp)) { state.result.cycle = true; - state.emit(""); + state.emit("*CYCLETP*"); return; } @@ -733,13 +934,13 @@ struct TypePackStringifier else state.emit(", "); - LUAU_ASSERT(elemNames.empty() || elemIndex < elemNames.size()); - - if (!elemNames.empty() && elemNames[elemIndex]) + // Do not respect opts.namedFunctionOverrideArgNames here + if (elemIndex < elemNames.size() && elemNames[elemIndex]) { state.emit(elemNames[elemIndex]->name); state.emit(": "); } + elemIndex++; stringify(typeId); @@ -747,13 +948,16 @@ struct TypePackStringifier if (tp.tail && !isEmpty(*tp.tail)) { - const auto& tail = *tp.tail; - if (first) - first = false; - else - state.emit(", "); + TypePackId tail = follow(*tp.tail); + if (auto vtp = get(tail); !vtp || (!FFlag::DebugLuauVerboseTypeNames && !vtp->hidden)) + { + if (first) + first = false; + else + state.emit(", "); - stringify(tail); + stringify(tail); + } } state.unsee(&tp); @@ -762,12 +966,16 @@ struct TypePackStringifier void operator()(TypePackId, const Unifiable::Error& error) { state.result.error = true; - state.emit("*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*"); } void operator()(TypePackId, const VariadicTypePack& pack) { state.emit("..."); + if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) + { + state.emit("*hidden*"); + } stringify(pack.ty); } @@ -775,35 +983,56 @@ struct TypePackStringifier { if (pack.explicitName) { - state.result.nameMap.typePacks[tp] = pack.name; + state.usedNames.insert(pack.name); + state.opts.nameMap.typePacks[tp] = pack.name; state.emit(pack.name); } else { state.emit(state.getName(tp)); } + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(pack.scope); + else + state.emit(pack.level); + } state.emit("..."); } void operator()(TypePackId tp, const FreeTypePack& pack) { state.result.invalid = true; + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("free-"); + state.emit(state.getName(tp)); - if (FFlag::LuauAddMissingFollow) + if (FFlag::DebugLuauVerboseTypeNames) { - state.emit(state.getName(tp)); - state.emit("..."); - } - else - { - state.emit(""); + state.emit("-"); + if (FFlag::DebugLuauDeferredConstraintResolution) + state.emitLevel(pack.scope); + else + state.emit(pack.level); } + + state.emit("..."); } void operator()(TypePackId, const BoundTypePack& btv) { stringify(btv.boundTo); } + + void operator()(TypePackId, const BlockedTypePack& btp) + { + state.emit("*blocked-tp-"); + state.emit(btp.index); + state.emit("*"); + } }; void TypeVarStringifier::stringify(TypePackId tp) @@ -818,36 +1047,27 @@ void TypeVarStringifier::stringify(TypePackId tpid, const std::vector& cycles, const std::unordered_set& cycleTPs, +static void assignCycleNames(const std::set& cycles, const std::set& cycleTPs, std::unordered_map& cycleNames, std::unordered_map& cycleTpNames, bool exhaustive) { int nextIndex = 1; - std::vector sortedCycles{cycles.begin(), cycles.end()}; - std::sort(sortedCycles.begin(), sortedCycles.end(), std::less{}); - - for (TypeId cycleTy : sortedCycles) + for (TypeId cycleTy : cycles) { std::string name; // TODO: use the stringified type list if there are no cycles - if (FFlag::LuauInstantiatedTypeParamRecursion) + if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) { - if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) - { - // If we have a cycle type in type parameters, assign a cycle name for this named table - if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { + // If we have a cycle type in type parameters, assign a cycle name for this named table + if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), + [&](auto&& el) + { return cycles.count(follow(el)); }) != ttv->instantiatedTypeParams.end()) - cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; + cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; - continue; - } - } - else - { - if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) - continue; + continue; } name = "t" + std::to_string(nextIndex); @@ -856,10 +1076,7 @@ static void assignCycleNames(const std::unordered_set& cycles, const std cycleNames[cycleTy] = std::move(name); } - std::vector sortedCycleTps{cycleTPs.begin(), cycleTPs.end()}; - std::sort(sortedCycleTps.begin(), sortedCycleTps.end(), std::less()); - - for (TypePackId tp : sortedCycleTps) + for (TypePackId tp : cycleTPs) { std::string name = "tp" + std::to_string(nextIndex); ++nextIndex; @@ -867,7 +1084,7 @@ static void assignCycleNames(const std::unordered_set& cycles, const std } } -ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) +ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) { /* * 1. Walk the TypeVar and track seen TypeIds. When you reencounter a TypeId, add it to a set of seen cycles. @@ -879,49 +1096,10 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) ToStringResult result; - if (!FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive) - { - if (auto ttv = get(ty); ttv && (ttv->name || ttv->syntheticName)) - { - if (ttv->syntheticName) - result.invalid = true; + StringifierState state{opts, result}; - // If scope if provided, add module name and check visibility - if (ttv->name && opts.scope) - { - auto [success, moduleName] = canUseTypeNameInScope(opts.scope, *ttv->name); - - if (!success) - result.invalid = true; - - if (moduleName) - result.name = format("%s.", moduleName->c_str()); - } - - result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - - if (ttv->instantiatedTypeParams.empty()) - return result; - - std::vector params; - for (TypeId tp : ttv->instantiatedTypeParams) - params.push_back(toString(tp)); - - result.name += "<" + join(params, ", ") + ">"; - return result; - } - else if (auto mtv = get(ty); mtv && mtv->syntheticName) - { - result.invalid = true; - result.name = *mtv->syntheticName; - return result; - } - } - - StringifierState state{opts, result, opts.nameMap}; - - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; findCyclicTypes(cycles, cycleTPs, ty, opts.exhaustive); @@ -929,7 +1107,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) TypeVarStringifier tvs{state}; - if (FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive) + if (!opts.exhaustive) { if (auto ttv = get(ty); ttv && (ttv->name || ttv->syntheticName)) { @@ -950,31 +1128,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (ttv->instantiatedTypeParams.empty()) - return result; - - result.name += "<"; - - bool first = true; - for (TypeId ty : ttv->instantiatedTypeParams) - { - if (!first) - result.name += ", "; - else - first = false; - - tvs.stringify(ty); - } - - if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) - { - result.truncated = true; - result.name += "... "; - } - else - { - result.name += ">"; - } + tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams); return result; } @@ -997,7 +1151,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) else tvs.stringify(ty); - if (!state.cycleNames.empty()) + if (!state.cycleNames.empty() || !state.cycleTpNames.empty()) { result.cycle = true; state.emit(" where "); @@ -1006,9 +1160,11 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + }); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -1019,7 +1175,8 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto&& t) { + [&tvs, cycleTy = cycleTy](auto&& t) + { return tvs(cycleTy, t); }, cycleTy->ty); @@ -1027,16 +1184,40 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) semi = true; } + std::vector> sortedCycleTpNames(state.cycleTpNames.begin(), state.cycleTpNames.end()); + std::sort(sortedCycleTpNames.begin(), sortedCycleTpNames.end(), [](const auto& a, const auto& b) { + return a.second < b.second; + }); + + TypePackStringifier tps{state}; + + for (const auto& [cycleTp, name] : sortedCycleTpNames) + { + if (semi) + state.emit(" ; "); + + state.emit(name); + state.emit(" = "); + Luau::visit( + [&tps, cycleTy = cycleTp](auto&& t) { + return tps(cycleTy, t); + }, + cycleTp->ty); + + semi = true; + } + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) { result.truncated = true; - result.name += "... "; + + result.name += "... *TRUNCATED*"; } return result; } -ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) +ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) { /* * 1. Walk the TypeVar and track seen TypeIds. When you reencounter a TypeId, add it to a set of seen cycles. @@ -1045,10 +1226,10 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) * 4. Print out the root of the type using the same algorithm as step 3. */ ToStringResult result; - StringifierState state{opts, result, opts.nameMap}; + StringifierState state{opts, result}; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; findCyclicTypes(cycles, cycleTPs, tp, opts.exhaustive); @@ -1076,9 +1257,11 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) state.exhaustive = true; std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; - std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { - return a.second < b.second; - }); + std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), + [](const auto& a, const auto& b) + { + return a.second < b.second; + }); bool semi = false; for (const auto& [cycleTy, name] : sortedCycleNames) @@ -1089,7 +1272,8 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto&& t) { + [&tvs, cycleTy = cycleTy](auto t) + { return tvs(cycleTy, t); }, cycleTy->ty); @@ -1098,45 +1282,317 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) } if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) - result.name += "... "; + { + result.name += "... *TRUNCATED*"; + } return result; } -std::string toString(TypeId ty, const ToStringOptions& opts) +std::string toString(TypeId ty, ToStringOptions& opts) { return toStringDetailed(ty, opts).name; } -std::string toString(TypePackId tp, const ToStringOptions& opts) +std::string toString(TypePackId tp, ToStringOptions& opts) { return toStringDetailed(tp, opts).name; } -std::string toString(const TypeVar& tv, const ToStringOptions& opts) +std::string toString(const TypeVar& tv, ToStringOptions& opts) { - return toString(const_cast(&tv), std::move(opts)); + return toString(const_cast(&tv), opts); } -std::string toString(const TypePackVar& tp, const ToStringOptions& opts) +std::string toString(const TypePackVar& tp, ToStringOptions& opts) { - return toString(const_cast(&tp), std::move(opts)); + return toString(const_cast(&tp), opts); } -void dump(TypeId ty) +std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, ToStringOptions& opts) +{ + ToStringResult result; + StringifierState state{opts, result}; + TypeVarStringifier tvs{state}; + + state.emit(funcName); + + if (!opts.hideNamedFunctionTypeParameters) + tvs.stringify(ftv.generics, ftv.genericPacks); + + state.emit("("); + + auto argPackIter = begin(ftv.argTypes); + + bool first = true; + size_t idx = 0; + while (argPackIter != end(ftv.argTypes)) + { + // ftv takes a self parameter as the first argument, skip it if specified in option + if (idx == 0 && ftv.hasSelf && opts.hideFunctionSelfArgument) + { + ++argPackIter; + ++idx; + continue; + } + + if (!first) + state.emit(", "); + first = false; + + // We don't respect opts.functionTypeArguments + if (idx < opts.namedFunctionOverrideArgNames.size()) + { + state.emit(opts.namedFunctionOverrideArgNames[idx] + ": "); + } + else if (idx < ftv.argNames.size() && ftv.argNames[idx]) + { + state.emit(ftv.argNames[idx]->name + ": "); + } + else + { + state.emit("_: "); + } + tvs.stringify(*argPackIter); + + ++argPackIter; + ++idx; + } + + if (argPackIter.tail()) + { + if (auto vtp = get(*argPackIter.tail()); !vtp || !vtp->hidden) + { + if (!first) + state.emit(", "); + + state.emit("...: "); + + if (vtp) + tvs.stringify(vtp->ty); + else + tvs.stringify(*argPackIter.tail()); + } + } + + state.emit("): "); + + size_t retSize = size(ftv.retTypes); + bool hasTail = !finite(ftv.retTypes); + bool wrap = get(follow(ftv.retTypes)) && (hasTail ? retSize != 0 : retSize != 1); + + if (wrap) + state.emit("("); + + tvs.stringify(ftv.retTypes); + + if (wrap) + state.emit(")"); + + return result.name; +} + +static ToStringOptions& dumpOptions() +{ + static ToStringOptions opts = ([]() { + ToStringOptions o; + o.exhaustive = true; + o.functionTypeArguments = true; + o.maxTableLength = 0; + o.maxTypeLength = 0; + return o; + })(); + + return opts; +} + +std::string dump(TypeId ty) +{ + std::string s = toString(ty, dumpOptions()); + printf("%s\n", s.c_str()); + return s; +} + +std::string dump(TypePackId ty) +{ + std::string s = toString(ty, dumpOptions()); + printf("%s\n", s.c_str()); + return s; +} + +std::string dump(const ScopePtr& scope, const char* name) +{ + auto binding = scope->linearSearchForBinding(name); + if (!binding) + { + printf("No binding %s\n", name); + return {}; + } + + TypeId ty = binding->typeId; + std::string s = toString(ty, dumpOptions()); + printf("%s\n", s.c_str()); + return s; +} + +std::string generateName(size_t i) +{ + std::string n; + n = char('a' + i % 26); + if (i >= 26) + n += std::to_string(i / 26); + return n; +} + +std::string toString(const Constraint& constraint, ToStringOptions& opts) +{ + auto go = [&opts](auto&& c) -> std::string { + using T = std::decay_t; + + auto tos = [&opts](auto&& a) { + return toString(a, opts); + }; + + if constexpr (std::is_same_v) + { + std::string subStr = tos(c.subType); + std::string superStr = tos(c.superType); + return subStr + " <: " + superStr; + } + else if constexpr (std::is_same_v) + { + std::string subStr = tos(c.subPack); + std::string superStr = tos(c.superPack); + return subStr + " <: " + superStr; + } + else if constexpr (std::is_same_v) + { + std::string subStr = tos(c.generalizedType); + std::string superStr = tos(c.sourceType); + return subStr + " ~ gen " + superStr; + } + else if constexpr (std::is_same_v) + { + std::string subStr = tos(c.subType); + std::string superStr = tos(c.superType); + return subStr + " ~ inst " + superStr; + } + else if constexpr (std::is_same_v) + { + std::string resultStr = tos(c.resultType); + std::string operandStr = tos(c.operandType); + + return resultStr + " ~ Unary<" + toString(c.op) + ", " + operandStr + ">"; + } + else if constexpr (std::is_same_v) + { + std::string resultStr = tos(c.resultType); + std::string leftStr = tos(c.leftType); + std::string rightStr = tos(c.rightType); + + return resultStr + " ~ Binary<" + toString(c.op) + ", " + leftStr + ", " + rightStr + ">"; + } + else if constexpr (std::is_same_v) + { + std::string iteratorStr = tos(c.iterator); + std::string variableStr = tos(c.variables); + + return variableStr + " ~ Iterate<" + iteratorStr + ">"; + } + else if constexpr (std::is_same_v) + { + std::string namedStr = tos(c.namedType); + return "@name(" + namedStr + ") = " + c.name; + } + else if constexpr (std::is_same_v) + { + std::string targetStr = tos(c.target); + return "expand " + targetStr; + } + else if constexpr (std::is_same_v) + { + return "call " + tos(c.fn) + " with { result = " + tos(c.result) + " }"; + } + else if constexpr (std::is_same_v) + { + return tos(c.resultType) + " ~ prim " + tos(c.expectedType) + ", " + tos(c.singletonType) + ", " + tos(c.multitonType); + } + else if constexpr (std::is_same_v) + { + return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\""; + } + else if constexpr (std::is_same_v) + { + const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]"; + return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType); + } + else if constexpr (std::is_same_v) + { + std::string result = tos(c.resultType); + std::string discriminant = tos(c.discriminantType); + + if (c.negated) + return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant; + else + return result + " ~ if isSingleton D then D else unknown where D = " + discriminant; + } + else + static_assert(always_false_v, "Non-exhaustive constraint switch"); + }; + + return visit(go, constraint.c); +} + +std::string dump(const Constraint& c) { ToStringOptions opts; opts.exhaustive = true; opts.functionTypeArguments = true; - printf("%s\n", toString(ty, opts).c_str()); + std::string s = toString(c, opts); + printf("%s\n", s.c_str()); + return s; } -void dump(TypePackId ty) +std::optional getFunctionNameAsString(const AstExpr& expr) { - ToStringOptions opts; - opts.exhaustive = true; - opts.functionTypeArguments = true; - printf("%s\n", toString(ty, opts).c_str()); + const AstExpr* curr = &expr; + std::string s; + + for (;;) + { + if (auto local = curr->as()) + return local->local->name.value + s; + + if (auto global = curr->as()) + return global->name.value + s; + + if (auto indexname = curr->as()) + { + curr = indexname->expr; + + s = "." + std::string(indexname->index.value) + s; + } + else if (auto group = curr->as()) + { + curr = group->expr; + } + else + { + return std::nullopt; + } + } + + return s; +} + +std::string toString(const Position& position) +{ + return "{ line = " + std::to_string(position.line) + ", col = " + std::to_string(position.column) + " }"; +} + +std::string toString(const Location& location) +{ + return "Location { " + toString(location.begin) + ", " + toString(location.end) + " }"; } } // namespace Luau diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 2d356384..fbf74115 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TopoSortStatements.h" +#include "Luau/Error.h" /* Decide the order in which we typecheck Lua statements in a block. * * Algorithm: @@ -26,9 +27,10 @@ * 3. Cyclic dependencies can be resolved by picking an arbitrary statement to check first. */ -#include "Luau/Parser.h" +#include "Luau/Ast.h" #include "Luau/DenseHash.h" #include "Luau/Common.h" +#include "Luau/StringUtils.h" #include #include @@ -148,7 +150,7 @@ Identifier mkName(const AstStatFunction& function) auto name = mkName(*function.name); LUAU_ASSERT(bool(name)); if (!name) - throw std::runtime_error("Internal error: Function declaration has a bad name"); + throw InternalCompilerError("Internal error: Function declaration has a bad name"); return *name; } @@ -214,6 +216,7 @@ struct ArcCollector : public AstVisitor } } + // Adds a dependency from the current node to the named node. void add(const Identifier& name) { Node** it = map.find(name); @@ -253,7 +256,7 @@ struct ArcCollector : public AstVisitor { auto name = mkName(*node->name); if (!name) - throw std::runtime_error("Internal error: AstStatFunction has a bad name"); + throw InternalCompilerError("Internal error: AstStatFunction has a bad name"); add(*name); return true; @@ -298,8 +301,15 @@ struct ArcCollector : public AstVisitor struct ContainsFunctionCall : public AstVisitor { + bool alsoReturn = false; bool result = false; + ContainsFunctionCall() = default; + explicit ContainsFunctionCall(bool alsoReturn) + : alsoReturn(alsoReturn) + { + } + bool visit(AstExpr*) override { return !result; // short circuit if result is true @@ -318,6 +328,17 @@ struct ContainsFunctionCall : public AstVisitor return false; } + bool visit(AstStatReturn* stat) override + { + if (alsoReturn) + { + result = true; + return false; + } + else + return AstVisitor::visit(stat); + } + bool visit(AstExprFunction*) override { return false; @@ -479,6 +500,13 @@ bool containsFunctionCall(const AstStat& stat) return cfc.result; } +bool containsFunctionCallOrReturn(const AstStat& stat) +{ + detail::ContainsFunctionCall cfc{true}; + const_cast(stat).visit(&cfc); + return cfc.result; +} + bool isFunction(const AstStat& stat) { return stat.is() || stat.is(); diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 462c70ff..cdfe6549 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,65 +10,8 @@ #include #include -LUAU_FASTFLAG(LuauGenericFunctions) - namespace { - -std::string escape(std::string_view s) -{ - std::string r; - r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting - - for (uint8_t c : s) - { - if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') - r += c; - else - { - r += '\\'; - - switch (c) - { - case '\a': - r += 'a'; - break; - case '\b': - r += 'b'; - break; - case '\f': - r += 'f'; - break; - case '\n': - r += 'n'; - break; - case '\r': - r += 'r'; - break; - case '\t': - r += 't'; - break; - case '\v': - r += 'v'; - break; - case '\'': - r += '\''; - break; - case '\"': - r += '\"'; - break; - case '\\': - r += '\\'; - break; - default: - Luau::formatAppend(r, "%03u", c); - } - } - } - - return r; -} - bool isIdentifierStartChar(char c) { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_'; @@ -96,9 +39,6 @@ struct Writer { virtual ~Writer() {} - virtual void begin() {} - virtual void end() {} - virtual void advance(const Position&) = 0; virtual void newline() = 0; virtual void space() = 0; @@ -130,6 +70,7 @@ struct StringWriter : Writer if (pos.column < newPos.column) write(std::string(newPos.column - pos.column, ' ')); } + void maybeSpace(const Position& newPos, int reserve) override { if (pos.column + reserve < newPos.column) @@ -264,26 +205,25 @@ struct Printer } } - void visualizeWithSelf(AstExpr& expr, bool self) + void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg) { - if (!self) - return visualize(expr); - - AstExprIndexName* func = expr.as(); - LUAU_ASSERT(func); - - visualize(*func->expr); - writer.symbol(":"); - advance(func->indexLocation.begin); - writer.identifier(func->index.value); - } - - void visualizeTypePackAnnotation(const AstTypePack& annotation) - { - if (const AstTypePackVariadic* variadic = annotation.as()) + advance(annotation.location.begin); + if (const AstTypePackVariadic* variadicTp = annotation.as()) { + if (!forVarArg) + writer.symbol("..."); + + visualizeTypeAnnotation(*variadicTp->variadicType); + } + else if (const AstTypePackGeneric* genericTp = annotation.as()) + { + writer.symbol(genericTp->genericName.value); writer.symbol("..."); - visualizeTypeAnnotation(*variadic->variadicType); + } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + LUAU_ASSERT(!forVarArg); + visualizeTypeList(explicitTp->typeList, true); } else { @@ -307,7 +247,7 @@ struct Printer // Only variadic tail if (list.types.size == 0) { - visualizeTypePackAnnotation(*list.tailType); + visualizeTypePackAnnotation(*list.tailType, false); } else { @@ -335,7 +275,7 @@ struct Printer if (list.tailType) { writer.symbol(","); - visualizeTypePackAnnotation(*list.tailType); + visualizeTypePackAnnotation(*list.tailType, false); } writer.symbol(")"); @@ -412,7 +352,7 @@ struct Printer } else if (const auto& a = expr.as()) { - visualizeWithSelf(*a->func, a->self); + visualize(*a->func); writer.symbol("("); bool first = true; @@ -431,7 +371,7 @@ struct Printer else if (const auto& a = expr.as()) { visualize(*a->expr); - writer.symbol("."); + writer.symbol(std::string(1, a->op)); writer.write(a->index.value); } else if (const auto& a = expr.as()) @@ -532,6 +472,7 @@ struct Printer case AstExprBinary::CompareLt: case AstExprBinary::CompareGt: writer.maybeSpace(a->right->location.begin, 2); + writer.symbol(toString(a->op)); break; case AstExprBinary::Concat: case AstExprBinary::CompareNe: @@ -540,19 +481,57 @@ struct Printer case AstExprBinary::CompareGe: case AstExprBinary::Or: writer.maybeSpace(a->right->location.begin, 3); + writer.keyword(toString(a->op)); break; case AstExprBinary::And: writer.maybeSpace(a->right->location.begin, 4); + writer.keyword(toString(a->op)); break; } - writer.symbol(toString(a->op)); - visualize(*a->right); } else if (const auto& a = expr.as()) { visualize(*a->expr); + + if (writeTypes) + { + writer.maybeSpace(a->annotation->location.begin, 2); + writer.symbol("::"); + visualizeTypeAnnotation(*a->annotation); + } + } + else if (const auto& a = expr.as()) + { + writer.keyword("if"); + visualize(*a->condition); + writer.keyword("then"); + visualize(*a->trueExpr); + writer.keyword("else"); + visualize(*a->falseExpr); + } + else if (const auto& a = expr.as()) + { + writer.symbol("`"); + + size_t index = 0; + + for (const auto& string : a->strings) + { + writer.write(escape(std::string_view(string.data, string.size), /* escapeForInterpString = */ true)); + + if (index < a->expressions.size) + { + writer.symbol("{"); + visualize(*a->expressions.data[index]); + writer.symbol("}"); + } + + index++; + } + + writer.symbol("`"); } else if (const auto& a = expr.as()) { @@ -759,24 +738,31 @@ struct Printer switch (a->op) { case AstExprBinary::Add: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("+="); break; case AstExprBinary::Sub: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("-="); break; case AstExprBinary::Mul: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("*="); break; case AstExprBinary::Div: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("/="); break; case AstExprBinary::Mod: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("%="); break; case AstExprBinary::Pow: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("^="); break; case AstExprBinary::Concat: + writer.maybeSpace(a->value->location.begin, 3); writer.symbol("..="); break; default: @@ -788,7 +774,7 @@ struct Printer else if (const auto& a = program.as()) { writer.keyword("function"); - visualizeWithSelf(*a->name, a->func->self != nullptr); + visualize(*a->name); visualizeFunctionBody(*a->func); } else if (const auto& a = program.as()) @@ -807,7 +793,7 @@ struct Printer writer.keyword("type"); writer.identifier(a->name.value); - if (a->generics.size > 0) + if (a->generics.size > 0 || a->genericPacks.size > 0) { writer.symbol("<"); CommaSeparatorInserter comma(writer); @@ -815,8 +801,34 @@ struct Printer for (auto o : a->generics) { comma(); - writer.identifier(o.value); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*o.defaultValue); + } } + + for (auto o : a->genericPacks) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypePackAnnotation(*o.defaultValue, false); + } + } + writer.symbol(">"); } writer.maybeSpace(a->type->location.begin, 2); @@ -853,19 +865,23 @@ struct Printer void visualizeFunctionBody(AstExprFunction& func) { - if (FFlag::LuauGenericFunctions && (func.generics.size > 0 || func.genericPacks.size > 0)) + if (func.generics.size > 0 || func.genericPacks.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); for (const auto& o : func.generics) { comma(); - writer.identifier(o.value); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); } for (const auto& o : func.genericPacks) { comma(); - writer.identifier(o.value); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); writer.symbol("..."); } writer.symbol(">"); @@ -892,23 +908,24 @@ struct Printer if (func.vararg) { comma(); + advance(func.varargLocation.begin); writer.symbol("..."); if (func.varargAnnotation) { writer.symbol(":"); - visualizeTypePackAnnotation(*func.varargAnnotation); + visualizeTypePackAnnotation(*func.varargAnnotation, true); } } writer.symbol(")"); - if (writeTypes && func.hasReturnAnnotation) + if (writeTypes && func.returnAnnotation) { writer.symbol(":"); writer.space(); - visualizeTypeList(func.returnAnnotation, false); + visualizeTypeList(*func.returnAnnotation, false); } visualizeBlock(*func.body); @@ -959,34 +976,49 @@ struct Printer advance(typeAnnotation.location.begin); if (const auto& a = typeAnnotation.as()) { + if (a->prefix) + { + writer.write(a->prefix->value); + writer.symbol("."); + } + writer.write(a->name.value); - if (a->generics.size > 0) + if (a->parameters.size > 0 || a->hasParameterList) { CommaSeparatorInserter comma(writer); writer.symbol("<"); - for (auto o : a->generics) + for (auto o : a->parameters) { comma(); - visualizeTypeAnnotation(*o); + + if (o.type) + visualizeTypeAnnotation(*o.type); + else + visualizeTypePackAnnotation(*o.typePack, false); } + writer.symbol(">"); } } else if (const auto& a = typeAnnotation.as()) { - if (FFlag::LuauGenericFunctions && (a->generics.size > 0 || a->genericPacks.size > 0)) + if (a->generics.size > 0 || a->genericPacks.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); for (const auto& o : a->generics) { comma(); - writer.identifier(o.value); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); } for (const auto& o : a->genericPacks) { comma(); - writer.identifier(o.value); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); writer.symbol("..."); } writer.symbol(">"); @@ -1001,31 +1033,42 @@ struct Printer } else if (const auto& a = typeAnnotation.as()) { - CommaSeparatorInserter comma(writer); + AstTypeReference* indexType = a->indexer ? a->indexer->indexType->as() : nullptr; - writer.symbol("{"); - - for (std::size_t i = 0; i < a->props.size; ++i) + if (a->props.size == 0 && indexType && indexType->name == "number") { - comma(); - advance(a->props.data[i].location.begin); - writer.identifier(a->props.data[i].name.value); - if (a->props.data[i].type) - { - writer.symbol(":"); - visualizeTypeAnnotation(*a->props.data[i].type); - } - } - if (a->indexer) - { - comma(); - writer.symbol("["); - visualizeTypeAnnotation(*a->indexer->indexType); - writer.symbol("]"); - writer.symbol(":"); + writer.symbol("{"); visualizeTypeAnnotation(*a->indexer->resultType); + writer.symbol("}"); + } + else + { + CommaSeparatorInserter comma(writer); + + writer.symbol("{"); + + for (std::size_t i = 0; i < a->props.size; ++i) + { + comma(); + advance(a->props.data[i].location.begin); + writer.identifier(a->props.data[i].name.value); + if (a->props.data[i].type) + { + writer.symbol(":"); + visualizeTypeAnnotation(*a->props.data[i].type); + } + } + if (a->indexer) + { + comma(); + writer.symbol("["); + visualizeTypeAnnotation(*a->indexer->indexType); + writer.symbol("]"); + writer.symbol(":"); + visualizeTypeAnnotation(*a->indexer->resultType); + } + writer.symbol("}"); } - writer.symbol("}"); } else if (auto a = typeAnnotation.as()) { @@ -1049,7 +1092,16 @@ struct Printer auto rta = r->as(); if (rta && rta->name == "nil") { + bool wrap = l->as() || l->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*l); + + if (wrap) + writer.symbol(")"); + writer.symbol("?"); return; } @@ -1063,7 +1115,15 @@ struct Printer writer.symbol("|"); } + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); } } else if (const auto& a = typeAnnotation.as()) @@ -1076,9 +1136,25 @@ struct Printer writer.symbol("&"); } + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); } } + else if (const auto& a = typeAnnotation.as()) + { + writer.keyword(a->value ? "true" : "false"); + } + else if (const auto& a = typeAnnotation.as()) + { + writer.string(std::string_view(a->value.data, a->value.size)); + } else if (typeAnnotation.is()) { writer.symbol("%error-type%"); @@ -1090,31 +1166,27 @@ struct Printer } }; -void dump(AstNode* node) +std::string toString(AstNode* node) { StringWriter writer; + writer.pos = node->location.begin; + Printer printer(writer); printer.writeTypes = true; if (auto statNode = dynamic_cast(node)) - { printer.visualize(*statNode); - printf("%s\n", writer.str().c_str()); - } else if (auto exprNode = dynamic_cast(node)) - { printer.visualize(*exprNode); - printf("%s\n", writer.str().c_str()); - } else if (auto typeNode = dynamic_cast(node)) - { printer.visualizeTypeAnnotation(*typeNode); - printf("%s\n", writer.str().c_str()); - } - else - { - printf("Can't dump this node\n"); - } + + return writer.str(); +} + +void dump(AstNode* node) +{ + printf("%s\n", toString(node).c_str()); } std::string transpile(AstStatBlock& block) @@ -1123,6 +1195,7 @@ std::string transpile(AstStatBlock& block) Printer(writer).visualizeBlock(block); return writer.str(); } + std::string transpileWithTypes(AstStatBlock& block) { StringWriter writer; @@ -1132,7 +1205,7 @@ std::string transpileWithTypes(AstStatBlock& block) return writer.str(); } -TranspileResult transpile(std::string_view source, ParseOptions options) +TranspileResult transpile(std::string_view source, ParseOptions options, bool withTypes) { auto allocator = Allocator{}; auto names = AstNameTable{allocator}; @@ -1150,6 +1223,9 @@ TranspileResult transpile(std::string_view source, ParseOptions options) if (!parseResult.root) return TranspileResult{"", {}, "Internal error: Parser yielded empty parse tree"}; + if (withTypes) + return TranspileResult{transpileWithTypes(*parseResult.root)}; + return TranspileResult{transpile(*parseResult.root)}; } diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 702d0ca2..18596a63 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -1,72 +1,426 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TxnLog.h" +#include "Luau/ToString.h" +#include "Luau/TypeArena.h" #include "Luau/TypePack.h" #include +#include + +LUAU_FASTFLAG(LuauUnknownAndNeverType) namespace Luau { -void TxnLog::operator()(TypeId a) +const std::string nullPendingResult = ""; + +std::string toString(PendingType* pending) { - typeVarChanges.emplace_back(a, *a); + if (pending == nullptr) + return nullPendingResult; + + return toString(pending->pending); } -void TxnLog::operator()(TypePackId a) +std::string dump(PendingType* pending) { - typePackChanges.emplace_back(a, *a); + if (pending == nullptr) + { + printf("%s\n", nullPendingResult.c_str()); + return nullPendingResult; + } + + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string result = toString(pending->pending, opts); + printf("%s\n", result.c_str()); + return result; } -void TxnLog::operator()(TableTypeVar* a) +std::string toString(PendingTypePack* pending) { - tableChanges.emplace_back(a, a->boundTo); + if (pending == nullptr) + return nullPendingResult; + + return toString(pending->pending); } -void TxnLog::rollback() +std::string dump(PendingTypePack* pending) { - for (auto it = typeVarChanges.rbegin(); it != typeVarChanges.rend(); ++it) - std::swap(*asMutable(it->first), it->second); + if (pending == nullptr) + { + printf("%s\n", nullPendingResult.c_str()); + return nullPendingResult; + } - for (auto it = typePackChanges.rbegin(); it != typePackChanges.rend(); ++it) - std::swap(*asMutable(it->first), it->second); + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string result = toString(pending->pending, opts); + printf("%s\n", result.c_str()); + return result; +} - for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) - std::swap(it->first->boundTo, it->second); +static const TxnLog emptyLog; + +const TxnLog* TxnLog::empty() +{ + return &emptyLog; } void TxnLog::concat(TxnLog rhs) { - typeVarChanges.insert(typeVarChanges.end(), rhs.typeVarChanges.begin(), rhs.typeVarChanges.end()); - rhs.typeVarChanges.clear(); + for (auto& [ty, rep] : rhs.typeVarChanges) + typeVarChanges[ty] = std::move(rep); - typePackChanges.insert(typePackChanges.end(), rhs.typePackChanges.begin(), rhs.typePackChanges.end()); - rhs.typePackChanges.clear(); - - tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end()); - rhs.tableChanges.clear(); - - seen.swap(rhs.seen); - rhs.seen.clear(); + for (auto& [tp, rep] : rhs.typePackChanges) + typePackChanges[tp] = std::move(rep); } -bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) +void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) { - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair)); + for (auto& [ty, rightRep] : rhs.typeVarChanges) + { + if (auto leftRep = typeVarChanges.find(ty)) + { + TypeId leftTy = arena->addType((*leftRep)->pending); + TypeId rightTy = arena->addType(rightRep->pending); + typeVarChanges[ty]->pending.ty = IntersectionTypeVar{{leftTy, rightTy}}; + } + else + typeVarChanges[ty] = std::move(rightRep); + } + + for (auto& [tp, rep] : rhs.typePackChanges) + typePackChanges[tp] = std::move(rep); +} + +void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) +{ + for (auto& [ty, rightRep] : rhs.typeVarChanges) + { + if (auto leftRep = typeVarChanges.find(ty)) + { + TypeId leftTy = arena->addType((*leftRep)->pending); + TypeId rightTy = arena->addType(rightRep->pending); + typeVarChanges[ty]->pending.ty = UnionTypeVar{{leftTy, rightTy}}; + } + else + typeVarChanges[ty] = std::move(rightRep); + } + + for (auto& [tp, rep] : rhs.typePackChanges) + typePackChanges[tp] = std::move(rep); +} + +void TxnLog::commit() +{ + for (auto& [ty, rep] : typeVarChanges) + asMutable(ty)->reassign(rep.get()->pending); + + for (auto& [tp, rep] : typePackChanges) + asMutable(tp)->reassign(rep.get()->pending); + + clear(); +} + +void TxnLog::clear() +{ + typeVarChanges.clear(); + typePackChanges.clear(); +} + +TxnLog TxnLog::inverse() +{ + TxnLog inversed(sharedSeen); + + for (auto& [ty, _rep] : typeVarChanges) + inversed.typeVarChanges[ty] = std::make_unique(*ty); + + for (auto& [tp, _rep] : typePackChanges) + inversed.typePackChanges[tp] = std::make_unique(*tp); + + return inversed; +} + +bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) const +{ + return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); } void TxnLog::pushSeen(TypeId lhs, TypeId rhs) { - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - seen.push_back(sortedPair); + pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); } void TxnLog::popSeen(TypeId lhs, TypeId rhs) { - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - LUAU_ASSERT(sortedPair == seen.back()); - seen.pop_back(); + popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +bool TxnLog::haveSeen(TypePackId lhs, TypePackId rhs) const +{ + return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void TxnLog::pushSeen(TypePackId lhs, TypePackId rhs) +{ + pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void TxnLog::popSeen(TypePackId lhs, TypePackId rhs) +{ + popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const +{ + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) + { + return true; + } + + return false; +} + +void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) +{ + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + sharedSeen->push_back(sortedPair); +} + +void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) +{ + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + LUAU_ASSERT(sortedPair == sharedSeen->back()); + sharedSeen->pop_back(); +} + +PendingType* TxnLog::queue(TypeId ty) +{ + LUAU_ASSERT(!ty->persistent); + + // Explicitly don't look in ancestors. If we have discovered something new + // about this type, we don't want to mutate the parent's state. + auto& pending = typeVarChanges[ty]; + if (!pending) + { + pending = std::make_unique(*ty); + pending->pending.owningArena = nullptr; + } + + return pending.get(); +} + +PendingTypePack* TxnLog::queue(TypePackId tp) +{ + LUAU_ASSERT(!tp->persistent); + + // Explicitly don't look in ancestors. If we have discovered something new + // about this type, we don't want to mutate the parent's state. + auto& pending = typePackChanges[tp]; + if (!pending) + { + pending = std::make_unique(*tp); + pending->pending.owningArena = nullptr; + } + + return pending.get(); +} + +PendingType* TxnLog::pending(TypeId ty) const +{ + // This function will technically work if `this` is nullptr, but this + // indicates a bug, so we explicitly assert. + LUAU_ASSERT(static_cast(this) != nullptr); + + for (const TxnLog* current = this; current; current = current->parent) + { + if (auto it = current->typeVarChanges.find(ty)) + return it->get(); + } + + return nullptr; +} + +PendingTypePack* TxnLog::pending(TypePackId tp) const +{ + // This function will technically work if `this` is nullptr, but this + // indicates a bug, so we explicitly assert. + LUAU_ASSERT(static_cast(this) != nullptr); + + for (const TxnLog* current = this; current; current = current->parent) + { + if (auto it = current->typePackChanges.find(tp)) + return it->get(); + } + + return nullptr; +} + +PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) +{ + PendingType* newTy = queue(ty); + newTy->pending.reassign(replacement); + return newTy; +} + +PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) +{ + PendingTypePack* newTp = queue(tp); + newTp->pending.reassign(replacement); + return newTp; +} + +PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) +{ + LUAU_ASSERT(get(ty)); + + PendingType* newTy = queue(ty); + if (TableTypeVar* ttv = Luau::getMutable(newTy)) + ttv->boundTo = newBoundTo; + + return newTy; +} + +PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) +{ + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + + PendingType* newTy = queue(ty); + if (FreeTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->level = newLevel; + } + else if (TableTypeVar* ttv = Luau::getMutable(newTy)) + { + LUAU_ASSERT(ttv->state == TableState::Free || ttv->state == TableState::Generic); + ttv->level = newLevel; + } + else if (FunctionTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->level = newLevel; + } + + return newTy; +} + +PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) +{ + LUAU_ASSERT(get(tp)); + + PendingTypePack* newTp = queue(tp); + if (FreeTypePack* ftp = Luau::getMutable(newTp)) + { + ftp->level = newLevel; + } + + return newTp; +} + +PendingType* TxnLog::changeScope(TypeId ty, NotNull newScope) +{ + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + + PendingType* newTy = queue(ty); + if (FreeTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->scope = newScope; + } + else if (TableTypeVar* ttv = Luau::getMutable(newTy)) + { + LUAU_ASSERT(ttv->state == TableState::Free || ttv->state == TableState::Generic); + ttv->scope = newScope; + } + else if (FunctionTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->scope = newScope; + } + + return newTy; +} + +PendingTypePack* TxnLog::changeScope(TypePackId tp, NotNull newScope) +{ + LUAU_ASSERT(get(tp)); + + PendingTypePack* newTp = queue(tp); + if (FreeTypePack* ftp = Luau::getMutable(newTp)) + { + ftp->scope = newScope; + } + + return newTp; +} + +PendingType* TxnLog::changeIndexer(TypeId ty, std::optional indexer) +{ + LUAU_ASSERT(get(ty)); + + PendingType* newTy = queue(ty); + if (TableTypeVar* ttv = Luau::getMutable(newTy)) + { + ttv->indexer = indexer; + } + + return newTy; +} + +std::optional TxnLog::getLevel(TypeId ty) const +{ + if (FreeTypeVar* ftv = getMutable(ty)) + return ftv->level; + else if (TableTypeVar* ttv = getMutable(ty); ttv && (ttv->state == TableState::Free || ttv->state == TableState::Generic)) + return ttv->level; + else if (FunctionTypeVar* ftv = getMutable(ty)) + return ftv->level; + + return std::nullopt; +} + +TypeId TxnLog::follow(TypeId ty) const +{ + return Luau::follow(ty, [this](TypeId ty) { + PendingType* state = this->pending(ty); + + if (state == nullptr) + return ty; + + // Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants + // that normally apply. This is safe because follow will only call get<> + // on the returned pointer. + return const_cast(&state->pending); + }); +} + +TypePackId TxnLog::follow(TypePackId tp) const +{ + return Luau::follow(tp, [this](TypePackId tp) { + PendingTypePack* state = this->pending(tp); + + if (state == nullptr) + return tp; + + // Ugly: Fabricate a TypePackId that doesn't adhere to most of the + // invariants that normally apply. This is safe because follow will + // only call get<> on the returned pointer. + return const_cast(&state->pending); + }); +} + +std::pair, std::vector> TxnLog::getChanges() const +{ + std::pair, std::vector> result; + + for (const auto& [typeId, _newState] : typeVarChanges) + result.first.push_back(typeId); + for (const auto& [typePackId, _newState] : typePackChanges) + result.second.push_back(typePackId); + + return result; } } // namespace Luau diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp new file mode 100644 index 00000000..666ab867 --- /dev/null +++ b/Analysis/src/TypeArena.cpp @@ -0,0 +1,115 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeArena.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false); + +namespace Luau +{ + +void TypeArena::clear() +{ + typeVars.clear(); + typePacks.clear(); +} + +TypeId TypeArena::addTV(TypeVar&& tv) +{ + TypeId allocated = typeVars.allocate(std::move(tv)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(TypeLevel level) +{ + TypeId allocated = typeVars.allocate(FreeTypeVar{level}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(Scope* scope) +{ + TypeId allocated = typeVars.allocate(FreeTypeVar{scope}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(Scope* scope, TypeLevel level) +{ + TypeId allocated = typeVars.allocate(FreeTypeVar{scope, level}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::freshTypePack(Scope* scope) +{ + TypePackId allocated = typePacks.allocate(FreeTypePack{scope}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(std::initializer_list types) +{ + TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(std::vector types, std::optional tail) +{ + TypePackId allocated = typePacks.allocate(TypePack{std::move(types), tail}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(TypePack tp) +{ + TypePackId allocated = typePacks.allocate(std::move(tp)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(TypePackVar tp) +{ + TypePackId allocated = typePacks.allocate(std::move(tp)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +void freeze(TypeArena& arena) +{ + if (!FFlag::DebugLuauFreezeArena) + return; + + arena.typeVars.freeze(); + arena.typePacks.freeze(); +} + +void unfreeze(TypeArena& arena) +{ + if (!FFlag::DebugLuauFreezeArena) + return; + + arena.typeVars.unfreeze(); + arena.typePacks.unfreeze(); +} + +} // namespace Luau diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 17c57c84..e483c047 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -3,16 +3,15 @@ #include "Luau/Error.h" #include "Luau/Module.h" -#include "Luau/Parser.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" +#include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include -LUAU_FASTFLAG(LuauGenericFunctions) - static char* allocateString(Luau::Allocator& allocator, std::string_view contents) { char* result = (char*)allocator.allocate(contents.size() + 1); @@ -31,15 +30,31 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data return result; } +using SyntheticNames = std::unordered_map; + namespace Luau { +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen) +{ + size_t s = syntheticNames->size(); + char*& n = (*syntheticNames)[&gen]; + if (!n) + { + std::string str = gen.explicitName ? gen.name : generateName(s); + n = static_cast(allocator->allocate(str.size() + 1)); + strcpy(n, str.c_str()); + } + + return n; +} + class TypeRehydrationVisitor { - mutable std::map seen; - mutable int count = 0; + std::map seen; + int count = 0; - bool hasSeen(const void* tv) const + bool hasSeen(const void* tv) { void* ttv = const_cast(tv); auto it = seen.find(ttv); @@ -51,13 +66,16 @@ class TypeRehydrationVisitor } public: - TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions()) + TypeRehydrationVisitor(Allocator* alloc, SyntheticNames* syntheticNames, const TypeRehydrationOptions& options = TypeRehydrationOptions()) : allocator(alloc) + , syntheticNames(syntheticNames) , options(options) { } - AstType* operator()(const PrimitiveTypeVar& ptv) const + AstTypePack* rehydrate(TypePackId tp); + + AstType* operator()(const PrimitiveTypeVar& ptv) { switch (ptv.type) { @@ -75,26 +93,57 @@ public: return nullptr; } } - AstType* operator()(const AnyTypeVar&) const + + AstType* operator()(const BlockedTypeVar& btv) + { + return allocator->alloc(Location(), std::nullopt, AstName("*blocked*")); + } + + AstType* operator()(const PendingExpansionTypeVar& petv) + { + return allocator->alloc(Location(), std::nullopt, AstName("*pending-expansion*")); + } + + AstType* operator()(const SingletonTypeVar& stv) + { + if (const BooleanSingleton* bs = get(&stv)) + return allocator->alloc(Location(), bs->value); + else if (const StringSingleton* ss = get(&stv)) + { + AstArray value; + value.data = const_cast(ss->value.c_str()); + value.size = strlen(value.data); + return allocator->alloc(Location(), value); + } + else + return nullptr; + } + + AstType* operator()(const AnyTypeVar&) { return allocator->alloc(Location(), std::nullopt, AstName("any")); } - AstType* operator()(const TableTypeVar& ttv) const + AstType* operator()(const TableTypeVar& ttv) { RecursionCounter counter(&count); if (ttv.name && options.bannedNames.find(*ttv.name) == options.bannedNames.end()) { - AstArray generics; - generics.size = ttv.instantiatedTypeParams.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstType*) * generics.size)); + AstArray parameters; + parameters.size = ttv.instantiatedTypeParams.size(); + parameters.data = static_cast(allocator->allocate(sizeof(AstTypeOrPack) * parameters.size)); for (size_t i = 0; i < ttv.instantiatedTypeParams.size(); ++i) { - generics.data[i] = Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty); + parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}}; } - return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), generics); + for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i) + { + parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; + } + + return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters); } if (hasSeen(&ttv)) @@ -133,12 +182,12 @@ public: return allocator->alloc(Location(), props, indexer); } - AstType* operator()(const MetatableTypeVar& mtv) const + AstType* operator()(const MetatableTypeVar& mtv) { return Luau::visit(*this, mtv.table->ty); } - AstType* operator()(const ClassTypeVar& ctv) const + AstType* operator()(const ClassTypeVar& ctv) { RecursionCounter counter(&count); @@ -165,47 +214,31 @@ public: return allocator->alloc(Location(), props); } - AstType* operator()(const FunctionTypeVar& ftv) const + AstType* operator()(const FunctionTypeVar& ftv) { RecursionCounter counter(&count); if (hasSeen(&ftv)) return allocator->alloc(Location(), std::nullopt, AstName("")); - AstArray generics; - if (FFlag::LuauGenericFunctions) + AstArray generics; + generics.size = ftv.generics.size(); + generics.data = static_cast(allocator->allocate(sizeof(AstGenericType) * generics.size)); + size_t numGenerics = 0; + for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) { - generics.size = ftv.generics.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); - size_t i = 0; - for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) - { - if (auto gtv = get(*it)) - generics.data[i++] = AstName(gtv->name.c_str()); - } - } - else - { - generics.size = 0; - generics.data = nullptr; + if (auto gtv = get(*it)) + generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } - AstArray genericPacks; - if (FFlag::LuauGenericFunctions) + AstArray genericPacks; + genericPacks.size = ftv.genericPacks.size(); + genericPacks.data = static_cast(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size)); + size_t numGenericPacks = 0; + for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { - genericPacks.size = ftv.genericPacks.size(); - genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); - size_t i = 0; - for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) - { - if (auto gtv = get(*it)) - genericPacks.data[i++] = AstName(gtv->name.c_str()); - } - } - else - { - generics.size = 0; - generics.data = nullptr; + if (auto gtv = get(*it)) + genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } AstArray argTypes; @@ -221,13 +254,7 @@ public: AstTypePack* argTailAnnotation = nullptr; if (argTail) - { - TypePackId tail = *argTail; - if (const VariadicTypePack* vtp = get(tail)) - { - argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); - } - } + argTailAnnotation = rehydrate(*argTail); AstArray> argNames; argNames.size = ftv.argNames.size(); @@ -235,14 +262,16 @@ public: size_t i = 0; for (const auto& el : ftv.argNames) { + std::optional* arg = &argNames.data[i++]; + if (el) - argNames.data[i++] = {AstName(el->name.c_str()), el->location}; + new (arg) std::optional(AstArgumentName(AstName(el->name.c_str()), el->location)); else - argNames.data[i++] = {}; + new (arg) std::optional(); } AstArray returnTypes; - const auto& [retVector, retTail] = flatten(ftv.retType); + const auto& [retVector, retTail] = flatten(ftv.retTypes); returnTypes.size = retVector.size(); returnTypes.data = static_cast(allocator->allocate(sizeof(AstType*) * returnTypes.size)); for (size_t i = 0; i < returnTypes.size; ++i) @@ -254,34 +283,28 @@ public: AstTypePack* retTailAnnotation = nullptr; if (retTail) - { - TypePackId tail = *retTail; - if (const VariadicTypePack* vtp = get(tail)) - { - retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); - } - } + retTailAnnotation = rehydrate(*retTail); return allocator->alloc( Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); } - AstType* operator()(const Unifiable::Error&) const + AstType* operator()(const Unifiable::Error&) { return allocator->alloc(Location(), std::nullopt, AstName("Unifiable")); } - AstType* operator()(const GenericTypeVar& gtv) const + AstType* operator()(const GenericTypeVar& gtv) { - return allocator->alloc(Location(), std::nullopt, AstName(gtv.name.c_str())); + return allocator->alloc(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv))); } - AstType* operator()(const Unifiable::Bound& bound) const + AstType* operator()(const Unifiable::Bound& bound) { return Luau::visit(*this, bound.boundTo->ty); } - AstType* operator()(Unifiable::Free ftv) const + AstType* operator()(const FreeTypeVar& ftv) { return allocator->alloc(Location(), std::nullopt, AstName("free")); } - AstType* operator()(const UnionTypeVar& uv) const + AstType* operator()(const UnionTypeVar& uv) { AstArray unionTypes; unionTypes.size = uv.options.size(); @@ -292,7 +315,7 @@ public: } return allocator->alloc(Location(), unionTypes); } - AstType* operator()(const IntersectionTypeVar& uv) const + AstType* operator()(const IntersectionTypeVar& uv) { AstArray intersectionTypes; intersectionTypes.size = uv.parts.size(); @@ -303,16 +326,105 @@ public: } return allocator->alloc(Location(), intersectionTypes); } - AstType* operator()(const LazyTypeVar& ltv) const + AstType* operator()(const LazyTypeVar& ltv) { return allocator->alloc(Location(), std::nullopt, AstName("")); } + AstType* operator()(const UnknownTypeVar& ttv) + { + return allocator->alloc(Location(), std::nullopt, AstName{"unknown"}); + } + AstType* operator()(const NeverTypeVar& ttv) + { + return allocator->alloc(Location(), std::nullopt, AstName{"never"}); + } + AstType* operator()(const NegationTypeVar& ntv) + { + // FIXME: do the same thing we do with ErrorTypeVar + throw InternalCompilerError("Cannot convert NegationTypeVar into AstNode"); + } private: Allocator* allocator; + SyntheticNames* syntheticNames; const TypeRehydrationOptions& options; }; +class TypePackRehydrationVisitor +{ +public: + TypePackRehydrationVisitor(Allocator* allocator, SyntheticNames* syntheticNames, TypeRehydrationVisitor* typeVisitor) + : allocator(allocator) + , syntheticNames(syntheticNames) + , typeVisitor(typeVisitor) + { + LUAU_ASSERT(allocator); + LUAU_ASSERT(syntheticNames); + LUAU_ASSERT(typeVisitor); + } + + AstTypePack* operator()(const BoundTypePack& btp) const + { + return Luau::visit(*this, btp.boundTo->ty); + } + + AstTypePack* operator()(const BlockedTypePack& btp) const + { + return allocator->alloc(Location(), AstName("*blocked*")); + } + + AstTypePack* operator()(const TypePack& tp) const + { + AstArray head; + head.size = tp.head.size(); + head.data = static_cast(allocator->allocate(sizeof(AstType*) * tp.head.size())); + + for (size_t i = 0; i < tp.head.size(); i++) + head.data[i] = Luau::visit(*typeVisitor, tp.head[i]->ty); + + AstTypePack* tail = nullptr; + + if (tp.tail) + tail = Luau::visit(*this, (*tp.tail)->ty); + + return allocator->alloc(Location(), AstTypeList{head, tail}); + } + + AstTypePack* operator()(const VariadicTypePack& vtp) const + { + if (vtp.hidden) + return nullptr; + + return allocator->alloc(Location(), Luau::visit(*typeVisitor, vtp.ty->ty)); + } + + AstTypePack* operator()(const GenericTypePack& gtp) const + { + return allocator->alloc(Location(), AstName(getName(allocator, syntheticNames, gtp))); + } + + AstTypePack* operator()(const FreeTypePack& gtp) const + { + return allocator->alloc(Location(), AstName("free")); + } + + AstTypePack* operator()(const Unifiable::Error&) const + { + return allocator->alloc(Location(), AstName("Unifiable")); + } + +private: + Allocator* allocator; + SyntheticNames* syntheticNames; + TypeRehydrationVisitor* typeVisitor; +}; + +AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) +{ + TypePackRehydrationVisitor tprv(allocator, syntheticNames, this); + return Luau::visit(tprv, tp->ty); +} + class TypeAttacher : public AstVisitor { public: @@ -344,7 +456,7 @@ public: { if (!type) return nullptr; - return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty); + return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), (*type)->ty); } AstArray typeAstPack(TypePackId type) @@ -356,7 +468,7 @@ public: result.data = static_cast(allocator->allocate(sizeof(AstType*) * v.size())); for (size_t i = 0; i < v.size(); ++i) { - result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty); + result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), v[i]->ty); } return result; } @@ -385,6 +497,20 @@ public: { return visitLocal(al->local); } + + virtual bool visit(AstStatFor* stat) override + { + visitLocal(stat->var); + return true; + } + + virtual bool visit(AstStatForIn* stat) override + { + for (size_t i = 0; i < stat->vars.size; ++i) + visitLocal(stat->vars.data[i]); + return true; + } + virtual bool visit(AstExprFunction* fn) override { // TODO: add generics if the inferred type of the function is generic CLI-39908 @@ -394,22 +520,17 @@ public: visitLocal(arg); } - if (!fn->hasReturnAnnotation) + if (!fn->returnAnnotation) { if (auto result = getScope(fn->body->location)) { TypePackId ret = result->returnType; - fn->hasReturnAnnotation = true; AstTypePack* variadicAnnotation = nullptr; const auto& [v, tail] = flatten(ret); if (tail) - { - TypePackId tailPack = *tail; - if (const VariadicTypePack* vtp = get(tailPack)) - variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); - } + variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail); fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation}; } @@ -421,6 +542,7 @@ public: private: Module& module; Allocator* allocator; + SyntheticNames syntheticNames; }; void attachTypeData(SourceModule& source, Module& result) @@ -431,7 +553,8 @@ void attachTypeData(SourceModule& source, Module& result) AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options) { - return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty); + SyntheticNames syntheticNames; + return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames, options), type->ty); } } // namespace Luau diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp new file mode 100644 index 00000000..8c44f90a --- /dev/null +++ b/Analysis/src/TypeChecker2.cpp @@ -0,0 +1,1693 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeChecker2.h" + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/Clone.h" +#include "Luau/Instantiation.h" +#include "Luau/Metamethods.h" +#include "Luau/Normalize.h" +#include "Luau/ToString.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeUtils.h" +#include "Luau/TypeVar.h" +#include "Luau/Unifier.h" +#include "Luau/ToString.h" +#include "Luau/DcrLogger.h" + +#include + +LUAU_FASTFLAG(DebugLuauLogSolverToJson); +LUAU_FASTFLAG(DebugLuauMagicTypes); + +namespace Luau +{ + +// TypeInfer.h +// TODO move these +using PrintLineProc = void (*)(const std::string&); +extern PrintLineProc luauPrintLine; + +/* Push a scope onto the end of a stack for the lifetime of the StackPusher instance. + * TypeChecker2 uses this to maintain knowledge about which scope encloses every + * given AstNode. + */ +struct StackPusher +{ + std::vector>* stack; + NotNull scope; + + explicit StackPusher(std::vector>& stack, Scope* scope) + : stack(&stack) + , scope(scope) + { + stack.push_back(NotNull{scope}); + } + + ~StackPusher() + { + if (stack) + { + LUAU_ASSERT(stack->back() == scope); + stack->pop_back(); + } + } + + StackPusher(const StackPusher&) = delete; + StackPusher&& operator=(const StackPusher&) = delete; + + StackPusher(StackPusher&& other) + : stack(std::exchange(other.stack, nullptr)) + , scope(other.scope) + { + } +}; + +static std::optional getIdentifierOfBaseVar(AstExpr* node) +{ + if (AstExprGlobal* expr = node->as()) + return expr->name.value; + + if (AstExprLocal* expr = node->as()) + return expr->local->name.value; + + if (AstExprIndexExpr* expr = node->as()) + return getIdentifierOfBaseVar(expr->expr); + + if (AstExprIndexName* expr = node->as()) + return getIdentifierOfBaseVar(expr->expr); + + return std::nullopt; +} + +struct TypeChecker2 +{ + NotNull singletonTypes; + DcrLogger* logger; + InternalErrorReporter ice; // FIXME accept a pointer from Frontend + const SourceModule* sourceModule; + Module* module; + + std::vector> stack; + + UnifierSharedState sharedState{&ice}; + Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; + + TypeChecker2(NotNull singletonTypes, DcrLogger* logger, const SourceModule* sourceModule, Module* module) + : singletonTypes(singletonTypes) + , logger(logger) + , sourceModule(sourceModule) + , module(module) + { + if (FFlag::DebugLuauLogSolverToJson) + LUAU_ASSERT(logger); + } + + std::optional pushStack(AstNode* node) + { + if (Scope** scope = module->astScopes.find(node)) + return StackPusher{stack, *scope}; + else + return std::nullopt; + } + + TypePackId lookupPack(AstExpr* expr) + { + // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. + // We'll just return anyType in these cases. Typechecking against any is very fast and this + // allows us not to think about this very much in the actual typechecking logic. + TypePackId* tp = module->astTypePacks.find(expr); + if (tp) + return follow(*tp); + else + return singletonTypes->anyTypePack; + } + + TypeId lookupType(AstExpr* expr) + { + // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. + // We'll just return anyType in these cases. Typechecking against any is very fast and this + // allows us not to think about this very much in the actual typechecking logic. + TypeId* ty = module->astTypes.find(expr); + if (ty) + return follow(*ty); + + TypePackId* tp = module->astTypePacks.find(expr); + if (tp) + return flattenPack(*tp); + + return singletonTypes->anyType; + } + + TypeId lookupAnnotation(AstType* annotation) + { + if (FFlag::DebugLuauMagicTypes) + { + if (auto ref = annotation->as(); ref && ref->name == "_luau_print" && ref->parameters.size > 0) + { + if (auto ann = ref->parameters.data[0].type) + { + TypeId argTy = lookupAnnotation(ref->parameters.data[0].type); + luauPrintLine(format( + "_luau_print (%d, %d): %s\n", annotation->location.begin.line, annotation->location.begin.column, toString(argTy).c_str())); + return follow(argTy); + } + } + } + + TypeId* ty = module->astResolvedTypes.find(annotation); + LUAU_ASSERT(ty); + return follow(*ty); + } + + TypePackId lookupPackAnnotation(AstTypePack* annotation) + { + TypePackId* tp = module->astResolvedTypePacks.find(annotation); + LUAU_ASSERT(tp); + return follow(*tp); + } + + TypePackId reconstructPack(AstArray exprs, TypeArena& arena) + { + if (exprs.size == 0) + return arena.addTypePack(TypePack{{}, std::nullopt}); + + std::vector head; + + for (size_t i = 0; i < exprs.size - 1; ++i) + { + head.push_back(lookupType(exprs.data[i])); + } + + TypePackId tail = lookupPack(exprs.data[exprs.size - 1]); + return arena.addTypePack(TypePack{head, tail}); + } + + Scope* findInnermostScope(Location location) + { + Scope* bestScope = module->getModuleScope().get(); + Location bestLocation = module->scopes[0].first; + + for (size_t i = 0; i < module->scopes.size(); ++i) + { + auto& [scopeBounds, scope] = module->scopes[i]; + if (scopeBounds.encloses(location)) + { + if (scopeBounds.begin > bestLocation.begin || scopeBounds.end < bestLocation.end) + { + bestScope = scope.get(); + bestLocation = scopeBounds; + } + } + else if (scopeBounds.begin > location.end) + { + // TODO: Is this sound? This relies on the fact that scopes are inserted + // into the scope list in the order that they appear in the AST. + break; + } + } + + return bestScope; + } + + void visit(AstStat* stat) + { + auto pusher = pushStack(stat); + + if (0) + { + } + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else if (auto s = stat->as()) + return visit(s); + else + LUAU_ASSERT(!"TypeChecker2 encountered an unknown node type"); + } + + void visit(AstStatBlock* block) + { + auto StackPusher = pushStack(block); + + for (AstStat* statement : block->body) + visit(statement); + } + + void visit(AstStatIf* ifStatement) + { + visit(ifStatement->condition); + visit(ifStatement->thenbody); + if (ifStatement->elsebody) + visit(ifStatement->elsebody); + } + + void visit(AstStatWhile* whileStatement) + { + visit(whileStatement->condition); + visit(whileStatement->body); + } + + void visit(AstStatRepeat* repeatStatement) + { + visit(repeatStatement->body); + visit(repeatStatement->condition); + } + + void visit(AstStatBreak*) {} + + void visit(AstStatContinue*) {} + + void visit(AstStatReturn* ret) + { + Scope* scope = findInnermostScope(ret->location); + TypePackId expectedRetType = scope->returnType; + + TypeArena* arena = &module->internalTypes; + TypePackId actualRetType = reconstructPack(ret->list, *arena); + + Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; + + u.tryUnify(actualRetType, expectedRetType); + const bool ok = u.errors.empty() && u.log.empty(); + + if (!ok) + { + for (const TypeError& e : u.errors) + reportError(e); + } + + for (AstExpr* expr : ret->list) + visit(expr); + } + + void visit(AstStatExpr* expr) + { + visit(expr->expr); + } + + void visit(AstStatLocal* local) + { + size_t count = std::max(local->values.size, local->vars.size); + for (size_t i = 0; i < count; ++i) + { + AstExpr* value = i < local->values.size ? local->values.data[i] : nullptr; + + if (value) + visit(value); + + TypeId* maybeValueType = value ? module->astTypes.find(value) : nullptr; + if (i != local->values.size - 1 || maybeValueType) + { + AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; + + if (var && var->annotation) + { + TypeId annotationType = lookupAnnotation(var->annotation); + TypeId valueType = value ? lookupType(value) : nullptr; + if (valueType) + { + ErrorVec errors = tryUnify(stack.back(), value->location, valueType, annotationType); + if (!errors.empty()) + reportErrors(std::move(errors)); + } + } + } + else + { + LUAU_ASSERT(value); + + TypePackId valueTypes = lookupPack(value); + auto it = begin(valueTypes); + for (size_t j = i; j < local->vars.size; ++j) + { + if (it == end(valueTypes)) + { + break; + } + + AstLocal* var = local->vars.data[i]; + if (var->annotation) + { + TypeId varType = lookupAnnotation(var->annotation); + ErrorVec errors = tryUnify(stack.back(), value->location, *it, varType); + if (!errors.empty()) + reportErrors(std::move(errors)); + } + + ++it; + } + } + } + } + + void visit(AstStatFor* forStatement) + { + if (forStatement->var->annotation) + visit(forStatement->var->annotation); + + visit(forStatement->from); + visit(forStatement->to); + if (forStatement->step) + visit(forStatement->step); + visit(forStatement->body); + } + + void visit(AstStatForIn* forInStatement) + { + for (AstLocal* local : forInStatement->vars) + { + if (local->annotation) + visit(local->annotation); + } + + for (AstExpr* expr : forInStatement->values) + visit(expr); + + visit(forInStatement->body); + + // Rule out crazy stuff. Maybe possible if the file is not syntactically valid. + if (!forInStatement->vars.size || !forInStatement->values.size) + return; + + NotNull scope = stack.back(); + TypeArena& arena = module->internalTypes; + + std::vector variableTypes; + for (AstLocal* var : forInStatement->vars) + { + std::optional ty = scope->lookup(var); + LUAU_ASSERT(ty); + variableTypes.emplace_back(*ty); + } + + // ugh. There's nothing in the AST to hang a whole type pack on for the + // set of iteratees, so we have to piece it back together by hand. + std::vector valueTypes; + for (size_t i = 0; i < forInStatement->values.size - 1; ++i) + valueTypes.emplace_back(lookupType(forInStatement->values.data[i])); + TypePackId iteratorTail = lookupPack(forInStatement->values.data[forInStatement->values.size - 1]); + TypePackId iteratorPack = arena.addTypePack(valueTypes, iteratorTail); + + // ... and then expand it out to 3 values (if possible) + TypePack iteratorTypes = extendTypePack(arena, singletonTypes, iteratorPack, 3); + if (iteratorTypes.head.empty()) + { + reportError(GenericError{"for..in loops require at least one value to iterate over. Got zero"}, getLocation(forInStatement->values)); + return; + } + TypeId iteratorTy = follow(iteratorTypes.head[0]); + + auto checkFunction = [this, &arena, &scope, &forInStatement, &variableTypes]( + const FunctionTypeVar* iterFtv, std::vector iterTys, bool isMm) { + if (iterTys.size() < 1 || iterTys.size() > 3) + { + if (isMm) + reportError(GenericError{"__iter metamethod must return (next[, table[, state]])"}, getLocation(forInStatement->values)); + else + reportError(GenericError{"for..in loops must be passed (next[, table[, state]])"}, getLocation(forInStatement->values)); + + return; + } + + // It is okay if there aren't enough iterators, but the iteratee must provide enough. + TypePack expectedVariableTypes = extendTypePack(arena, singletonTypes, iterFtv->retTypes, variableTypes.size()); + if (expectedVariableTypes.head.size() < variableTypes.size()) + { + if (isMm) + reportError( + GenericError{"__iter metamethod's next() function does not return enough values"}, getLocation(forInStatement->values)); + else + reportError(GenericError{"next() does not return enough values"}, forInStatement->values.data[0]->location); + } + + for (size_t i = 0; i < std::min(expectedVariableTypes.head.size(), variableTypes.size()); ++i) + reportErrors(tryUnify(scope, forInStatement->vars.data[i]->location, variableTypes[i], expectedVariableTypes.head[i])); + + // nextFn is going to be invoked with (arrayTy, startIndexTy) + + // It will be passed two arguments on every iteration save the + // first. + + // It may be invoked with 0 or 1 argument on the first iteration. + // This depends on the types in iterateePack and therefore + // iteratorTypes. + + // If iteratorTypes is too short to be a valid call to nextFn, we have to report a count mismatch error. + // If 2 is too short to be a valid call to nextFn, we have to report a count mismatch error. + // If 2 is too long to be a valid call to nextFn, we have to report a count mismatch error. + auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), iterFtv->argTypes, /*includeHiddenVariadics*/ true); + + if (minCount > 2) + reportError(CountMismatch{2, std::nullopt, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + if (maxCount && *maxCount < 2) + reportError(CountMismatch{2, std::nullopt, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + + TypePack flattenedArgTypes = extendTypePack(arena, singletonTypes, iterFtv->argTypes, 2); + size_t firstIterationArgCount = iterTys.empty() ? 0 : iterTys.size() - 1; + size_t actualArgCount = expectedVariableTypes.head.size(); + + if (firstIterationArgCount < minCount) + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + else if (actualArgCount < minCount) + reportError(CountMismatch{2, std::nullopt, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + + if (iterTys.size() >= 2 && flattenedArgTypes.head.size() > 0) + { + size_t valueIndex = forInStatement->values.size > 1 ? 1 : 0; + reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[1], flattenedArgTypes.head[0])); + } + + if (iterTys.size() == 3 && flattenedArgTypes.head.size() > 1) + { + size_t valueIndex = forInStatement->values.size > 2 ? 2 : 0; + reportErrors(tryUnify(scope, forInStatement->values.data[valueIndex]->location, iterTys[2], flattenedArgTypes.head[1])); + } + }; + + /* + * If the first iterator argument is a function + * * There must be 1 to 3 iterator arguments. Name them (nextTy, + * arrayTy, startIndexTy) + * * The return type of nextTy() must correspond to the variables' + * types and counts. HOWEVER the first iterator will never be nil. + * * The first return value of nextTy must be compatible with + * startIndexTy. + * * The first argument to nextTy() must be compatible with arrayTy if + * present. nil if not. + * * The second argument to nextTy() must be compatible with + * startIndexTy if it is present. Else, it must be compatible with + * nil. + * * nextTy() must be callable with only 2 arguments. + */ + if (const FunctionTypeVar* nextFn = get(iteratorTy)) + { + checkFunction(nextFn, iteratorTypes.head, false); + } + else if (const TableTypeVar* ttv = get(iteratorTy)) + { + if ((forInStatement->vars.size == 1 || forInStatement->vars.size == 2) && ttv->indexer) + { + reportErrors(tryUnify(scope, forInStatement->vars.data[0]->location, variableTypes[0], ttv->indexer->indexType)); + if (variableTypes.size() == 2) + reportErrors(tryUnify(scope, forInStatement->vars.data[1]->location, variableTypes[1], ttv->indexer->indexResultType)); + } + else + reportError(GenericError{"Cannot iterate over a table without indexer"}, forInStatement->values.data[0]->location); + } + else if (get(iteratorTy) || get(iteratorTy)) + { + // nothing + } + else if (std::optional iterMmTy = + findMetatableEntry(singletonTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) + { + Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}, scope}; + + if (std::optional instantiatedIterMmTy = instantiation.substitute(*iterMmTy)) + { + if (const FunctionTypeVar* iterMmFtv = get(*instantiatedIterMmTy)) + { + TypePackId argPack = arena.addTypePack({iteratorTy}); + reportErrors(tryUnify(scope, forInStatement->values.data[0]->location, argPack, iterMmFtv->argTypes)); + + TypePack mmIteratorTypes = extendTypePack(arena, singletonTypes, iterMmFtv->retTypes, 3); + + if (mmIteratorTypes.head.size() == 0) + { + reportError(GenericError{"__iter must return at least one value"}, forInStatement->values.data[0]->location); + return; + } + + TypeId nextFn = follow(mmIteratorTypes.head[0]); + + if (std::optional instantiatedNextFn = instantiation.substitute(nextFn)) + { + std::vector instantiatedIteratorTypes = mmIteratorTypes.head; + instantiatedIteratorTypes[0] = *instantiatedNextFn; + + if (const FunctionTypeVar* nextFtv = get(*instantiatedNextFn)) + { + checkFunction(nextFtv, instantiatedIteratorTypes, true); + } + else + { + reportError(CannotCallNonFunction{*instantiatedNextFn}, forInStatement->values.data[0]->location); + } + } + else + { + reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); + } + } + else + { + // TODO: This will not tell the user that this is because the + // metamethod isn't callable. This is not ideal, and we should + // improve this error message. + + // TODO: This will also not handle intersections of functions or + // callable tables (which are supported by the runtime). + reportError(CannotCallNonFunction{*iterMmTy}, forInStatement->values.data[0]->location); + } + } + else + { + reportError(UnificationTooComplex{}, forInStatement->values.data[0]->location); + } + } + else + { + reportError(CannotCallNonFunction{iteratorTy}, forInStatement->values.data[0]->location); + } + } + + void visit(AstStatAssign* assign) + { + size_t count = std::min(assign->vars.size, assign->values.size); + + for (size_t i = 0; i < count; ++i) + { + AstExpr* lhs = assign->vars.data[i]; + visit(lhs); + TypeId lhsType = lookupType(lhs); + + AstExpr* rhs = assign->values.data[i]; + visit(rhs); + TypeId rhsType = lookupType(rhs); + + if (!isSubtype(rhsType, lhsType, stack.back())) + { + reportError(TypeMismatch{lhsType, rhsType}, rhs->location); + } + } + } + + void visit(AstStatCompoundAssign* stat) + { + visit(stat->var); + visit(stat->value); + } + + void visit(AstStatFunction* stat) + { + visit(stat->name); + visit(stat->func); + } + + void visit(AstStatLocalFunction* stat) + { + visit(stat->func); + } + + void visit(const AstTypeList* typeList) + { + for (AstType* ty : typeList->types) + visit(ty); + + if (typeList->tailType) + visit(typeList->tailType); + } + + void visit(AstStatTypeAlias* stat) + { + for (const AstGenericType& el : stat->generics) + { + if (el.defaultValue) + visit(el.defaultValue); + } + + for (const AstGenericTypePack& el : stat->genericPacks) + { + if (el.defaultValue) + visit(el.defaultValue); + } + + visit(stat->type); + } + + void visit(AstTypeList types) + { + for (AstType* type : types.types) + visit(type); + if (types.tailType) + visit(types.tailType); + } + + void visit(AstStatDeclareFunction* stat) + { + visit(stat->params); + visit(stat->retTypes); + } + + void visit(AstStatDeclareGlobal* stat) + { + visit(stat->type); + } + + void visit(AstStatDeclareClass* stat) + { + for (const AstDeclaredClassProp& prop : stat->props) + visit(prop.ty); + } + + void visit(AstStatError* stat) + { + for (AstExpr* expr : stat->expressions) + visit(expr); + + for (AstStat* s : stat->statements) + visit(s); + } + + void visit(AstExpr* expr) + { + auto StackPusher = pushStack(expr); + + if (0) + { + } + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else if (auto e = expr->as()) + return visit(e); + else + LUAU_ASSERT(!"TypeChecker2 encountered an unknown expression type"); + } + + void visit(AstExprGroup* expr) + { + visit(expr->expr); + } + + void visit(AstExprConstantNil* expr) + { + // TODO! + } + + void visit(AstExprConstantBool* expr) + { + // TODO! + } + + void visit(AstExprConstantNumber* number) + { + TypeId actualType = lookupType(number); + TypeId numberType = singletonTypes->numberType; + + if (!isSubtype(numberType, actualType, stack.back())) + { + reportError(TypeMismatch{actualType, numberType}, number->location); + } + } + + void visit(AstExprConstantString* string) + { + TypeId actualType = lookupType(string); + TypeId stringType = singletonTypes->stringType; + + if (!isSubtype(actualType, stringType, stack.back())) + { + reportError(TypeMismatch{actualType, stringType}, string->location); + } + } + + void visit(AstExprLocal* expr) + { + // TODO! + } + + void visit(AstExprGlobal* expr) + { + // TODO! + } + + void visit(AstExprVarargs* expr) + { + // TODO! + } + + void visit(AstExprCall* call) + { + visit(call->func); + + for (AstExpr* arg : call->args) + visit(arg); + + TypeArena* arena = &module->internalTypes; + Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; + + TypePackId expectedRetType = lookupPack(call); + TypeId functionType = lookupType(call->func); + TypeId testFunctionType = functionType; + TypePack args; + + if (get(functionType) || get(functionType)) + return; + else if (std::optional callMm = findMetatableEntry(singletonTypes, module->errors, functionType, "__call", call->func->location)) + { + if (get(follow(*callMm))) + { + if (std::optional instantiatedCallMm = instantiation.substitute(*callMm)) + { + args.head.push_back(functionType); + testFunctionType = follow(*instantiatedCallMm); + } + else + { + reportError(UnificationTooComplex{}, call->func->location); + return; + } + } + else + { + // TODO: This doesn't flag the __call metamethod as the problem + // very clearly. + reportError(CannotCallNonFunction{*callMm}, call->func->location); + return; + } + } + else if (get(functionType)) + { + if (std::optional instantiatedFunctionType = instantiation.substitute(functionType)) + { + testFunctionType = *instantiatedFunctionType; + } + else + { + reportError(UnificationTooComplex{}, call->func->location); + return; + } + } + else if (auto utv = get(functionType)) + { + // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. + std::optional fst; + for (TypeId ty : utv) + { + if (!fst) + fst = follow(ty); + else if (fst != follow(ty)) + { + reportError(CannotCallNonFunction{functionType}, call->func->location); + return; + } + } + + if (!fst) + ice.ice("UnionTypeVar had no elements, so fst is nullopt?"); + + if (std::optional instantiatedFunctionType = instantiation.substitute(*fst)) + { + testFunctionType = *instantiatedFunctionType; + } + else + { + reportError(UnificationTooComplex{}, call->func->location); + return; + } + } + else + { + reportError(CannotCallNonFunction{functionType}, call->func->location); + return; + } + + if (call->self) + { + AstExprIndexName* indexExpr = call->func->as(); + if (!indexExpr) + ice.ice("method call expression has no 'self'"); + + args.head.push_back(lookupType(indexExpr->expr)); + } + + for (size_t i = 0; i < call->args.size; ++i) + { + AstExpr* arg = call->args.data[i]; + TypeId* argTy = module->astTypes.find(arg); + if (argTy) + args.head.push_back(*argTy); + else if (i == call->args.size - 1) + { + TypePackId* argTail = module->astTypePacks.find(arg); + if (argTail) + args.tail = *argTail; + else + args.tail = singletonTypes->anyTypePack; + } + else + args.head.push_back(singletonTypes->anyType); + } + + TypePackId argsTp = arena->addTypePack(args); + FunctionTypeVar ftv{argsTp, expectedRetType}; + TypeId expectedType = arena->addType(ftv); + + if (!isSubtype(testFunctionType, expectedType, stack.back())) + { + CloneState cloneState; + expectedType = clone(expectedType, module->internalTypes, cloneState); + reportError(TypeMismatch{expectedType, functionType}, call->location); + } + } + + void visit(AstExprIndexName* indexName) + { + TypeId leftType = lookupType(indexName->expr); + + const NormalizedType* norm = normalizer.normalize(leftType); + if (!norm) + reportError(NormalizationTooComplex{}, indexName->indexLocation); + + checkIndexTypeFromType(leftType, *norm, indexName->index.value, indexName->location); + } + + void visit(AstExprIndexExpr* indexExpr) + { + // TODO! + visit(indexExpr->expr); + visit(indexExpr->index); + } + + void visit(AstExprFunction* fn) + { + auto StackPusher = pushStack(fn); + + TypeId inferredFnTy = lookupType(fn); + const FunctionTypeVar* inferredFtv = get(inferredFnTy); + LUAU_ASSERT(inferredFtv); + + auto argIt = begin(inferredFtv->argTypes); + for (const auto& arg : fn->args) + { + if (argIt == end(inferredFtv->argTypes)) + break; + + if (arg->annotation) + { + TypeId inferredArgTy = *argIt; + TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + + if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back())) + { + reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); + } + } + + ++argIt; + } + + visit(fn->body); + } + + void visit(AstExprTable* expr) + { + // TODO! + for (const AstExprTable::Item& item : expr->items) + { + if (item.key) + visit(item.key); + visit(item.value); + } + } + + void visit(AstExprUnary* expr) + { + visit(expr->expr); + + NotNull scope = stack.back(); + TypeId operandType = lookupType(expr->expr); + + if (get(operandType) || get(operandType) || get(operandType)) + return; + + if (auto it = kUnaryOpMetamethods.find(expr->op); it != kUnaryOpMetamethods.end()) + { + std::optional mm = findMetatableEntry(singletonTypes, module->errors, operandType, it->second, expr->location); + if (mm) + { + if (const FunctionTypeVar* ftv = get(follow(*mm))) + { + TypePackId expectedArgs = module->internalTypes.addTypePack({operandType}); + reportErrors(tryUnify(scope, expr->location, expectedArgs, ftv->argTypes)); + + if (std::optional ret = first(ftv->retTypes)) + { + if (expr->op == AstExprUnary::Op::Len) + { + reportErrors(tryUnify(scope, expr->location, follow(*ret), singletonTypes->numberType)); + } + } + else + { + reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); + } + } + + return; + } + } + + if (expr->op == AstExprUnary::Op::Len) + { + DenseHashSet seen{nullptr}; + int recursionCount = 0; + + if (!hasLength(operandType, seen, &recursionCount)) + { + reportError(NotATable{operandType}, expr->location); + } + } + else if (expr->op == AstExprUnary::Op::Minus) + { + reportErrors(tryUnify(scope, expr->location, operandType, singletonTypes->numberType)); + } + else if (expr->op == AstExprUnary::Op::Not) + { + } + else + { + LUAU_ASSERT(!"Unhandled unary operator"); + } + } + + void visit(AstExprBinary* expr) + { + visit(expr->left); + visit(expr->right); + + NotNull scope = stack.back(); + + bool isEquality = expr->op == AstExprBinary::Op::CompareEq || expr->op == AstExprBinary::Op::CompareNe; + bool isComparison = expr->op >= AstExprBinary::Op::CompareEq && expr->op <= AstExprBinary::Op::CompareGe; + bool isLogical = expr->op == AstExprBinary::Op::And || expr->op == AstExprBinary::Op::Or; + + TypeId leftType = lookupType(expr->left); + TypeId rightType = lookupType(expr->right); + + if (expr->op == AstExprBinary::Op::Or) + { + leftType = stripNil(singletonTypes, module->internalTypes, leftType); + } + + bool isStringOperation = isString(leftType) && isString(rightType); + + if (get(leftType) || get(leftType) || get(rightType) || get(rightType)) + return; + + if ((get(leftType) || get(leftType)) && !isEquality && !isLogical) + { + auto name = getIdentifierOfBaseVar(expr->left); + reportError(CannotInferBinaryOperation{expr->op, name, + isComparison ? CannotInferBinaryOperation::OpKind::Comparison : CannotInferBinaryOperation::OpKind::Operation}, + expr->location); + return; + } + + if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) + { + std::optional leftMt = getMetatable(leftType, singletonTypes); + std::optional rightMt = getMetatable(rightType, singletonTypes); + + bool matches = leftMt == rightMt; + if (isEquality && !matches) + { + auto testUnion = [&matches, singletonTypes = this->singletonTypes](const UnionTypeVar* utv, std::optional otherMt) { + for (TypeId option : utv) + { + if (getMetatable(follow(option), singletonTypes) == otherMt) + { + matches = true; + break; + } + } + }; + + if (const UnionTypeVar* utv = get(leftType); utv && rightMt) + { + testUnion(utv, rightMt); + } + + if (const UnionTypeVar* utv = get(rightType); utv && leftMt && !matches) + { + testUnion(utv, leftMt); + } + } + + if (!matches && isComparison) + { + reportError(GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + + return; + } + + std::optional mm; + if (std::optional leftMm = findMetatableEntry(singletonTypes, module->errors, leftType, it->second, expr->left->location)) + mm = leftMm; + else if (std::optional rightMm = findMetatableEntry(singletonTypes, module->errors, rightType, it->second, expr->right->location)) + { + mm = rightMm; + std::swap(leftType, rightType); + } + + if (mm) + { + TypeId instantiatedMm = module->astOverloadResolvedTypes[expr]; + if (!instantiatedMm) + reportError(CodeTooComplex{}, expr->location); + + else if (const FunctionTypeVar* ftv = get(follow(instantiatedMm))) + { + TypePackId expectedArgs; + // For >= and > we invoke __lt and __le respectively with + // swapped argument ordering. + if (expr->op == AstExprBinary::Op::CompareGe || expr->op == AstExprBinary::Op::CompareGt) + { + expectedArgs = module->internalTypes.addTypePack({rightType, leftType}); + } + else + { + expectedArgs = module->internalTypes.addTypePack({leftType, rightType}); + } + + reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); + + if (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe || expr->op == AstExprBinary::CompareGe || + expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) + { + TypePackId expectedRets = module->internalTypes.addTypePack({singletonTypes->booleanType}); + if (!isSubtype(ftv->retTypes, expectedRets, scope)) + { + reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location); + } + } + else if (!first(ftv->retTypes)) + { + reportError(GenericError{format("Metamethod '%s' must return a value", it->second)}, expr->location); + } + } + else + { + reportError(CannotCallNonFunction{*mm}, expr->location); + } + + return; + } + // If this is a string comparison, or a concatenation of strings, we + // want to fall through to primitive behavior. + else if (!isEquality && !(isStringOperation && (expr->op == AstExprBinary::Op::Concat || isComparison))) + { + if (leftMt || rightMt) + { + if (isComparison) + { + reportError(GenericError{format( + "Types '%s' and '%s' cannot be compared with %s because neither type's metatable has a '%s' metamethod", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str(), it->second)}, + expr->location); + } + else + { + reportError(GenericError{format( + "Operator %s is not applicable for '%s' and '%s' because neither type's metatable has a '%s' metamethod", + toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str(), it->second)}, + expr->location); + } + + return; + } + else if (!leftMt && !rightMt && (get(leftType) || get(rightType))) + { + if (isComparison) + { + reportError(GenericError{format("Types '%s' and '%s' cannot be compared with %s because neither type has a metatable", + toString(leftType).c_str(), toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + } + else + { + reportError(GenericError{format("Operator %s is not applicable for '%s' and '%s' because neither type has a metatable", + toString(expr->op).c_str(), toString(leftType).c_str(), toString(rightType).c_str())}, + expr->location); + } + + return; + } + } + } + + switch (expr->op) + { + case AstExprBinary::Op::Add: + case AstExprBinary::Op::Sub: + case AstExprBinary::Op::Mul: + case AstExprBinary::Op::Div: + case AstExprBinary::Op::Pow: + case AstExprBinary::Op::Mod: + reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->numberType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType)); + + break; + case AstExprBinary::Op::Concat: + reportErrors(tryUnify(scope, expr->left->location, leftType, singletonTypes->stringType)); + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType)); + + break; + case AstExprBinary::Op::CompareGe: + case AstExprBinary::Op::CompareGt: + case AstExprBinary::Op::CompareLe: + case AstExprBinary::Op::CompareLt: + if (isNumber(leftType)) + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->numberType)); + else if (isString(leftType)) + reportErrors(tryUnify(scope, expr->right->location, rightType, singletonTypes->stringType)); + else + reportError(GenericError{format("Types '%s' and '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), + toString(rightType).c_str(), toString(expr->op).c_str())}, + expr->location); + + break; + case AstExprBinary::Op::And: + case AstExprBinary::Op::Or: + case AstExprBinary::Op::CompareEq: + case AstExprBinary::Op::CompareNe: + break; + default: + // Unhandled AstExprBinary::Op possibility. + LUAU_ASSERT(false); + } + } + + void visit(AstExprTypeAssertion* expr) + { + visit(expr->expr); + visit(expr->annotation); + + TypeId annotationType = lookupAnnotation(expr->annotation); + TypeId computedType = lookupType(expr->expr); + + // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. + if (isSubtype(annotationType, computedType, stack.back())) + return; + + if (isSubtype(computedType, annotationType, stack.back())) + return; + + reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); + } + + void visit(AstExprIfElse* expr) + { + // TODO! + visit(expr->condition); + visit(expr->trueExpr); + visit(expr->falseExpr); + } + + void visit(AstExprError* expr) + { + // TODO! + for (AstExpr* e : expr->expressions) + visit(e); + } + + /** Extract a TypeId for the first type of the provided pack. + * + * Note that this may require modifying some types. I hope this doesn't cause problems! + */ + TypeId flattenPack(TypePackId pack) + { + pack = follow(pack); + + while (true) + { + auto tp = get(pack); + if (tp && tp->head.empty() && tp->tail) + pack = *tp->tail; + else + break; + } + + if (auto ty = first(pack)) + return *ty; + else if (auto vtp = get(pack)) + return vtp->ty; + else if (auto ftp = get(pack)) + { + TypeId result = module->internalTypes.addType(FreeTypeVar{ftp->scope}); + TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); + + TypePack& resultPack = asMutable(pack)->ty.emplace(); + resultPack.head.assign(1, result); + resultPack.tail = freeTail; + + return result; + } + else if (get(pack)) + return singletonTypes->errorRecoveryType(); + else + ice.ice("flattenPack got a weird pack!"); + } + + void visit(AstType* ty) + { + if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + else if (auto t = ty->as()) + return visit(t); + } + + void visit(AstTypeReference* ty) + { + for (const AstTypeOrPack& param : ty->parameters) + { + if (param.type) + visit(param.type); + else + visit(param.typePack); + } + + Scope* scope = findInnermostScope(ty->location); + LUAU_ASSERT(scope); + + std::optional alias = + (ty->prefix) ? scope->lookupImportedType(ty->prefix->value, ty->name.value) : scope->lookupType(ty->name.value); + + if (alias.has_value()) + { + size_t typesRequired = alias->typeParams.size(); + size_t packsRequired = alias->typePackParams.size(); + + bool hasDefaultTypes = std::any_of(alias->typeParams.begin(), alias->typeParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + + bool hasDefaultPacks = std::any_of(alias->typePackParams.begin(), alias->typePackParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + + if (!ty->hasParameterList) + { + if ((!alias->typeParams.empty() && !hasDefaultTypes) || (!alias->typePackParams.empty() && !hasDefaultPacks)) + { + reportError(GenericError{"Type parameter list is required"}, ty->location); + } + } + + size_t typesProvided = 0; + size_t extraTypes = 0; + size_t packsProvided = 0; + + for (const AstTypeOrPack& p : ty->parameters) + { + if (p.type) + { + if (packsProvided != 0) + { + reportError(GenericError{"Type parameters must come before type pack parameters"}, ty->location); + } + + if (typesProvided < typesRequired) + { + typesProvided += 1; + } + else + { + extraTypes += 1; + } + } + else if (p.typePack) + { + TypePackId tp = lookupPackAnnotation(p.typePack); + + if (typesProvided < typesRequired && size(tp) == 1 && finite(tp) && first(tp)) + { + typesProvided += 1; + } + else + { + packsProvided += 1; + } + } + } + + if (extraTypes != 0 && packsProvided == 0) + { + packsProvided += 1; + } + + for (size_t i = typesProvided; i < typesRequired; ++i) + { + if (alias->typeParams[i].defaultValue) + { + typesProvided += 1; + } + } + + for (size_t i = packsProvided; i < packsProvided; ++i) + { + if (alias->typePackParams[i].defaultValue) + { + packsProvided += 1; + } + } + + if (extraTypes == 0 && packsProvided + 1 == packsRequired) + { + packsProvided += 1; + } + + if (typesProvided != typesRequired || packsProvided != packsRequired) + { + reportError(IncorrectGenericParameterCount{ + /* name */ ty->name.value, + /* typeFun */ *alias, + /* actualParameters */ typesProvided, + /* actualPackParameters */ packsProvided, + }, + ty->location); + } + } + else + { + if (scope->lookupPack(ty->name.value)) + { + reportError( + SwappedGenericTypeParameter{ + ty->name.value, + SwappedGenericTypeParameter::Kind::Type, + }, + ty->location); + } + else + { + reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); + } + } + } + + void visit(AstTypeTable* table) + { + // TODO! + + for (const AstTableProp& prop : table->props) + visit(prop.type); + + if (table->indexer) + { + visit(table->indexer->indexType); + visit(table->indexer->resultType); + } + } + + void visit(AstTypeFunction* ty) + { + // TODO! + + visit(ty->argTypes); + visit(ty->returnTypes); + } + + void visit(AstTypeTypeof* ty) + { + visit(ty->expr); + } + + void visit(AstTypeUnion* ty) + { + // TODO! + for (AstType* type : ty->types) + visit(type); + } + + void visit(AstTypeIntersection* ty) + { + // TODO! + for (AstType* type : ty->types) + visit(type); + } + + void visit(AstTypePack* pack) + { + if (auto p = pack->as()) + return visit(p); + else if (auto p = pack->as()) + return visit(p); + else if (auto p = pack->as()) + return visit(p); + } + + void visit(AstTypePackExplicit* tp) + { + // TODO! + for (AstType* type : tp->typeList.types) + visit(type); + + if (tp->typeList.tailType) + visit(tp->typeList.tailType); + } + + void visit(AstTypePackVariadic* tp) + { + // TODO! + visit(tp->variadicType); + } + + void visit(AstTypePackGeneric* tp) + { + Scope* scope = findInnermostScope(tp->location); + LUAU_ASSERT(scope); + + std::optional alias = scope->lookupPack(tp->genericName.value); + if (!alias.has_value()) + { + if (scope->lookupType(tp->genericName.value)) + { + reportError( + SwappedGenericTypeParameter{ + tp->genericName.value, + SwappedGenericTypeParameter::Kind::Pack, + }, + tp->location); + } + else + { + reportError(UnknownSymbol{tp->genericName.value, UnknownSymbol::Context::Type}, tp->location); + } + } + } + + template + bool isSubtype(TID subTy, TID superTy, NotNull scope) + { + TypeArena arena; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(subTy, superTy); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; + } + + template + ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy) + { + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; + u.useScopes = true; + u.tryUnify(subTy, superTy); + + return std::move(u.errors); + } + + void reportError(TypeErrorData data, const Location& location) + { + module->errors.emplace_back(location, sourceModule->name, std::move(data)); + + if (FFlag::DebugLuauLogSolverToJson) + logger->captureTypeCheckError(module->errors.back()); + } + + void reportError(TypeError e) + { + reportError(std::move(e.data), e.location); + } + + void reportErrors(ErrorVec errors) + { + for (TypeError e : errors) + reportError(std::move(e)); + } + + void checkIndexTypeFromType(TypeId denormalizedTy, const NormalizedType& norm, const std::string& prop, const Location& location) + { + bool foundOneProp = false; + std::vector typesMissingTheProp; + + auto fetch = [&](TypeId ty) { + if (!normalizer.isInhabited(ty)) + return; + + bool found = hasIndexTypeFromType(ty, prop, location); + foundOneProp |= found; + if (!found) + typesMissingTheProp.push_back(ty); + }; + + fetch(norm.tops); + fetch(norm.booleans); + for (TypeId ty : norm.classes) + fetch(ty); + fetch(norm.errors); + fetch(norm.nils); + fetch(norm.numbers); + if (!norm.strings.isNever()) + fetch(singletonTypes->stringType); + fetch(norm.threads); + for (TypeId ty : norm.tables) + fetch(ty); + if (norm.functions.isTop) + fetch(singletonTypes->functionType); + else if (!norm.functions.isNever()) + { + if (norm.functions.parts->size() == 1) + fetch(norm.functions.parts->front()); + else + { + std::vector parts; + parts.insert(parts.end(), norm.functions.parts->begin(), norm.functions.parts->end()); + fetch(module->internalTypes.addType(IntersectionTypeVar{std::move(parts)})); + } + } + for (const auto& [tyvar, intersect] : norm.tyvars) + { + if (get(intersect->tops)) + { + TypeId ty = normalizer.typeFromNormal(*intersect); + fetch(module->internalTypes.addType(IntersectionTypeVar{{tyvar, ty}})); + } + else + fetch(tyvar); + } + + if (!typesMissingTheProp.empty()) + { + if (foundOneProp) + reportError(TypeError{location, MissingUnionProperty{denormalizedTy, typesMissingTheProp, prop}}); + else + reportError(TypeError{location, UnknownProperty{denormalizedTy, prop}}); + } + } + + bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location) + { + if (get(ty) || get(ty) || get(ty)) + return true; + + if (isString(ty)) + { + std::optional mtIndex = Luau::findMetatableEntry(singletonTypes, module->errors, singletonTypes->stringType, "__index", location); + LUAU_ASSERT(mtIndex); + ty = *mtIndex; + } + + if (getTableType(ty)) + return bool(findTablePropertyRespectingMeta(singletonTypes, module->errors, ty, prop, location)); + else if (const ClassTypeVar* cls = get(ty)) + return bool(lookupClassProp(cls, prop)); + else if (const UnionTypeVar* utv = get(ty)) + ice.ice("getIndexTypeFromTypeHelper cannot take a UnionTypeVar"); + else if (const IntersectionTypeVar* itv = get(ty)) + return std::any_of(begin(itv), end(itv), [&](TypeId part) { + return hasIndexTypeFromType(part, prop, location); + }); + else + return false; + } +}; + +void check(NotNull singletonTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module) +{ + TypeChecker2 typeChecker{singletonTypes, logger, &sourceModule, module}; + + typeChecker.visit(sourceModule.root); +} + +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 2216881b..aa738ad9 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1,50 +1,59 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" +#include "Luau/ApplyTypeFunction.h" +#include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/Instantiation.h" #include "Luau/ModuleResolver.h" +#include "Luau/Normalize.h" #include "Luau/Parser.h" +#include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/Substitution.h" +#include "Luau/TimeTrace.h" #include "Luau/TopoSortStatements.h" #include "Luau/ToString.h" +#include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" -#include #include +#include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) -LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 0) -LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 0) -LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) -LUAU_FASTFLAGVARIABLE(LuauIndexTablesWithIndexers, false) -LUAU_FASTFLAGVARIABLE(LuauGenericFunctions, false) -LUAU_FASTFLAGVARIABLE(LuauGenericVariadicsUnification, false) +LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) +LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) +LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) +LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) -LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) -LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) -LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(LuauImprovedTypeGuardPredicate2, false) -LUAU_FASTFLAG(LuauTraceRequireLookupChild) -LUAU_FASTFLAG(DebugLuauTrackOwningArena) -LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) -LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) -LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) -LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauFixTableTypeAliasClone, false) -LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false) -LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false) -LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) -LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) -LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false) -LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false) -LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) -LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false) -LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) +LUAU_FASTFLAG(LuauTypeNormalization2) +LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) +LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. +LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) +LUAU_FASTFLAGVARIABLE(LuauNilIterator, false) +LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) +LUAU_FASTFLAGVARIABLE(LuauTypeInferMissingFollows, false) +LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) +LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) +LUAU_FASTFLAGVARIABLE(LuauFollowInLvalueIndexCheck, false) +LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) +LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) +LUAU_FASTFLAGVARIABLE(LuauOptionalNextKey, false) +LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) +LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) +LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) +LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) +LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) +LUAU_FASTFLAGVARIABLE(LuauDeclareClassPrototype, false) +LUAU_FASTFLAG(LuauUninhabitedSubAnything2) +LUAU_FASTFLAGVARIABLE(LuauCallableClasses, false) namespace Luau { @@ -59,9 +68,7 @@ static void defaultLuauPrintLine(const std::string& s) printf("%s\n", s.c_str()); } -using PrintLineProc = decltype(&defaultLuauPrintLine); - -static PrintLineProc luauPrintLine = &defaultLuauPrintLine; +PrintLineProc luauPrintLine = &defaultLuauPrintLine; void setPrintLine(PrintLineProc pl) { @@ -207,50 +214,84 @@ static bool isMetamethod(const Name& name) { return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || - name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode"; + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; } -TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler) +size_t HashBoolNamePair::operator()(const std::pair& pair) const +{ + return std::hash()(pair.first) ^ std::hash()(pair.second); +} + +TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull singletonTypes, InternalErrorReporter* iceHandler) : resolver(resolver) + , singletonTypes(singletonTypes) , iceHandler(iceHandler) - , nilType(singletonTypes.nilType) - , numberType(singletonTypes.numberType) - , stringType(singletonTypes.stringType) - , booleanType( - FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.booleanType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Boolean))) - , threadType(FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.threadType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Thread))) - , anyType(singletonTypes.anyType) - , errorType(singletonTypes.errorType) - , optionalNumberType(globalTypes.addType(UnionTypeVar{{numberType, nilType}})) - , anyTypePack(globalTypes.addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}, true})) - , errorTypePack(globalTypes.addTypePack(TypePackVar{Unifiable::Error{}})) + , unifierState(iceHandler) + , normalizer(nullptr, singletonTypes, NotNull{&unifierState}) + , nilType(singletonTypes->nilType) + , numberType(singletonTypes->numberType) + , stringType(singletonTypes->stringType) + , booleanType(singletonTypes->booleanType) + , threadType(singletonTypes->threadType) + , anyType(singletonTypes->anyType) + , unknownType(singletonTypes->unknownType) + , neverType(singletonTypes->neverType) + , anyTypePack(singletonTypes->anyTypePack) + , neverTypePack(singletonTypes->neverTypePack) + , uninhabitableTypePack(singletonTypes->uninhabitableTypePack) + , duplicateTypeAliases{{false, {}}} { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); - globalScope->exportedTypeBindings["any"] = TypeFun{{}, anyType}; - globalScope->exportedTypeBindings["nil"] = TypeFun{{}, nilType}; - globalScope->exportedTypeBindings["number"] = TypeFun{{}, numberType}; - globalScope->exportedTypeBindings["string"] = TypeFun{{}, stringType}; - globalScope->exportedTypeBindings["boolean"] = TypeFun{{}, booleanType}; - globalScope->exportedTypeBindings["thread"] = TypeFun{{}, threadType}; + globalScope->addBuiltinTypeBinding("any", TypeFun{{}, anyType}); + globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, nilType}); + globalScope->addBuiltinTypeBinding("number", TypeFun{{}, numberType}); + globalScope->addBuiltinTypeBinding("string", TypeFun{{}, stringType}); + globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, booleanType}); + globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, threadType}); + if (FFlag::LuauUnknownAndNeverType) + { + globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, unknownType}); + globalScope->addBuiltinTypeBinding("never", TypeFun{{}, neverType}); + } } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) { - currentModule.reset(new Module()); + try + { + return checkWithoutRecursionCheck(module, mode, environmentScope); + } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(module.root->location); + return std::move(currentModule); + } +} + +ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope) +{ + LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + + currentModule.reset(new Module); currentModule->type = module.type; + currentModule->allocator = module.allocator; + currentModule->names = module.names; iceHandler->moduleName = module.name; + normalizer.arena = ¤tModule->internalTypes; + + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); if (module.cyclic) moduleScope->returnType = addTypePack(TypePack{{anyType}, std::nullopt}); - else if (FFlag::LuauRankNTypes) - moduleScope->returnType = freshTypePack(moduleScope); else - moduleScope->returnType = DEPRECATED_freshTypePack(moduleScope, true); + moduleScope->returnType = freshTypePack(moduleScope); moduleScope->varargPack = anyTypePack; @@ -262,25 +303,52 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona if (prepareModuleScope) prepareModuleScope(module.name, currentModule->getModuleScope()); - checkBlock(moduleScope, *module.root); + try + { + checkBlock(moduleScope, *module.root); + } + catch (const TimeLimitError&) + { + currentModule->timeout = true; + } - if (get(FFlag::LuauAddMissingFollow ? follow(moduleScope->returnType) : moduleScope->returnType)) + if (FFlag::DebugLuauSharedSelf) + { + for (auto& [ty, scope] : deferredQuantification) + Luau::quantify(ty, scope->level); + deferredQuantification.clear(); + } + + if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); else moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); + moduleScope->returnType = anyifyModuleReturnTypePackGenerics(moduleScope->returnType); + for (auto& [_, typeFun] : moduleScope->exportedTypeBindings) typeFun.type = anyify(moduleScope, typeFun.type, Location{}); prepareErrorsForDisplay(currentModule->errors); - bool encounteredFreeType = currentModule->clonePublicInterface(); - if (encounteredFreeType) + if (FFlag::LuauTypeNormalization2) { - reportError(TypeError{module.root->location, - GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); + // Clear the normalizer caches, since they contain types from the internal type surface + normalizer.clearCaches(); + normalizer.arena = nullptr; } + currentModule->clonePublicInterface(singletonTypes, *iceHandler); + + // Clear unifier cache since it's keyed off internal types that get deallocated + // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. + unifierState.cachedUnify.clear(); + unifierState.cachedUnifyError.clear(); + unifierState.skipCacheForType.clear(); + + duplicateTypeAliases.clear(); + incorrectClassDefinitions.clear(); + return std::move(currentModule); } @@ -322,7 +390,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) check(scope, *typealias); else if (auto global = program.as()) { - TypeId globalType = (FFlag::LuauRankNTypes ? resolveType(scope, *global->type) : resolveType(scope, *global->type, true)); + TypeId globalType = resolveType(scope, *global->type); Name globalName(global->name.value); currentModule->declaredGlobals[globalName] = globalType; @@ -348,6 +416,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) } else ice("Unknown AstStat"); + + if (finishTime && TimeTrace::getClock() > *finishTime) + throw TimeLimitError(iceHandler->moduleName); } // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. @@ -365,6 +436,69 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) reportErrorCodeTooComplex(block.location); return; } + try + { + checkBlockWithoutRecursionCheck(scope, block); + } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(block.location); + return; + } +} + +struct InplaceDemoter : TypeVarOnceVisitor +{ + TypeLevel newLevel; + TypeArena* arena; + + InplaceDemoter(TypeLevel level, TypeArena* arena) + : TypeVarOnceVisitor(/* skipBoundTypes= */ true) + , newLevel(level) + , arena(arena) + { + } + + bool demote(TypeId ty) + { + if (auto level = getMutableLevel(ty)) + { + if (level->subsumesStrict(newLevel)) + { + *level = newLevel; + return true; + } + } + + return false; + } + + bool visit(TypeId ty) override + { + if (ty->owningArena != arena) + return false; + return demote(ty); + } + + bool visit(TypePackId tp, const FreeTypePack& ftpRef) override + { + if (tp->owningArena != arena) + return false; + + FreeTypePack* ftp = &const_cast(ftpRef); + if (ftp->level.subsumesStrict(newLevel)) + { + ftp->level = newLevel; + return true; + } + + return false; + } +}; + +void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) +{ + int subLevel = 0; std::vector sorted(block.body.data, block.body.data + block.body.size); toposort(sorted); @@ -372,7 +506,14 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) for (const auto& stat : sorted) { if (const auto& typealias = stat->as()) - check(scope, *typealias, true); + { + prototype(scope, *typealias, subLevel); + ++subLevel; + } + else if (const auto& declaredClass = stat->as(); FFlag::LuauDeclareClassPrototype && declaredClass) + { + prototype(scope, *declaredClass); + } } auto protoIter = sorted.begin(); @@ -395,8 +536,6 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } }; - int subLevel = 0; - while (protoIter != sorted.end()) { // protoIter walks forward @@ -412,7 +551,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // ``` // These both call each other, so `f` will be ordered before `g`, so the call to `g` // is typechecked before `g` has had its body checked. For this reason, there's three - // types for each functuion: before its body is checked, during checking its body, + // types for each function: before its body is checked, during checking its body, // and after its body is checked. // // We currently treat the before-type and the during-type as the same, @@ -429,7 +568,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (containsFunctionCall(**protoIter)) + if (containsFunctionCallOrReturn(**protoIter)) { while (checkIter != protoIter) { @@ -443,18 +582,55 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } else if (auto fun = (*protoIter)->as()) { - auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt); + std::optional selfType; + std::optional expectedType; + + if (FFlag::DebugLuauSharedSelf) + { + if (auto name = fun->name->as()) + { + TypeId baseTy = checkExpr(scope, *name->expr).type; + tablify(baseTy); + + if (!fun->func->self) + expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, /* addErrors= */ false); + else if (auto ttv = getMutableTableType(baseTy)) + { + if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy) + { + ttv->selfTy = anyIfNonstrict(freshType(ttv->level)); + deferredQuantification.push_back({baseTy, scope}); + } + + selfType = ttv->selfTy; + } + } + } + else + { + if (!fun->func->self) + { + if (auto name = fun->name->as()) + { + TypeId exprTy = checkExpr(scope, *name->expr).type; + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false); + } + } + } + + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, selfType, expectedType); auto [funTy, funScope] = pair; functionDecls[*protoIter] = pair; ++subLevel; - TypeId leftType = checkFunctionName(scope, *fun->name); - unify(leftType, funTy, fun->location); + TypeId leftType = follow(checkFunctionName(scope, *fun->name, funScope->level)); + + unify(funTy, leftType, scope, fun->location); } else if (auto fun = (*protoIter)->as()) { - auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt); + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt, std::nullopt); auto [funTy, funScope] = pair; functionDecls[*protoIter] = pair; @@ -483,13 +659,22 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { + if (typealias->name == kParseNameError) + continue; + auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; Name name = typealias->name.value; + + if (FFlag::LuauReportShadowedTypeAlias && duplicateTypeAliases.contains({typealias->exported, name})) + continue; + TypeId type = bindings[name].type; - if (get(FFlag::LuauAddMissingFollow ? follow(type) : type)) + if (get(follow(type))) { - *asMutable(type) = ErrorTypeVar{}; + TypeVar* mty = asMutable(follow(type)); + mty->reassign(*errorRecoveryType(anyType)); + reportError(TypeError{typealias->location, OccursCheckFailed{}}); } } @@ -542,10 +727,10 @@ static std::optional tryGetTypeGuardPredicate(const AstExprBinary& ex void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) { - ExprResult result = checkExpr(scope, *statement.condition); + WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr ifScope = childScope(scope, statement.thenbody->location); - reportErrors(resolve(result.predicates, ifScope, true)); + resolve(result.predicates, ifScope, true); check(ifScope, *statement.thenbody); if (statement.elsebody) @@ -556,35 +741,29 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) } } -ErrorVec TypeChecker::canUnify(TypeId left, TypeId right, const Location& location) -{ - return canUnify_(left, right, location); -} - -ErrorVec TypeChecker::canUnify(TypePackId left, TypePackId right, const Location& location) -{ - return canUnify_(left, right, location); -} - -ErrorVec TypeChecker::canUnify(const std::vector>& seen, TypeId superTy, TypeId subTy, const Location& location) -{ - Unifier state = mkUnifier(seen, location); - return state.canUnify(superTy, subTy); -} - template -ErrorVec TypeChecker::canUnify_(Id superTy, Id subTy, const Location& location) +ErrorVec TypeChecker::canUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location) { - Unifier state = mkUnifier(location); - return state.canUnify(superTy, subTy); + Unifier state = mkUnifier(scope, location); + return state.canUnify(subTy, superTy); +} + +ErrorVec TypeChecker::canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location) +{ + return canUnify_(subTy, superTy, scope, location); +} + +ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location) +{ + return canUnify_(subTy, superTy, scope, location); } void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) { - ExprResult result = checkExpr(scope, *statement.condition); + WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr whileScope = childScope(scope, statement.body->location); - reportErrors(resolve(result.predicates, whileScope, true)); + resolve(result.predicates, whileScope, true); check(whileScope, *statement.body); } @@ -597,39 +776,92 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } +struct Demoter : Substitution +{ + Demoter(TypeArena* arena) + : Substitution(TxnLog::empty(), arena) + { + } + + bool isDirty(TypeId ty) override + { + return get(ty); + } + + bool isDirty(TypePackId tp) override + { + return get(tp); + } + + bool ignoreChildren(TypeId ty) override + { + if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + return true; + + return false; + } + + TypeId clean(TypeId ty) override + { + auto ftv = get(ty); + LUAU_ASSERT(ftv); + return addType(FreeTypeVar{demotedLevel(ftv->level)}); + } + + TypePackId clean(TypePackId tp) override + { + auto ftp = get(tp); + LUAU_ASSERT(ftp); + return addTypePack(TypePackVar{FreeTypePack{demotedLevel(ftp->level)}}); + } + + TypeLevel demotedLevel(TypeLevel level) + { + return TypeLevel{level.level + 5000, level.subLevel}; + } + + void demote(std::vector>& expectedTypes) + { + for (std::optional& ty : expectedTypes) + { + if (ty) + ty = substitute(*ty); + } + } +}; + void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; + expectedTypes.reserve(return_.list.size); - if (FFlag::LuauInferReturnAssertAssign) + TypePackIterator expectedRetCurr = begin(scope->returnType); + TypePackIterator expectedRetEnd = end(scope->returnType); + + for (size_t i = 0; i < return_.list.size; ++i) { - expectedTypes.reserve(return_.list.size); - - TypePackIterator expectedRetCurr = begin(scope->returnType); - TypePackIterator expectedRetEnd = end(scope->returnType); - - for (size_t i = 0; i < return_.list.size; ++i) + if (expectedRetCurr != expectedRetEnd) { - if (expectedRetCurr != expectedRetEnd) - { - expectedTypes.push_back(*expectedRetCurr); - ++expectedRetCurr; - } - else if (auto expectedArgsTail = expectedRetCurr.tail()) - { - if (const VariadicTypePack* vtp = get(follow(*expectedArgsTail))) - expectedTypes.push_back(vtp->ty); - } + expectedTypes.push_back(*expectedRetCurr); + ++expectedRetCurr; + } + else if (auto expectedArgsTail = expectedRetCurr.tail()) + { + if (const VariadicTypePack* vtp = get(follow(*expectedArgsTail))) + expectedTypes.push_back(vtp->ty); } } + Demoter demoter{¤tModule->internalTypes}; + demoter.demote(expectedTypes); + TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; // HACK: Nonstrict mode gets a bit too smart and strict for us when we // start typechecking everything across module boundaries. if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) { - ErrorVec errors = tryUnify(scope->returnType, retPack, return_.location); + ErrorVec errors = tryUnify(retPack, scope->returnType, scope, return_.location); if (!errors.empty()) currentModule->getModuleScope()->returnType = addTypePack({anyType}); @@ -637,62 +869,65 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) return; } - unify(scope->returnType, retPack, return_.location, CountMismatch::Context::Return); -} - -ErrorVec TypeChecker::tryUnify(TypeId left, TypeId right, const Location& location) -{ - return tryUnify_(left, right, location); -} - -ErrorVec TypeChecker::tryUnify(TypePackId left, TypePackId right, const Location& location) -{ - return tryUnify_(left, right, location); + unify(retPack, scope->returnType, scope, return_.location, CountMismatch::Context::Return); } template -ErrorVec TypeChecker::tryUnify_(Id left, Id right, const Location& location) +ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location) { - Unifier state = mkUnifier(location); - state.tryUnify(left, right); + Unifier state = mkUnifier(scope, location); - if (!state.errors.empty()) - state.log.rollback(); + if (FFlag::DebugLuauFreezeDuringUnification) + freeze(currentModule->internalTypes); + + state.tryUnify(subTy, superTy); + + if (FFlag::DebugLuauFreezeDuringUnification) + unfreeze(currentModule->internalTypes); + + if (state.errors.empty()) + state.log.commit(); return state.errors; } +ErrorVec TypeChecker::tryUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location) +{ + return tryUnify_(subTy, superTy, scope, location); +} + +ErrorVec TypeChecker::tryUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location) +{ + return tryUnify_(subTy, superTy, scope, location); +} + void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { std::vector> expectedTypes; + expectedTypes.reserve(assign.vars.size); - if (FFlag::LuauInferReturnAssertAssign) + ScopePtr moduleScope = currentModule->getModuleScope(); + + for (size_t i = 0; i < assign.vars.size; ++i) { - expectedTypes.reserve(assign.vars.size); + AstExpr* dest = assign.vars.data[i]; - ScopePtr moduleScope = currentModule->getModuleScope(); - - for (size_t i = 0; i < assign.vars.size; ++i) + if (auto a = dest->as()) { - AstExpr* dest = assign.vars.data[i]; - - if (auto a = dest->as()) - { - // AstExprLocal l-values will have to be checked again because their type might have been mutated during checkExprList later - expectedTypes.push_back(scope->lookup(a->local)); - } - else if (auto a = dest->as()) - { - // AstExprGlobal l-values lookup is inlined here to avoid creating a global binding before checkExprList - if (auto it = moduleScope->bindings.find(a->name); it != moduleScope->bindings.end()) - expectedTypes.push_back(it->second.typeId); - else - expectedTypes.push_back(std::nullopt); - } + // AstExprLocal l-values will have to be checked again because their type might have been mutated during checkExprList later + expectedTypes.push_back(scope->lookup(a->local)); + } + else if (auto a = dest->as()) + { + // AstExprGlobal l-values lookup is inlined here to avoid creating a global binding before checkExprList + if (auto it = moduleScope->bindings.find(a->name); it != moduleScope->bindings.end()) + expectedTypes.push_back(it->second.typeId); else - { - expectedTypes.push_back(checkLValue(scope, *dest)); - } + expectedTypes.push_back(std::nullopt); + } + else + { + expectedTypes.push_back(checkLValue(scope, *dest)); } } @@ -708,7 +943,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) AstExpr* dest = assign.vars.data[i]; TypeId left = nullptr; - if (!FFlag::LuauInferReturnAssertAssign || dest->is() || dest->is()) + if (dest->is() || dest->is()) left = checkLValue(scope, *dest); else left = *expectedTypes[i]; @@ -731,25 +966,25 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) } else if (auto tail = valueIter.tail()) { - if (get(*tail)) - right = errorType; - else if (auto vtp = get(*tail)) + TypePackId tailPack = follow(*tail); + if (get(tailPack)) + right = errorRecoveryType(scope); + else if (auto vtp = get(tailPack)) right = vtp->ty; - else if (get(*tail)) + else if (get(tailPack)) { - *asMutable(*tail) = TypePack{{left}}; - growingPack = getMutable(*tail); + *asMutable(tailPack) = TypePack{{left}}; + growingPack = getMutable(tailPack); } } if (right) { - if (FFlag::LuauGenericFunctions && !maybeGeneric(left) && isGeneric(right)) - right = instantiate(scope, right, loc); - - if (!FFlag::LuauGenericFunctions && get(FFlag::LuauAddMissingFollow ? follow(left) : left) && - get(FFlag::LuauAddMissingFollow ? follow(right) : right)) - right = instantiate(scope, right, loc); + if (!FFlag::LuauInstantiateInSubtyping) + { + if (!maybeGeneric(left) && isGeneric(right)) + right = instantiate(scope, right, loc); + } // Setting a table entry to nil doesn't mean nil is the type of the indexer, it is just deleting the entry const TableTypeVar* destTableTypeReceivingNil = nullptr; @@ -759,10 +994,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) if (!destTableTypeReceivingNil || !destTableTypeReceivingNil->indexer) { // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar. - if (isNonstrictMode() && get(FFlag::LuauAddMissingFollow ? follow(left) : left) && !get(follow(right))) - unify(left, anyType, loc); + if (isNonstrictMode() && get(follow(left)) && !get(follow(right))) + unify(anyType, left, scope, loc); else - unify(left, right, loc); + unify(right, left, scope, loc); } } } @@ -777,7 +1012,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assi TypeId result = checkBinaryOperation(scope, expr, left, right); - unify(left, result, assign.location); + unify(result, left, scope, assign.location); } void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) @@ -808,7 +1043,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) if (annotation) { - ty = (FFlag::LuauRankNTypes ? resolveType(scope, *annotation) : resolveType(scope, *annotation, true)); + ty = resolveType(scope, *annotation); // If the annotation type has an error, treat it as if there was no annotation if (get(follow(ty))) @@ -816,31 +1051,46 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) } if (!ty) - ty = rhsIsTable ? (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)) - : isNonstrictMode() ? anyType : (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + ty = rhsIsTable ? freshType(scope) : isNonstrictMode() ? anyType : freshType(scope); varBindings.emplace_back(vars[i], Binding{ty, vars[i]->location}); variableTypes.push_back(ty); expectedTypes.push_back(ty); - if (FFlag::LuauGenericFunctions) + // with FFlag::LuauInstantiateInSubtyping enabled, we shouldn't need to produce instantiateGenerics at all. + if (!FFlag::LuauInstantiateInSubtyping) instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); - else - instantiateGenerics.push_back(annotation != nullptr && get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)); } if (local.values.size > 0) { - TypePackId variablePack = addTypePack(variableTypes, FFlag::LuauRankNTypes ? freshTypePack(scope) : DEPRECATED_freshTypePack(scope, true)); + TypePackId variablePack = addTypePack(variableTypes, freshTypePack(scope)); TypePackId valuePack = checkExprList(scope, local.location, local.values, /* substituteFreeForNil= */ true, instantiateGenerics, expectedTypes).type; - Unifier state = mkUnifier(local.location); - state.ctx = CountMismatch::Result; - state.tryUnify(variablePack, valuePack); + // If the expression list only contains one expression and it's a function call or is otherwise within parentheses, use FunctionResult. + // Otherwise, we'll want to use ExprListResult to make the error messaging more general. + CountMismatch::Context ctx = FFlag::LuauBetterMessagingOnCountMismatch ? CountMismatch::ExprListResult : CountMismatch::FunctionResult; + if (FFlag::LuauBetterMessagingOnCountMismatch) + { + if (local.values.size == 1) + { + AstExpr* e = local.values.data[0]; + while (auto group = e->as()) + e = group->expr; + if (e->is()) + ctx = CountMismatch::FunctionResult; + } + } + + Unifier state = mkUnifier(scope, local.location); + state.ctx = ctx; + state.tryUnify(valuePack, variablePack); reportErrors(state.errors); + state.log.commit(); + // In the code 'local T = {}', we wish to ascribe the name 'T' to the type of the table for error-reporting purposes. // We also want to do this for 'local T = setmetatable(...)'. if (local.vars.size == 1 && local.values.size == 1) @@ -910,7 +1160,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) TypeId loopVarType = numberType; if (expr.var->annotation) - unify(resolveType(scope, *expr.var->annotation), loopVarType, expr.location); + unify(loopVarType, resolveType(scope, *expr.var->annotation), scope, expr.location); loopScope->bindings[expr.var] = {loopVarType, expr.var->location}; @@ -920,11 +1170,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) if (!expr.to) ice("Bad AstStatFor has no to expr"); - unify(loopVarType, checkExpr(loopScope, *expr.from).type, expr.from->location); - unify(loopVarType, checkExpr(loopScope, *expr.to).type, expr.to->location); + unify(checkExpr(loopScope, *expr.from).type, loopVarType, scope, expr.from->location); + unify(checkExpr(loopScope, *expr.to).type, loopVarType, scope, expr.to->location); if (expr.step) - unify(loopVarType, checkExpr(loopScope, *expr.step).type, expr.step->location); + unify(checkExpr(loopScope, *expr.step).type, loopVarType, scope, expr.step->location); check(loopScope, *expr.body); } @@ -951,10 +1201,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) AstExpr* firstValue = forin.values.data[0]; // next is a function that takes Table and an optional index of type K - // next(t: Table, index: K | nil) -> (K, V) + // next(t: Table, index: K | nil) -> (K?, V) // however, pairs and ipairs are quite messy, but they both share the same types // pairs returns 'next, t, nil', thus the type would be - // pairs(t: Table) -> ((Table, K | nil) -> (K, V), Table, K | nil) + // pairs(t: Table) -> ((Table, K | nil) -> (K?, V), Table, K | nil) // ipairs returns 'next, t, 0', thus ipairs will also share the same type as pairs, except K = number // // we can also define our own custom iterators by by returning a wrapped coroutine that calls coroutine.yield @@ -972,44 +1222,88 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) { AstExprCall* exprCall = firstValue->as(); callRetPack = checkExprPack(scope, *exprCall).type; - if (!FFlag::LuauRankNTypes) - callRetPack = DEPRECATED_instantiate(scope, callRetPack, exprCall->location); callRetPack = follow(callRetPack); if (get(callRetPack)) { iterTy = freshType(scope); - unify(addTypePack({{iterTy}, freshTypePack(scope)}), callRetPack, forin.location); + unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), scope, forin.location); } else if (get(callRetPack) || !first(callRetPack)) { for (TypeId var : varTypes) - unify(var, errorType, forin.location); + unify(errorRecoveryType(scope), var, scope, forin.location); return check(loopScope, *forin.body); } else { iterTy = *first(callRetPack); - if (FFlag::LuauRankNTypes) - iterTy = instantiate(scope, iterTy, exprCall->location); + iterTy = instantiate(scope, iterTy, exprCall->location); } } else { - iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); + iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); + } + + if (FFlag::LuauNilIterator) + iterTy = stripFromNilAndReport(iterTy, firstValue->location); + + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location, /* addErrors= */ true)) + { + // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions + // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments + // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types + for (TypeId var : varTypes) + unify(anyType, var, scope, forin.location); + + return check(loopScope, *forin.body); + } + + if (const TableTypeVar* iterTable = get(iterTy)) + { + // TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer + // this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting + if (iterTable->indexer) + { + if (varTypes.size() > 0) + unify(iterTable->indexer->indexType, varTypes[0], scope, forin.location); + + if (varTypes.size() > 1) + unify(iterTable->indexer->indexResultType, varTypes[1], scope, forin.location); + + for (size_t i = 2; i < varTypes.size(); ++i) + unify(nilType, varTypes[i], scope, forin.location); + } + else if (isNonstrictMode()) + { + for (TypeId var : varTypes) + unify(anyType, var, scope, forin.location); + } + else + { + TypeId varTy = errorRecoveryType(loopScope); + + for (TypeId var : varTypes) + unify(varTy, var, scope, forin.location); + + reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); + } + + return check(loopScope, *forin.body); } const FunctionTypeVar* iterFunc = get(iterTy); if (!iterFunc) { - TypeId varTy = get(iterTy) ? anyType : errorType; + TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); for (TypeId var : varTypes) - unify(var, varTy, forin.location); + unify(varTy, var, scope, forin.location); - if (!get(iterTy) && !get(iterTy) && !get(iterTy)) - reportError(TypeError{firstValue->location, TypeMismatch{globalScope->bindings[AstName{"next"}].typeId, iterTy}}); + if (!get(iterTy) && !get(iterTy) && !get(iterTy) && !get(iterTy)) + reportError(firstValue->location, CannotCallNonFunction{iterTy}); return check(loopScope, *forin.body); } @@ -1031,27 +1325,69 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) argPack = addTypePack(TypePack{}); } - Unifier state = mkUnifier(firstValue->location); - checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); + Unifier state = mkUnifier(loopScope, firstValue->location); + checkArgumentList(loopScope, *firstValue, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); + + state.log.commit(); reportErrors(state.errors); } - TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); - - if (forin.values.size >= 2) + if (FFlag::LuauOptionalNextKey) { - AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + TypePackId retPack = iterFunc->retTypes; - Position start = firstValue->location.begin; - Position end = values[forin.values.size - 1]->location.end; - AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + if (forin.values.size >= 2) + { + AstArray arguments{forin.values.data + 1, forin.values.size - 1}; - TypePackId retPack = checkExprPack(scope, exprCall).type; - unify(varPack, retPack, forin.location); + Position start = firstValue->location.begin; + Position end = values[forin.values.size - 1]->location.end; + AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + + retPack = checkExprPack(scope, exprCall).type; + } + + // We need to remove 'nil' from the set of options of the first return value + // Because for loop stops when it gets 'nil', this result is never actually assigned to the first variable + if (std::optional fty = first(retPack); fty && !varTypes.empty()) + { + TypeId keyTy = follow(*fty); + + if (get(keyTy)) + { + if (std::optional ty = tryStripUnionFromNil(keyTy)) + keyTy = *ty; + } + + unify(keyTy, varTypes.front(), scope, forin.location); + + // We have already handled the first variable type, make it match in the pack check + varTypes.front() = *fty; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); + + unify(retPack, varPack, scope, forin.location); } else - unify(varPack, iterFunc->retType, forin.location); + { + TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); + + if (forin.values.size >= 2) + { + AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + + Position start = firstValue->location.begin; + Position end = values[forin.values.size - 1]->location.end; + AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + + TypePackId retPack = checkExprPack(scope, exprCall).type; + unify(retPack, varPack, scope, forin.location); + } + else + unify(iterFunc->retTypes, varPack, scope, forin.location); + } check(loopScope, *forin.body); } @@ -1076,7 +1412,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); // If in nonstrict mode and allowing redefinition of global function, restore the previous definition type - // in case this function has a differing signature. The signature discrepency will be caught in checkBlock. + // in case this function has a differing signature. The signature discrepancy will be caught in checkBlock. if (previouslyDefined) globalBindings[name] = oldBinding; else @@ -1093,53 +1429,53 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } - else if (function.func->self) + else if (auto name = function.name->as()) { - AstExprIndexName* indexName = function.name->as(); - if (!indexName) - ice("member function declaration has malformed name expression"); + TypeId exprTy = checkExpr(scope, *name->expr).type; + TableTypeVar* ttv = getMutableTableType(exprTy); - TypeId selfTy = checkExpr(scope, *indexName->expr).type; - TableTypeVar* tableSelf = getMutableTableType(selfTy); - if (!tableSelf) + if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false)) { - if (isTableIntersection(selfTy)) - reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); - else if (!get(selfTy) && !get(selfTy)) - reportError(TypeError{function.location, OnlyTablesCanHaveMethods{selfTy}}); + if (ttv || isTableIntersection(exprTy)) + reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); + else + reportError(TypeError{function.location, OnlyTablesCanHaveMethods{exprTy}}); } - else if (tableSelf->state == TableState::Sealed) - reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); ty = follow(ty); - if (tableSelf && !selfTy->persistent) - tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; + if (ttv && ttv->state != TableState::Sealed) + ttv->props[name->index.value] = {ty, /* deprecated */ false, {}, name->indexLocation}; - const FunctionTypeVar* funTy = get(ty); - if (!funTy) - ice("Methods should be functions"); + if (function.func->self) + { + const FunctionTypeVar* funTy = get(ty); + if (!funTy) + ice("Methods should be functions"); - std::optional arg0 = first(funTy->argTypes); - if (!arg0) - ice("Methods should always have at least 1 argument (self)"); + std::optional arg0 = first(funTy->argTypes); + if (!arg0) + ice("Methods should always have at least 1 argument (self)"); + } checkFunctionBody(funScope, ty, *function.func); - if (tableSelf && !selfTy->persistent) - tableSelf->props[indexName->index.value] = { - follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; + if (FFlag::LuauUnknownAndNeverType) + { + InplaceDemoter demoter{funScope->level, ¤tModule->internalTypes}; + demoter.traverse(ty); + } + + if (ttv && ttv->state != TableState::Sealed) + ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } else { - auto [leftType, leftTypeBinding] = checkLValueBinding(scope, *function.name); + LUAU_ASSERT(function.name->is()); + + ty = follow(ty); checkFunctionBody(funScope, ty, *function.func); - - unify(leftType, ty, function.location); - - if (leftTypeBinding) - *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); } } @@ -1151,18 +1487,17 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (FFlag::LuauGenericFunctions) - scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; - else - scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location}; + scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; } -void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare) +void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias) { - // This function should be called at most twice for each type alias. - // Once with forwardDeclare, and once without. Name name = typealias.name.value; + // If the alias is missing a name, we can't do anything with it. Ignore it. + if (name == kParseNameError) + return; + std::optional binding; if (auto it = scope->exportedTypeBindings.find(name); it != scope->exportedTypeBindings.end()) binding = it->second; @@ -1171,102 +1506,165 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias auto& bindingsMap = typealias.exported ? scope->exportedTypeBindings : scope->privateTypeBindings; - if (forwardDeclare) + // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be + // interesting. + if (duplicateTypeAliases.find({typealias.exported, name})) + return; + + // By now this alias must have been `prototype()`d first. + if (!binding) + ice("Not predeclared"); + + ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); + + for (auto param : binding->typeParams) { - if (binding) + auto generic = get(param.ty); + LUAU_ASSERT(generic); + aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, param.ty}; + } + + for (auto param : binding->typePackParams) + { + auto generic = get(param.tp); + LUAU_ASSERT(generic); + aliasScope->privateTypePackBindings[generic->name] = param.tp; + } + + TypeId ty = resolveType(aliasScope, *typealias.type); + if (auto ttv = getMutable(follow(ty))) + { + // If the table is already named and we want to rename the type function, we have to bind new alias to a copy + // Additionally, we can't modify types that come from other modules + if (ttv->name || follow(ty)->owningArena != ¤tModule->internalTypes) { - Location location = scope->typeAliasLocations[name]; - reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); - bindingsMap[name] = TypeFun{binding->typeParams, errorType}; + bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), + binding->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(), binding->typePackParams.begin(), + binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; + }); + + // Copy can be skipped if this is an identical alias + if (!ttv->name || ttv->name != name || !sameTys || !sameTps) + { + // This is a shallow clone, original recursive links to self are not updated + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; + clone.definitionModuleName = ttv->definitionModuleName; + clone.name = name; + + for (auto param : binding->typeParams) + clone.instantiatedTypeParams.push_back(param.ty); + + for (auto param : binding->typePackParams) + clone.instantiatedTypePackParams.push_back(param.tp); + + ty = addType(std::move(clone)); + } + } + else + { + ttv->name = name; + + ttv->instantiatedTypeParams.clear(); + for (auto param : binding->typeParams) + ttv->instantiatedTypeParams.push_back(param.ty); + + ttv->instantiatedTypePackParams.clear(); + for (auto param : binding->typePackParams) + ttv->instantiatedTypePackParams.push_back(param.tp); + } + } + else if (auto mtv = getMutable(follow(ty))) + { + // We can't modify types that come from other modules + if (follow(ty)->owningArena == ¤tModule->internalTypes) + mtv->syntheticName = name; + } + + TypeId& bindingType = bindingsMap[name].type; + + if (unify(ty, bindingType, aliasScope, typealias.location)) + bindingType = ty; +} + +void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel) +{ + Name name = typealias.name.value; + + // If the alias is missing a name, we can't do anything with it. Ignore it. + if (name == kParseNameError) + return; + + std::optional binding; + if (auto it = scope->exportedTypeBindings.find(name); it != scope->exportedTypeBindings.end()) + binding = it->second; + else if (auto it = scope->privateTypeBindings.find(name); it != scope->privateTypeBindings.end()) + binding = it->second; + + auto& bindingsMap = typealias.exported ? scope->exportedTypeBindings : scope->privateTypeBindings; + + if (binding) + { + Location location = scope->typeAliasLocations[name]; + reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); + + if (!FFlag::LuauReportShadowedTypeAlias) + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; + + duplicateTypeAliases.insert({typealias.exported, name}); + } + else if (FFlag::LuauReportShadowedTypeAlias) + { + if (globalScope->builtinTypeNames.contains(name)) + { + reportError(typealias.location, DuplicateTypeDefinition{name}); + duplicateTypeAliases.insert({typealias.exported, name}); } else { ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); + aliasScope->level.subLevel = subLevel; - std::vector generics; - for (AstName generic : typealias.generics) - { - Name n = generic.value; + auto [generics, genericPacks] = + createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); - // These generics are the only thing that will ever be added to aliasScope, so we can be certain that - // a collision can only occur when two generic typevars have the same name. - if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) - { - // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. - reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); - } - - TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - TypeId& cached = scope->typeAliasParameters[n]; - if (!cached) - cached = addType(GenericTypeVar{aliasScope->level, n}); - g = cached; - } - else - g = addType(GenericTypeVar{aliasScope->level, n}); - generics.push_back(g); - aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; - } - - TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + TypeId ty = freshType(aliasScope); FreeTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), ty}; + bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; + + scope->typeAliasLocations[name] = typealias.location; } } else { - if (!binding) - ice("Not predeclared"); - ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); + aliasScope->level.subLevel = subLevel; - for (TypeId ty : binding->typeParams) - { - auto generic = get(ty); - LUAU_ASSERT(generic); - aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; - } + auto [generics, genericPacks] = + createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); - TypeId ty = (FFlag::LuauRankNTypes ? resolveType(aliasScope, *typealias.type) : resolveType(aliasScope, *typealias.type, true)); - if (auto ttv = getMutable(follow(ty))) - { - // If the table is already named and we want to rename the type function, we have to bind new alias to a copy - if (ttv->name) - { - // Copy can be skipped if this is an identical alias - if (!FFlag::LuauFixTableTypeAliasClone || ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams) - { - // This is a shallow clone, original recursive links to self are not updated - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; + TypeId ty = freshType(aliasScope); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - - clone.name = name; - clone.instantiatedTypeParams = binding->typeParams; - - ty = addType(std::move(clone)); - } - } - else - { - ttv->name = name; - ttv->instantiatedTypeParams = binding->typeParams; - } - } - else if (auto mtv = getMutable(follow(ty))) - mtv->syntheticName = name; - - unify(bindingsMap[name].type, ty, typealias.location); + scope->typeAliasLocations[name] = typealias.location; } } -void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { + LUAU_ASSERT(FFlag::LuauDeclareClassPrototype); + std::optional superTy = std::nullopt; if (declaredClass.superName) { @@ -1276,101 +1674,195 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar if (!lookupType) { reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type}); + incorrectClassDefinitions.insert(&declaredClass); return; } // We don't have generic classes, so this assertion _should_ never be hit. - LUAU_ASSERT(lookupType->typeParams.size() == 0); + LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); superTy = lookupType->type; - if (FFlag::LuauAddMissingFollow) + if (!get(follow(*superTy))) { - if (!get(follow(*superTy))) - { - reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", - superName.c_str(), declaredClass.name.value)}); - - return; - } - } - else - { - if (const ClassTypeVar* superCtv = get(*superTy); !superCtv) - { - reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", - superName.c_str(), declaredClass.name.value)}); - - return; - } + reportError(declaredClass.location, + GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); + incorrectClassDefinitions.insert(&declaredClass); + return; } } Name className(declaredClass.name.value); - TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {})); + TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); ClassTypeVar* ctv = getMutable(classTy); - TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); - TableTypeVar* metatable = getMutable(metaTy); ctv->metatable = metaTy; - scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; +} - for (const AstDeclaredClassProp& prop : declaredClass.props) +void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +{ + if (FFlag::LuauDeclareClassPrototype) { - Name propName(prop.name.value); - TypeId propTy = resolveType(scope, *prop.ty); + Name className(declaredClass.name.value); - bool assignToMetatable = isMetamethod(propName); + // Don't bother checking if the class definition was incorrect + if (incorrectClassDefinitions.find(&declaredClass)) + return; - // Function types always take 'self', but this isn't reflected in the - // parsed annotation. Add it here. - if (prop.isMethod) + std::optional binding; + if (auto it = scope->exportedTypeBindings.find(className); it != scope->exportedTypeBindings.end()) + binding = it->second; + + // This class definition must have been `prototype()`d first. + if (!binding) + ice("Class not predeclared"); + + TypeId classTy = binding->type; + ClassTypeVar* ctv = getMutable(classTy); + + if (!ctv->metatable) + ice("No metatable for declared class"); + + TableTypeVar* metatable = getMutable(*ctv->metatable); + for (const AstDeclaredClassProp& prop : declaredClass.props) { - if (FunctionTypeVar* ftv = getMutable(propTy)) + Name propName(prop.name.value); + TypeId propTy = resolveType(scope, *prop.ty); + + bool assignToMetatable = isMetamethod(propName); + Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; + + // Function types always take 'self', but this isn't reflected in the + // parsed annotation. Add it here. + if (prop.isMethod) { - ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); - ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); + if (FunctionTypeVar* ftv = getMutable(propTy)) + { + ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); + ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); + ftv->hasSelf = true; + } } - } - if (ctv->props.count(propName) == 0) - { - if (assignToMetatable) - metatable->props[propName] = {propTy}; - else - ctv->props[propName] = {propTy}; - } - else - { - TypeId currentTy = assignToMetatable ? metatable->props[propName].type : ctv->props[propName].type; - - // We special-case this logic to keep the intersection flat; otherwise we - // would create a ton of nested intersection types. - if (const IntersectionTypeVar* itv = get(currentTy)) + if (assignTo.count(propName) == 0) { - std::vector options = itv->parts; - options.push_back(propTy); - TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); - - if (assignToMetatable) - metatable->props[propName] = {newItv}; - else - ctv->props[propName] = {newItv}; - } - else if (get(currentTy)) - { - TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); - - if (assignToMetatable) - metatable->props[propName] = {intersection}; - else - ctv->props[propName] = {intersection}; + assignTo[propName] = {propTy}; } else { - reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + TypeId currentTy = assignTo[propName].type; + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionTypeVar* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); + + assignTo[propName] = {newItv}; + } + else if (get(currentTy)) + { + TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); + + assignTo[propName] = {intersection}; + } + else + { + reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } + } + } + } + else + { + std::optional superTy = std::nullopt; + if (declaredClass.superName) + { + Name superName = Name(declaredClass.superName->value); + std::optional lookupType = scope->lookupType(superName); + + if (!lookupType) + { + reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type}); + return; + } + + // We don't have generic classes, so this assertion _should_ never be hit. + LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); + superTy = lookupType->type; + + if (!get(follow(*superTy))) + { + reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", + superName.c_str(), declaredClass.name.value)}); + return; + } + } + + Name className(declaredClass.name.value); + + TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); + + ClassTypeVar* ctv = getMutable(classTy); + TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); + TableTypeVar* metatable = getMutable(metaTy); + + ctv->metatable = metaTy; + + scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + + for (const AstDeclaredClassProp& prop : declaredClass.props) + { + Name propName(prop.name.value); + TypeId propTy = resolveType(scope, *prop.ty); + + bool assignToMetatable = isMetamethod(propName); + Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; + + // Function types always take 'self', but this isn't reflected in the + // parsed annotation. Add it here. + if (prop.isMethod) + { + if (FunctionTypeVar* ftv = getMutable(propTy)) + { + ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); + ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); + ftv->hasSelf = true; + } + } + + if (assignTo.count(propName) == 0) + { + assignTo[propName] = {propTy}; + } + else + { + TypeId currentTy = assignTo[propName].type; + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionTypeVar* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); + + assignTo[propName] = {newItv}; + } + else if (get(currentTy)) + { + TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); + + assignTo[propName] = {intersection}; + } + else + { + reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } } } } @@ -1380,11 +1872,23 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo { ScopePtr funScope = childFunctionScope(scope, global.location); - auto [generics, genericPacks] = createGenericTypes(funScope, global, global.generics, global.genericPacks); + auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks); + + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); - TypeId fnType = addType(FunctionTypeVar{funScope->level, generics, genericPacks, argPack, retPack}); + TypeId fnType = addType(FunctionTypeVar{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack}); FunctionTypeVar* ftv = getMutable(fnType); ftv->argNames.reserve(global.paramNames.size); @@ -1397,27 +1901,37 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorType}; + return {errorRecoveryType(scope)}; } - ExprResult result; + WithPredicate result; if (auto a = expr.as()) - result = checkExpr(scope, *a->expr); + result = checkExpr(scope, *a->expr, expectedType); else if (expr.is()) result = {nilType}; - else if (expr.is()) - result = {booleanType}; + else if (const AstExprConstantBool* bexpr = expr.as()) + { + if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) + result = {singletonType(bexpr->value)}; + else + result = {booleanType}; + } + else if (const AstExprConstantString* sexpr = expr.as()) + { + if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) + result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; + else + result = {stringType}; + } else if (expr.is()) result = {numberType}; - else if (expr.is()) - result = {stringType}; else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1437,40 +1951,22 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - result = checkExpr(scope, *a); + result = checkExpr(scope, *a, FFlag::LuauBinaryNeedsExpectedTypesToo ? expectedType : std::nullopt); else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - { - if (FFlag::LuauIfElseExpressionAnalysisSupport) - { - result = checkExpr(scope, *a); - } - else - { - // Note: When the fast flag is disabled we can't skip the handling of AstExprIfElse - // because we would generate an ICE. We also can't use the default value - // of result, because it will lead to a compiler crash. - // Note: LuauIfElseExpressionBaseSupport can be used to disable parser support - // for if-else expressions which will mean this node type is never created. - result = {anyType}; - } - } + result = checkExpr(scope, *a, expectedType); + else if (auto a = expr.as()) + result = checkExpr(scope, *a); else ice("Unhandled AstExpr?"); result.type = follow(result.type); - if (FFlag::LuauStoreMatchingOverloadFnType) - { - currentModule->astTypes.try_emplace(&expr, result.type); - } - else - { + if (!currentModule->astTypes.find(&expr)) currentModule->astTypes[&expr] = result.type; - } if (expectedType) currentModule->astExpectedTypes[&expr] = *expectedType; @@ -1478,7 +1974,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& return result; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLocal& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLocal& expr) { std::optional lvalue = tryGetLValue(expr); LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprLocal is an LValue. @@ -1489,10 +1985,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo // TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint // ice("AstExprLocal exists but no binding definition for it?", expr.location); reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) { std::optional lvalue = tryGetLValue(expr); LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprGlobal is an LValue. @@ -1501,66 +1997,73 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) { TypePackId varargPack = checkExprPack(scope, expr).type; if (get(varargPack)) { - std::vector types = flatten(varargPack).first; - return {!types.empty() ? types[0] : nilType}; + if (std::optional ty = first(varargPack)) + return {*ty}; + + return {nilType}; } - else if (auto ftp = get(varargPack)) + else if (get(varargPack)) { - TypeId head = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, ftp->DEPRECATED_canBeGeneric)); - TypePackId tail = (FFlag::LuauRankNTypes ? freshTypePack(scope) : DEPRECATED_freshTypePack(scope, ftp->DEPRECATED_canBeGeneric)); + TypeId head = freshType(scope); + TypePackId tail = freshTypePack(scope); *asMutable(varargPack) = TypePack{{head}, tail}; return {head}; } if (get(varargPack)) - return {errorType}; + return {errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return {vtp->ty}; - else if (FFlag::LuauGenericVariadicsUnification && get(varargPack)) + else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); - return {errorType}; + return {errorRecoveryType(scope)}; } else ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCall& expr) { - ExprResult result = checkExprPack(scope, expr); + WithPredicate result = checkExprPack(scope, expr); TypePackId retPack = follow(result.type); if (auto pack = get(retPack)) { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (auto ftp = get(retPack)) + else if (const FreeTypePack* ftp = get(retPack)) { - TypeId head = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, ftp->DEPRECATED_canBeGeneric)); - TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); - unify(retPack, pack, expr.location); + TypeId head = freshType(scope->level); + TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope->level)}}); + unify(pack, retPack, scope, expr.location); return {head, std::move(result.predicates)}; } if (get(retPack)) - return {errorType, std::move(result.predicates)}; + return {errorRecoveryType(scope), std::move(result.predicates)}; else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) - ice("Unexpected abstract type pack!"); + { + if (FFlag::LuauReturnAnyInsteadOfICE) + return {anyType, std::move(result.predicates)}; + else + ice("Unexpected abstract type pack!", expr.location); + } else - ice("Unknown TypePack type!"); + ice("Unknown TypePack type!", expr.location); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) { Name name = expr.index.value; @@ -1571,80 +2074,87 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn if (std::optional ty = resolveLValue(scope, *lvalue)) return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; - if (FFlag::LuauExtraNilRecovery) - lhsType = stripFromNilAndReport(lhsType, expr.expr->location); + lhsType = stripFromNilAndReport(lhsType, expr.expr->location); - if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) + if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, /* addErrors= */ true)) return {*ty}; - if (!FFlag::LuauMissingUnionPropertyError) - reportError(expr.indexLocation, UnknownProperty{lhsType, expr.index.value}); - - if (!FFlag::LuauExtraNilRecovery) - { - // Try to recover using a union without 'nil' options - if (std::optional strippedUnion = tryStripUnionFromNil(lhsType)) - { - if (std::optional ty = getIndexTypeFromType(scope, *strippedUnion, name, expr.location, false)) - return {*ty}; - } - } - - return {errorType}; + return {errorRecoveryType(scope)}; } -std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) +std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors) { ErrorVec errors; - auto result = Luau::findTablePropertyRespectingMeta(errors, globalScope, lhsType, name, location); - reportErrors(errors); + auto result = Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); + if (addErrors) + reportErrors(errors); return result; } -std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location) +std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors) { ErrorVec errors; - auto result = Luau::findMetatableEntry(errors, globalScope, type, entry, location); - reportErrors(errors); + auto result = Luau::findMetatableEntry(singletonTypes, errors, type, entry, location); + if (addErrors) + reportErrors(errors); return result; } std::optional TypeChecker::getIndexTypeFromType( const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) +{ + size_t errorCount = currentModule->errors.size(); + + std::optional result = getIndexTypeFromTypeImpl(scope, type, name, location, addErrors); + + if (!addErrors) + LUAU_ASSERT(errorCount == currentModule->errors.size()); + + return result; +} + +std::optional TypeChecker::getIndexTypeFromTypeImpl( + const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) { type = follow(type); - if (get(type) || get(type)) + if (get(type) || get(type) || get(type)) return type; tablify(type); - const PrimitiveTypeVar* primitiveType = get(type); - if (primitiveType && primitiveType->type == PrimitiveTypeVar::String) + if (isString(type)) { - if (std::optional mtIndex = findMetatableEntry(type, "__index", location)) - type = *mtIndex; + std::optional mtIndex = findMetatableEntry(stringType, "__index", location, addErrors); + LUAU_ASSERT(mtIndex); + type = *mtIndex; } if (TableTypeVar* tableType = getMutableTableType(type)) { - const auto& it = tableType->props.find(name); - if (it != tableType->props.end()) + if (auto it = tableType->props.find(name); it != tableType->props.end()) return it->second.type; else if (auto indexer = tableType->indexer) { - tryUnify(indexer->indexType, stringType, location); - return indexer->indexResultType; + // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. + ErrorVec errors = tryUnify(stringType, indexer->indexType, scope, location); + + if (errors.empty()) + return indexer->indexResultType; + + if (addErrors) + reportError(location, UnknownProperty{type, name}); + + return std::nullopt; } else if (tableType->state == TableState::Free) { - TypeId result = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId result = freshType(tableType->level); tableType->props[name] = {result}; return result; } - auto found = findTablePropertyRespectingMeta(type, name, location); - if (found) + if (auto found = findTablePropertyRespectingMeta(type, name, location, addErrors)) return *found; } else if (const ClassTypeVar* cls = get(type)) @@ -1655,61 +2165,43 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (const UnionTypeVar* utv = get(type)) { - if (FFlag::LuauMissingUnionPropertyError) + std::vector goodOptions; + std::vector badOptions; + + for (TypeId t : utv) { - std::vector goodOptions; - std::vector badOptions; + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - for (TypeId t : utv) - { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + // Not needed when we normalize types. + if (get(follow(t))) + return t; - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) - goodOptions.push_back(*ty); - else - badOptions.push_back(t); - } - - if (!badOptions.empty()) - { - if (addErrors) - { - if (goodOptions.empty()) - reportError(location, UnknownProperty{type, name}); - else - reportError(location, MissingUnionProperty{type, badOptions, name}); - } - return std::nullopt; - } - - std::vector result = reduceUnion(goodOptions); - - if (result.size() == 1) - return result[0]; - - return addType(UnionTypeVar{std::move(result)}); + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false)) + goodOptions.push_back(*ty); + else + badOptions.push_back(t); } - else + + if (!badOptions.empty()) { - std::vector options; - - for (TypeId t : utv->options) + if (addErrors) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) - options.push_back(*ty); + if (goodOptions.empty()) + reportError(location, UnknownProperty{type, name}); else - return std::nullopt; + reportError(location, MissingUnionProperty{type, badOptions, name}); } - - std::vector result = reduceUnion(options); - - if (result.size() == 1) - return result[0]; - - return addType(UnionTypeVar{std::move(result)}); + return std::nullopt; } + + std::vector result = reduceUnion(goodOptions); + if (FFlag::LuauUnknownAndNeverType && result.empty()) + return neverType; + + if (result.size() == 1) + return result[0]; + + return addType(UnionTypeVar{std::move(result)}); } else if (const IntersectionTypeVar* itv = get(type)) { @@ -1719,78 +2211,35 @@ std::optional TypeChecker::getIndexTypeFromType( { RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false)) parts.push_back(*ty); } // If no parts of the intersection had the property we looked up for, it never existed at all. if (parts.empty()) { - if (FFlag::LuauMissingUnionPropertyError && addErrors) + if (addErrors) reportError(location, UnknownProperty{type, name}); return std::nullopt; } - // TODO(amccord): Write some logic to correctly handle intersections. CLI-34659 - std::vector result = reduceUnion(parts); + if (parts.size() == 1) + return parts[0]; - if (result.size() == 1) - return result[0]; - - return addType(IntersectionTypeVar{result}); + return addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. } - if (FFlag::LuauMissingUnionPropertyError && addErrors) + if (addErrors) reportError(location, UnknownProperty{type, name}); return std::nullopt; } -std::vector TypeChecker::reduceUnion(const std::vector& types) -{ - std::set s; - - for (TypeId t : types) - { - if (const UnionTypeVar* utv = get(follow(t))) - { - std::vector r = reduceUnion(utv->options); - for (TypeId ty : r) - s.insert(ty); - } - else - s.insert(t); - } - - // If any of them are ErrorTypeVars/AnyTypeVars, decay into them. - for (TypeId t : s) - { - t = follow(t); - if (get(t) || get(t)) - return {t}; - } - - std::vector r(s.begin(), s.end()); - std::sort(r.begin(), r.end()); - return r; -} - std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) { if (const UnionTypeVar* utv = get(ty)) { - bool hasNil = false; - - for (TypeId option : utv) - { - if (isNil(option)) - { - hasNil = true; - break; - } - } - - if (!hasNil) + if (!std::any_of(begin(utv), end(utv), isNil)) return ty; std::vector result; @@ -1812,26 +2261,37 @@ std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) { - if (isOptional(ty)) + ty = follow(ty); + + if (auto utv = get(ty)) { - if (std::optional strippedUnion = tryStripUnionFromNil(follow(ty))) - { - reportError(location, OptionalValueAccess{ty}); - return follow(*strippedUnion); - } + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + } + + if (std::optional strippedUnion = tryStripUnionFromNil(ty)) + { + reportError(location, OptionalValueAccess{ty}); + return follow(*strippedUnion); } return ty; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) { - return {checkLValue(scope, expr)}; + TypeId ty = checkLValue(scope, expr); + + if (std::optional lvalue = tryGetLValue(expr)) + if (std::optional refiTy = resolveLValue(scope, *lvalue)) + return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; + + return {ty}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) { - auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, expectedType); + auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, std::nullopt, expectedType); checkFunctionBody(funScope, funTy, expr); @@ -1870,7 +2330,10 @@ TypeId TypeChecker::checkExprTable( indexer = expectedTable->indexer; if (indexer) - unify(indexer->indexResultType, valueType, value->location); + { + unify(numberType, indexer->indexType, scope, value->location); + unify(valueType, indexer->indexResultType, scope, value->location); + } else indexer = TableIndexer{numberType, anyIfNonstrict(valueType)}; } @@ -1882,6 +2345,24 @@ TypeId TypeChecker::checkExprTable( if (isNonstrictMode() && !getTableType(exprType) && !get(exprType)) exprType = anyType; + if (expectedTable) + { + auto it = expectedTable->props.find(key->value.data); + if (it != expectedTable->props.end()) + { + Property expectedProp = it->second; + ErrorVec errors = tryUnify(exprType, expectedProp.type, scope, k->location); + if (errors.empty()) + exprType = expectedProp.type; + } + else if (expectedTable->indexer && maybeString(expectedTable->indexer->indexType)) + { + ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, scope, k->location); + if (errors.empty()) + exprType = expectedTable->indexer->indexResultType; + } + } + props[key->value.data] = {exprType, /* deprecated */ false, {}, k->location}; } else @@ -1891,8 +2372,8 @@ TypeId TypeChecker::checkExprTable( if (indexer) { - unify(indexer->indexType, keyType, k->location); - unify(indexer->indexResultType, valueType, value->location); + unify(keyType, indexer->indexType, scope, k->location); + unify(valueType, indexer->indexResultType, scope, value->location); } else if (isNonstrictMode()) { @@ -1906,18 +2387,25 @@ TypeId TypeChecker::checkExprTable( } } - TableState state = (expr.items.size == 0 || isNonstrictMode()) ? TableState::Unsealed : TableState::Sealed; + TableState state = TableState::Unsealed; TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state}; table.definitionModuleName = currentModuleName; return addType(table); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionCounter _rc(&checkRecursionCount); + if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) + { + reportErrorCodeTooComplex(expr.location); + return {errorRecoveryType(scope)}; + } + std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; + const UnionTypeVar* expectedUnion = nullptr; std::optional expectedIndexType; std::optional expectedIndexResultType; @@ -1936,6 +2424,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa } } } + else if (const UnionTypeVar* utv = get(follow(*expectedType))) + expectedUnion = utv; } for (size_t i = 0; i < expr.items.size; ++i) @@ -1957,6 +2447,27 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa { if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; + else if (expectedIndexType && maybeString(*expectedIndexType)) + expectedResultType = expectedIndexResultType; + } + else if (expectedUnion) + { + std::vector expectedResultTypes; + for (TypeId expectedOption : expectedUnion) + { + if (const TableTypeVar* ttv = get(follow(expectedOption))) + { + if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) + expectedResultTypes.push_back(prop->second.type); + else if (ttv->indexer && maybeString(ttv->indexer->indexType)) + expectedResultTypes.push_back(ttv->indexer->indexResultType); + } + } + + if (expectedResultTypes.size() == 1) + expectedResultType = expectedResultTypes[0]; + else if (expectedResultTypes.size() > 1) + expectedResultType = addType(UnionTypeVar{expectedResultTypes}); } } else @@ -1978,9 +2489,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa return {checkExprTable(scope, expr, fieldTypes, expectedType)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) { - ExprResult result = checkExpr(scope, *expr.expr); + WithPredicate result = checkExpr(scope, *expr.expr); TypeId operandType = follow(result.type); switch (expr.op) @@ -1989,60 +2500,81 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn return {booleanType, {NotPredicate{std::move(result.predicates)}}}; case AstExprUnary::Minus: { - const bool operandIsAny = get(operandType) || get(operandType); + const bool operandIsAny = get(operandType) || get(operandType) || get(operandType); if (operandIsAny) return {operandType}; if (typeCouldHaveMetatable(operandType)) { - if (auto fnt = findMetatableEntry(operandType, "__unm", expr.location)) + if (auto fnt = findMetatableEntry(operandType, "__unm", expr.location, /* addErrors= */ true)) { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); - Unifier state = mkUnifier(expr.location); - state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + Unifier state = mkUnifier(scope, expr.location); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); + state.log.commit(); + reportErrors(state.errors); + + TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) - return {errorType}; + retType = errorRecoveryType(retType); - return {first(retType).value_or(nilType)}; + return {retType}; } reportError(expr.location, GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); - return {errorType}; + return {errorRecoveryType(scope)}; } - reportErrors(tryUnify(numberType, operandType, expr.location)); + reportErrors(tryUnify(operandType, numberType, scope, expr.location)); return {numberType}; } case AstExprUnary::Len: + { tablify(operandType); - if (FFlag::LuauExtraNilRecovery) - operandType = stripFromNilAndReport(operandType, expr.location); + operandType = stripFromNilAndReport(operandType, expr.location); - if (get(operandType)) - return {errorType}; - - if (get(operandType)) - return {numberType}; // Not strictly correct: metatables permit overriding this - - if (auto p = get(operandType)) + // # operator is guaranteed to return number + if ((FFlag::LuauNeverTypesAndOperatorsInference && get(operandType)) || get(operandType) || + get(operandType)) { - if (p->type == PrimitiveTypeVar::String) + if (FFlag::LuauNeverTypesAndOperatorsInference) return {numberType}; + else + return {!FFlag::LuauUnknownAndNeverType ? errorRecoveryType(scope) : operandType}; } - if (!getTableType(operandType)) + DenseHashSet seen{nullptr}; + + if (typeCouldHaveMetatable(operandType)) + { + if (auto fnt = findMetatableEntry(operandType, "__len", expr.location, /* addErrors= */ true)) + { + TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); + TypePackId arguments = addTypePack({operandType}); + TypePackId retTypePack = addTypePack({numberType}); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); + + Unifier state = mkUnifier(scope, expr.location); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); + state.log.commit(); + + reportErrors(state.errors); + } + } + + if (!hasLength(operandType, seen, &recursionCount)) reportError(TypeError{expr.location, NotATable{operandType}}); return {numberType}; - + } default: ice("Unknown AstExprUnary " + std::to_string(int(expr.op))); } @@ -2080,20 +2612,26 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op) } } -TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes) +TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes) { + a = follow(a); + b = follow(b); + if (unifyFreeTypes && (get(a) || get(b))) { - if (unify(a, b, location)) + if (unify(b, a, scope, location)) return a; - return errorType; + return errorRecoveryType(anyType); } if (*a == *b) return a; std::vector types = reduceUnion({a, b}); + if (FFlag::LuauUnknownAndNeverType && types.empty()) + return neverType; + if (types.size() == 1) return types[0]; @@ -2117,6 +2655,48 @@ static std::optional getIdentifierOfBaseVar(AstExpr* node) return std::nullopt; } +/** Return true if comparison between the types a and b should be permitted with + * the == or ~= operators. + * + * Two types are considered eligible for equality testing if it is possible for + * the test to ever succeed. In other words, we test to see whether the two + * types have any overlap at all. + * + * In order to make things work smoothly with the greedy solver, this function + * exempts any and FreeTypeVars from this requirement. + * + * This function does not (yet?) take into account extra Lua restrictions like + * that two tables can only be compared if they have the same metatable. That + * is presently handled by the caller. + * + * @return True if the types are comparable. False if they are not. + * + * If an internal recursion limit is reached while performing this test, the + * function returns std::nullopt. + */ +static std::optional areEqComparable(NotNull arena, NotNull normalizer, TypeId a, TypeId b) +{ + a = follow(a); + b = follow(b); + + auto isExempt = [](TypeId t) { + return isNil(t) || get(t); + }; + + if (isExempt(a) || isExempt(b)) + return true; + + TypeId c = arena->addType(IntersectionTypeVar{{a, b}}); + const NormalizedType* n = normalizer->normalize(c); + if (!n) + return std::nullopt; + + if (FFlag::LuauUninhabitedSubAnything2) + return normalizer->isInhabited(n); + else + return isInhabited_DEPRECATED(*n); +} + TypeId TypeChecker::checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates) { @@ -2125,7 +2705,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!isNonstrictMode() && !isOrOp) return ty; - if (auto i = get(ty)) + if (get(ty)) { std::optional cleaned = tryStripUnionFromNil(ty); @@ -2146,7 +2726,7 @@ TypeId TypeChecker::checkRelationalOperation( // If we know nothing at all about the lhs type, we can usually say nothing about the result. // The notable exception to this is the equality and inequality operators, which always produce a boolean. - const bool lhsIsAny = get(lhsType) || get(lhsType); + const bool lhsIsAny = get(lhsType) || get(lhsType) || get(lhsType); // Peephole check for `cond and a or b -> type(a)|type(b)` // TODO: Kill this when singleton types arrive. :( @@ -2154,16 +2734,9 @@ TypeId TypeChecker::checkRelationalOperation( { if (expr.op == AstExprBinary::Or && subexp->op == AstExprBinary::And) { - if (FFlag::LuauSlightlyMoreFlexibleBinaryPredicates) - { - ScopePtr subScope = childScope(scope, subexp->location); - reportErrors(resolve(predicates, subScope, true)); - return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); - } - else - { - return unionOfTypes(rhsType, checkExpr(scope, *subexp->right).type, expr.location); - } + ScopePtr subScope = childScope(scope, subexp->location); + resolve(predicates, subScope, true); + return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), subScope, expr.location); } } @@ -2176,7 +2749,7 @@ TypeId TypeChecker::checkRelationalOperation( if (isNonstrictMode() && (isNil(lhsType) || isNil(rhsType))) return booleanType; - const bool rhsIsAny = get(rhsType) || get(rhsType); + const bool rhsIsAny = get(rhsType) || get(rhsType) || get(rhsType); if (lhsIsAny || rhsIsAny) return booleanType; @@ -2187,61 +2760,129 @@ TypeId TypeChecker::checkRelationalOperation( case AstExprBinary::CompareGe: case AstExprBinary::CompareLe: { + if (FFlag::LuauNeverTypesAndOperatorsInference) + { + // If one of the operand is never, it doesn't make sense to unify these. + if (get(lhsType) || get(rhsType)) + return booleanType; + } + + if (FFlag::LuauIntersectionTestForEquality && isEquality) + { + // Unless either type is free or any, an equality comparison is only + // valid when the intersection of the two operands is non-empty. + // + // eg it is okay to compare string? == number? because the two types + // have nil in common, but string == number is not allowed. + std::optional eqTestResult = areEqComparable(NotNull{¤tModule->internalTypes}, NotNull{&normalizer}, lhsType, rhsType); + if (!eqTestResult) + { + reportErrorCodeTooComplex(expr.location); + return errorRecoveryType(booleanType); + } + + if (!*eqTestResult) + { + reportError( + expr.location, GenericError{format("Type %s cannot be compared with %s", toString(lhsType).c_str(), toString(rhsType).c_str())}); + return errorRecoveryType(booleanType); + } + } + /* Subtlety here: * We need to do this unification first, but there are situations where we don't actually want to * report any problems that might have been surfaced as a result of this step because we might already * have a better, more descriptive error teed up. */ - Unifier state = mkUnifier(expr.location); - if (!FFlag::LuauEqConstraint || !isEquality) - state.tryUnify(lhsType, rhsType); + Unifier state = mkUnifier(scope, expr.location); + if (!isEquality) + { + state.tryUnify(rhsType, lhsType); + state.log.commit(); + } - bool needsMetamethod = !isEquality; + const bool needsMetamethod = !isEquality; TypeId leftType = follow(lhsType); if (get(leftType) || get(leftType) || get(leftType) || get(leftType)) { reportErrors(state.errors); - const PrimitiveTypeVar* ptv = get(leftType); - if (!isEquality && state.errors.empty() && (get(leftType) || (ptv && ptv->type == PrimitiveTypeVar::Boolean))) + if (!isEquality && state.errors.empty() && (get(leftType) || isBoolean(leftType))) + { reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())}); + } return booleanType; } std::string metamethodName = opToMetaTableEntry(expr.op); - std::optional leftMetatable = - isString(lhsType) ? std::nullopt : getMetatable(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType); - std::optional rightMetatable = - isString(rhsType) ? std::nullopt : getMetatable(FFlag::LuauAddMissingFollow ? follow(rhsType) : rhsType); + std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType), singletonTypes); + std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType), singletonTypes); - if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) + if (leftMetatable != rightMetatable) { - reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorType; + bool matches = false; + if (isEquality) + { + if (const UnionTypeVar* utv = get(leftType); utv && rightMetatable) + { + for (TypeId leftOption : utv) + { + if (getMetatable(follow(leftOption), singletonTypes) == rightMetatable) + { + matches = true; + break; + } + } + } + + if (!matches) + { + if (const UnionTypeVar* utv = get(rhsType); utv && leftMetatable) + { + for (TypeId rightOption : utv) + { + if (getMetatable(follow(rightOption), singletonTypes) == leftMetatable) + { + matches = true; + break; + } + } + } + } + } + + if (!matches) + { + reportError( + expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + return errorRecoveryType(booleanType); + } } if (leftMetatable) { - std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); + std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true); if (metamethod) { if (const FunctionTypeVar* ftv = get(*metamethod)) { if (isEquality) { - Unifier state = mkUnifier(expr.location); - state.tryUnify(ftv->retType, addTypePack({booleanType})); + Unifier state = mkUnifier(scope, expr.location); + state.tryUnify(addTypePack({booleanType}), ftv->retTypes); if (!state.errors.empty()) { reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } + + state.log.commit(); } } @@ -2249,7 +2890,9 @@ TypeId TypeChecker::checkRelationalOperation( TypeId actualFunctionType = addType(FunctionTypeVar(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType}))); state.tryUnify( - instantiate(scope, *metamethod, expr.location), instantiate(scope, actualFunctionType, expr.location), /*isFunctionCall*/ true); + instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true); + + state.log.commit(); reportErrors(state.errors); return booleanType; @@ -2258,34 +2901,22 @@ TypeId TypeChecker::checkRelationalOperation( { reportError( expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } } - if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && (!FFlag::LuauEqConstraint || !isEquality)) + if (get(follow(lhsType)) && !isEquality) { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); - return errorType; + return errorRecoveryType(booleanType); } if (needsMetamethod) { reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable", toString(lhsType).c_str(), toString(expr.op).c_str())}); - return errorType; - } - - if (!FFlag::LuauEqConstraint) - { - if (isEquality) - { - ErrorVec errVec = tryUnify(rhsType, lhsType, expr.location); - if (!state.errors.empty() && !errVec.empty()) - reportError(expr.location, TypeMismatch{lhsType, rhsType}); - } - else - reportErrors(state.errors); + return errorRecoveryType(booleanType); } return booleanType; @@ -2293,12 +2924,54 @@ TypeId TypeChecker::checkRelationalOperation( case AstExprBinary::And: if (lhsIsAny) + { return lhsType; - return unionOfTypes(rhsType, booleanType, expr.location, false); + } + else if (FFlag::LuauTryhardAnd) + { + // If lhs is free, we can't tell which 'falsy' components it has, if any + if (get(lhsType)) + return unionOfTypes(addType(UnionTypeVar{{nilType, singletonType(false)}}), rhsType, scope, expr.location, false); + + auto [oty, notNever] = pickTypesFromSense(lhsType, false, neverType); // Filter out falsy types + + if (notNever) + { + LUAU_ASSERT(oty); + return unionOfTypes(*oty, rhsType, scope, expr.location, false); + } + else + { + return rhsType; + } + } + else + { + return unionOfTypes(rhsType, booleanType, scope, expr.location, false); + } case AstExprBinary::Or: if (lhsIsAny) + { return lhsType; - return unionOfTypes(lhsType, rhsType, expr.location); + } + else if (FFlag::LuauTryhardAnd) + { + auto [oty, notNever] = pickTypesFromSense(lhsType, true, neverType); // Filter out truthy types + + if (notNever) + { + LUAU_ASSERT(oty); + return unionOfTypes(*oty, rhsType, scope, expr.location); + } + else + { + return rhsType; + } + } + else + { + return unionOfTypes(lhsType, rhsType, scope, expr.location); + } default: LUAU_ASSERT(0); ice(format("checkRelationalOperation called with incorrect binary expression '%s'", toString(expr.op).c_str()), expr.location); @@ -2330,13 +3003,15 @@ TypeId TypeChecker::checkBinaryOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - return errorType; + // We will fall-through to the `return anyType` check below. } // If we know nothing at all about the lhs type, we can usually say nothing about the result. // The notable exception to this is the equality and inequality operators, which always produce a boolean. - const bool lhsIsAny = get(lhsType) || get(lhsType); - const bool rhsIsAny = get(rhsType) || get(rhsType); + const bool lhsIsAny = get(lhsType) || get(lhsType) || + (FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get(lhsType)); + const bool rhsIsAny = get(rhsType) || get(rhsType) || + (FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get(rhsType)); if (lhsIsAny) return lhsType; @@ -2353,31 +3028,52 @@ TypeId TypeChecker::checkBinaryOperation( } if (get(rhsType)) - unify(lhsType, rhsType, expr.location); + unify(rhsType, lhsType, scope, expr.location); if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType)) { auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId { TypeId actualFunctionType = instantiate(scope, fnt, expr.location); TypePackId arguments = addTypePack({lhst, rhst}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); - Unifier state = mkUnifier(expr.location); - state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + Unifier state = mkUnifier(scope, expr.location); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); reportErrors(state.errors); + bool hasErrors = !state.errors.empty(); - if (!state.errors.empty()) - return errorType; + if (hasErrors) + { + // If there are unification errors, the return type may still be unknown + // so we loosen the argument types to see if that helps. + TypePackId fallbackArguments = freshTypePack(scope); + TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); + state.errors.clear(); + state.log.clear(); - return first(retType).value_or(nilType); + state.tryUnify(actualFunctionType, fallbackFunctionType, /*isFunctionCall*/ true); + + if (state.errors.empty()) + state.log.commit(); + } + else + { + state.log.commit(); + } + + TypeId retType = first(retTypePack).value_or(nilType); + if (hasErrors) + retType = errorRecoveryType(retType); + + return retType; }; std::string op = opToMetaTableEntry(expr.op); - if (auto fnt = findMetatableEntry(lhsType, op, expr.location)) + if (auto fnt = findMetatableEntry(lhsType, op, expr.location, /* addErrors= */ true)) return checkMetatableCall(*fnt, lhsType, rhsType); - if (auto fnt = findMetatableEntry(rhsType, op, expr.location)) + if (auto fnt = findMetatableEntry(rhsType, op, expr.location, /* addErrors= */ true)) { // Note the intentionally reversed arguments here. return checkMetatableCall(*fnt, rhsType, lhsType); @@ -2385,14 +3081,15 @@ TypeId TypeChecker::checkBinaryOperation( reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(), toString(lhsType).c_str(), toString(rhsType).c_str())}); - return errorType; + + return errorRecoveryType(scope); } switch (expr.op) { case AstExprBinary::Concat: - reportErrors(tryUnify(addType(UnionTypeVar{{stringType, numberType}}), lhsType, expr.left->location)); - reportErrors(tryUnify(addType(UnionTypeVar{{stringType, numberType}}), rhsType, expr.right->location)); + reportErrors(tryUnify(lhsType, addType(UnionTypeVar{{stringType, numberType}}), scope, expr.left->location)); + reportErrors(tryUnify(rhsType, addType(UnionTypeVar{{stringType, numberType}}), scope, expr.right->location)); return stringType; case AstExprBinary::Add: case AstExprBinary::Sub: @@ -2400,8 +3097,8 @@ TypeId TypeChecker::checkBinaryOperation( case AstExprBinary::Div: case AstExprBinary::Mod: case AstExprBinary::Pow: - reportErrors(tryUnify(numberType, lhsType, expr.left->location)); - reportErrors(tryUnify(numberType, rhsType, expr.right->location)); + reportErrors(tryUnify(lhsType, numberType, scope, expr.left->location)); + reportErrors(tryUnify(rhsType, numberType, scope, expr.right->location)); return numberType; default: // These should have been handled with checkRelationalOperation @@ -2410,46 +3107,41 @@ TypeId TypeChecker::checkBinaryOperation( } } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional expectedType) { if (expr.op == AstExprBinary::And) { - ExprResult lhs = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left, expectedType); - // We can't just report errors here. - // This function can be called from AstStatLocal or from AstStatIf, or even from AstExprBinary (and others). - // For now, ignore the errors returned by the predicate resolver. - // We may need an extra property for each predicate set that indicates it has been resolved. - // Requires a slight modification to the data structure. ScopePtr innerScope = childScope(scope, expr.location); - resolve(lhs.predicates, innerScope, true); + resolve(lhsPredicates, innerScope, true); - ExprResult rhs = checkExpr(innerScope, *expr.right); - if (!FFlag::LuauSlightlyMoreFlexibleBinaryPredicates) - resolve(rhs.predicates, innerScope, true); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right, expectedType); - return {checkBinaryOperation(innerScope, expr, lhs.type, rhs.type), {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; + return {checkBinaryOperation(scope, expr, lhsTy, rhsTy), {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } - else if (FFlag::LuauOrPredicate && expr.op == AstExprBinary::Or) + else if (expr.op == AstExprBinary::Or) { - ExprResult lhs = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left, expectedType); ScopePtr innerScope = childScope(scope, expr.location); - resolve(lhs.predicates, innerScope, false); + resolve(lhsPredicates, innerScope, false); - ExprResult rhs = checkExpr(innerScope, *expr.right); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right, expectedType); - // Because of C++, I'm not sure if lhs.predicates was not moved out by the time we call checkBinaryOperation. - TypeId result = checkBinaryOperation(innerScope, expr, lhs.type, rhs.type, lhs.predicates); - return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; + // Because of C++, I'm not sure if lhsPredicates was not moved out by the time we call checkBinaryOperation. + TypeId result = checkBinaryOperation(scope, expr, lhsTy, rhsTy, lhsPredicates); + return {result, {OrPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } - else if (FFlag::LuauEqConstraint && (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe)) + else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left); - ExprResult rhs = checkExpr(scope, *expr.right); + // For these, passing expectedType is worse than simply forcing them, because their implementation + // may inadvertently check if expectedTypes exist first and use it, instead of forceSingleton first. + WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); + WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); PredicateVec predicates; @@ -2466,49 +3158,32 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi } else { - // Once we have EqPredicate, we should break this else branch into its' own branch. - // For now, fall through is intentional. - if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) - { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; - } - - ExprResult lhs = checkExpr(scope, *expr.left); - ExprResult rhs = checkExpr(scope, *expr.right); + // Expected types are not useful for other binary operators. + WithPredicate lhs = checkExpr(scope, *expr.left); + WithPredicate rhs = checkExpr(scope, *expr.right); // Intentionally discarding predicates with other operators. return {checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; } } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) { - ExprResult result; - TypeId annotationType; + TypeId annotationType = resolveType(scope, *expr.annotation); + WithPredicate result = checkExpr(scope, *expr.expr, annotationType); - if (FFlag::LuauInferReturnAssertAssign) - { - annotationType = (FFlag::LuauRankNTypes ? resolveType(scope, *expr.annotation) : resolveType(scope, *expr.annotation, true)); - result = checkExpr(scope, *expr.expr, annotationType); - } - else - { - result = checkExpr(scope, *expr.expr); - annotationType = (FFlag::LuauRankNTypes ? resolveType(scope, *expr.annotation) : resolveType(scope, *expr.annotation, true)); - } + // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. + if (canUnify(annotationType, result.type, scope, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; - ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); - if (!errorVec.empty()) - { - reportErrors(errorVec); - return {errorType, std::move(result.predicates)}; - } + if (canUnify(result.type, annotationType, scope, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; - return {annotationType, std::move(result.predicates)}; + reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); + return {errorRecoveryType(annotationType), std::move(result.predicates)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) { const size_t oldSize = currentModule->errors.size(); @@ -2519,35 +3194,44 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr // any type errors that may arise from it are going to be useless. currentModule->errors.resize(oldSize); - return {errorType}; + return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) { - ExprResult result = checkExpr(scope, *expr.condition); + WithPredicate result = checkExpr(scope, *expr.condition); + ScopePtr trueScope = childScope(scope, expr.trueExpr->location); - reportErrors(resolve(result.predicates, trueScope, true)); - ExprResult trueType = checkExpr(trueScope, *expr.trueExpr); + resolve(result.predicates, trueScope, true); + WithPredicate trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); ScopePtr falseScope = childScope(scope, expr.falseExpr->location); - // Don't report errors for this scope to avoid potentially duplicating errors reported for the first scope. resolve(result.predicates, falseScope, false); - ExprResult falseType = checkExpr(falseScope, *expr.falseExpr); + WithPredicate falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); - unify(trueType.type, falseType.type, expr.location); + if (falseType.type == trueType.type) + return {trueType.type}; - // TODO: normalize(UnionTypeVar{{trueType, falseType}}) - // For now both trueType and falseType must be the same type. - return {trueType.type}; + std::vector types = reduceUnion({trueType.type, falseType.type}); + if (FFlag::LuauUnknownAndNeverType && types.empty()) + return {neverType}; + return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; +} + +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprInterpString& expr) +{ + for (AstExpr* expr : expr.expressions) + checkExpr(scope, *expr); + + return {stringType}; } TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) { - auto [ty, binding] = checkLValueBinding(scope, expr); - return ty; + return checkLValueBinding(scope, expr); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) return checkLValueBinding(scope, *a); @@ -2561,22 +3245,25 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { for (AstExpr* expr : a->expressions) checkExpr(scope, *expr); - return std::pair(errorType, nullptr); + return errorRecoveryType(scope); } else ice("Unexpected AST node in checkLValue", expr.location); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) { if (std::optional ty = scope->lookup(expr.local)) - return {*ty, nullptr}; + { + ty = follow(*ty); + return get(*ty) ? unknownType : *ty; + } reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); - return {errorType, nullptr}; + return errorRecoveryType(scope); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) { Name name = expr.name.value; ScopePtr moduleScope = currentModule->getModuleScope(); @@ -2584,78 +3271,79 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope const auto it = moduleScope->bindings.find(expr.name); if (it != moduleScope->bindings.end()) - return std::pair(it->second.typeId, &it->second.typeId); - - if (isNonstrictMode() || FFlag::LuauSecondTypecheckKnowsTheDataModel) { - TypeId result = (FFlag::LuauGenericFunctions && FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(moduleScope, true)); - - Binding& binding = moduleScope->bindings[expr.name]; - binding = {result, expr.location}; - - // If we're in strict mode, we want to report defining a global as an error, - // but still add it to the bindings, so that autocomplete includes it in completions. - if (!isNonstrictMode()) - reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - - return std::pair(result, &binding.typeId); + TypeId ty = follow(it->second.typeId); + return get(ty) ? unknownType : ty; } - reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - return std::pair(errorType, nullptr); + TypeId result = freshType(scope); + Binding& binding = moduleScope->bindings[expr.name]; + binding = {result, expr.location}; + + // If we're in strict mode, we want to report defining a global as an error, + // but still add it to the bindings, so that autocomplete includes it in completions. + if (!isNonstrictMode()) + reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); + + return result; } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) { TypeId lhs = checkExpr(scope, *expr.expr).type; if (get(lhs) || get(lhs)) - return std::pair(lhs, nullptr); + return lhs; + + if (get(lhs)) + return unknownType; tablify(lhs); Name name = expr.index.value; - if (FFlag::LuauExtraNilRecovery) - lhs = stripFromNilAndReport(lhs, expr.expr->location); + lhs = stripFromNilAndReport(lhs, expr.expr->location); if (TableTypeVar* lhsTable = getMutableTableType(lhs)) { const auto& it = lhsTable->props.find(name); if (it != lhsTable->props.end()) { - return std::pair(it->second.type, &it->second.type); + return it->second.type; } else if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { - TypeId theType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; property.type = theType; property.location = expr.indexLocation; - return std::pair(theType, &property.type); + return theType; } else if (auto indexer = lhsTable->indexer) { - Unifier state = mkUnifier(expr.location); - state.tryUnify(indexer->indexType, stringType); + Unifier state = mkUnifier(scope, expr.location); + state.tryUnify(stringType, indexer->indexType); + TypeId retType = indexer->indexResultType; if (!state.errors.empty()) { - state.log.rollback(); - reportError(expr.location, UnknownProperty{lhs, name}); - return std::pair(errorType, nullptr); - } - return std::pair(indexer->indexResultType, nullptr); + reportError(expr.location, UnknownProperty{lhs, name}); + retType = errorRecoveryType(retType); + } + else + state.log.commit(); + + return retType; } else if (lhsTable->state == TableState::Sealed) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return errorRecoveryType(scope); } else { reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}}); - return std::pair(errorType, nullptr); + return errorRecoveryType(scope); } } else if (const ClassTypeVar* lhsClass = get(lhs)) @@ -2664,44 +3352,49 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return std::pair(errorType, nullptr); + return errorRecoveryType(scope); } - return std::pair(prop->type, nullptr); + return prop->type; } else if (get(lhs)) { - if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, false)) - return std::pair(*ty, nullptr); + if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, /* addErrors= */ false)) + return *ty; // If intersection has a table part, report that it cannot be extended just as a sealed table if (isTableIntersection(lhs)) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return errorRecoveryType(scope); } } reportError(TypeError{expr.location, NotATable{lhs}}); - return std::pair(errorType, nullptr); + return errorRecoveryType(scope); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) { TypeId exprType = checkExpr(scope, *expr.expr).type; tablify(exprType); - if (FFlag::LuauExtraNilRecovery) - exprType = stripFromNilAndReport(exprType, expr.expr->location); + exprType = stripFromNilAndReport(exprType, expr.expr->location); TypeId indexType = checkExpr(scope, *expr.index).type; + if (FFlag::LuauFollowInLvalueIndexCheck) + exprType = follow(exprType); + if (get(exprType) || get(exprType)) - return std::pair(exprType, nullptr); + return exprType; + + if (get(exprType)) + return unknownType; AstExprConstantString* value = expr.index->as(); - if (value && FFlag::LuauClassPropertyAccessAsString) + if (value) { if (const ClassTypeVar* exprClass = get(exprType)) { @@ -2709,9 +3402,19 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); - return std::pair(errorType, nullptr); + return errorRecoveryType(scope); } - return std::pair(prop->type, nullptr); + return prop->type; + } + } + else if (FFlag::LuauAllowIndexClassParameters) + { + if (const ClassTypeVar* exprClass = get(exprType)) + { + if (isNonstrictMode()) + return unknownType; + reportError(TypeError{expr.location, DynamicPropertyLookupOnClassesUnsafe{exprType}}); + return errorRecoveryType(scope); } } @@ -2719,11 +3422,8 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { - if (FFlag::LuauExtraNilRecovery) - reportError(TypeError{expr.expr->location, NotATable{exprType}}); - else - reportError(TypeError{expr.location, NotATable{exprType}}); - return std::pair(errorType, nullptr); + reportError(TypeError{expr.expr->location, NotATable{exprType}}); + return errorRecoveryType(scope); } if (value) @@ -2731,47 +3431,49 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope const auto& it = exprTable->props.find(value->value.data); if (it != exprTable->props.end()) { - return std::pair(it->second.type, &it->second.type); + return it->second.type; } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; property.type = resultType; property.location = expr.index->location; - return std::pair(resultType, &property.type); + return resultType; } } if (exprTable->indexer) { const TableIndexer& indexer = *exprTable->indexer; - unify(indexer.indexType, indexType, expr.index->location); - return std::pair(indexer.indexResultType, nullptr); + unify(indexType, indexer.indexType, scope, expr.index->location); + return indexer.indexResultType; } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId resultType = freshType(exprTable->level); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; - return std::pair(resultType, nullptr); - } - else if (FFlag::LuauIndexTablesWithIndexers) - { - // We allow t[x] where x:string for tables without an indexer - unify(indexType, stringType, expr.location); - return std::pair(anyType, nullptr); + return resultType; } else { + /* + * If we use [] indexing to fetch a property from a sealed table that has no indexer, we have no idea if it will + * work, so we just mint a fresh type, return that, and hope for the best. + */ TypeId resultType = freshType(scope); - return std::pair(resultType, nullptr); + return resultType; } } // Answers the question: "Can I define another function with this name?" // Primarily about detecting duplicates. -TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) +TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level) { + auto freshTy = [&]() { + return freshType(level); + }; + if (auto globalName = funName.as()) { const ScopePtr& globalScope = currentModule->getModuleScope(); @@ -2781,11 +3483,11 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) if (isNonstrictMode()) return globalScope->bindings[name].typeId; - return errorType; + return errorRecoveryType(scope); } else { - TypeId ty = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId ty = freshTy(); globalScope->bindings[name] = {ty, funName.location}; return ty; } @@ -2795,45 +3497,36 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Symbol name = localName->local; Binding& binding = scope->bindings[name]; if (binding.typeId == nullptr) - binding = {(FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)), funName.location}; + binding = {freshTy(), funName.location}; return binding.typeId; } else if (auto indexName = funName.as()) { TypeId lhsType = checkExpr(scope, *indexName->expr).type; - if (get(lhsType) || get(lhsType)) - return lhsType; - TableTypeVar* ttv = getMutableTableType(lhsType); - if (!ttv) + + if (!ttv || ttv->state == TableState::Sealed) { - if (!isTableIntersection(lhsType)) - reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); + if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, /* addErrors= */ false)) + return *ty; - return errorType; + return errorRecoveryType(scope); } - // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check - if (lhsType->persistent || ttv->state == TableState::Sealed) - return errorType; - Name name = indexName->index.value; if (ttv->props.count(name)) - return errorType; + return ttv->props[name].type; Property& property = ttv->props[name]; - property.type = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + property.type = freshTy(); property.location = indexName->indexLocation; - ttv->methodDefinitionLocations[name] = funName.location; return property.type; } else if (funName.is()) - { - return errorType; - } + return errorRecoveryType(scope); else { ice("Unexpected AST node type", funName.location); @@ -2851,8 +3544,8 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) // `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X` // to get type `(X) -> X`, then we quantify the free types to get the final // generic type `(a) -> a`. -std::pair TypeChecker::checkFunctionSignature( - const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::optional originalName, std::optional expectedType) +std::pair TypeChecker::checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, + std::optional originalName, std::optional selfType, std::optional expectedType) { ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel); @@ -2886,54 +3579,32 @@ std::pair TypeChecker::checkFunctionSignature( } } } - - // We do not infer type binders, so if a generic function is required we do not propagate - if (expectedFunctionType && !(expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty())) - expectedFunctionType = nullptr; } - std::vector generics; - std::vector genericPacks; - - if (FFlag::LuauGenericFunctions) - { - std::tie(generics, genericPacks) = createGenericTypes(funScope, expr, expr.generics, expr.genericPacks); - } + auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); TypePackId retPack; - if (expr.hasReturnAnnotation) - { - if (FFlag::LuauGenericFunctions) - retPack = resolveTypePack(funScope, expr.returnAnnotation); - else - retPack = resolveTypePack(scope, expr.returnAnnotation); - } + if (expr.returnAnnotation) + retPack = resolveTypePack(funScope, *expr.returnAnnotation); else if (isNonstrictMode()) retPack = anyTypePack; - else if (expectedFunctionType) + else if (expectedFunctionType && expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty()) { - auto [head, tail] = flatten(expectedFunctionType->retType); + auto [head, tail] = flatten(expectedFunctionType->retTypes); // Do not infer 'nil' as function return type if (!tail && head.size() == 1 && isNil(head[0])) - retPack = FFlag::LuauGenericFunctions ? freshTypePack(funScope) : freshTypePack(scope); + retPack = freshTypePack(funScope); else retPack = addTypePack(head, tail); } - else if (FFlag::LuauGenericFunctions) - retPack = freshTypePack(funScope); else - retPack = freshTypePack(scope); + retPack = freshTypePack(funScope); if (expr.vararg) { if (expr.varargAnnotation) - { - if (FFlag::LuauGenericFunctions) - funScope->varargPack = resolveTypePack(funScope, *expr.varargAnnotation); - else - funScope->varargPack = resolveTypePack(scope, *expr.varargAnnotation); - } + funScope->varargPack = resolveTypePack(funScope, *expr.varargAnnotation); else { if (expectedFunctionType && !isNonstrictMode()) @@ -2967,12 +3638,25 @@ std::pair TypeChecker::checkFunctionSignature( funScope->returnType = retPack; - if (expr.self) + if (FFlag::DebugLuauSharedSelf) { - // TODO: generic self types: CLI-39906 - TypeId selfType = anyIfNonstrict(freshType(funScope)); - funScope->bindings[expr.self] = {selfType, expr.self->location}; - argTypes.push_back(selfType); + if (expr.self) + { + // TODO: generic self types: CLI-39906 + TypeId selfTy = anyIfNonstrict(selfType ? *selfType : freshType(funScope)); + funScope->bindings[expr.self] = {selfTy, expr.self->location}; + argTypes.push_back(selfTy); + } + } + else + { + if (expr.self) + { + // TODO: generic self types: CLI-39906 + TypeId selfType = anyIfNonstrict(freshType(funScope)); + funScope->bindings[expr.self] = {selfType, expr.self->location}; + argTypes.push_back(selfType); + } } // Prepare expected argument type iterators if we have an expected function type @@ -2990,7 +3674,7 @@ std::pair TypeChecker::checkFunctionSignature( if (local->annotation) { - argType = resolveType((FFlag::LuauGenericFunctions ? funScope : scope), *local->annotation); + argType = resolveType(funScope, *local->annotation); // If the annotation type has an error, treat it as if there was no annotation if (get(follow(argType))) @@ -3003,9 +3687,6 @@ std::pair TypeChecker::checkFunctionSignature( if (expectedArgsCurr != expectedArgsEnd) { argType = *expectedArgsCurr; - - if (!FFlag::LuauInferFunctionArgsFix) - ++expectedArgsCurr; } else if (auto expectedArgsTail = expectedArgsCurr.tail()) { @@ -3021,7 +3702,7 @@ std::pair TypeChecker::checkFunctionSignature( funScope->bindings[local] = {argType, local->location}; argTypes.push_back(argType); - if (FFlag::LuauInferFunctionArgsFix && expectedArgsCurr != expectedArgsEnd) + if (expectedArgsCurr != expectedArgsEnd) ++expectedArgsCurr; } @@ -3033,7 +3714,34 @@ std::pair TypeChecker::checkFunctionSignature( defn.varargLocation = expr.vararg ? std::make_optional(expr.varargLocation) : std::nullopt; defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); - TypeId funTy = addType(FunctionTypeVar(funScope->level, generics, genericPacks, argPack, retPack, std::move(defn), bool(expr.self))); + std::vector genericTys; + // if we have a generic expected function type and no generics, we should use the expected ones. + if (expectedFunctionType && generics.empty()) + { + genericTys = expectedFunctionType->generics; + } + else + { + genericTys.reserve(generics.size()); + for (const GenericTypeDefinition& generic : generics) + genericTys.push_back(generic.ty); + } + + std::vector genericTps; + // if we have a generic expected function type and no generic typepacks, we should use the expected ones. + if (expectedFunctionType && genericPacks.empty()) + { + genericTps = expectedFunctionType->genericPacks; + } + else + { + genericTps.reserve(genericPacks.size()); + for (const GenericTypePackDefinition& generic : genericPacks) + genericTps.push_back(generic.tp); + } + + TypeId funTy = + addType(FunctionTypeVar(funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, std::move(defn), bool(expr.self))); FunctionTypeVar* ftv = getMutable(funTy); @@ -3052,7 +3760,7 @@ static bool allowsNoReturnValues(const TypePackId tp) { for (TypeId ty : tp) { - if (!get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) + if (!get(follow(ty))) { return false; } @@ -3076,39 +3784,62 @@ static Location getEndLocation(const AstExprFunction& function) void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstExprFunction& function) { + LUAU_TIMETRACE_SCOPE("TypeChecker::checkFunctionBody", "TypeChecker"); + + if (function.debugname.value) + LUAU_TIMETRACE_ARGUMENT("name", function.debugname.value); + else + LUAU_TIMETRACE_ARGUMENT("line", std::to_string(function.location.begin.line).c_str()); + if (FunctionTypeVar* funTy = getMutable(ty)) { check(scope, *function.body); // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (FFlag::LuauAddMissingFollow ? get_if(&funTy->retType->ty) : get(funTy->retType)) - *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + if (get_if(&funTy->retTypes->ty)) + *asMutable(funTy->retTypes) = TypePack{{}, std::nullopt}; bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; - if (reachesImplicitReturn && !allowsNoReturnValues(follow(funTy->retType))) + if (reachesImplicitReturn && !allowsNoReturnValues(follow(funTy->retTypes))) { // If we're in nonstrict mode we want to only report this missing return // statement if there are type annotations on the function. In strict mode // we report it regardless. - if (!isNonstrictMode() || function.hasReturnAnnotation) + if (!isNonstrictMode() || function.returnAnnotation) { - reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retType}); + reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retTypes}); } } + + if (!currentModule->astTypes.find(&function)) + currentModule->astTypes[&function] = ty; } else ice("Checking non functional type"); } -ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) +WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) +{ + if (FFlag::LuauUnknownAndNeverType) + { + WithPredicate result = checkExprPackHelper(scope, expr); + if (containsNever(result.type)) + return {uninhabitableTypePack}; + return result; + } + else + return checkExprPackHelper(scope, expr); +} + +WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) - return checkExprPack(scope, *a); + return checkExprPackHelper(scope, *a); else if (expr.is()) { if (!scope->varargPack) - return {addTypePack({addType(ErrorTypeVar())})}; + return {errorRecoveryTypePack(scope)}; return {*scope->varargPack}; } @@ -3119,47 +3850,34 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } } -// Returns the minimum number of arguments the argument list can accept. -static size_t getMinParameterCount(TypePackId tp) -{ - size_t minCount = 0; - size_t optionalCount = 0; - - auto it = begin(tp); - auto endIter = end(tp); - - while (it != endIter) - { - TypeId ty = *it; - if (isOptional(ty)) - ++optionalCount; - else - { - minCount += optionalCount; - optionalCount = 0; - minCount++; - } - - ++it; - } - - return minCount; -} - -void TypeChecker::checkArgumentList( - const ScopePtr& scope, Unifier& state, TypePackId argPack, TypePackId paramPack, const std::vector& argLocations) +void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId argPack, TypePackId paramPack, + const std::vector& argLocations) { /* Important terminology refresher: - * A function requires paramaters. + * A function requires parameters. * To call a function, you supply arguments. */ - TypePackIterator argIter = begin(argPack); - TypePackIterator paramIter = begin(paramPack); + TypePackIterator argIter = begin(argPack, &state.log); + TypePackIterator paramIter = begin(paramPack, &state.log); TypePackIterator endIter = end(argPack); // Important subtlety: All end TypePackIterators are equivalent size_t paramIndex = 0; - size_t minParams = getMinParameterCount(paramPack); + auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack, &funName]() { + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + + std::string namePath; + + if (std::optional path = getFunctionNameAsString(funName)) + namePath = *path; + + auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); + state.reportError(TypeError{location, + CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath}}); + }; while (true) { @@ -3170,33 +3888,28 @@ void TypeChecker::checkArgumentList( std::optional argTail = argIter.tail(); std::optional paramTail = paramIter.tail(); - // If we hit the end of both type packs simultaneously, then there are definitely no further type - // errors to report. All we need to do is tie up any free tails. - // - // If one side has a free tail and the other has none at all, we create an empty pack and bind the - // free tail to that. + // If we hit the end of both type packs simultaneously, we have to unify them. + // But if one side has a free tail and the other has none at all, we create an empty pack and bind the free tail to that. if (argTail) { - if (get(*argTail)) + if (state.log.getMutable(state.log.follow(*argTail))) { if (paramTail) - state.tryUnify(*argTail, *paramTail); + state.tryUnify(*paramTail, *argTail); else - { - state.log(*argTail); - *asMutable(*argTail) = TypePack{{}}; - } + state.log.replace(*argTail, TypePackVar(TypePack{{}})); + } + else if (paramTail) + { + state.tryUnify(*argTail, *paramTail); } } else if (paramTail) { // argTail is definitely empty - if (get(*paramTail)) - { - state.log(*paramTail); - *asMutable(*paramTail) = TypePack{{}}; - } + if (state.log.getMutable(state.log.follow(*paramTail))) + state.log.replace(*paramTail, TypePackVar(TypePack{{}})); } return; @@ -3209,28 +3922,29 @@ void TypeChecker::checkArgumentList( if (argIter.tail()) { TypePackId tail = *argIter.tail(); - if (get(tail)) + if (state.log.getMutable(tail)) { // Unify remaining parameters so we don't leave any free-types hanging around. - TypeId argTy = errorType; while (paramIter != endIter) { - state.tryUnify(*paramIter, argTy); + state.tryUnify(errorRecoveryType(anyType), *paramIter); ++paramIter; } return; } - else if (auto vtp = get(tail)) + else if (auto vtp = state.log.getMutable(tail)) { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. while (paramIter != endIter) { - state.tryUnify(*paramIter, vtp->ty); + state.tryUnify(vtp->ty, *paramIter); ++paramIter; } return; } - else if (get(tail)) + else if (state.log.getMutable(tail)) { std::vector rest; rest.reserve(std::distance(paramIter, endIter)); @@ -3241,7 +3955,10 @@ void TypeChecker::checkArgumentList( } TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); - state.tryUnify(tail, varPack); + if (FFlag::LuauReturnsFromCallsitesAreNotWidened) + state.tryUnify(tail, varPack); + else + state.tryUnify(varPack, tail); return; } } @@ -3249,19 +3966,27 @@ void TypeChecker::checkArgumentList( // If any remaining unfulfilled parameters are nonoptional, this is a problem. while (paramIter != endIter) { - TypeId t = follow(*paramIter); + TypeId t = state.log.follow(*paramIter); if (isOptional(t)) { } // ok - else if (get(t)) - { - } // ok - else if (isNonstrictMode() && get(t)) + else if (state.log.getMutable(t)) { } // ok else { - state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); + + std::optional tail = flatten(paramPack, state.log).second; + bool isVariadic = tail && Luau::isVariadic(*tail); + + std::string namePath; + + if (std::optional path = getFunctionNameAsString(funName)) + namePath = *path; + + state.reportError(TypeError{ + funName.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); return; } ++paramIter; @@ -3274,24 +3999,22 @@ void TypeChecker::checkArgumentList( { while (argIter != endIter) { - unify(*argIter, errorType, state.location); + // The use of unify here is deliberate. We don't want this unification + // to be undoable. + unify(errorRecoveryType(scope), *argIter, scope, state.location); ++argIter; } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + reportCountMismatchError(); return; } - TypePackId tail = *paramIter.tail(); + TypePackId tail = state.log.follow(*paramIter.tail()); - if (get(tail)) + if (state.log.getMutable(tail)) { // Function is variadic. Ok. return; } - else if (auto vtp = get(tail)) + else if (auto vtp = state.log.getMutable(tail)) { // Function is variadic and requires that all subsequent parameters // be compatible with a type. @@ -3303,14 +4026,14 @@ void TypeChecker::checkArgumentList( if (argIndex < argLocations.size()) location = argLocations[argIndex]; - unify(vtp->ty, *argIter, location); + unify(*argIter, vtp->ty, scope, location); ++argIter; ++argIndex; } return; } - else if (FFlag::LuauGenericVariadicsUnification && get(tail)) + else if (state.log.getMutable(tail)) { // Create a type pack out of the remaining argument types // and unify it with the tail. @@ -3323,33 +4046,27 @@ void TypeChecker::checkArgumentList( } TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); - state.tryUnify(tail, varPack); - return; - } - else if (get(tail)) - { - state.log(tail); - *asMutable(tail) = TypePack{}; + state.tryUnify(varPack, tail); return; } - else if (FFlag::LuauRankNTypes && get(tail)) + else if (state.log.getMutable(tail)) { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.log.replace(tail, TypePackVar(TypePack{{}})); + return; + } + else if (state.log.getMutable(tail)) + { + reportCountMismatchError(); return; } } else { - if (FFlag::LuauRankNTypes) - unifyWithInstantiationIfNeeded(scope, *paramIter, *argIter, state); + if (FFlag::LuauInstantiateInSubtyping) + state.tryUnify(*argIter, *paramIter, /*isFunctionCall*/ false); else - state.tryUnify(*paramIter, *argIter, /*isFunctionCall*/ false); + unifyWithInstantiationIfNeeded(*argIter, *paramIter, scope, state); ++argIter; ++paramIter; } @@ -3358,7 +4075,7 @@ void TypeChecker::checkArgumentList( } } -ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr) { // evaluate type of function // decompose an intersection into its component overloads @@ -3366,7 +4083,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A // For each overload // Compare parameter and argument types // Report any errors (also speculate dot vs colon warnings!) - // If there are no errors, return the resulting return type + // Return the resulting return type (even if there are errors) + // If there are no matching overloads, unify with (a...) -> (b...) and return b... TypeId selfType = nullptr; TypeId functionType = nullptr; @@ -3379,47 +4097,17 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A ice("method call expression has no 'self'"); selfType = checkExpr(scope, *indexExpr->expr).type; - if (!FFlag::LuauRankNTypes) - instantiate(scope, selfType, expr.func->location); + selfType = stripFromNilAndReport(selfType, expr.func->location); - if (FFlag::LuauExtraNilRecovery) - selfType = stripFromNilAndReport(selfType, expr.func->location); - - if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) + if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, /* addErrors= */ true)) { functionType = *propTy; actualFunctionType = instantiate(scope, functionType, expr.func->location); } else { - if (!FFlag::LuauMissingUnionPropertyError) - reportError(indexExpr->indexLocation, UnknownProperty{selfType, indexExpr->index.value}); - - if (!FFlag::LuauExtraNilRecovery) - { - // Try to recover using a union without 'nil' options - if (std::optional strippedUnion = tryStripUnionFromNil(selfType)) - { - if (std::optional propTy = getIndexTypeFromType(scope, *strippedUnion, indexExpr->index.value, expr.location, false)) - { - selfType = *strippedUnion; - - functionType = *propTy; - actualFunctionType = instantiate(scope, functionType, expr.func->location); - } - } - - if (!actualFunctionType) - { - functionType = errorType; - actualFunctionType = errorType; - } - } - else - { - functionType = errorType; - actualFunctionType = errorType; - } + functionType = errorRecoveryType(scope); + actualFunctionType = functionType; } } else @@ -3428,7 +4116,15 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A actualFunctionType = instantiate(scope, functionType, expr.func->location); } - actualFunctionType = follow(actualFunctionType); + TypePackId retPack; + if (auto free = get(actualFunctionType)) + { + retPack = freshTypePack(free->level); + TypePackId freshArgPack = freshTypePack(free->level); + asMutable(actualFunctionType)->ty.emplace(free->level, freshArgPack, retPack); + } + else + retPack = freshTypePack(scope->level); // checkExpr will log the pre-instantiated type of the function. // That's not nearly as interesting as the instantiated type, which will include details about how @@ -3438,22 +4134,33 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector overloads = flattenIntersection(actualFunctionType); - TypePackId retPack = freshTypePack(scope->level); - std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); - ExprResult argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); - TypePackId argList = argListResult.type; - TypePackId argPack = (FFlag::LuauRankNTypes ? argList : DEPRECATED_instantiate(scope, argList, expr.location)); + WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); + TypePackId argPack = argListResult.type; if (get(argPack)) - return ExprResult{errorTypePack}; + return {errorRecoveryTypePack(scope)}; - TypePack* args = getMutable(argPack); - LUAU_ASSERT(args != nullptr); + TypePack* args = nullptr; + if (FFlag::LuauUnknownAndNeverType) + { + if (expr.self) + { + argPack = addTypePack(TypePack{{selfType}, argPack}); + argListResult.type = argPack; + } + args = getMutable(argPack); + LUAU_ASSERT(args); + } + else + { + args = getMutable(argPack); + LUAU_ASSERT(args != nullptr); - if (expr.self) - args->head.insert(args->head.begin(), selfType); + if (expr.self) + args->head.insert(args->head.begin(), selfType); + } std::vector argLocations; argLocations.reserve(expr.args.size + 1); @@ -3465,26 +4172,38 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector errors; // errors encountered for each overload std::vector overloadsThatMatchArgCount; + std::vector overloadsThatDont; for (TypeId fn : overloads) { fn = follow(fn); - if (auto ret = checkCallOverload(scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, errors)) + if (auto ret = checkCallOverload( + scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) return *ret; } if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) return {retPack}; - return reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + + const FunctionTypeVar* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return {errorRecoveryTypePack(overload->retTypes)}; + + return {errorRecoveryTypePack(retPack)}; } std::vector> TypeChecker::getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall) { std::vector> expectedTypes; - auto assignOption = [this, &expectedTypes](size_t index, std::optional ty) { + auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { if (index == expectedTypes.size()) { expectedTypes.push_back(ty); @@ -3499,8 +4218,11 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } else { - std::vector result = reduceUnion({*el, *ty}); - el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); + std::vector result = reduceUnion({*el, ty}); + if (FFlag::LuauUnknownAndNeverType && result.empty()) + el = neverType; + else + el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); } } }; @@ -3519,7 +4241,8 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st if (argsTail) { - if (const VariadicTypePack* vtp = get(follow(*argsTail))) + argsTail = follow(*argsTail); + if (const VariadicTypePack* vtp = get(*argsTail)) { while (index < argumentCount) assignOption(index++, vtp->ty); @@ -3528,66 +4251,85 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } } + Demoter demoter{¤tModule->internalTypes}; + demoter.demote(expectedTypes); + return expectedTypes; } -std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors) +std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { - if (FFlag::LuauExtraNilRecovery) - fn = stripFromNilAndReport(fn, expr.func->location); + LUAU_ASSERT(argLocations); + + fn = stripFromNilAndReport(fn, expr.func->location); if (get(fn)) { - unify(argPack, anyTypePack, expr.location); + unify(anyTypePack, argPack, scope, expr.location); return {{anyTypePack}}; } if (get(fn)) { - return {{addTypePack(TypePackVar{Unifiable::Error{}})}}; + return {{errorRecoveryTypePack(scope)}}; } - if (get(fn)) + if (get(fn)) + return {{uninhabitableTypePack}}; + + if (auto ftv = get(fn)) { // fn is one of the overloads of actualFunctionType, which // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - unify(r, fn, expr.location); + + UnifierOptions options; + options.isFunctionCall = true; + unify(r, fn, scope, expr.location, options); + return {{retPack}}; } + std::vector metaArgLocations; + + // Might be a callable table or class + std::optional callTy = std::nullopt; + if (const MetatableTypeVar* mttv = get(fn)) + { + callTy = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false); + } + else if (const ClassTypeVar* ctv = get(fn); FFlag::LuauCallableClasses && ctv && ctv->metatable) + { + callTy = getIndexTypeFromType(scope, *ctv->metatable, "__call", expr.func->location, /* addErrors= */ false); + } + + if (callTy) + { + // Construct arguments with 'self' added in front + TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); + + TypePack* metaCallArgs = getMutable(metaCallArgPack); + metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); + + metaArgLocations = *argLocations; + metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + + fn = instantiate(scope, *callTy, expr.func->location); + + argPack = metaCallArgPack; + args = metaCallArgs; + argLocations = &metaArgLocations; + } + const FunctionTypeVar* ftv = get(fn); if (!ftv) { - // Might be a callable table - if (const MetatableTypeVar* mttv = get(fn)) - { - if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) - { - // Construct arguments with 'self' added in front - TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); - - TypePack* metaCallArgs = getMutable(metaCallArgPack); - metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); - - std::vector metaArgLocations = argLocations; - metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); - - TypeId fn = *ty; - if (FFlag::LuauRankNTypes) - fn = instantiate(scope, fn, expr.func->location); - - return checkCallOverload( - scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors); - } - } - reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); - unify(retPack, errorTypePack, expr.func->location); - return {{errorTypePack}}; + unify(errorRecoveryTypePack(scope), retPack, scope, expr.func->location); + return {{errorRecoveryTypePack(retPack)}}; } // When this function type has magic functions and did return something, we select that overload instead. @@ -3595,21 +4337,20 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (ftv->magicFunction) { // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 - if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) + if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) return *ret; } - Unifier state = mkUnifier(expr.location); + Unifier state = mkUnifier(scope, expr.location); // Unify return types - checkArgumentList(scope, state, retPack, ftv->retType, /*argLocations*/ {}); + checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, /*argLocations*/ {}); if (!state.errors.empty()) { - state.log.rollback(); return {}; } - checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); + checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, *argLocations); if (!state.errors.empty()) { @@ -3629,40 +4370,16 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!argMismatch) overloadsThatMatchArgCount.push_back(fn); + else + overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); - state.log.rollback(); } else { - if (isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) - { - // If we are running in nonstrict mode, passing fewer arguments than the function is declared to take AND - // the function is declared with colon notation AND we use dot notation, warn. - auto [providedArgs, providedTail] = flatten(argPack); + state.log.commit(); - // If we have a variadic tail, we can't say how many arguments were actually provided - if (!providedTail) - { - std::vector actualArgs = flatten(ftv->argTypes).first; - - size_t providedCount = providedArgs.size(); - size_t requiredCount = actualArgs.size(); - - // Ignore optional arguments - while (providedCount < requiredCount && requiredCount != 0 && isOptional(actualArgs[requiredCount - 1])) - requiredCount--; - - if (providedCount < requiredCount) - { - int requiredExtraNils = int(requiredCount - providedCount); - reportError(TypeError{expr.func->location, FunctionRequiresSelf{requiredExtraNils}}); - } - } - } - - if (FFlag::LuauStoreMatchingOverloadFnType) - currentModule->astOverloadResolvedTypes[&expr] = fn; + currentModule->astOverloadResolvedTypes[&expr] = fn; // We select this overload return {{retPack}}; @@ -3686,21 +4403,21 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal std::vector editedParamList(args->head.begin() + 1, args->head.end()); TypePackId editedArgPack = addTypePack(TypePack{editedParamList}); - Unifier editedState = mkUnifier(expr.location); - checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + Unifier editedState = mkUnifier(scope, expr.location); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); if (editedState.errors.empty()) { + editedState.log.commit(); + reportError(TypeError{expr.location, FunctionDoesNotTakeSelf{}}); // This is a little bit suspect: If this overload would work with a . replaced by a : // we eagerly assume that that's what you actually meant and we commit to it. // This could be incorrect if the function has an additional overload that // actually works. - // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + // checkArgumentList(scope, editedState, retPack, ftv->retTypes, retLocations, CountMismatch::Return); return true; } - else - editedState.log.rollback(); } else if (ftv->hasSelf) { @@ -3716,22 +4433,22 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal editedArgList.insert(editedArgList.begin(), checkExpr(scope, *indexName->expr).type); TypePackId editedArgPack = addTypePack(TypePack{editedArgList}); - Unifier editedState = mkUnifier(expr.location); + Unifier editedState = mkUnifier(scope, expr.location); - checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); if (editedState.errors.empty()) { + editedState.log.commit(); + reportError(TypeError{expr.location, FunctionRequiresSelf{}}); // This is a little bit suspect: If this overload would work with a : replaced by a . // we eagerly assume that that's what you actually meant and we commit to it. // This could be incorrect if the function has an additional overload that // actually works. - // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + // checkArgumentList(scope, editedState, retPack, ftv->retTypes, retLocations, CountMismatch::Return); return true; } - else - editedState.log.rollback(); } } } @@ -3739,14 +4456,14 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal return false; } -ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, - TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, - const std::vector& overloadsThatMatchArgCount, const std::vector& errors) +void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, + const std::vector& errors) { if (overloads.size() == 1) { reportErrors(std::get<0>(errors.front())); - return {errorTypePack}; + return; } std::vector overloadTypes = overloadsThatMatchArgCount; @@ -3775,22 +4492,25 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr // If only one overload matched, we don't need this error because we provided the previous errors. if (overloadsThatMatchArgCount.size() == 1) - return {errorTypePack}; + return; } std::string s; for (size_t i = 0; i < overloadTypes.size(); ++i) { - TypeId overload = overloadTypes[i]; - Unifier state = mkUnifier(expr.location); + TypeId overload = FFlag::LuauTypeInferMissingFollows ? follow(overloadTypes[i]) : overloadTypes[i]; + Unifier state = mkUnifier(scope, expr.location); // Unify return types if (const FunctionTypeVar* ftv = get(overload)) { - checkArgumentList(scope, state, retPack, ftv->retType, {}); - checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); + checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, {}); + checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, argLocations); } + if (state.errors.empty()) + state.log.commit(); + if (i > 0) s += "; "; @@ -3798,8 +4518,6 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr s += "and "; s += toString(overload); - - state.log.rollback(); } if (overloadsThatMatchArgCount.size() == 0) @@ -3808,12 +4526,13 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr reportError(expr.func->location, ExtraInformation{"Other overloads are also not viable: " + s}); // No viable overload - return {errorTypePack}; + return; } -ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, +WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil, const std::vector& instantiateGenerics, const std::vector>& expectedTypes) { + bool uninhabitable = false; TypePackId pack = addTypePack(TypePack{}); PredicateVec predicates; // At the moment we will be pushing all predicate sets into this. Do we need some way to split them up? @@ -3830,39 +4549,80 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L size_t lastIndex = exprs.size - 1; tp->head.reserve(lastIndex); - Unifier state = mkUnifier(location); + Unifier state = mkUnifier(scope, location); + + std::vector inverseLogs; for (size_t i = 0; i < exprs.size; ++i) { AstExpr* expr = exprs.data[i]; + std::optional expectedType = i < expectedTypes.size() ? expectedTypes[i] : std::nullopt; if (i == lastIndex && (expr->is() || expr->is())) { auto [typePack, exprPredicates] = checkExprPack(scope, *expr); insert(exprPredicates); + if (FFlag::LuauUnknownAndNeverType && containsNever(typePack)) + { + // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, + // ...never) + uninhabitable = true; + continue; + } + else if (std::optional firstTy = first(typePack)) + { + if (!currentModule->astTypes.find(expr)) + currentModule->astTypes[expr] = follow(*firstTy); + } + + if (expectedType) + currentModule->astExpectedTypes[expr] = *expectedType; + tp->tail = typePack; } else { - std::optional expectedType = i < expectedTypes.size() ? expectedTypes[i] : std::nullopt; auto [type, exprPredicates] = checkExpr(scope, *expr, expectedType); insert(exprPredicates); + if (FFlag::LuauUnknownAndNeverType && get(type)) + { + // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, + // ...never) + uninhabitable = true; + continue; + } + TypeId actualType = substituteFreeForNil && expr->is() ? freshType(scope) : type; - if (instantiateGenerics.size() > i && instantiateGenerics[i] && (FFlag::LuauGenericFunctions || get(actualType))) - actualType = instantiate(scope, actualType, expr->location); + if (!FFlag::LuauInstantiateInSubtyping) + { + if (instantiateGenerics.size() > i && instantiateGenerics[i]) + actualType = instantiate(scope, actualType, expr->location); + } if (expectedType) - state.tryUnify(*expectedType, actualType); + { + state.tryUnify(actualType, *expectedType); + + // Ugly: In future iterations of the loop, we might need the state of the unification we + // just performed. There's not a great way to pass that into checkExpr. Instead, we store + // the inverse of the current log, and commit it. When we're done, we'll commit all the + // inverses. This isn't optimal, and a better solution is welcome here. + inverseLogs.push_back(state.log.inverse()); + state.log.commit(); + } tp->head.push_back(actualType); } } - state.log.rollback(); + for (TxnLog& log : inverseLogs) + log.commit(); + if (FFlag::LuauUnknownAndNeverType && uninhabitable) + return {uninhabitableTypePack}; return {pack, predicates}; } @@ -3885,39 +4645,60 @@ std::optional TypeChecker::matchRequire(const AstExprCall& call) TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location) { + LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker"); + LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str()); + + if (moduleInfo.name.empty()) + { + if (currentModule->mode == Mode::Strict) + { + reportError(TypeError{location, UnknownRequire{}}); + return errorRecoveryType(anyType); + } + + return anyType; + } + + // Types of requires that transitively refer to current module have to be replaced with 'any' + std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); + + for (const auto& [location, path] : requireCycles) + { + if (!path.empty() && path.front() == humanReadableName) + return anyType; + } + ModulePtr module = resolver->getModule(moduleInfo.name); if (!module) { // There are two reasons why we might fail to find the module: // either the file does not exist or there's a cycle. If there's a cycle // we will already have reported the error. - if (!resolver->moduleExists(moduleInfo.name) && (FFlag::LuauTraceRequireLookupChild ? !moduleInfo.optional : true)) - { - std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(TypeError{location, UnknownRequire{reportedModulePath}}); - } + if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) + reportError(TypeError{location, UnknownRequire{humanReadableName}}); - return errorType; + return errorRecoveryType(scope); } if (module->type != SourceCode::Module) { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } - std::optional moduleType = first(module->getModuleScope()->returnType); + TypePackId modulePack = module->getModuleScope()->returnType; + + if (get(modulePack)) + return errorRecoveryType(scope); + + std::optional moduleType = first(modulePack); if (!moduleType) { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - return clone(*moduleType, currentModule->internalTypes, seenTypes, seenTypePacks); + return *moduleType; } void TypeChecker::tablify(TypeId type) @@ -3936,69 +4717,81 @@ TypeId TypeChecker::anyIfNonstrict(TypeId ty) const return ty; } -bool TypeChecker::unify(TypeId left, TypeId right, const Location& location) +bool TypeChecker::unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location) { - Unifier state = mkUnifier(location); - state.tryUnify(left, right); + UnifierOptions options; + return unify(subTy, superTy, scope, location, options); +} + +bool TypeChecker::unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options) +{ + Unifier state = mkUnifier(scope, location); + state.tryUnify(subTy, superTy, options.isFunctionCall); + + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -bool TypeChecker::unify(TypePackId left, TypePackId right, const Location& location, CountMismatch::Context ctx) +bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location, CountMismatch::Context ctx) { - Unifier state = mkUnifier(location); + Unifier state = mkUnifier(scope, location); state.ctx = ctx; - state.tryUnify(left, right); + state.tryUnify(subTy, superTy); + + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location) +bool TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location) { - LUAU_ASSERT(FFlag::LuauRankNTypes); - Unifier state = mkUnifier(location); - unifyWithInstantiationIfNeeded(scope, left, right, state); + Unifier state = mkUnifier(scope, location); + unifyWithInstantiationIfNeeded(subTy, superTy, scope, state); + + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state) +void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, Unifier& state) { - LUAU_ASSERT(FFlag::LuauRankNTypes); - if (!maybeGeneric(right)) + LUAU_ASSERT(!FFlag::LuauInstantiateInSubtyping); + + if (!maybeGeneric(subTy)) // Quick check to see if we definitely can't instantiate - state.tryUnify(left, right, /*isFunctionCall*/ false); - else if (!maybeGeneric(left) && isGeneric(right)) + state.tryUnify(subTy, superTy, /*isFunctionCall*/ false); + else if (!maybeGeneric(superTy) && isGeneric(subTy)) { // Quick check to see if we definitely have to instantiate - TypeId instantiated = instantiate(scope, right, state.location); - state.tryUnify(left, instantiated, /*isFunctionCall*/ false); + TypeId instantiated = instantiate(scope, subTy, state.location); + state.tryUnify(instantiated, superTy, /*isFunctionCall*/ false); } else { // First try unifying with the original uninstantiated type // but if that fails, try the instantiated one. Unifier child = state.makeChildUnifier(); - child.tryUnify(left, right, /*isFunctionCall*/ false); + child.tryUnify(subTy, superTy, /*isFunctionCall*/ false); if (!child.errors.empty()) { - TypeId instantiated = instantiate(scope, right, state.location); - if (right == instantiated) + TypeId instantiated = instantiate(scope, subTy, state.location, &child.log); + if (subTy == instantiated) { // Instantiating the argument made no difference, so just report any child errors state.log.concat(std::move(child.log)); + state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); } else { - child.log.rollback(); - state.tryUnify(left, instantiated, /*isFunctionCall*/ false); + state.tryUnify(instantiated, superTy, /*isFunctionCall*/ false); } } else @@ -4008,329 +4801,109 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId l } } -bool Instantiation::isDirty(TypeId ty) -{ - if (FFlag::LuauRankNTypes) - { - if (get(ty)) - return true; - else - return false; - } - - if (const FunctionTypeVar* ftv = get(ty)) - return !ftv->generics.empty() || !ftv->genericPacks.empty(); - else if (const TableTypeVar* ttv = get(ty)) - return ttv->state == TableState::Generic; - else if (get(ty)) - return true; - else - return false; -} - -bool Instantiation::isDirty(TypePackId tp) -{ - if (FFlag::LuauRankNTypes) - return false; - - if (get(tp)) - return true; - else - return false; -} - -bool Instantiation::ignoreChildren(TypeId ty) -{ - LUAU_ASSERT(FFlag::LuauRankNTypes); - if (get(ty)) - return true; - else - return false; -} - -TypeId Instantiation::clean(TypeId ty) -{ - LUAU_ASSERT(isDirty(ty)); - - if (const FunctionTypeVar* ftv = get(ty)) - { - FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; - clone.magicFunction = ftv->magicFunction; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - TypeId result = addType(std::move(clone)); - - if (FFlag::LuauRankNTypes) - { - // Annoyingly, we have to do this even if there are no generics, - // to replace any generic tables. - replaceGenerics.level = level; - replaceGenerics.currentModule = currentModule; - replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); - replaceGenerics.genericPacks.assign(ftv->genericPacks.begin(), ftv->genericPacks.end()); - - // TODO: What to do if this returns nullopt? - // We don't have access to the error-reporting machinery - result = replaceGenerics.substitute(result).value_or(result); - } - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } - else if (const TableTypeVar* ttv = get(ty)) - { - LUAU_ASSERT(!FFlag::LuauRankNTypes); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - TypeId result = addType(std::move(clone)); - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } - else - { - LUAU_ASSERT(!FFlag::LuauRankNTypes); - TypeId result = addType(FreeTypeVar{level}); - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } -} - -TypePackId Instantiation::clean(TypePackId tp) -{ - LUAU_ASSERT(!FFlag::LuauRankNTypes); - return addTypePack(TypePackVar(FreeTypePack{level})); -} - -bool ReplaceGenerics::ignoreChildren(TypeId ty) -{ - LUAU_ASSERT(FFlag::LuauRankNTypes); - if (const FunctionTypeVar* ftv = get(ty)) - // We aren't recursing in the case of a generic function which - // binds the same generics. This can happen if, for example, there's recursive types. - // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. - // It's OK to use vector equality here, since we always generate fresh generics - // whenever we quantify, so the vectors overlap if and only if they are equal. - return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); - else - return false; -} - -bool ReplaceGenerics::isDirty(TypeId ty) -{ - LUAU_ASSERT(FFlag::LuauRankNTypes); - if (const TableTypeVar* ttv = get(ty)) - return ttv->state == TableState::Generic; - else if (get(ty)) - return std::find(generics.begin(), generics.end(), ty) != generics.end(); - else - return false; -} - -bool ReplaceGenerics::isDirty(TypePackId tp) -{ - LUAU_ASSERT(FFlag::LuauRankNTypes); - if (get(tp)) - return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); - else - return false; -} - -TypeId ReplaceGenerics::clean(TypeId ty) -{ - LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = get(ty)) - { - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - return addType(std::move(clone)); - } - else - return addType(FreeTypeVar{level}); -} - -TypePackId ReplaceGenerics::clean(TypePackId tp) -{ - LUAU_ASSERT(isDirty(tp)); - return addTypePack(TypePackVar(FreeTypePack{level})); -} - -bool Quantification::isDirty(TypeId ty) -{ - if (const TableTypeVar* ttv = get(ty)) - return level.subsumes(ttv->level) && ((ttv->state == TableState::Free) || (ttv->state == TableState::Unsealed)); - else if (const FreeTypeVar* ftv = get(ty)) - return level.subsumes(ftv->level); - else - return false; -} - -bool Quantification::isDirty(TypePackId tp) -{ - if (const FreeTypePack* ftv = get(tp)) - return level.subsumes(ftv->level); - else - return false; -} - -TypeId Quantification::clean(TypeId ty) -{ - LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = get(ty)) - { - TableState state = (ttv->state == TableState::Unsealed ? TableState::Sealed : TableState::Generic); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, state}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - return addType(std::move(clone)); - } - else - { - TypeId generic = addType(GenericTypeVar{level}); - generics.push_back(generic); - return generic; - } -} - -TypePackId Quantification::clean(TypePackId tp) -{ - LUAU_ASSERT(isDirty(tp)); - TypePackId genericPack = addTypePack(TypePackVar(GenericTypePack{level})); - genericPacks.push_back(genericPack); - return genericPack; -} - -bool Anyification::isDirty(TypeId ty) -{ - if (const TableTypeVar* ttv = get(ty)) - return (ttv->state == TableState::Free); - else if (get(ty)) - return true; - else - return false; -} - -bool Anyification::isDirty(TypePackId tp) -{ - if (get(tp)) - return true; - else - return false; -} - -TypeId Anyification::clean(TypeId ty) -{ - LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = get(ty)) - { - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - return addType(std::move(clone)); - } - else - return anyType; -} - -TypePackId Anyification::clean(TypePackId tp) -{ - LUAU_ASSERT(isDirty(tp)); - return anyTypePack; -} - TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location) { ty = follow(ty); - const FunctionTypeVar* ftv = get(ty); - if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) - return ty; - - quantification.level = scope->level; - quantification.generics.clear(); - quantification.genericPacks.clear(); - quantification.currentModule = currentModule; - - std::optional qty = quantification.substitute(ty); - - if (!qty.has_value()) + if (FFlag::DebugLuauSharedSelf) { - reportError(location, UnificationTooComplex{}); - return errorType; + if (auto ftv = get(ty)) + Luau::quantify(ty, scope->level); + else if (auto ttv = getTableType(ty); ttv && ttv->selfTy) + Luau::quantify(ty, scope->level); + } + else + { + const FunctionTypeVar* ftv = get(ty); + + if (ftv) + Luau::quantify(ty, scope->level); } - if (ty == *qty) - return ty; - - FunctionTypeVar* qftv = getMutable(*qty); - LUAU_ASSERT(qftv); - qftv->generics = quantification.generics; - qftv->genericPacks = quantification.genericPacks; - return *qty; + return ty; } -TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location) +TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { - instantiation.level = scope->level; - instantiation.currentModule = currentModule; + ty = follow(ty); + + const FunctionTypeVar* ftv = get(ty); + if (ftv && ftv->hasNoGenerics) + return ty; + + Instantiation instantiation{log, ¤tModule->internalTypes, scope->level, /*scope*/ nullptr}; + + if (instantiationChildLimit) + instantiation.childLimit = *instantiationChildLimit; + std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; else { reportError(location, UnificationTooComplex{}); - return errorType; - } -} - -TypePackId TypeChecker::DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location) -{ - LUAU_ASSERT(!FFlag::LuauRankNTypes); - instantiation.level = scope->level; - instantiation.currentModule = currentModule; - std::optional instantiated = instantiation.substitute(ty); - if (instantiated.has_value()) - return *instantiated; - else - { - reportError(location, UnificationTooComplex{}); - return errorTypePack; + return errorRecoveryType(scope); } } TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { - anyification.anyType = anyType; - anyification.anyTypePack = anyTypePack; - anyification.currentModule = currentModule; + Anyification anyification{¤tModule->internalTypes, scope, singletonTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); + if (anyification.normalizationTooComplex) + reportError(location, NormalizationTooComplex{}); if (any.has_value()) return *any; else { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(anyType); } } TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) { - anyification.anyType = anyType; - anyification.anyTypePack = anyTypePack; - anyification.currentModule = currentModule; + Anyification anyification{¤tModule->internalTypes, scope, singletonTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; else { reportError(location, UnificationTooComplex{}); - return errorTypePack; + return errorRecoveryTypePack(anyTypePack); } } +TypePackId TypeChecker::anyifyModuleReturnTypePackGenerics(TypePackId tp) +{ + tp = follow(tp); + + if (const VariadicTypePack* vtp = get(tp)) + { + TypeId ty = FFlag::LuauTypeInferMissingFollows ? follow(vtp->ty) : vtp->ty; + return get(ty) ? anyTypePack : tp; + } + + if (!get(follow(tp))) + return tp; + + std::vector resultTypes; + std::optional resultTail; + + TypePackIterator it = begin(tp); + + for (TypePackIterator e = end(tp); it != e; ++it) + { + TypeId ty = follow(*it); + resultTypes.push_back(get(ty) ? anyType : ty); + } + + if (std::optional tail = it.tail()) + resultTail = anyifyModuleReturnTypePackGenerics(*tail); + + return addTypePack(resultTypes, resultTail); +} + void TypeChecker::reportError(const TypeError& error) { if (currentModule->mode == Mode::NoCheck) @@ -4389,7 +4962,7 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d } }; - if (auto ttv = getTableType(follow(utk->table))) + if (auto ttv = getTableType(utk->table)) accumulate(ttv->props); else if (auto ctv = get(follow(utk->table))) { @@ -4423,9 +4996,10 @@ ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& } // Creates a new Scope and carries forward the varargs from the parent. -ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location, int subLevel) +ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location) { - ScopePtr scope = std::make_shared(parent, subLevel); + ScopePtr scope = std::make_shared(parent); + scope->level = parent->level; scope->varargPack = parent->varargPack; currentModule->scopes.push_back(std::make_pair(location, scope)); @@ -4455,14 +5029,9 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) }); } -Unifier TypeChecker::mkUnifier(const Location& location) +Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, iceHandler}; -} - -Unifier TypeChecker::mkUnifier(const std::vector>& seen, const Location& location) -{ - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, iceHandler}; + return Unifier{NotNull{&normalizer}, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant}; } TypeId TypeChecker::freshType(const ScopePtr& scope) @@ -4472,24 +5041,65 @@ TypeId TypeChecker::freshType(const ScopePtr& scope) TypeId TypeChecker::freshType(TypeLevel level) { - return currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level))); + return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } -TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric) +TypeId TypeChecker::singletonType(bool value) { - return DEPRECATED_freshType(scope->level, canBeGeneric); + return value ? singletonTypes->trueType : singletonTypes->falseType; } -TypeId TypeChecker::DEPRECATED_freshType(TypeLevel level, bool canBeGeneric) +TypeId TypeChecker::singletonType(std::string value) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level, canBeGeneric))); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + // TODO: cache singleton types + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(StringSingleton{std::move(value)}))); } -std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) +TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) +{ + return singletonTypes->errorRecoveryType(); +} + +TypeId TypeChecker::errorRecoveryType(TypeId guess) +{ + return singletonTypes->errorRecoveryType(guess); +} + +TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) +{ + return singletonTypes->errorRecoveryTypePack(); +} + +TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) +{ + return singletonTypes->errorRecoveryTypePack(guess); +} + +TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense, TypeId emptySetTy) +{ + return [this, sense, emptySetTy](TypeId ty) -> std::optional { + // any/error/free gets a special pass unconditionally because they can't be decided. + if (get(ty) || get(ty) || get(ty)) + return ty; + + // maps boolean primitive to the corresponding singleton equal to sense + if (isPrim(ty, PrimitiveTypeVar::Boolean)) + return singletonType(sense); + + // if we have boolean singleton, eliminate it if the sense doesn't match with that singleton + if (auto boolean = get(get(ty))) + return boolean->value == sense ? std::optional(ty) : std::nullopt; + + // if we have nil, eliminate it if sense is true, otherwise take it + if (isNil(ty)) + return sense ? std::nullopt : std::optional(ty); + + // at this point, anything else is kept if sense is true, or replaced by nil + return sense ? ty : emptySetTy; + }; +} + +std::optional TypeChecker::filterMapImpl(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); if (!types.empty()) @@ -4497,29 +5107,33 @@ std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predic return std::nullopt; } -TypeId TypeChecker::addType(const UnionTypeVar& utv) +std::pair, bool> TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { - LUAU_ASSERT(utv.options.size() > 1); + if (FFlag::LuauUnknownAndNeverType) + { + TypeId ty = filterMapImpl(type, predicate).value_or(neverType); + return {ty, !bool(get(ty))}; + } + else + { + std::optional ty = filterMapImpl(type, predicate); + return {ty, bool(ty)}; + } +} - return addTV(TypeVar(utv)); +std::pair, bool> TypeChecker::pickTypesFromSense(TypeId type, bool sense, TypeId emptySetTy) +{ + return filterMap(type, mkTruthyPredicate(sense, emptySetTy)); } TypeId TypeChecker::addTV(TypeVar&& tv) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(std::move(tv)); } TypePackId TypeChecker::addTypePack(TypePackVar&& tv) { - TypePackId allocated = currentModule->internalTypes.typePacks.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addTypePack(std::move(tv)); } TypePackId TypeChecker::addTypePack(TypePack&& tp) @@ -4552,43 +5166,38 @@ TypePackId TypeChecker::freshTypePack(TypeLevel level) return addTypePack(TypePackVar(FreeTypePack(level))); } -TypePackId TypeChecker::DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric) +TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation) { - return DEPRECATED_freshTypePack(scope->level, canBeGeneric); + TypeId ty = resolveTypeWorker(scope, annotation); + currentModule->astResolvedTypes[&annotation] = ty; + return ty; } -TypePackId TypeChecker::DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric) +TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& annotation) { - return addTypePack(TypePackVar(FreeTypePack(level, canBeGeneric))); -} - -TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation, bool DEPRECATED_canBeGeneric) -{ - if (DEPRECATED_canBeGeneric) - LUAU_ASSERT(!FFlag::LuauRankNTypes); - if (const auto& lit = annotation.as()) { std::optional tf; - if (lit->hasPrefix) - tf = scope->lookupImportedType(lit->prefix.value, lit->name.value); + if (lit->prefix) + tf = scope->lookupImportedType(lit->prefix->value, lit->name.value); else if (FFlag::DebugLuauMagicTypes && lit->name == "_luau_ice") ice("_luau_ice encountered", lit->location); else if (FFlag::DebugLuauMagicTypes && lit->name == "_luau_print") { - if (lit->generics.size != 1) + if (lit->parameters.size != 1 || !lit->parameters.data[0].type) { reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(anyType); } ToStringOptions opts; opts.exhaustive = true; opts.maxTableLength = 0; + opts.useLineBreaks = true; - TypeId param = resolveType(scope, *lit->generics.data[0]); + TypeId param = resolveType(scope, *lit->parameters.data[0].type); luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str())); return param; } @@ -4598,12 +5207,12 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (!tf) { - if (lit->name == Parser::errorName) - return addType(ErrorTypeVar{}); + if (lit->name == kParseNameError) + return errorRecoveryType(scope); std::string typeName; - if (lit->hasPrefix) - typeName = std::string(lit->prefix.value) + "."; + if (lit->prefix) + typeName = std::string(lit->prefix->value) + "."; typeName += lit->name.value; if (scope->lookupPack(typeName)) @@ -4611,31 +5220,164 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } - if (lit->generics.size == 0 && tf->typeParams.empty()) + if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) return tf->type; - else if (lit->generics.size != tf->typeParams.size()) - { - reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->generics.size}}); - return addType(ErrorTypeVar{}); - } - else - { - std::vector typeParams; - for (AstType* paramAnnot : lit->generics) - typeParams.push_back(resolveType(scope, *paramAnnot)); - if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) + bool parameterCountErrorReported = false; + bool hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + bool hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + + if (!lit->hasParameterList) + { + if ((!tf->typeParams.empty() && !hasDefaultTypes) || (!tf->typePackParams.empty() && !hasDefaultPacks)) { - // If the generic parameters and the type arguments are the same, we are about to - // perform an identity substitution, which we can just short-circuit. - return tf->type; + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + parameterCountErrorReported = true; + } + } + + std::vector typeParams; + std::vector extraTypes; + std::vector typePackParams; + + for (size_t i = 0; i < lit->parameters.size; ++i) + { + if (AstType* type = lit->parameters.data[i].type) + { + TypeId ty = resolveType(scope, *type); + + if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty()) + typeParams.push_back(ty); + else if (typePackParams.empty()) + extraTypes.push_back(ty); + else + reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}}); + } + else if (AstTypePack* typePack = lit->parameters.data[i].typePack) + { + TypePackId tp = resolveTypePack(scope, *typePack); + + // If we have collected an implicit type pack, materialize it + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we need more regular types, we can use single element type packs to fill those in + if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) + typeParams.push_back(*first(tp)); + else + typePackParams.push_back(tp); + } + } + + // If we still haven't meterialized an implicit type pack, do it now + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + size_t typesProvided = typeParams.size(); + size_t typesRequired = tf->typeParams.size(); + + size_t packsProvided = typePackParams.size(); + size_t packsRequired = tf->typePackParams.size(); + + bool notEnoughParameters = + (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired); + bool hasDefaultParameters = hasDefaultTypes || hasDefaultPacks; + + // Add default type and type pack parameters if that's required and it's possible + if (notEnoughParameters && hasDefaultParameters) + { + // 'applyTypeFunction' is used to substitute default types that reference previous generic types + ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes}; + + for (size_t i = 0; i < typesProvided; ++i) + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; + + if (typesProvided < typesRequired) + { + for (size_t i = typesProvided; i < typesRequired; ++i) + { + TypeId defaultTy = tf->typeParams[i].defaultValue.value_or(nullptr); + + if (!defaultTy) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTy); + + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryType(scope); + } + + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = *maybeInstantiated; + typeParams.push_back(*maybeInstantiated); + } } - return instantiateTypeFun(scope, *tf, typeParams, annotation.location); + for (size_t i = 0; i < packsProvided; ++i) + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = typePackParams[i]; + + if (packsProvided < packsRequired) + { + for (size_t i = packsProvided; i < packsRequired; ++i) + { + TypePackId defaultTp = tf->typePackParams[i].defaultValue.value_or(nullptr); + + if (!defaultTp) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTp); + + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryTypePack(scope); + } + + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = *maybeInstantiated; + typePackParams.push_back(*maybeInstantiated); + } + } } + + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack + if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) + typePackParams.push_back(addTypePack({})); + + if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) + { + if (!parameterCountErrorReported) + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) + typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); + } + + bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal( + typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; + }); + + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + if (sameTys && sameTps) + return tf->type; + + return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); } else if (const auto& table = annotation.as()) { @@ -4643,37 +5385,38 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation std::optional tableIndexer; for (const auto& prop : table->props) - props[prop.name.value] = {resolveType(scope, *prop.type, DEPRECATED_canBeGeneric)}; + props[prop.name.value] = {resolveType(scope, *prop.type)}; if (const auto& indexer = table->indexer) - tableIndexer = TableIndexer( - resolveType(scope, *indexer->indexType, DEPRECATED_canBeGeneric), resolveType(scope, *indexer->resultType, DEPRECATED_canBeGeneric)); + tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); - return addType(TableTypeVar{ - props, tableIndexer, scope->level, - TableState::Sealed // FIXME: probably want a way to annotate other kinds of tables maybe - }); + TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed}; + ttv.definitionModuleName = currentModuleName; + return addType(std::move(ttv)); } else if (const auto& func = annotation.as()) { ScopePtr funcScope = childScope(scope, func->location); + funcScope->level = scope->level.incr(); - std::vector generics; - std::vector genericPacks; - - if (FFlag::LuauGenericFunctions) - { - std::tie(generics, genericPacks) = createGenericTypes(funcScope, annotation, func->generics, func->genericPacks); - } - - // TODO: better error message CLI-39912 - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !DEPRECATED_canBeGeneric && (generics.size() > 0 || genericPacks.size() > 0)) - reportError(TypeError{annotation.location, GenericError{"generic function where only monotypes are allowed"}}); + auto [generics, genericPacks] = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); TypePackId argTypes = resolveTypePack(funcScope, func->argTypes); TypePackId retTypes = resolveTypePack(funcScope, func->returnTypes); - TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(generics), std::move(genericPacks), argTypes, retTypes}); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + + TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes}); FunctionTypeVar* ftv = getMutable(fnType); @@ -4691,16 +5434,13 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (auto typeOf = annotation.as()) { TypeId ty = checkExpr(scope, *typeOf->expr).type; - // TODO: better error message CLI-39912 - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !DEPRECATED_canBeGeneric && isGeneric(ty)) - reportError(TypeError{annotation.location, GenericError{"typeof produced a polytype where only monotypes are allowed"}}); return ty; } else if (const auto& un = annotation.as()) { std::vector types; for (AstType* ann : un->types) - types.push_back(resolveType(scope, *ann, DEPRECATED_canBeGeneric)); + types.push_back(resolveType(scope, *ann)); return addType(UnionTypeVar{types}); } @@ -4708,18 +5448,24 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { std::vector types; for (AstType* ann : un->types) - types.push_back(resolveType(scope, *ann, DEPRECATED_canBeGeneric)); + types.push_back(resolveType(scope, *ann)); return addType(IntersectionTypeVar{types}); } - else if (annotation.is()) + else if (const auto& tsb = annotation.as()) { - return addType(ErrorTypeVar{}); + return singletonType(tsb->value); } + else if (const auto& tss = annotation.as()) + { + return singletonType(std::string(tss->value.data, tss->value.size)); + } + else if (annotation.is()) + return errorRecoveryType(scope); else { reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } } @@ -4744,9 +5490,10 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypeList TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation) { + TypePackId result; if (const AstTypePackVariadic* variadic = annotation.as()) { - return addTypePack(TypePackVar{VariadicTypePack{resolveType(scope, *variadic->variadicType)}}); + result = addTypePack(TypePackVar{VariadicTypePack{resolveType(scope, *variadic->variadicType)}}); } else if (const AstTypePackGeneric* generic = annotation.as()) { @@ -4760,149 +5507,114 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack else reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); - return addTypePack(TypePackVar{Unifiable::Error{}}); + result = errorRecoveryTypePack(scope); } + else + { + result = *genericTy; + } + } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + std::vector types; - return *genericTy; + for (auto type : explicitTp->typeList.types) + types.push_back(resolveType(scope, *type)); + + if (auto tailType = explicitTp->typeList.tailType) + result = addTypePack(types, resolveTypePack(scope, *tailType)); + else + result = addTypePack(types); } else { ice("Unknown AstTypePack kind"); } + + currentModule->astResolvedTypePacks[&annotation] = result; + return result; } -bool ApplyTypeFunction::isDirty(TypeId ty) +TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, + const std::vector& typePackParams, const Location& location) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics. - if (get(ty)) - return true; - else if (const FreeTypeVar* ftv = get(ty)) - { - if (FFlag::LuauRecursiveTypeParameterRestriction && ftv->forwardedTypeAlias) - encounteredForwardedType = true; - return false; - } - else - return false; -} - -bool ApplyTypeFunction::isDirty(TypePackId tp) -{ - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics. - if (get(tp)) - return true; - else - return false; -} - -TypeId ApplyTypeFunction::clean(TypeId ty) -{ - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics by free type variables. - TypeId& arg = arguments[ty]; - if (arg) - return arg; - else - return addType(FreeTypeVar{level}); -} - -TypePackId ApplyTypeFunction::clean(TypePackId tp) -{ - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics by free type variables. - return addTypePack(FreeTypePack{level}); -} - -TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location) -{ - if (tf.typeParams.empty()) + if (tf.typeParams.empty() && tf.typePackParams.empty()) return tf.type; - applyTypeFunction.arguments.clear(); + ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes}; + for (size_t i = 0; i < tf.typeParams.size(); ++i) - applyTypeFunction.arguments[tf.typeParams[i]] = typeParams[i]; - applyTypeFunction.currentModule = currentModule; - applyTypeFunction.level = scope->level; - applyTypeFunction.encounteredForwardedType = false; + applyTypeFunction.typeArguments[tf.typeParams[i].ty] = typeParams[i]; + + for (size_t i = 0; i < tf.typePackParams.size(); ++i) + applyTypeFunction.typePackArguments[tf.typePackParams[i].tp] = typePackParams[i]; + std::optional maybeInstantiated = applyTypeFunction.substitute(tf.type); if (!maybeInstantiated.has_value()) { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } - if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType) + if (applyTypeFunction.encounteredForwardedType) { reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}}); - return errorType; + return errorRecoveryType(scope); } TypeId instantiated = *maybeInstantiated; - if (FFlag::LuauCloneCorrectlyBeforeMutatingTableType) + TypeId target = follow(instantiated); + bool needsClone = follow(tf.type) == target; + bool shouldMutate = getTableType(tf.type); + TableTypeVar* ttv = getMutableTableType(target); + + if (shouldMutate && ttv && needsClone) { - // TODO: CLI-46926 it's a bad idea to rename the type whether we follow through the BoundTypeVar or not - TypeId target = FFlag::LuauFollowInTypeFunApply ? follow(instantiated) : instantiated; - - bool needsClone = follow(tf.type) == target; - TableTypeVar* ttv = getMutableTableType(target); - - if (ttv && needsClone) + // Substitution::clone is a shallow clone. If this is a metatable type, we + // want to mutate its table, so we need to explicitly clone that table as + // well. If we don't, we will mutate another module's type surface and cause + // a use-after-free. + if (get(target)) { - // Substitution::clone is a shallow clone. If this is a metatable type, we - // want to mutate its table, so we need to explicitly clone that table as - // well. If we don't, we will mutate another module's type surface and cause - // a use-after-free. - if (get(target)) - { - instantiated = applyTypeFunction.clone(tf.type); - MetatableTypeVar* mtv = getMutable(instantiated); - mtv->table = applyTypeFunction.clone(mtv->table); - ttv = getMutable(mtv->table); - } - if (get(target)) - { - instantiated = applyTypeFunction.clone(tf.type); - ttv = getMutable(instantiated); - } + instantiated = applyTypeFunction.clone(tf.type); + MetatableTypeVar* mtv = getMutable(instantiated); + mtv->table = applyTypeFunction.clone(mtv->table); + ttv = getMutable(mtv->table); } - - if (ttv) + if (get(target)) { - ttv->instantiatedTypeParams = typeParams; + instantiated = applyTypeFunction.clone(tf.type); + ttv = getMutable(instantiated); } } - else - { - if (TableTypeVar* ttv = getMutableTableType(instantiated)) - { - if (follow(tf.type) == instantiated) - { - // This can happen if a type alias has generics that it does not use at all. - // ex type FooBar = { a: number } - instantiated = applyTypeFunction.clone(tf.type); - ttv = getMutableTableType(instantiated); - } - ttv->instantiatedTypeParams = typeParams; - } + if (shouldMutate && ttv) + { + ttv->instantiatedTypeParams = typeParams; + ttv->instantiatedTypePackParams = typePackParams; + ttv->definitionModuleName = currentModuleName; } return instantiated; } -std::pair, std::vector> TypeChecker::createGenericTypes( - const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) +GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, + const AstArray& genericNames, const AstArray& genericPackNames, bool useCache) { - std::vector generics; - for (const AstName& generic : genericNames) + LUAU_ASSERT(scope->parent); + + const TypeLevel level = levelOpt.value_or(scope->level); + + std::vector generics; + + for (const AstGenericType& generic : genericNames) { - Name n = generic.value; + std::optional defaultValue; + + if (generic.defaultValue) + defaultValue = resolveType(scope, *generic.defaultValue); + + Name n = generic.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic typevars have the same name. @@ -4912,15 +5624,33 @@ std::pair, std::vector> TypeChecker::createGener reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypeId g = addType(Unifiable::Generic{scope->level, n}); - generics.push_back(g); + TypeId g; + if (useCache) + { + TypeId& cached = scope->parent->typeAliasTypeParameters[n]; + if (!cached) + cached = addType(GenericTypeVar{level, n}); + g = cached; + } + else + { + g = addType(Unifiable::Generic{level, n}); + } + + generics.push_back({g, defaultValue}); scope->privateTypeBindings[n] = TypeFun{{}, g}; } - std::vector genericPacks; - for (const AstName& genericPack : genericPackNames) + std::vector genericPacks; + + for (const AstGenericTypePack& genericPack : genericPackNames) { - Name n = genericPack.value; + std::optional defaultValue; + + if (genericPack.defaultValue) + defaultValue = resolveTypePack(scope, *genericPack.defaultValue); + + Name n = genericPack.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic typevars have the same name. @@ -4930,50 +5660,188 @@ std::pair, std::vector> TypeChecker::createGener reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypePackId g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); - genericPacks.push_back(g); - scope->privateTypePackBindings[n] = g; + TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; + if (!cached) + cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); + + genericPacks.push_back({cached, defaultValue}); + scope->privateTypePackBindings[n] = cached; } return {generics, genericPacks}; } +void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate) +{ + const LValue* target = &lvalue; + std::optional key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type. + + auto ty = resolveLValue(scope, *target); + if (!ty) + return; // Do nothing. An error was already reported. + + // If the provided lvalue is a local or global, then that's without a doubt the target. + // However, if there is a base lvalue, then we'll want that to be the target iff the base is a union type. + if (auto base = baseof(lvalue)) + { + std::optional baseTy = resolveLValue(scope, *base); + if (baseTy && get(follow(*baseTy))) + { + ty = baseTy; + target = base; + key = lvalue; + } + } + + // If we do not have a key, it means we're not trying to discriminate anything, so it's a simple matter of just filtering for a subset. + if (!key) + { + auto [result, ok] = filterMap(*ty, predicate); + if (FFlag::LuauUnknownAndNeverType) + { + addRefinement(refis, *target, *result); + } + else + { + if (ok) + addRefinement(refis, *target, *result); + else + addRefinement(refis, *target, errorRecoveryType(scope)); + } + + return; + } + + // Otherwise, we'll want to walk each option of ty, get its index type, and filter that. + auto utv = get(follow(*ty)); + LUAU_ASSERT(utv); + + std::unordered_set viableTargetOptions; + std::unordered_set viableChildOptions; // There may be additional refinements that apply. We add those here too. + + for (TypeId option : utv) + { + std::optional discriminantTy; + if (auto field = Luau::get(*key)) // need to fully qualify Luau::get because of ADL. + discriminantTy = getIndexTypeFromType(scope, option, field->key, Location(), /* addErrors= */ false); + else + LUAU_ASSERT(!"Unhandled LValue alternative?"); + + if (!discriminantTy) + return; // Do nothing. An error was already reported, as per usual. + + auto [result, ok] = filterMap(*discriminantTy, predicate); + if (FFlag::LuauUnknownAndNeverType) + { + if (!get(*result)) + { + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); + } + } + else + { + if (ok) + { + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); + } + } + } + + auto intoType = [this](const std::unordered_set& s) -> std::optional { + if (s.empty()) + return std::nullopt; + + // TODO: allocate UnionTypeVar and just normalize. + std::vector options(s.begin(), s.end()); + if (options.size() == 1) + return options[0]; + + return addType(UnionTypeVar{std::move(options)}); + }; + + if (std::optional viableTargetType = intoType(viableTargetOptions)) + addRefinement(refis, *target, *viableTargetType); + + if (std::optional viableChildType = intoType(viableChildOptions)) + addRefinement(refis, lvalue, *viableChildType); +} + std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue) { - std::string path = toString(lvalue); - auto [symbol, keys] = getFullName(lvalue); + // We want to be walking the Scope parents. + // We'll also want to walk up the LValue path. As we do this, we need to save each LValue because we must walk back. + // For example: + // There exists an entry t.x. + // We are asked to look for t.x.y. + // We need to search in the provided Scope. Find t.x.y first. + // We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x. + // If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate. + const Symbol symbol = getBaseSymbol(lvalue); ScopePtr currentScope = scope; while (currentScope) { - if (auto it = currentScope->refinements.find(path); it != currentScope->refinements.end()) - return it->second; + std::optional found; - // Should not be using scope->lookup. This is already recursive. - if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) + const LValue* topLValue = nullptr; + + for (topLValue = &lvalue; topLValue; topLValue = baseof(*topLValue)) { - std::optional currentTy = it->second.typeId; - - for (std::string key : keys) + if (auto it = currentScope->refinements.find(*topLValue); it != currentScope->refinements.end()) { - // TODO: This function probably doesn't need Location at all, or at least should hide the argument. - currentTy = getIndexTypeFromType(scope, *currentTy, key, Location(), false); - if (!currentTy) - break; + found = it->second; + break; } - - return currentTy; } - currentScope = currentScope->parent; + if (!found) + { + // Should not be using scope->lookup. This is already recursive. + if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) + found = it->second.typeId; + else + { + // Nothing exists in this Scope. Just skip and try the parent one. + currentScope = currentScope->parent; + continue; + } + } + + // We need to walk the l-value path in reverse, so we collect components into a vector + std::vector childKeys; + + for (const LValue* curr = &lvalue; curr != topLValue; curr = baseof(*curr)) + childKeys.push_back(curr); + + for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) + { + const LValue& key = **it; + + // Symbol can happen. Skip. + if (get(key)) + continue; + else if (auto field = get(key)) + { + found = getIndexTypeFromType(scope, *found, field->key, Location(), /* addErrors= */ false); + if (!found) + return std::nullopt; // Turns out this type doesn't have the property at all. We're done. + } + else + LUAU_ASSERT(!"New LValue alternative not handled here."); + } + + return found; } + // No entry for it at all. Can happen when LValue root is a global. return std::nullopt; } std::optional TypeChecker::resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue) { - if (auto it = refis.find(toString(lvalue)); it != refis.end()) + if (auto it = refis.find(lvalue); it != refis.end()) return it->second; else return resolveLValue(scope, lvalue); @@ -4987,97 +5855,63 @@ static bool isUndecidable(TypeId ty) return get(ty) || get(ty) || get(ty); } -ErrorVec TypeChecker::resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense) { - ErrorVec errVec; - resolve(predicates, errVec, scope->refinements, scope, sense); - return errVec; + resolve(predicates, scope->refinements, scope, sense); } -void TypeChecker::resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const PredicateVec& predicates, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { for (const Predicate& c : predicates) - resolve(c, errVec, refis, scope, sense, fromOr); + resolve(c, refis, scope, sense, fromOr); } -void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const Predicate& predicate, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { if (auto truthyP = get(predicate)) - resolve(*truthyP, errVec, refis, scope, sense, fromOr); + resolve(*truthyP, refis, scope, sense, fromOr); else if (auto andP = get(predicate)) - resolve(*andP, errVec, refis, scope, sense); + resolve(*andP, refis, scope, sense); else if (auto orP = get(predicate)) - resolve(*orP, errVec, refis, scope, sense); + resolve(*orP, refis, scope, sense); else if (auto notP = get(predicate)) - resolve(notP->predicates, errVec, refis, scope, !sense, fromOr); + resolve(notP->predicates, refis, scope, !sense, fromOr); else if (auto isaP = get(predicate)) - resolve(*isaP, errVec, refis, scope, sense); + resolve(*isaP, refis, scope, sense); else if (auto typeguardP = get(predicate)) - { - if (FFlag::LuauImprovedTypeGuardPredicate2) - resolve(*typeguardP, errVec, refis, scope, sense); - else - DEPRECATED_resolve(*typeguardP, errVec, refis, scope, sense); - } - else if (auto eqP = get(predicate); eqP && FFlag::LuauEqConstraint) - resolve(*eqP, errVec, refis, scope, sense); + resolve(*typeguardP, refis, scope, sense); + else if (auto eqP = get(predicate)) + resolve(*eqP, refis, scope, sense); else ice("Unhandled predicate kind"); } -void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const TruthyPredicate& truthyP, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { - auto predicate = [sense](TypeId option) -> std::optional { - if (isUndecidable(option) || isBoolean(option) || isNil(option) != sense) - return option; - - return std::nullopt; - }; - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (!ty) - return; - - // This is a hack. :( - // Without this, the expression 'a or b' might refine 'b' to be falsy. - // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. - if (fromOr) + if (ty && fromOr) return addRefinement(refis, truthyP.lvalue, *ty); - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, truthyP.lvalue, *result); + refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense, nilType)); } -void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense) { - if (FFlag::LuauOrPredicate) + if (!sense) { - if (!sense) - { - OrPredicate orP{ - {NotPredicate{std::move(andP.lhs)}}, - {NotPredicate{std::move(andP.rhs)}}, - }; + OrPredicate orP{ + {NotPredicate{std::move(andP.lhs)}}, + {NotPredicate{std::move(andP.rhs)}}, + }; - return resolve(orP, errVec, refis, scope, !sense); - } - - resolve(andP.lhs, errVec, refis, scope, sense); - resolve(andP.rhs, errVec, refis, scope, sense); + return resolve(orP, refis, scope, !sense); } - else - { - // And predicate is currently not resolvable when sense is false. 'not (a and b)' is synonymous with '(not a) or (not b)'. - // TODO: implement environment merging to permit this case. - if (!sense) - return; - resolve(andP.lhs, errVec, refis, scope, sense); - resolve(andP.rhs, errVec, refis, scope, sense); - } + resolve(andP.lhs, refis, scope, sense); + resolve(andP.rhs, refis, scope, sense); } -void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const OrPredicate& orP, RefinementMap& refis, const ScopePtr& scope, bool sense) { if (!sense) { @@ -5086,115 +5920,73 @@ void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMa {NotPredicate{std::move(orP.rhs)}}, }; - return resolve(andP, errVec, refis, scope, !sense); + return resolve(andP, refis, scope, !sense); } - ErrorVec discarded; - RefinementMap leftRefis; - resolve(orP.lhs, errVec, leftRefis, scope, sense); + resolve(orP.lhs, leftRefis, scope, sense); RefinementMap rightRefis; - resolve(orP.lhs, discarded, rightRefis, scope, !sense); - resolve(orP.rhs, errVec, rightRefis, scope, sense, true); // :( + resolve(orP.lhs, rightRefis, scope, !sense); + resolve(orP.rhs, rightRefis, scope, sense, true); // :( merge(refis, leftRefis); merge(refis, rightRefis); } -void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense) { auto predicate = [&](TypeId option) -> std::optional { - if (FFlag::LuauTypeGuardPeelsAwaySubclasses) - { - // This by itself is not truly enough to determine that A is stronger than B or vice versa. - // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. - // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) - bool optionIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); - bool targetIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); + // This by itself is not truly enough to determine that A is stronger than B or vice versa. + bool optionIsSubtype = canUnify(option, isaP.ty, scope, isaP.location).empty(); + bool targetIsSubtype = canUnify(isaP.ty, option, scope, isaP.location).empty(); - // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. - if (!optionIsSubtype && targetIsSubtype) + // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. + if (!optionIsSubtype && targetIsSubtype) + return sense ? isaP.ty : option; + + // If A is a subset of B, then if sense is true we pick A, otherwise we eliminate A. + if (optionIsSubtype && !targetIsSubtype) + return sense ? std::optional(option) : std::nullopt; + + // If neither has any relationship, we only return A if sense is false. + if (!optionIsSubtype && !targetIsSubtype) + return sense ? std::nullopt : std::optional(option); + + // If both are subtypes, then we're in one of the two situations: + // 1. Instanceâ‚ <: Instanceâ‚‚ ∧ Instanceâ‚‚ <: Instanceâ‚ + // 2. any <: Instance ∧ Instance <: any + // Right now, we have to look at the types to see if they were undecidables. + // By this point, we also know free tables are also subtypes and supertypes. + if (optionIsSubtype && targetIsSubtype) + { + // We can only have (any, Instance) because the rhs is never undecidable right now. + // So we can just return the right hand side immediately. + + // typeof(x) == "Instance" where x : any + auto ttv = get(option); + if (isUndecidable(option) || (ttv && ttv->state == TableState::Free)) return sense ? isaP.ty : option; - // If A is a subset of B, then if sense is true we pick A, otherwise we eliminate A. - if (optionIsSubtype && !targetIsSubtype) - return sense ? std::optional(option) : std::nullopt; - - // If neither has any relationship, we only return A if sense is false. - if (!optionIsSubtype && !targetIsSubtype) - return sense ? std::nullopt : std::optional(option); - - // If both are subtypes, then we're in one of the two situations: - // 1. Instanceâ‚ <: Instanceâ‚‚ ∧ Instanceâ‚‚ <: Instanceâ‚ - // 2. any <: Instance ∧ Instance <: any - // Right now, we have to look at the types to see if they were undecidables. - // By this point, we also know free tables are also subtypes and supertypes. - if (optionIsSubtype && targetIsSubtype) - { - // We can only have (any, Instance) because the rhs is never undecidable right now. - // So we can just return the right hand side immediately. - - // typeof(x) == "Instance" where x : any - auto ttv = get(option); - if (isUndecidable(option) || (ttv && ttv->state == TableState::Free)) - return sense ? isaP.ty : option; - - // typeof(x) == "Instance" where x : Instance - if (sense) - return isaP.ty; - } - } - else if (FFlag::LuauImprovedTypeGuardPredicate2) - { - auto lctv = get(option); - auto rctv = get(isaP.ty); - - if (isSubclass(lctv, rctv) == sense) - return option; - - if (isSubclass(rctv, lctv) == sense) - return isaP.ty; - - if (canUnify(option, isaP.ty, isaP.location).empty() == sense) + // typeof(x) == "Instance" where x : Instance + if (sense) return isaP.ty; } - else - { - auto lctv = get(option); - auto rctv = get(isaP.ty); - if (lctv && rctv) - { - if (isSubclass(lctv, rctv) == sense) - return option; - else if (isSubclass(rctv, lctv) == sense) - return isaP.ty; - } - } - - return std::nullopt; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; }; - std::optional ty = resolveLValue(refis, scope, isaP.lvalue); - if (!ty) - return; - - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, isaP.lvalue, *result); - else - { - addRefinement(refis, isaP.lvalue, errorType); - errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); - } + refineLValue(isaP.lvalue, refis, scope, predicate); } -void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& refis, const ScopePtr& scope, bool sense) { // Rewrite the predicate 'type(foo) == "vector"' to be 'typeof(foo) == "Vector3"'. They're exactly identical. // This allows us to avoid writing in edge cases. if (!typeguardP.isTypeof && typeguardP.kind == "vector") - return resolve(TypeGuardPredicate{std::move(typeguardP.lvalue), typeguardP.location, "Vector3", true}, errVec, refis, scope, sense); + return resolve(TypeGuardPredicate{std::move(typeguardP.lvalue), typeguardP.location, "Vector3", true}, refis, scope, sense); std::optional ty = resolveLValue(refis, scope, typeguardP.lvalue); if (!ty) @@ -5207,150 +5999,149 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return; } - using ConditionFunc = bool(TypeId); - using SenseToTypeIdPredicate = std::function; - auto mkFilter = [](ConditionFunc f, std::optional other = std::nullopt) -> SenseToTypeIdPredicate { - return [f, other](bool sense) -> TypeIdPredicate { - return [f, other, sense](TypeId ty) -> std::optional { - if (f(ty) == sense) - return ty; + auto refine = [this, &lvalue = typeguardP.lvalue, &refis, &scope, sense](bool(f)(TypeId), std::optional mapsTo = std::nullopt) { + TypeIdPredicate predicate = [f, mapsTo, sense](TypeId ty) -> std::optional { + if (FFlag::LuauUnknownAndNeverType && sense && get(ty)) + return mapsTo.value_or(ty); - if (isUndecidable(ty)) - return other.value_or(ty); + if (f(ty) == sense) + return ty; - return std::nullopt; - }; + if (isUndecidable(ty)) + return mapsTo.value_or(ty); + + return std::nullopt; }; + + refineLValue(lvalue, refis, scope, predicate); }; // Note: "vector" never happens here at this point, so we don't have to write something for it. - // clang-format off - static const std::unordered_map primitives{ - // Trivial primitives. - {"nil", mkFilter(isNil, nilType)}, // This can still happen when sense is false! - {"string", mkFilter(isString, stringType)}, - {"number", mkFilter(isNumber, numberType)}, - {"boolean", mkFilter(isBoolean, booleanType)}, - {"thread", mkFilter(isThread, threadType)}, - - // Non-trivial primitives. - {"table", mkFilter([](TypeId ty) -> bool { return isTableIntersection(ty) || get(ty) || get(ty); })}, - {"function", mkFilter([](TypeId ty) -> bool { return isOverloadedFunction(ty) || get(ty); })}, - - // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. - {"userdata", mkFilter([](TypeId ty) -> bool { return get(ty); })}, - }; - // clang-format on - - if (auto it = primitives.find(typeguardP.kind); it != primitives.end()) + if (typeguardP.kind == "nil") + return refine(isNil, nilType); // This can still happen when sense is false! + else if (typeguardP.kind == "string") + return refine(isString, stringType); + else if (typeguardP.kind == "number") + return refine(isNumber, numberType); + else if (typeguardP.kind == "boolean") + return refine(isBoolean, booleanType); + else if (typeguardP.kind == "thread") + return refine(isThread, threadType); + else if (typeguardP.kind == "table") { - if (std::optional result = filterMap(*ty, it->second(sense))) - addRefinement(refis, typeguardP.lvalue, *result); - else - { - addRefinement(refis, typeguardP.lvalue, errorType); - if (sense) - errVec.push_back( - TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); - } - - return; + return refine([](TypeId ty) -> bool { + return isTableIntersection(ty) || get(ty) || get(ty); + }); + } + else if (typeguardP.kind == "function") + { + return refine([](TypeId ty) -> bool { + return isOverloadedFunction(ty) || get(ty); + }); + } + else if (typeguardP.kind == "userdata") + { + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. + return refine([](TypeId ty) -> bool { + return get(ty); + }); } - auto fail = [&](const TypeErrorData& err) { - errVec.push_back(TypeError{typeguardP.location, err}); - addRefinement(refis, typeguardP.lvalue, errorType); - }; - if (!typeguardP.isTypeof) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); auto typeFun = globalScope->lookupType(typeguardP.kind); - if (!typeFun || !typeFun->typeParams.empty()) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + if (!typeFun || !typeFun->typeParams.empty() || !typeFun->typePackParams.empty()) + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); TypeId type = follow(typeFun->type); // We're only interested in the root class of any classes. if (auto ctv = get(type); !ctv || ctv->parent) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. // Until then, we rewrite this to be the same as using IsA. - return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense); + return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, refis, scope, sense); } -void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) -{ - if (!sense) - return; - - static std::vector primitives{ - "string", "number", "boolean", "nil", "thread", - "table", // no op. Requires special handling. - "function", // no op. Requires special handling. - "userdata", // no op. Requires special handling. - }; - - if (auto typeFun = globalScope->lookupType(typeguardP.kind); typeFun && typeFun->typeParams.empty()) - { - if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end()) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - else if (typeguardP.isTypeof) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - } -} - -void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. - auto options = [](TypeId ty) -> std::vector { if (auto utv = get(follow(ty))) return std::vector(begin(utv), end(utv)); return {ty}; }; - if (FFlag::LuauWeakEqConstraint) - { - if (!sense && isNil(eqP.type)) - resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false); - - return; - } - - std::optional ty = resolveLValue(refis, scope, eqP.lvalue); - if (!ty) - return; - - std::vector lhs = options(*ty); std::vector rhs = options(eqP.type); - if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) - { - addRefinement(refis, eqP.lvalue, eqP.type); - return; - } - else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - std::unordered_set set; - for (TypeId left : lhs) - { - for (TypeId right : rhs) + auto predicate = [&](TypeId option) -> std::optional { + if (!sense && isNil(eqP.type)) + return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; + + if (maybeSingleton(eqP.type)) { - // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) - set.insert(left); + if (FFlag::LuauImplicitElseRefinement) + { + bool optionIsSubtype = canUnify(option, eqP.type, scope, eqP.location).empty(); + bool targetIsSubtype = canUnify(eqP.type, option, scope, eqP.location).empty(); + + // terminology refresher: + // - option is the type of the expression `x`, and + // - eqP.type is the type of the expression `"hello"` + // + // "hello" == x where + // x : "hello" | "world" -> x : "hello" + // x : number | string -> x : "hello" + // x : number -> x : never + // + // "hello" ~= x where + // x : "hello" | "world" -> x : "world" + // x : number | string -> x : number | string + // x : number -> x : number + + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional nope = std::nullopt; + + if (sense) + { + if (optionIsSubtype && !targetIsSubtype) + return option; + else if (!optionIsSubtype && targetIsSubtype) + return FFlag::LuauTypeInferMissingFollows ? follow(eqP.type) : eqP.type; + else if (!optionIsSubtype && !targetIsSubtype) + return nope; + else if (optionIsSubtype && targetIsSubtype) + return FFlag::LuauTypeInferMissingFollows ? follow(eqP.type) : eqP.type; + } + else + { + bool isOptionSingleton = get(option); + if (!isOptionSingleton) + return option; + else if (optionIsSubtype && targetIsSubtype) + return nope; + } + } + else + { + if (!sense || canUnify(eqP.type, option, scope, eqP.location).empty()) + return sense ? eqP.type : option; + + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; + } } - } - if (set.empty()) - return; + return option; + }; - std::vector viable(set.begin(), set.end()); - TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); - addRefinement(refis, eqP.lvalue, result); + refineLValue(eqP.lvalue, refis, scope, predicate); } bool TypeChecker::isNonstrictMode() const @@ -5364,9 +6155,16 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp TypePack* expectedPack = getMutable(expectedTypePack); LUAU_ASSERT(expectedPack); for (size_t i = 0; i < expectedLength; ++i) - expectedPack->head.push_back(FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + expectedPack->head.push_back(freshType(scope)); - unify(expectedTypePack, tp, location); + size_t oldErrorsSize = currentModule->errors.size(); + + unify(tp, expectedTypePack, scope, location); + + // HACK: tryUnify would undo the changes to the expectedTypePack if the length mismatches, but + // we want to tie up free types to be error types, so we do this instead. + if (FFlag::LuauUnknownAndNeverType) + currentModule->errors.resize(oldErrorsSize); for (TypeId& tp : expectedPack->head) tp = follow(tp); @@ -5379,119 +6177,4 @@ std::vector> TypeChecker::getScopes() const return currentModule->scopes; } -Scope::Scope(TypePackId returnType) - : parent(nullptr) - , returnType(returnType) - , level(TypeLevel()) -{ -} - -Scope::Scope(const ScopePtr& parent, int subLevel) - : parent(parent) - , returnType(parent->returnType) - , level(parent->level.incr()) -{ - level.subLevel = subLevel; -} - -std::optional Scope::lookup(const Symbol& name) -{ - Scope* scope = this; - - while (scope) - { - auto it = scope->bindings.find(name); - if (it != scope->bindings.end()) - return it->second.typeId; - - scope = scope->parent.get(); - } - - return std::nullopt; -} - -std::optional Scope::lookupType(const Name& name) -{ - const Scope* scope = this; - while (true) - { - auto it = scope->exportedTypeBindings.find(name); - if (it != scope->exportedTypeBindings.end()) - return it->second; - - it = scope->privateTypeBindings.find(name); - if (it != scope->privateTypeBindings.end()) - return it->second; - - if (scope->parent) - scope = scope->parent.get(); - else - return std::nullopt; - } -} - -std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) -{ - const Scope* scope = this; - while (scope) - { - auto it = scope->importedTypeBindings.find(moduleAlias); - if (it == scope->importedTypeBindings.end()) - { - scope = scope->parent.get(); - continue; - } - - auto it2 = it->second.find(name); - if (it2 == it->second.end()) - { - scope = scope->parent.get(); - continue; - } - - return it2->second; - } - - return std::nullopt; -} - -std::optional Scope::lookupPack(const Name& name) -{ - const Scope* scope = this; - while (true) - { - auto it = scope->privateTypePackBindings.find(name); - if (it != scope->privateTypePackBindings.end()) - return it->second; - - if (scope->parent) - scope = scope->parent.get(); - else - return std::nullopt; - } -} - -std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) -{ - Scope* scope = this; - - while (scope) - { - for (const auto& [n, binding] : scope->bindings) - { - if (n.local && n.local->name == name.c_str()) - return binding; - else if (n.global.value && n.global == name.c_str()) - return binding; - } - - scope = scope->parent.get(); - - if (!traverseScopeChain) - break; - } - - return std::nullopt; -} - } // namespace Luau diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 5970f304..0f75c3ef 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -1,11 +1,23 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypePack.h" +#include "Luau/Error.h" +#include "Luau/TxnLog.h" + #include +LUAU_FASTFLAGVARIABLE(LuauTxnLogTypePackIterator, false) + namespace Luau { +BlockedTypePack::BlockedTypePack() + : index(++nextIndex) +{ +} + +size_t BlockedTypePack::nextIndex = 0; + TypePackVar::TypePackVar(const TypePackVariant& tp) : ty(tp) { @@ -34,15 +46,31 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) return *this; } +TypePackVar& TypePackVar::operator=(const TypePackVar& rhs) +{ + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); + + reassign(rhs); + + return *this; +} + TypePackIterator::TypePackIterator(TypePackId typePack) - : currentTypePack(follow(typePack)) - , tp(get(currentTypePack)) + : TypePackIterator(typePack, TxnLog::empty()) +{ +} + +TypePackIterator::TypePackIterator(TypePackId typePack, const TxnLog* log) + : currentTypePack(FFlag::LuauTxnLogTypePackIterator ? log->follow(typePack) : follow(typePack)) + , tp(FFlag::LuauTxnLogTypePackIterator ? log->get(currentTypePack) : get(currentTypePack)) , currentIndex(0) + , log(log) { while (tp && tp->head.empty()) { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; } } @@ -53,8 +81,9 @@ TypePackIterator& TypePackIterator::operator++() ++currentIndex; while (tp && currentIndex >= tp->head.size()) { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; + currentIndex = 0; } @@ -95,9 +124,14 @@ TypePackIterator begin(TypePackId tp) return TypePackIterator{tp}; } +TypePackIterator begin(TypePackId tp, const TxnLog* log) +{ + return TypePackIterator{tp, log}; +} + TypePackIterator end(TypePackId tp) { - return FFlag::LuauAddMissingFollow ? TypePackIterator{} : TypePackIterator{nullptr}; + return TypePackIterator{}; } bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) @@ -160,8 +194,15 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId follow(TypePackId tp) { - auto advance = [](TypePackId ty) -> std::optional { - if (const Unifiable::Bound* btv = get>(ty)) + return follow(tp, [](TypePackId t) { + return t; + }); +} + +TypePackId follow(TypePackId tp, std::function mapper) +{ + auto advance = [&mapper](TypePackId ty) -> std::optional { + if (const Unifiable::Bound* btv = get>(mapper(ty))) return btv->boundTo; else return std::nullopt; @@ -196,32 +237,46 @@ TypePackId follow(TypePackId tp) cycleTester = nullptr; if (tp == cycleTester) - throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); + throw InternalCompilerError("Luau::follow detected a TypeVar cycle!!"); } } } -size_t size(TypePackId tp) +size_t size(TypePackId tp, TxnLog* log) { - if (auto pack = get(FFlag::LuauAddMissingFollow ? follow(tp) : tp)) - return size(*pack); + tp = log ? log->follow(tp) : follow(tp); + if (auto pack = get(tp)) + return size(*pack, log); else return 0; } -size_t size(const TypePack& tp) +bool finite(TypePackId tp, TxnLog* log) +{ + tp = log ? log->follow(tp) : follow(tp); + + if (auto pack = get(tp)) + return pack->tail ? finite(*pack->tail, log) : true; + + if (get(tp)) + return false; + + return true; +} + +size_t size(const TypePack& tp, TxnLog* log) { size_t result = tp.head.size(); if (tp.tail) { - const TypePack* tail = get(FFlag::LuauAddMissingFollow ? follow(*tp.tail) : *tp.tail); + const TypePack* tail = get(log ? log->follow(*tp.tail) : follow(*tp.tail)); if (tail) - result += size(*tail); + result += size(*tail, log); } return result; } -std::optional first(TypePackId tp) +std::optional first(TypePackId tp, bool ignoreHiddenVariadics) { auto it = begin(tp); auto endIter = end(tp); @@ -231,13 +286,23 @@ std::optional first(TypePackId tp) if (auto tail = it.tail()) { - if (auto vtp = get(*tail)) + if (auto vtp = get(*tail); vtp && (!vtp->hidden || !ignoreHiddenVariadics)) return vtp->ty; } return std::nullopt; } +TypePackVar* asMutable(TypePackId tp) +{ + return const_cast(tp); +} + +TypePack* asMutable(const TypePack* tp) +{ + return const_cast(tp); +} + bool isEmpty(TypePackId tp) { tp = follow(tp); @@ -264,14 +329,70 @@ std::pair, std::optional> flatten(TypePackId tp) return {res, iter.tail()}; } -TypePackVar* asMutable(TypePackId tp) +std::pair, std::optional> flatten(TypePackId tp, const TxnLog& log) { - return const_cast(tp); + tp = log.follow(tp); + + std::vector flattened; + std::optional tail = std::nullopt; + + TypePackIterator it(tp, &log); + + for (; it != end(tp); ++it) + { + flattened.push_back(*it); + } + + tail = it.tail(); + + return {flattened, tail}; } -TypePack* asMutable(const TypePack* tp) +bool isVariadic(TypePackId tp) { - return const_cast(tp); + return isVariadic(tp, *TxnLog::empty()); +} + +bool isVariadic(TypePackId tp, const TxnLog& log) +{ + std::optional tail = flatten(tp, log).second; + + if (!tail) + return false; + + return isVariadicTail(*tail, log); +} + +bool isVariadicTail(TypePackId tp, const TxnLog& log, bool includeHiddenVariadics) +{ + if (log.get(tp)) + return true; + + if (auto vtp = log.get(tp); vtp && (includeHiddenVariadics || !vtp->hidden)) + return true; + + return false; +} + +bool containsNever(TypePackId tp) +{ + auto it = begin(tp); + auto endIt = end(tp); + + while (it != endIt) + { + if (get(follow(*it))) + return true; + ++it; + } + + if (auto tail = it.tail()) + { + if (auto vtp = get(*tail); vtp && get(follow(vtp->ty))) + return true; + } + + return false; } } // namespace Luau diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index b9f50978..7478ac22 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -1,46 +1,34 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeUtils.h" +#include "Luau/Normalize.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" -LUAU_FASTFLAG(LuauStringMetatable) +#include namespace Luau { -std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globalScope, TypeId type, std::string entry, Location location) +std::optional findMetatableEntry( + NotNull singletonTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location) { type = follow(type); - if (!FFlag::LuauStringMetatable) - { - if (const PrimitiveTypeVar* primType = get(type)) - { - if (primType->type != PrimitiveTypeVar::String || "__index" != entry) - return std::nullopt; - - auto it = globalScope->bindings.find(AstName{"string"}); - if (it != globalScope->bindings.end()) - return it->second.typeId; - else - return std::nullopt; - } - } - - std::optional metatable = getMetatable(type); + std::optional metatable = getMetatable(type, singletonTypes); if (!metatable) return std::nullopt; TypeId unwrapped = follow(*metatable); if (get(unwrapped)) - return singletonTypes.anyType; + return singletonTypes->anyType; const TableTypeVar* mtt = getTableType(unwrapped); if (!mtt) { - errors.push_back(TypeError{location, GenericError{"Metatable was not a table."}}); + errors.push_back(TypeError{location, GenericError{"Metatable was not a table"}}); return std::nullopt; } @@ -51,7 +39,8 @@ std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globa return std::nullopt; } -std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const ScopePtr& globalScope, TypeId ty, Name name, Location location) +std::optional findTablePropertyRespectingMeta( + NotNull singletonTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location) { if (get(ty)) return ty; @@ -63,10 +52,17 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc return it->second.type; } - std::optional mtIndex = findMetatableEntry(errors, globalScope, ty, "__index", location); + std::optional mtIndex = findMetatableEntry(singletonTypes, errors, ty, "__index", location); + int count = 0; while (mtIndex) { TypeId index = follow(*mtIndex); + + if (count >= 100) + return std::nullopt; + + ++count; + if (const auto& itt = getTableType(index)) { const auto& fit = itt->props.find(name); @@ -75,21 +71,216 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc } else if (const auto& itf = get(index)) { - std::optional r = first(follow(itf->retType)); + std::optional r = first(follow(itf->retTypes)); if (!r) - return singletonTypes.nilType; + return singletonTypes->nilType; else return *r; } else if (get(index)) - return singletonTypes.anyType; + return singletonTypes->anyType; else errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); - mtIndex = findMetatableEntry(errors, globalScope, *mtIndex, "__index", location); + mtIndex = findMetatableEntry(singletonTypes, errors, *mtIndex, "__index", location); } return std::nullopt; } +std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics) +{ + size_t minCount = 0; + size_t optionalCount = 0; + + auto it = begin(tp, log); + auto endIter = end(tp); + + while (it != endIter) + { + TypeId ty = *it; + if (isOptional(ty)) + ++optionalCount; + else + { + minCount += optionalCount; + optionalCount = 0; + minCount++; + } + + ++it; + } + + if (it.tail() && isVariadicTail(*it.tail(), *log, includeHiddenVariadics)) + return {minCount, std::nullopt}; + else + return {minCount, minCount + optionalCount}; +} + +TypePack extendTypePack(TypeArena& arena, NotNull singletonTypes, TypePackId pack, size_t length) +{ + TypePack result; + + while (true) + { + pack = follow(pack); + + if (const TypePack* p = get(pack)) + { + size_t i = 0; + while (i < p->head.size() && result.head.size() < length) + { + result.head.push_back(p->head[i]); + ++i; + } + + if (result.head.size() == length) + { + if (i == p->head.size()) + result.tail = p->tail; + else + { + TypePackId newTail = arena.addTypePack(TypePack{}); + TypePack* newTailPack = getMutable(newTail); + + newTailPack->head.insert(newTailPack->head.begin(), p->head.begin() + i, p->head.end()); + newTailPack->tail = p->tail; + + result.tail = newTail; + } + + return result; + } + else if (p->tail) + { + pack = *p->tail; + continue; + } + else + { + // There just aren't enough types in this pack to satisfy the request. + return result; + } + } + else if (const VariadicTypePack* vtp = get(pack)) + { + while (result.head.size() < length) + result.head.push_back(vtp->ty); + result.tail = pack; + return result; + } + else if (FreeTypePack* ftp = getMutable(pack)) + { + // If we need to get concrete types out of a free pack, we choose to + // interpret this as proof that the pack must have at least 'length' + // elements. We mint fresh types for each element we're extracting + // and rebind the free pack to be a TypePack containing them. We + // also have to create a new tail. + + TypePack newPack; + newPack.tail = arena.freshTypePack(ftp->scope); + + while (result.head.size() < length) + { + newPack.head.push_back(arena.freshType(ftp->scope)); + result.head.push_back(newPack.head.back()); + } + + asMutable(pack)->ty.emplace(std::move(newPack)); + + return result; + } + else if (const Unifiable::Error* etp = getMutable(pack)) + { + while (result.head.size() < length) + result.head.push_back(singletonTypes->errorRecoveryType()); + + result.tail = pack; + return result; + } + else + { + // If the pack is blocked or generic, we can't extract. + // Return whatever we've got with this pack as the tail. + result.tail = pack; + return result; + } + } +} + +std::vector reduceUnion(const std::vector& types) +{ + std::vector result; + for (TypeId t : types) + { + t = follow(t); + if (get(t)) + continue; + + if (get(t) || get(t)) + return {t}; + + if (const UnionTypeVar* utv = get(t)) + { + for (TypeId ty : utv) + { + ty = follow(ty); + if (get(ty)) + continue; + if (get(ty) || get(ty)) + return {ty}; + + if (result.end() == std::find(result.begin(), result.end(), ty)) + result.push_back(ty); + } + } + else if (std::find(result.begin(), result.end(), t) == result.end()) + result.push_back(t); + } + + return result; +} + +static std::optional tryStripUnionFromNil(TypeArena& arena, TypeId ty) +{ + if (const UnionTypeVar* utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + + std::vector result; + + for (TypeId option : utv) + { + if (!isNil(option)) + result.push_back(option); + } + + if (result.empty()) + return std::nullopt; + + return result.size() == 1 ? result[0] : arena.addType(UnionTypeVar{std::move(result)}); + } + + return std::nullopt; +} + +TypeId stripNil(NotNull singletonTypes, TypeArena& arena, TypeId ty) +{ + ty = follow(ty); + + if (get(ty)) + { + std::optional cleaned = tryStripUnionFromNil(arena, ty); + + // If there is no union option without 'nil' + if (!cleaned) + return singletonTypes->nilType; + + return follow(*cleaned); + } + + return follow(ty); +} + } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 111f4f53..159e7712 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -3,8 +3,10 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/ConstraintSolver.h" #include "Luau/DenseHash.h" #include "Luau/Error.h" +#include "Luau/RecursionCounter.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" @@ -17,37 +19,59 @@ #include #include +LUAU_FASTFLAG(DebugLuauFreezeArena) + LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) -LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) -LUAU_FASTFLAGVARIABLE(LuauToStringFollowsBoundTo, false) -LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAGVARIABLE(LuauStringMetatable, false) -LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) +LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) +LUAU_FASTFLAGVARIABLE(LuauNewLibraryTypeNames, false) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) namespace Luau { -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context); + +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context); + +static std::optional> magicFunctionMatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context); + +static std::optional> magicFunctionFind( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionFind(MagicFunctionCallContext context); TypeId follow(TypeId t) { - auto advance = [](TypeId ty) -> std::optional { - if (auto btv = get>(ty)) + return follow(t, [](TypeId t) { + return t; + }); +} + +TypeId follow(TypeId t, std::function mapper) +{ + auto advance = [&mapper](TypeId ty) -> std::optional { + if (auto btv = get>(mapper(ty))) return btv->boundTo; - else if (auto ttv = get(ty)) + else if (auto ttv = get(mapper(ty))) return ttv->boundTo; else return std::nullopt; }; - auto force = [](TypeId ty) { - if (auto ltv = FFlag::LuauAddMissingFollow ? get_if(&ty->ty) : get(ty)) + auto force = [&mapper](TypeId ty) { + if (auto ltv = get_if(&mapper(ty)->ty)) { TypeId res = ltv->thunk(); if (get(res)) - throw std::runtime_error("Lazy TypeVar cannot resolve to another Lazy TypeVar"); + throw InternalCompilerError("Lazy TypeVar cannot resolve to another Lazy TypeVar"); *asMutable(ty) = BoundTypeVar(res); } @@ -85,7 +109,7 @@ TypeId follow(TypeId t) cycleTester = nullptr; if (t == cycleTester) - throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); + throw InternalCompilerError("Luau::follow detected a TypeVar cycle!!"); } } } @@ -135,7 +159,13 @@ bool isNil(TypeId ty) bool isBoolean(TypeId ty) { - return isPrim(ty, PrimitiveTypeVar::Boolean); + if (isPrim(ty, PrimitiveTypeVar::Boolean) || get(get(follow(ty)))) + return true; + + if (auto utv = get(follow(ty))) + return std::all_of(begin(utv), end(utv), isBoolean); + + return false; } bool isNumber(TypeId ty) @@ -143,9 +173,32 @@ bool isNumber(TypeId ty) return isPrim(ty, PrimitiveTypeVar::Number); } +// Returns true when ty is a subtype of string bool isString(TypeId ty) { - return isPrim(ty, PrimitiveTypeVar::String); + ty = follow(ty); + + if (isPrim(ty, PrimitiveTypeVar::String) || get(get(ty))) + return true; + + if (auto utv = get(ty)) + return std::all_of(begin(utv), end(utv), isString); + + return false; +} + +// Returns true when ty is a supertype of string +bool maybeString(TypeId ty) +{ + ty = follow(ty); + + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) + return true; + + if (auto utv = get(ty)) + return std::any_of(begin(utv), end(utv), maybeString); + + return false; } bool isThread(TypeId ty) @@ -158,62 +211,25 @@ bool isOptional(TypeId ty) if (isNil(ty)) return true; - if (!get(follow(ty))) + ty = follow(ty); + + if (get(ty) || (FFlag::LuauUnknownAndNeverType && get(ty))) + return true; + + auto utv = get(ty); + if (!utv) return false; - std::unordered_set seen; - std::deque queue{ty}; - while (!queue.empty()) - { - TypeId current = follow(queue.front()); - queue.pop_front(); - - if (seen.count(current)) - continue; - - seen.insert(current); - - if (isNil(current)) - return true; - - if (auto u = get(current)) - { - for (TypeId option : u->options) - { - if (isNil(option)) - return true; - - queue.push_back(option); - } - } - } - - return false; + return std::any_of(begin(utv), end(utv), isOptional); } bool isTableIntersection(TypeId ty) { - if (FFlag::LuauImprovedTypeGuardPredicate2) - { - if (!get(follow(ty))) - return false; - - std::vector parts = flattenIntersection(ty); - return std::all_of(parts.begin(), parts.end(), getTableType); - } - else - { - if (const IntersectionTypeVar* itv = get(ty)) - { - for (TypeId part : itv->parts) - { - if (getTableType(follow(part))) - return true; - } - } - + if (!get(follow(ty))) return false; - } + + std::vector parts = flattenIntersection(ty); + return std::all_of(parts.begin(), parts.end(), getTableType); } bool isOverloadedFunction(TypeId ty) @@ -229,28 +245,32 @@ bool isOverloadedFunction(TypeId ty) return std::all_of(parts.begin(), parts.end(), isFunction); } -std::optional getMetatable(TypeId type) +std::optional getMetatable(TypeId type, NotNull singletonTypes) { + type = follow(type); + if (const MetatableTypeVar* mtType = get(type)) return mtType->metatable; else if (const ClassTypeVar* classType = get(type)) return classType->metatable; - else if (const PrimitiveTypeVar* primitiveType = get(type); - FFlag::LuauStringMetatable && primitiveType && primitiveType->metatable) + else if (isString(type)) { - LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); - return primitiveType->metatable; + auto ptv = get(singletonTypes->stringType); + LUAU_ASSERT(ptv && ptv->metatable); + return ptv->metatable; } - else - return std::nullopt; + + return std::nullopt; } const TableTypeVar* getTableType(TypeId type) { + type = follow(type); + if (const TableTypeVar* ttv = get(type)) return ttv; else if (const MetatableTypeVar* mtv = get(type)) - return get(mtv->table); + return get(follow(mtv->table)); else return nullptr; } @@ -267,7 +287,7 @@ const std::string* getName(TypeId type) { if (mtv->syntheticName) return &*mtv->syntheticName; - type = mtv->table; + type = follow(mtv->table); } if (auto ttv = get(type)) @@ -281,6 +301,29 @@ const std::string* getName(TypeId type) return nullptr; } +std::optional getDefinitionModuleName(TypeId type) +{ + type = follow(type); + + if (auto ttv = get(type)) + { + if (!ttv->definitionModuleName.empty()) + return ttv->definitionModuleName; + } + else if (auto ftv = get(type)) + { + if (ftv->definition) + return ftv->definition->definitionModuleName; + } + else if (auto ctv = get(type)) + { + if (!ctv->definitionModuleName.empty()) + return ctv->definitionModuleName; + } + + return std::nullopt; +} + bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub) { std::unordered_set superTypes; @@ -301,6 +344,8 @@ bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub) // then instantiate U if `isGeneric(U)` is true, and `maybeGeneric(T)` is false. bool isGeneric(TypeId ty) { + LUAU_ASSERT(!FFlag::LuauInstantiateInSubtyping); + ty = follow(ty); if (auto ftv = get(ty)) return ftv->generics.size() > 0 || ftv->genericPacks.size() > 0; @@ -312,9 +357,33 @@ bool isGeneric(TypeId ty) bool maybeGeneric(TypeId ty) { + LUAU_ASSERT(!FFlag::LuauInstantiateInSubtyping); + + if (FFlag::LuauMaybeGenericIntersectionTypes) + { + ty = follow(ty); + + if (get(ty)) + return true; + + if (auto ttv = get(ty)) + { + // TODO: recurse on table types CLI-39914 + (void)ttv; + return true; + } + + if (auto itv = get(ty)) + { + return std::any_of(begin(itv), end(itv), maybeGeneric); + } + + return isGeneric(ty); + } + ty = follow(ty); - if (auto ftv = get(ty)) - return FFlag::LuauRankNTypes || ftv->DEPRECATED_canBeGeneric; + if (get(ty)) + return true; else if (auto ttv = get(ty)) { // TODO: recurse on table types CLI-39914 @@ -325,49 +394,146 @@ bool maybeGeneric(TypeId ty) return isGeneric(ty); } -FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) - : argTypes(argTypes) - , retType(retType) - , definition(std::move(defn)) - , hasSelf(hasSelf) +bool maybeSingleton(TypeId ty) +{ + ty = follow(ty); + if (get(ty)) + return true; + if (const UnionTypeVar* utv = get(ty)) + for (TypeId option : utv) + if (get(follow(option))) + return true; + return false; +} + +bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) +{ + RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit); + + ty = follow(ty); + + if (seen.contains(ty)) + return true; + + if (isString(ty) || get(ty) || get(ty) || get(ty)) + return true; + + if (auto uty = get(ty)) + { + seen.insert(ty); + + for (TypeId part : uty->options) + { + if (!hasLength(part, seen, recursionCount)) + return false; + } + + return true; + } + + if (auto ity = get(ty)) + { + seen.insert(ty); + + for (TypeId part : ity->parts) + { + if (hasLength(part, seen, recursionCount)) + return true; + } + + return false; + } + + return false; +} + +BlockedTypeVar::BlockedTypeVar() + : index(++nextIndex) { } -FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) - : level(level) +int BlockedTypeVar::nextIndex = 0; + +PendingExpansionTypeVar::PendingExpansionTypeVar( + std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) + : prefix(prefix) + , name(name) + , typeArguments(typeArguments) + , packArguments(packArguments) + , index(++nextIndex) +{ +} + +size_t PendingExpansionTypeVar::nextIndex = 0; + +FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) + : definition(std::move(defn)) , argTypes(argTypes) - , retType(retType) - , definition(std::move(defn)) + , retTypes(retTypes) , hasSelf(hasSelf) { } -FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, +FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) + : definition(std::move(defn)) + , level(level) + , argTypes(argTypes) + , retTypes(retTypes) + , hasSelf(hasSelf) +{ +} + +FunctionTypeVar::FunctionTypeVar( + TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) + : definition(std::move(defn)) + , level(level) + , scope(scope) + , argTypes(argTypes) + , retTypes(retTypes) + , hasSelf(hasSelf) +{ +} + +FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) - : generics(generics) + : definition(std::move(defn)) + , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) - , retType(retType) - , definition(std::move(defn)) + , retTypes(retTypes) , hasSelf(hasSelf) { } FunctionTypeVar::FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, - TypePackId retType, std::optional defn, bool hasSelf) - : level(level) + TypePackId retTypes, std::optional defn, bool hasSelf) + : definition(std::move(defn)) , generics(generics) , genericPacks(genericPacks) + , level(level) , argTypes(argTypes) - , retType(retType) - , definition(std::move(defn)) + , retTypes(retTypes) , hasSelf(hasSelf) { } -TableTypeVar::TableTypeVar(TableState state, TypeLevel level) +FunctionTypeVar::FunctionTypeVar(TypeLevel level, Scope* scope, std::vector generics, std::vector genericPacks, + TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) + : definition(std::move(defn)) + , generics(generics) + , genericPacks(genericPacks) + , level(level) + , scope(scope) + , argTypes(argTypes) + , retTypes(retTypes) + , hasSelf(hasSelf) +{ +} + +TableTypeVar::TableTypeVar(TableState state, TypeLevel level, Scope* scope) : state(state) , level(level) + , scope(scope) { } @@ -379,6 +545,15 @@ TableTypeVar::TableTypeVar(const Props& props, const std::optional { } +TableTypeVar::TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, Scope* scope, TableState state) + : props(props) + , indexer(indexer) + , state(state) + , level(level) + , scope(scope) +{ +} + // Test TypeVars for equivalence // More complex than we'd like because TypeVars can self-reference. @@ -405,7 +580,7 @@ bool areEqual(SeenSet& seen, const FunctionTypeVar& lhs, const FunctionTypeVar& if (!areEqual(seen, *lhs.argTypes, *rhs.argTypes)) return false; - if (!areEqual(seen, *lhs.retType, *rhs.retType)) + if (!areEqual(seen, *lhs.retTypes, *rhs.retTypes)) return false; return true; @@ -453,6 +628,9 @@ bool areEqual(SeenSet& seen, const TableTypeVar& lhs, const TableTypeVar& rhs) static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const MetatableTypeVar& rhs) { + if (areSeen(seen, &lhs, &rhs)) + return true; + return areEqual(seen, *lhs.table, *rhs.table) && areEqual(seen, *lhs.metatable, *rhs.metatable); } @@ -559,31 +737,67 @@ TypeVar& TypeVar::operator=(TypeVariant&& rhs) return *this; } +TypeVar& TypeVar::operator=(const TypeVar& rhs) +{ + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); + + reassign(rhs); + + return *this; +} + TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); SingletonTypes::SingletonTypes() : arena(new TypeArena) - , nilType_{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true} - , numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true} - , stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true} - , booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true} - , threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true} - , anyType_{AnyTypeVar{}} - , errorType_{ErrorTypeVar{}} + , debugFreezeArena(FFlag::DebugLuauFreezeArena) + , nilType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true})) + , numberType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true})) + , stringType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true})) + , booleanType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true})) + , threadType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true})) + , functionType(arena->addType(TypeVar{PrimitiveTypeVar{PrimitiveTypeVar::Function}, /*persistent*/ true})) + , trueType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true})) + , falseType(arena->addType(TypeVar{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true})) + , anyType(arena->addType(TypeVar{AnyTypeVar{}, /*persistent*/ true})) + , unknownType(arena->addType(TypeVar{UnknownTypeVar{}, /*persistent*/ true})) + , neverType(arena->addType(TypeVar{NeverTypeVar{}, /*persistent*/ true})) + , errorType(arena->addType(TypeVar{ErrorTypeVar{}, /*persistent*/ true})) + , falsyType(arena->addType(TypeVar{UnionTypeVar{{falseType, nilType}}, /*persistent*/ true})) + , truthyType(arena->addType(TypeVar{NegationTypeVar{falsyType}, /*persistent*/ true})) + , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) + , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) + , uninhabitableTypePack(arena->addTypePack({neverType}, neverTypePack)) + , errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) { TypeId stringMetatable = makeStringMetatable(); - stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, makeStringMetatable()}; + asMutable(stringType)->ty = PrimitiveTypeVar{PrimitiveTypeVar::String, stringMetatable}; persist(stringMetatable); + persist(uninhabitableTypePack); + freeze(*arena); } +SingletonTypes::~SingletonTypes() +{ + // Destroy the arena with the same memory management flags it was created with + bool prevFlag = FFlag::DebugLuauFreezeArena; + FFlag::DebugLuauFreezeArena.value = debugFreezeArena; + + unfreeze(*arena); + arena.reset(nullptr); + + FFlag::DebugLuauFreezeArena.value = prevFlag; +} + TypeId SingletonTypes::makeStringMetatable() { const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}}); const TypeId optionalString = arena->addType(UnionTypeVar{{nilType, stringType}}); - const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, &booleanType_}}); + const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, booleanType}}); const TypePackId oneStringPack = arena->addTypePack({stringType}); const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); @@ -591,6 +805,7 @@ TypeId SingletonTypes::makeStringMetatable() FunctionTypeVar formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; formatFTV.magicFunction = &magicFunctionFormat; const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); const TypePackId emptyPack = arena->addTypePack({}); const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); @@ -604,23 +819,35 @@ TypeId SingletonTypes::makeStringMetatable() const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); const TypeId gmatchFunc = makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})}); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + + const TypeId matchFunc = arena->addType( + FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); + attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); + + const TypeId findFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); + attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); TableTypeVar::Props stringLib = { {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, - {"char", {arena->addType(FunctionTypeVar{arena->addTypePack(TypePack{{numberType}, numberVariadicList}), arena->addTypePack({stringType})})}}, - {"find", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber, optionalBoolean}, {}, {optionalNumber, optionalNumber})}}, + {"char", {arena->addType(FunctionTypeVar{numberVariadicList, arena->addTypePack({stringType})})}}, + {"find", {findFunc}}, {"format", {formatFn}}, // FIXME {"gmatch", {gmatchFunc}}, {"gsub", {gsubFunc}}, {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, {"lower", {stringToStringType}}, - {"match", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber}, {}, {optionalString})}}, + {"match", {matchFunc}}, {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, {"reverse", {stringToStringType}}, {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalString}, {}, - {arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}})})}}, + {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, + {arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, {"pack", {arena->addType(FunctionTypeVar{ arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack, @@ -636,10 +863,31 @@ TypeId SingletonTypes::makeStringMetatable() TypeId tableType = arena->addType(TableTypeVar{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + if (TableTypeVar* ttv = getMutable(tableType)) + ttv->name = FFlag::LuauNewLibraryTypeNames ? "typeof(string)" : "string"; + return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } -SingletonTypes singletonTypes; +TypeId SingletonTypes::errorRecoveryType() +{ + return errorType; +} + +TypePackId SingletonTypes::errorRecoveryTypePack() +{ + return errorTypePack; +} + +TypeId SingletonTypes::errorRecoveryType(TypeId guess) +{ + return guess; +} + +TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) +{ + return guess; +} void persist(TypeId ty) { @@ -660,10 +908,12 @@ void persist(TypeId ty) else if (auto ftv = get(t)) { persist(ftv->argTypes); - persist(ftv->retType); + persist(ftv->retTypes); } else if (auto ttv = get(t)) { + LUAU_ASSERT(ttv->state != TableState::Free && ttv->state != TableState::Unsealed); + for (const auto& [_name, prop] : ttv->props) queue.push_back(prop.type); @@ -688,6 +938,19 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } + else if (auto mtv = get(t)) + { + queue.push_back(mtv->table); + queue.push_back(mtv->metatable); + } + else if (get(t) || get(t) || get(t) || get(t) || get(t) || + get(t)) + { + } + else + { + LUAU_ASSERT(!"TypeId is not supported in a persist call"); + } } } @@ -705,365 +968,19 @@ void persist(TypePackId tp) if (p->tail) persist(*p->tail); } -} - -namespace -{ - -struct StateDot -{ - StateDot(ToDotOptions opts) - : opts(opts) + else if (auto vtp = get(tp)) { + persist(vtp->ty); } - - ToDotOptions opts; - - std::unordered_set seenTy; - std::unordered_set seenTp; - std::unordered_map tyToIndex; - std::unordered_map tpToIndex; - int nextIndex = 1; - std::string result; - - bool canDuplicatePrimitive(TypeId ty); - - void visitChildren(TypeId ty, int index); - void visitChildren(TypePackId ty, int index); - - void visitChild(TypeId ty, int parentIndex, const char* linkName = nullptr); - void visitChild(TypePackId tp, int parentIndex, const char* linkName = nullptr); - - void startNode(int index); - void finishNode(); - - void startNodeLabel(); - void finishNodeLabel(TypeId ty); - void finishNodeLabel(TypePackId tp); -}; - -bool StateDot::canDuplicatePrimitive(TypeId ty) -{ - if (get(ty)) - return false; - - return get(ty) || get(ty); -} - -void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) -{ - if (!tyToIndex.count(ty) || (opts.duplicatePrimitives && canDuplicatePrimitive(ty))) - tyToIndex[ty] = nextIndex++; - - int index = tyToIndex[ty]; - - if (parentIndex != 0) + else if (get(tp)) { - if (linkName) - formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, index, linkName); - else - formatAppend(result, "n%d -> n%d;\n", parentIndex, index); - } - - if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) - { - if (const PrimitiveTypeVar* ptv = get(ty)) - formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); - else if (const AnyTypeVar* atv = get(ty)) - formatAppend(result, "n%d [label=\"any\"];\n", index); } else { - visitChildren(ty, index); + LUAU_ASSERT(!"TypePackId is not supported in a persist call"); } } -void StateDot::visitChild(TypePackId tp, int parentIndex, const char* linkName) -{ - if (!tpToIndex.count(tp)) - tpToIndex[tp] = nextIndex++; - - if (linkName) - formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, tpToIndex[tp], linkName); - else - formatAppend(result, "n%d -> n%d;\n", parentIndex, tpToIndex[tp]); - - visitChildren(tp, tpToIndex[tp]); -} - -void StateDot::startNode(int index) -{ - formatAppend(result, "n%d [", index); -} - -void StateDot::finishNode() -{ - formatAppend(result, "];\n"); -} - -void StateDot::startNodeLabel() -{ - formatAppend(result, "label=\""); -} - -void StateDot::finishNodeLabel(TypeId ty) -{ - if (opts.showPointers) - formatAppend(result, "\n0x%p", ty); - // additional common attributes can be added here as well - result += "\""; -} - -void StateDot::finishNodeLabel(TypePackId tp) -{ - if (opts.showPointers) - formatAppend(result, "\n0x%p", tp); - // additional common attributes can be added here as well - result += "\""; -} - -void StateDot::visitChildren(TypeId ty, int index) -{ - if (seenTy.count(ty)) - return; - seenTy.insert(ty); - - startNode(index); - startNodeLabel(); - - if (const BoundTypeVar* btv = get(ty)) - { - formatAppend(result, "BoundTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(btv->boundTo, index); - } - else if (const FunctionTypeVar* ftv = get(ty)) - { - formatAppend(result, "FunctionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retType, index, "ret"); - } - else if (const TableTypeVar* ttv = get(ty)) - { - if (ttv->name) - formatAppend(result, "TableTypeVar %s", ttv->name->c_str()); - else if (ttv->syntheticName) - formatAppend(result, "TableTypeVar %s", ttv->syntheticName->c_str()); - else - formatAppend(result, "TableTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - if (ttv->boundTo) - return visitChild(*ttv->boundTo, index, "boundTo"); - - for (const auto& [name, prop] : ttv->props) - visitChild(prop.type, index, name.c_str()); - if (ttv->indexer) - { - visitChild(ttv->indexer->indexType, index, "[index]"); - visitChild(ttv->indexer->indexResultType, index, "[value]"); - } - for (TypeId itp : ttv->instantiatedTypeParams) - visitChild(itp, index, "typeParam"); - } - else if (const MetatableTypeVar* mtv = get(ty)) - { - formatAppend(result, "MetatableTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(mtv->table, index, "table"); - visitChild(mtv->metatable, index, "metatable"); - } - else if (const UnionTypeVar* utv = get(ty)) - { - formatAppend(result, "UnionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId opt : utv->options) - visitChild(opt, index); - } - else if (const IntersectionTypeVar* itv = get(ty)) - { - formatAppend(result, "IntersectionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId part : itv->parts) - visitChild(part, index); - } - else if (const GenericTypeVar* gtv = get(ty)) - { - if (gtv->explicitName) - formatAppend(result, "GenericTypeVar %s", gtv->name.c_str()); - else - formatAppend(result, "GenericTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const FreeTypeVar* ftv = get(ty)) - { - formatAppend(result, "FreeTypeVar %d", ftv->index); - finishNodeLabel(ty); - finishNode(); - } - else if (const AnyTypeVar* atv = get(ty)) - { - formatAppend(result, "AnyTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const PrimitiveTypeVar* ptv = get(ty)) - { - formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); - finishNodeLabel(ty); - finishNode(); - } - else if (const ErrorTypeVar* etv = get(ty)) - { - formatAppend(result, "ErrorTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const ClassTypeVar* ctv = get(ty)) - { - formatAppend(result, "ClassTypeVar %s", ctv->name.c_str()); - finishNodeLabel(ty); - finishNode(); - - for (const auto& [name, prop] : ctv->props) - visitChild(prop.type, index, name.c_str()); - - if (ctv->parent) - visitChild(*ctv->parent, index, "[parent]"); - - if (ctv->metatable) - visitChild(*ctv->metatable, index, "[metatable]"); - } - else - { - LUAU_ASSERT(!"unknown type kind"); - finishNodeLabel(ty); - finishNode(); - } -} - -void StateDot::visitChildren(TypePackId tp, int index) -{ - if (seenTp.count(tp)) - return; - seenTp.insert(tp); - - startNode(index); - startNodeLabel(); - - if (const BoundTypePack* btp = get(tp)) - { - formatAppend(result, "BoundTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - visitChild(btp->boundTo, index); - } - else if (const TypePack* tpp = get(tp)) - { - formatAppend(result, "TypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - for (TypeId tv : tpp->head) - visitChild(tv, index); - if (tpp->tail) - visitChild(*tpp->tail, index, "tail"); - } - else if (const VariadicTypePack* vtp = get(tp)) - { - formatAppend(result, "VariadicTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - visitChild(vtp->ty, index); - } - else if (const FreeTypePack* ftp = get(tp)) - { - formatAppend(result, "FreeTypePack %d", ftp->index); - finishNodeLabel(tp); - finishNode(); - } - else if (const GenericTypePack* gtp = get(tp)) - { - if (gtp->explicitName) - formatAppend(result, "GenericTypePack %s", gtp->name.c_str()); - else - formatAppend(result, "GenericTypePack %d", gtp->index); - finishNodeLabel(tp); - finishNode(); - } - else if (const Unifiable::Error* etp = get(tp)) - { - formatAppend(result, "ErrorTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - } - else - { - LUAU_ASSERT(!"unknown type pack kind"); - finishNodeLabel(tp); - finishNode(); - } -} - -} // namespace - -std::string toDot(TypeId ty, const ToDotOptions& opts) -{ - StateDot state{opts}; - - state.result = "digraph graphname {\n"; - state.visitChild(ty, 0); - state.result += "}"; - - return state.result; -} - -std::string toDot(TypePackId tp, const ToDotOptions& opts) -{ - StateDot state{opts}; - - state.result = "digraph graphname {\n"; - state.visitChild(tp, 0); - state.result += "}"; - - return state.result; -} - -std::string toDot(TypeId ty) -{ - return toDot(ty, {}); -} - -std::string toDot(TypePackId tp) -{ - return toDot(tp, {}); -} - -void dumpDot(TypeId ty) -{ - printf("%s\n", toDot(ty).c_str()); -} - -void dumpDot(TypePackId tp) -{ - printf("%s\n", toDot(tp).c_str()); -} - const TypeLevel* getLevel(TypeId ty) { ty = follow(ty); @@ -1083,161 +1000,15 @@ TypeLevel* getMutableLevel(TypeId ty) return const_cast(getLevel(ty)); } -struct QVarFinder +std::optional getLevel(TypePackId tp) { - mutable DenseHashSet seen; + tp = follow(tp); - QVarFinder() - : seen(nullptr) - { - } - - bool hasSeen(const void* tv) const - { - if (seen.contains(tv)) - return true; - - seen.insert(tv); - return false; - } - - bool hasGeneric(TypeId tid) const - { - if (hasSeen(&tid->ty)) - return false; - - return Luau::visit(*this, tid->ty); - } - - bool hasGeneric(TypePackId tp) const - { - if (hasSeen(&tp->ty)) - return false; - - return Luau::visit(*this, tp->ty); - } - - bool operator()(const Unifiable::Free&) const - { - return false; - } - - bool operator()(const Unifiable::Bound& bound) const - { - return hasGeneric(bound.boundTo); - } - - bool operator()(const Unifiable::Generic&) const - { - return true; - } - bool operator()(const Unifiable::Error&) const - { - return false; - } - bool operator()(const PrimitiveTypeVar&) const - { - return false; - } - - bool operator()(const FunctionTypeVar& ftv) const - { - if (hasGeneric(ftv.argTypes)) - return true; - return hasGeneric(ftv.retType); - } - - bool operator()(const TableTypeVar& ttv) const - { - if (ttv.state == TableState::Generic) - return true; - - if (ttv.indexer) - { - if (hasGeneric(ttv.indexer->indexType)) - return true; - if (hasGeneric(ttv.indexer->indexResultType)) - return true; - } - - for (const auto& [_name, prop] : ttv.props) - { - if (hasGeneric(prop.type)) - return true; - } - - return false; - } - - bool operator()(const MetatableTypeVar& mtv) const - { - return hasGeneric(mtv.table) || hasGeneric(mtv.metatable); - } - - bool operator()(const ClassTypeVar& ctv) const - { - for (const auto& [name, prop] : ctv.props) - { - if (hasGeneric(prop.type)) - return true; - } - - if (ctv.parent) - return hasGeneric(*ctv.parent); - - return false; - } - - bool operator()(const AnyTypeVar&) const - { - return false; - } - - bool operator()(const UnionTypeVar& utv) const - { - for (TypeId tid : utv.options) - if (hasGeneric(tid)) - return true; - - return false; - } - - bool operator()(const IntersectionTypeVar& utv) const - { - for (TypeId tid : utv.parts) - if (hasGeneric(tid)) - return true; - - return false; - } - - bool operator()(const LazyTypeVar&) const - { - return false; - } - - bool operator()(const Unifiable::Bound& bound) const - { - return hasGeneric(bound.boundTo); - } - - bool operator()(const TypePack& pack) const - { - for (TypeId ty : pack.head) - if (hasGeneric(ty)) - return true; - - if (pack.tail) - return hasGeneric(*pack.tail); - - return false; - } - - bool operator()(const VariadicTypePack& pack) const - { - return hasGeneric(pack.ty); - } -}; + if (auto ftv = get(tp)) + return ftv->level; + else + return std::nullopt; +} const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) { @@ -1274,104 +1045,14 @@ bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent) return false; } -bool hasGeneric(TypeId ty) +const std::vector& getTypes(const UnionTypeVar* utv) { - return Luau::visit(QVarFinder{}, ty->ty); + return utv->options; } -bool hasGeneric(TypePackId tp) +const std::vector& getTypes(const IntersectionTypeVar* itv) { - return Luau::visit(QVarFinder{}, tp->ty); -} - -UnionTypeVarIterator::UnionTypeVarIterator(const UnionTypeVar* utv) -{ - LUAU_ASSERT(utv); - - if (!utv->options.empty()) - stack.push_front({utv, 0}); - - seen.insert(utv); -} - -UnionTypeVarIterator& UnionTypeVarIterator::operator++() -{ - advance(); - descend(); - return *this; -} - -UnionTypeVarIterator UnionTypeVarIterator::operator++(int) -{ - UnionTypeVarIterator copy = *this; - ++copy; - return copy; -} - -bool UnionTypeVarIterator::operator!=(const UnionTypeVarIterator& rhs) -{ - return !(*this == rhs); -} - -bool UnionTypeVarIterator::operator==(const UnionTypeVarIterator& rhs) -{ - if (!stack.empty() && !rhs.stack.empty()) - return stack.front() == rhs.stack.front(); - - return stack.empty() && rhs.stack.empty(); -} - -const TypeId& UnionTypeVarIterator::operator*() -{ - LUAU_ASSERT(!stack.empty()); - - descend(); - - auto [utv, currentIndex] = stack.front(); - LUAU_ASSERT(utv); - LUAU_ASSERT(currentIndex < utv->options.size()); - - const TypeId& ty = utv->options[currentIndex]; - LUAU_ASSERT(!get(follow(ty))); - return ty; -} - -void UnionTypeVarIterator::advance() -{ - while (!stack.empty()) - { - auto& [utv, currentIndex] = stack.front(); - ++currentIndex; - - if (currentIndex >= utv->options.size()) - stack.pop_front(); - else - break; - } -} - -void UnionTypeVarIterator::descend() -{ - while (!stack.empty()) - { - auto [utv, currentIndex] = stack.front(); - if (auto innerUnion = get(follow(utv->options[currentIndex]))) - { - // If we're about to descend into a cyclic UnionTypeVar, we should skip over this. - // Ideally this should never happen, but alas it does from time to time. :( - if (seen.find(innerUnion) != seen.end()) - advance(); - else - { - seen.insert(innerUnion); - stack.push_front({innerUnion, 0}); - } - - continue; - } - - break; - } + return itv->parts; } UnionTypeVarIterator begin(const UnionTypeVar* utv) @@ -1384,27 +1065,19 @@ UnionTypeVarIterator end(const UnionTypeVar* utv) return UnionTypeVarIterator{}; } -static std::vector DEPRECATED_filterMap(TypeId type, TypeIdPredicate predicate) +IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv) { - std::vector result; - - if (auto utv = get(follow(type))) - { - for (TypeId option : utv) - { - if (auto out = predicate(follow(option))) - result.push_back(*out); - } - } - else if (auto out = predicate(follow(type))) - return {*out}; - - return result; + return IntersectionTypeVarIterator{itv}; } -static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) +IntersectionTypeVarIterator end(const IntersectionTypeVar* itv) { - const char* options = "cdiouxXeEfgGqs"; + return IntersectionTypeVarIterator{}; +} + +static std::vector parseFormatString(NotNull singletonTypes, const char* data, size_t size) +{ + const char* options = "cdiouxXeEfgGqs*"; std::vector result; @@ -1418,28 +1091,30 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha continue; // we just ignore all characters (including flags/precision) up until first alphabetic character - while (i < size && !(data[i] > 0 && isalpha(data[i]))) + while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*'))) i++; if (i == size) break; if (data[i] == 'q' || data[i] == 's') - result.push_back(typechecker.stringType); + result.push_back(singletonTypes->stringType); + else if (data[i] == '*') + result.push_back(singletonTypes->unknownType); else if (strchr(options, data[i])) - result.push_back(typechecker.numberType); + result.push_back(singletonTypes->numberType); else - result.push_back(typechecker.errorType); + result.push_back(singletonTypes->errorRecoveryType(singletonTypes->anyType)); } } return result; } -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -1458,33 +1133,380 @@ std::optional> magicFunctionFormat( if (!fmt) return std::nullopt; - std::vector expected = parseFormatString(typechecker, fmt->value.data, fmt->value.size); + std::vector expected = parseFormatString(typechecker.singletonTypes, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(paramPack); - const size_t dataOffset = 1; + size_t paramOffset = 1; + size_t dataOffset = expr.self ? 0 : 1; // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + dataOffset < params.size(); ++i) + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) { - Location location = expr.args.data[std::min(i, expr.args.size - 1)]->location; + Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - typechecker.unify(expected[i], params[i + dataOffset], location); + typechecker.unify(params[i + paramOffset], expected[i], scope, location); } // if we know the argument count or if we have too many arguments for sure, we can issue an error - const size_t actualParamSize = params.size() - dataOffset; + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) - typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - return ExprResult{arena.addTypePack({typechecker.stringType})}; + return WithPredicate{arena.addTypePack({typechecker.stringType})}; +} + +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) +{ + TypeArena* arena = context.solver->arena; + + AstExprConstantString* fmt = nullptr; + if (auto index = context.callSite->func->as(); index && context.callSite->self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!context.callSite->self && context.callSite->args.size > 0) + fmt = context.callSite->args.data[0]->as(); + + if (!fmt) + return false; + + std::vector expected = parseFormatString(context.solver->singletonTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(context.arguments); + + size_t paramOffset = 1; + + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + context.solver->unify(params[i + paramOffset], expected[i], context.solver->rootScope); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); + + TypePackId resultPack = arena->addTypePack({context.solver->singletonTypes->stringType}); + asMutable(context.result)->ty.emplace(resultPack); + + return true; +} + +static std::vector parsePatternString(NotNull singletonTypes, const char* data, size_t size) +{ + std::vector result; + int depth = 0; + bool parsingSet = false; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + ++i; + if (!parsingSet && i < size && data[i] == 'b') + i += 2; + } + else if (!parsingSet && data[i] == '[') + { + parsingSet = true; + if (i + 1 < size && data[i + 1] == ']') + i += 1; + } + else if (parsingSet && data[i] == ']') + { + parsingSet = false; + } + else if (data[i] == '(') + { + if (parsingSet) + continue; + + if (i + 1 < size && data[i + 1] == ')') + { + i++; + result.push_back(singletonTypes->numberType); + continue; + } + + ++depth; + result.push_back(singletonTypes->stringType); + } + else if (data[i] == ')') + { + if (parsingSet) + continue; + + --depth; + + if (depth < 0) + break; + } + } + + if (depth != 0 || parsingSet) + return std::vector(); + + if (result.empty()) + result.push_back(singletonTypes->stringType); + + return result; +} + +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() != 2) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t index = expr.self ? 0 : 1; + if (expr.args.size > index) + pattern = expr.args.data[index]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypePackId emptyPack = arena.addTypePack({}); + const TypePackId returnList = arena.addTypePack(returnTypes); + const TypeId iteratorType = arena.addType(FunctionTypeVar{emptyPack, returnList}); + return WithPredicate{arena.addTypePack({iteratorType})}; +} + +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() != 2) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t index = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > index) + pattern = context.callSite->args.data[index]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(params[0], context.solver->singletonTypes->stringType, context.solver->rootScope); + + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId returnList = arena->addTypePack(returnTypes); + const TypeId iteratorType = arena->addType(FunctionTypeVar{emptyPack, returnList}); + const TypePackId resTypePack = arena->addTypePack({iteratorType}); + asMutable(context.result)->ty.emplace(resTypePack); + + return true; +} + +static std::optional> magicFunctionMatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 3) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() == 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 3) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(params[0], context.solver->singletonTypes->stringType, context.solver->rootScope); + + const TypeId optionalNumber = arena->addType(UnionTypeVar{{context.solver->singletonTypes->nilType, context.solver->singletonTypes->numberType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() == 3 && context.callSite->args.size > initIndex) + context.solver->unify(params[2], optionalNumber, context.solver->rootScope); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + + return true; +} + +static std::optional> magicFunctionFind( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 4) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + bool plain = false; + size_t plainIndex = expr.self ? 2 : 3; + if (expr.args.size > plainIndex) + { + AstExprConstantBool* p = expr.args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + } + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}}); + const TypeId optionalBoolean = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.booleanType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() >= 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); + + if (params.size() == 4 && expr.args.size > plainIndex) + typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + +static bool dcrMagicFunctionFind(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 4) + return false; + + TypeArena* arena = context.solver->arena; + NotNull singletonTypes = context.solver->singletonTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + bool plain = false; + size_t plainIndex = context.callSite->self ? 2 : 3; + if (context.callSite->args.size > plainIndex) + { + AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + } + + context.solver->unify(params[0], singletonTypes->stringType, context.solver->rootScope); + + const TypeId optionalNumber = arena->addType(UnionTypeVar{{singletonTypes->nilType, singletonTypes->numberType}}); + const TypeId optionalBoolean = arena->addType(UnionTypeVar{{singletonTypes->nilType, singletonTypes->booleanType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() >= 3 && context.callSite->args.size > initIndex) + context.solver->unify(params[2], optionalNumber, context.solver->rootScope); + + if (params.size() == 4 && context.callSite->args.size > plainIndex) + context.solver->unify(params[3], optionalBoolean, context.solver->rootScope); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + return true; } std::vector filterMap(TypeId type, TypeIdPredicate predicate) { - if (!FFlag::LuauTypeGuardPeelsAwaySubclasses) - return DEPRECATED_filterMap(type, predicate); - type = follow(type); if (auto utv = get(type)) @@ -1502,4 +1524,84 @@ std::vector filterMap(TypeId type, TypeIdPredicate predicate) return {}; } +static Tags* getTags(TypeId ty) +{ + ty = follow(ty); + + if (auto ftv = getMutable(ty)) + return &ftv->tags; + else if (auto ttv = getMutable(ty)) + return &ttv->tags; + else if (auto ctv = getMutable(ty)) + return &ctv->tags; + + return nullptr; +} + +void attachTag(TypeId ty, const std::string& tagName) +{ + if (auto tags = getTags(ty)) + tags->push_back(tagName); + else + LUAU_ASSERT(!"This TypeId does not support tags"); +} + +void attachTag(Property& prop, const std::string& tagName) +{ + prop.tags.push_back(tagName); +} + +// We would ideally not expose this because it could cause a footgun. +// If the Base class has a tag and you ask if Derived has that tag, it would return false. +// Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it. +bool hasTag(const Tags& tags, const std::string& tagName) +{ + return std::find(tags.begin(), tags.end(), tagName) != tags.end(); +} + +bool hasTag(TypeId ty, const std::string& tagName) +{ + ty = follow(ty); + + // We special case classes because getTags only returns a pointer to one vector of tags. + // But classes has multiple vector of tags, represented throughout the hierarchy. + if (auto ctv = get(ty)) + { + while (ctv) + { + if (hasTag(ctv->tags, tagName)) + return true; + else if (!ctv->parent) + return false; + + ctv = get(*ctv->parent); + LUAU_ASSERT(ctv); + } + } + else if (auto tags = getTags(ty)) + return hasTag(*tags, tagName); + + return false; +} + +bool hasTag(const Property& prop, const std::string& tagName) +{ + return hasTag(prop.tags, tagName); +} + +bool TypeFun::operator==(const TypeFun& rhs) const +{ + return type == rhs.type && typeParams == rhs.typeParams && typePackParams == rhs.typePackParams; +} + +bool GenericTypeDefinition::operator==(const GenericTypeDefinition& rhs) const +{ + return ty == rhs.ty && defaultValue == rhs.defaultValue; +} + +bool GenericTypePackDefinition::operator==(const GenericTypePackDefinition& rhs) const +{ + return tp == rhs.tp && defaultValue == rhs.defaultValue; +} + } // namespace Luau diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index f037351e..4dc26219 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -7,6 +7,9 @@ #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN #endif +#ifndef NOMINMAX +#define NOMINMAX +#endif #include const size_t kPageSize = 4096; @@ -14,8 +17,12 @@ const size_t kPageSize = 4096; #include #include +#if defined(__FreeBSD__) && !(_POSIX_C_SOURCE >= 200112L) +const size_t kPageSize = getpagesize(); +#else const size_t kPageSize = sysconf(_SC_PAGESIZE); #endif +#endif #include @@ -24,27 +31,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) namespace Luau { -static void* systemAllocateAligned(size_t size, size_t align) -{ -#ifdef _WIN32 - return _aligned_malloc(size, align); -#elif defined(__ANDROID__) // for Android 4.1 - return memalign(align, size); -#else - void* ptr; - return posix_memalign(&ptr, align, size) == 0 ? ptr : 0; -#endif -} - -static void systemDeallocateAligned(void* ptr) -{ -#ifdef _WIN32 - _aligned_free(ptr); -#else - free(ptr); -#endif -} - static size_t pageAlign(size_t size) { return (size + kPageSize - 1) & ~(kPageSize - 1); @@ -52,18 +38,37 @@ static size_t pageAlign(size_t size) void* pagedAllocate(size_t size) { - if (FFlag::DebugLuauFreezeArena) - return systemAllocateAligned(pageAlign(size), kPageSize); - else + // By default we use operator new/delete instead of malloc/free so that they can be overridden externally + if (!FFlag::DebugLuauFreezeArena) + { return ::operator new(size, std::nothrow); + } + + // On Windows, VirtualAlloc results in 64K granularity allocations; we allocate in chunks of ~32K so aligned_malloc is a little more efficient + // On Linux, we must use mmap because using regular heap results in mprotect() fragmenting the page table and us bumping into 64K mmap limit. +#ifdef _WIN32 + return _aligned_malloc(size, kPageSize); +#elif defined(__FreeBSD__) + return aligned_alloc(kPageSize, size); +#else + return mmap(nullptr, pageAlign(size), PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); +#endif } -void pagedDeallocate(void* ptr) +void pagedDeallocate(void* ptr, size_t size) { - if (FFlag::DebugLuauFreezeArena) - systemDeallocateAligned(ptr); - else - ::operator delete(ptr); + // By default we use operator new/delete instead of malloc/free so that they can be overridden externally + if (!FFlag::DebugLuauFreezeArena) + return ::operator delete(ptr); + +#ifdef _WIN32 + _aligned_free(ptr); +#elif defined(__FreeBSD__) + free(ptr); +#else + int rc = munmap(ptr, size); + LUAU_ASSERT(rc == 0); +#endif } void pagedFreeze(void* ptr, size_t size) diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index cef07833..e0cc1414 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -1,60 +1,79 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Unifiable.h" -LUAU_FASTFLAG(LuauRankNTypes) +LUAU_FASTFLAG(LuauTypeNormalization2); namespace Luau { namespace Unifiable { +static int nextIndex = 0; + Free::Free(TypeLevel level) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , level(level) { } -Free::Free(TypeLevel level, bool DEPRECATED_canBeGeneric) - : index(++nextIndex) - , level(level) - , DEPRECATED_canBeGeneric(DEPRECATED_canBeGeneric) +Free::Free(Scope* scope) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + , scope(scope) { - LUAU_ASSERT(!FFlag::LuauRankNTypes); } -int Free::nextIndex = 0; +Free::Free(Scope* scope, TypeLevel level) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + , level(level) + , scope(scope) +{ +} + +int Free::DEPRECATED_nextIndex = 0; Generic::Generic() - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , name("g" + std::to_string(index)) - , explicitName(false) { } Generic::Generic(TypeLevel level) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , level(level) , name("g" + std::to_string(index)) - , explicitName(false) { } Generic::Generic(const Name& name) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , name(name) , explicitName(true) { } +Generic::Generic(Scope* scope) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + , scope(scope) +{ +} + Generic::Generic(TypeLevel level, const Name& name) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , level(level) , name(name) , explicitName(true) { } -int Generic::nextIndex = 0; +Generic::Generic(Scope* scope, const Name& name) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + +int Generic::DEPRECATED_nextIndex = 0; Error::Error() : index(++nextIndex) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 89c3f80c..42882005 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -2,28 +2,318 @@ #include "Luau/Unifier.h" #include "Luau/Common.h" +#include "Luau/Instantiation.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" +#include "Luau/StringUtils.h" +#include "Luau/TimeTrace.h" +#include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" #include -LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 0); -LUAU_FASTFLAGVARIABLE(LuauLogTableTypeVarBoundTo, false) -LUAU_FASTFLAG(LuauGenericFunctions) -LUAU_FASTFLAGVARIABLE(LuauDontMutatePersistentFunctions, false) -LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauStringMetatable) -LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) -LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) -LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) -LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) +LUAU_FASTFLAG(LuauErrorRecoveryType); +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauReportTypeMismatchForTypePackUnificationFailure, false) +LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); +LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) +LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) +LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); +LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner2, false) +LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) +LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) +LUAU_FASTFLAG(LuauTxnLogTypePackIterator) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauNegatedFunctionTypes) namespace Luau { +struct PromoteTypeLevels final : TypeVarOnceVisitor +{ + TxnLog& log; + const TypeArena* typeArena = nullptr; + TypeLevel minLevel; + + Scope* outerScope = nullptr; + bool useScopes; + + PromoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes) + : log(log) + , typeArena(typeArena) + , minLevel(minLevel) + , outerScope(outerScope) + , useScopes(useScopes) + { + } + + template + void promote(TID ty, T* t) + { + if (FFlag::DebugLuauDeferredConstraintResolution && !t) + return; + + LUAU_ASSERT(t); + + if (useScopes) + { + if (subsumesStrict(outerScope, t->scope)) + log.changeScope(ty, NotNull{outerScope}); + } + else + { + if (minLevel.subsumesStrict(t->level)) + { + log.changeLevel(ty, minLevel); + } + } + } + + bool visit(TypeId ty) override + { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (ty->owningArena != typeArena) + return false; + + return true; + } + + bool visit(TypePackId tp) override + { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (tp->owningArena != typeArena) + return false; + + return true; + } + + bool visit(TypeId ty, const FreeTypeVar&) override + { + // Surprise, it's actually a BoundTypeVar that hasn't been committed yet. + // Calling getMutable on this will trigger an assertion. + if (!log.is(ty)) + return true; + + promote(ty, log.getMutable(ty)); + return true; + } + + bool visit(TypeId ty, const FunctionTypeVar&) override + { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (ty->owningArena != typeArena) + return false; + + // Surprise, it's actually a BoundTypePack that hasn't been committed yet. + // Calling getMutable on this will trigger an assertion. + if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + return true; + + promote(ty, log.getMutable(ty)); + return true; + } + + bool visit(TypeId ty, const TableTypeVar& ttv) override + { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (ty->owningArena != typeArena) + return false; + + if (ttv.state != TableState::Free && ttv.state != TableState::Generic) + return true; + + // Surprise, it's actually a BoundTypePack that hasn't been committed yet. + // Calling getMutable on this will trigger an assertion. + if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + return true; + + promote(ty, log.getMutable(ty)); + return true; + } + + bool visit(TypePackId tp, const FreeTypePack&) override + { + // Surprise, it's actually a BoundTypePack that hasn't been committed yet. + // Calling getMutable on this will trigger an assertion. + if (!log.is(tp)) + return true; + + promote(tp, log.getMutable(tp)); + return true; + } +}; + +static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes, TypeId ty) +{ + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (ty->owningArena != typeArena) + return; + + PromoteTypeLevels ptl{log, typeArena, minLevel, outerScope, useScopes}; + ptl.traverse(ty); +} + +void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes, TypePackId tp) +{ + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (tp->owningArena != typeArena) + return; + + PromoteTypeLevels ptl{log, typeArena, minLevel, outerScope, useScopes}; + ptl.traverse(tp); +} + +struct SkipCacheForType final : TypeVarOnceVisitor +{ + SkipCacheForType(const DenseHashMap& skipCacheForType, const TypeArena* typeArena) + : skipCacheForType(skipCacheForType) + , typeArena(typeArena) + { + } + + bool visit(TypeId, const FreeTypeVar&) override + { + result = true; + return false; + } + + bool visit(TypeId, const BoundTypeVar&) override + { + result = true; + return false; + } + + bool visit(TypeId, const GenericTypeVar&) override + { + result = true; + return false; + } + + bool visit(TypeId ty, const TableTypeVar&) override + { + // Types from other modules don't contain mutable elements and are ok to cache + if (ty->owningArena != typeArena) + return false; + + TableTypeVar& ttv = *getMutable(ty); + + if (ttv.boundTo) + { + result = true; + return false; + } + + if (ttv.state != TableState::Sealed) + { + result = true; + return false; + } + + return true; + } + + bool visit(TypeId ty) override + { + // Types from other modules don't contain mutable elements and are ok to cache + if (ty->owningArena != typeArena) + return false; + + const bool* prev = skipCacheForType.find(ty); + + if (prev && *prev) + { + result = true; + return false; + } + + return true; + } + + bool visit(TypePackId tp) override + { + // Types from other modules don't contain mutable elements and are ok to cache + if (tp->owningArena != typeArena) + return false; + + return true; + } + + bool visit(TypePackId tp, const FreeTypePack&) override + { + result = true; + return false; + } + + bool visit(TypePackId tp, const BoundTypePack&) override + { + result = true; + return false; + } + + bool visit(TypePackId tp, const GenericTypePack&) override + { + result = true; + return false; + } + + const DenseHashMap& skipCacheForType; + const TypeArena* typeArena = nullptr; + bool result = false; +}; + +bool Widen::isDirty(TypeId ty) +{ + return log->is(ty); +} + +bool Widen::isDirty(TypePackId) +{ + return false; +} + +TypeId Widen::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + auto stv = log->getMutable(ty); + LUAU_ASSERT(stv); + + if (get(stv)) + return singletonTypes->stringType; + else + { + // If this assert trips, it's likely we now have number singletons. + LUAU_ASSERT(get(stv)); + return singletonTypes->booleanType; + } +} + +TypePackId Widen::clean(TypePackId) +{ + throw InternalCompilerError("Widen attempted to clean a dirty type pack?"); +} + +bool Widen::ignoreChildren(TypeId ty) +{ + if (FFlag::LuauClassTypeVarsInSubstitution && get(ty)) + return true; + + return !log->is(ty); +} + +TypeId Widen::operator()(TypeId ty) +{ + return substitute(ty).value_or(ty); +} + +TypePackId Widen::operator()(TypePackId tp) +{ + return substitute(tp).value_or(tp); +} + static std::optional hasUnificationTooComplex(const ErrorVec& errors) { auto isUnificationTooComplex = [](const TypeError& te) { @@ -37,160 +327,186 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler) - : types(types) - , mode(mode) - , globalScope(std::move(globalScope)) - , location(location) - , variance(variance) - , counters(std::make_shared()) - , iceHandler(iceHandler) +// Used for tagged union matching heuristic, returns first singleton type field +static std::optional> getTableMatchTag(TypeId type) { - LUAU_ASSERT(iceHandler); -} - -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters) - : types(types) - , mode(mode) - , globalScope(std::move(globalScope)) - , log(seen) - , location(location) - , variance(variance) - , counters(counters ? counters : std::make_shared()) - , iceHandler(iceHandler) -{ - LUAU_ASSERT(iceHandler); -} - -void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) -{ - counters->iterationCount = 0; - return tryUnify_(superTy, subTy, isFunctionCall, isIntersection); -} - -void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) -{ - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); - - ++counters->iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + if (auto ttv = getTableType(type)) { - errors.push_back(TypeError{location, UnificationTooComplex{}}); + for (auto&& [name, prop] : ttv->props) + { + if (auto sing = get(follow(prop.type))) + return {{name, sing}}; + } + } + + return std::nullopt; +} + +// TODO: Inline and clip with FFlag::DebugLuauDeferredConstraintResolution +template +static bool subsumes(bool useScopes, TY_A* left, TY_B* right) +{ + if (useScopes) + return subsumes(left->scope, right->scope); + else + return left->level.subsumes(right->level); +} + +TypeMismatch::Context Unifier::mismatchContext() +{ + switch (variance) + { + case Covariant: + return TypeMismatch::CovariantContext; + case Invariant: + return TypeMismatch::InvariantContext; + default: + LUAU_ASSERT(false); // This codepath should be unreachable. + return TypeMismatch::CovariantContext; + } +} + +Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog) + : types(normalizer->arena) + , singletonTypes(normalizer->singletonTypes) + , normalizer(normalizer) + , mode(mode) + , scope(scope) + , log(parentLog) + , location(location) + , variance(variance) + , sharedState(*normalizer->sharedState) +{ + normalize = FFlag::LuauSubtypeNormalizer; + LUAU_ASSERT(sharedState.iceHandler); +} + +void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) +{ + sharedState.counters.iterationCount = 0; + + tryUnify_(subTy, superTy, isFunctionCall, isIntersection); +} + +void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) +{ + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + + ++sharedState.counters.iterationCount; + + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) + { + reportError(location, UnificationTooComplex{}); return; } - superTy = follow(superTy); - subTy = follow(subTy); + superTy = log.follow(superTy); + subTy = log.follow(subTy); if (superTy == subTy) return; - auto l = getMutable(superTy); - auto r = getMutable(subTy); + auto superFree = log.getMutable(superTy); + auto subFree = log.getMutable(subTy); - if (l && r && l->level.subsumes(r->level)) + if (superFree && subFree && subsumes(useScopes, superFree, subFree)) { - occursCheck(subTy, superTy); - - if (!get(subTy)) - { - log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); - } - - if (!FFlag::LuauRankNTypes) - l->DEPRECATED_canBeGeneric &= r->DEPRECATED_canBeGeneric; + if (!occursCheck(subTy, superTy)) + log.replace(subTy, BoundTypeVar(superTy)); return; } - else if (l && r && FFlag::LuauGenericFunctions) + else if (superFree && subFree) { - log(superTy); - occursCheck(superTy, subTy); - if (!FFlag::LuauRankNTypes) - r->DEPRECATED_canBeGeneric &= l->DEPRECATED_canBeGeneric; - r->level = min(r->level, l->level); - *asMutable(superTy) = BoundTypeVar(subTy); - return; - } - else if (l) - { - occursCheck(superTy, subTy); - - // Unification can't change the level of a generic. - auto rightGeneric = get(subTy); - if (FFlag::LuauRankNTypes && rightGeneric && !rightGeneric->level.subsumes(l->level)) + if (!occursCheck(superTy, subTy)) { - // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); - return; - } - - if (!get(superTy)) - { - if (auto rightLevel = getMutableLevel(subTy)) + if (subsumes(useScopes, superFree, subFree)) { - if (!rightLevel->subsumes(l->level)) - *rightLevel = l->level; + log.changeLevel(subTy, superFree->level); } - log(superTy); - *asMutable(superTy) = BoundTypeVar(subTy); + log.replace(superTy, BoundTypeVar(subTy)); } + return; } - else if (r) + else if (superFree) { - occursCheck(subTy, superTy); + // Unification can't change the level of a generic. + auto subGeneric = log.getMutable(subTy); + if (subGeneric && !subsumes(useScopes, subGeneric, superFree)) + { + // TODO: a more informative error message? CLI-39912 + reportError(location, GenericError{"Generic subtype escaping scope"}); + return; + } + + if (!occursCheck(superTy, subTy)) + { + promoteTypeLevels(log, types, superFree->level, superFree->scope, useScopes, subTy); + + Widen widen{types, singletonTypes}; + log.replace(superTy, BoundTypeVar(widen(subTy))); + } + + return; + } + else if (subFree) + { + if (FFlag::LuauUnknownAndNeverType) + { + // Normally, if the subtype is free, it should not be bound to any, unknown, or error types. + // But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors. + if (log.get(superTy)) + return; + } // Unification can't change the level of a generic. - auto leftGeneric = get(superTy); - if (FFlag::LuauRankNTypes && leftGeneric && !leftGeneric->level.subsumes(r->level)) + auto superGeneric = log.getMutable(superTy); + if (superGeneric && !subsumes(useScopes, superGeneric, subFree)) { // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); + reportError(location, GenericError{"Generic supertype escaping scope"}); return; } - // This is the old code which is just wrong - auto wrongGeneric = get(subTy); // Guaranteed to be null - if (!FFlag::LuauRankNTypes && FFlag::LuauGenericFunctions && wrongGeneric && r->level.subsumes(wrongGeneric->level)) + if (!occursCheck(subTy, superTy)) { - // This code is unreachable! Should we just remove it? - // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); - return; - } - - // Check if we're unifying a monotype with a polytype - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !r->DEPRECATED_canBeGeneric && isGeneric(superTy)) - { - // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Failed to unify a polytype with a monotype"}}); - return; - } - - if (!get(subTy)) - { - if (auto leftLevel = getMutableLevel(superTy)) - { - if (!leftLevel->subsumes(r->level)) - *leftLevel = r->level; - } - - log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); + promoteTypeLevels(log, types, subFree->level, subFree->scope, useScopes, superTy); + log.replace(subTy, BoundTypeVar(superTy)); } return; } - if (get(superTy) || get(superTy)) + if (get(superTy) || get(superTy) || get(superTy)) + return tryUnifyWithAny(subTy, superTy); + + if (get(subTy)) return tryUnifyWithAny(superTy, subTy); - if (get(subTy) || get(subTy)) - return tryUnifyWithAny(subTy, superTy); + if (log.get(subTy)) + return tryUnifyWithAny(superTy, subTy); + + if (log.get(subTy)) + return tryUnifyWithAny(superTy, subTy); + + auto& cache = sharedState.cachedUnify; + + // What if the types are immutable and we proved their relation before + bool cacheEnabled = !isFunctionCall && !isIntersection && variance == Invariant; + + if (cacheEnabled) + { + if (cache.contains({subTy, superTy})) + return; + + if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) + { + reportError(location, *error); + return; + } + } // If we have seen this pair of types before, we are currently recursing into cyclic types. // Here, we assume that the types unify. If they do not, we will find out as we roll back @@ -201,179 +517,655 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool log.pushSeen(superTy, subTy); - if (const UnionTypeVar* uv = get(subTy)) + size_t errorCount = errors.size(); + + if (const UnionTypeVar* subUnion = log.getMutable(subTy)) { - // A | B <: T if A <: T and B <: T - bool failed = false; - std::optional unificationTooComplex; - - size_t count = uv->options.size(); - size_t i = 0; - - for (TypeId type : uv->options) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTy, type); - - if (auto e = hasUnificationTooComplex(innerState.errors)) - unificationTooComplex = e; - else if (!innerState.errors.empty()) - failed = true; - - if (i != count - 1) - innerState.log.rollback(); - else - log.concat(std::move(innerState.log)); - - ++i; - } - - if (unificationTooComplex) - errors.push_back(*unificationTooComplex); - else if (failed) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + tryUnifyUnionWithType(subTy, subUnion, superTy); } - else if (const UnionTypeVar* uv = get(superTy)) + else if (const UnionTypeVar* uv = (FFlag::LuauSubtypeNormalizer ? nullptr : log.getMutable(superTy))) { - // T <: A | B if T <: A or T <: B - bool found = false; - std::optional unificationTooComplex; - - size_t startIndex = 0; - - if (FFlag::LuauUnionHeuristic) - { - const std::string* subName = getName(subTy); - if (subName) - { - for (size_t i = 0; i < uv->options.size(); ++i) - { - const std::string* optionName = getName(uv->options[i]); - if (optionName && *optionName == *subName) - { - startIndex = i; - break; - } - } - } - } - - for (size_t i = 0; i < uv->options.size(); ++i) - { - TypeId type = uv->options[(i + startIndex) % uv->options.size()]; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, subTy, isFunctionCall); - - if (innerState.errors.empty()) - { - found = true; - log.concat(std::move(innerState.log)); - break; - } - else if (auto e = hasUnificationTooComplex(innerState.errors)) - { - unificationTooComplex = e; - } - - innerState.log.rollback(); - } - - if (unificationTooComplex) - errors.push_back(*unificationTooComplex); - else if (!found) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); } - else if (const IntersectionTypeVar* uv = get(superTy)) + else if (const IntersectionTypeVar* uv = log.getMutable(superTy)) { - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) - { - tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); - } + tryUnifyTypeWithIntersection(subTy, superTy, uv); } - else if (const IntersectionTypeVar* uv = get(subTy)) + else if (const UnionTypeVar* uv = log.getMutable(superTy)) { - // A & B <: T if T <: A or T <: B - bool found = false; - std::optional unificationTooComplex; - - for (TypeId type : uv->parts) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTy, type, isFunctionCall); - - if (innerState.errors.empty()) - { - found = true; - log.concat(std::move(innerState.log)); - break; - } - else if (auto e = hasUnificationTooComplex(innerState.errors)) - { - unificationTooComplex = e; - } - - innerState.log.rollback(); - } - - if (unificationTooComplex) - errors.push_back(*unificationTooComplex); - else if (!found) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); } - else if (get(superTy) && get(subTy)) - tryUnifyPrimitives(superTy, subTy); + else if (const IntersectionTypeVar* uv = log.getMutable(subTy)) + { + tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); + } + else if (log.getMutable(superTy) && log.getMutable(subTy)) + tryUnifyPrimitives(subTy, superTy); - else if (get(superTy) && get(subTy)) - tryUnifyFunctions(superTy, subTy, isFunctionCall); + else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) + tryUnifySingletons(subTy, superTy); - else if (get(superTy) && get(subTy)) - tryUnifyTables(superTy, subTy, isIntersection); + else if (auto ptv = get(superTy); + FFlag::LuauNegatedFunctionTypes && ptv && ptv->type == PrimitiveTypeVar::Function && get(subTy)) + { + // Ok. Do nothing. forall functions F, F <: function + } + + else if (log.getMutable(superTy) && log.getMutable(subTy)) + tryUnifyFunctions(subTy, superTy, isFunctionCall); + + else if (log.getMutable(superTy) && log.getMutable(subTy)) + { + tryUnifyTables(subTy, superTy, isIntersection); + } + else if (FFlag::LuauScalarShapeSubtyping && log.get(superTy) && + (log.get(subTy) || log.get(subTy))) + { + tryUnifyScalarShape(subTy, superTy, /*reversed*/ false); + } + else if (FFlag::LuauScalarShapeSubtyping && log.get(subTy) && + (log.get(superTy) || log.get(superTy))) + { + tryUnifyScalarShape(subTy, superTy, /*reversed*/ true); + } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. - else if (get(superTy)) - tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false); - else if (get(subTy)) - tryUnifyWithMetatable(subTy, superTy, /*reversed*/ true); + else if (log.getMutable(superTy)) + tryUnifyWithMetatable(subTy, superTy, /*reversed*/ false); + else if (log.getMutable(subTy)) + tryUnifyWithMetatable(superTy, subTy, /*reversed*/ true); - else if (get(superTy)) - tryUnifyWithClass(superTy, subTy, /*reversed*/ false); + else if (log.getMutable(superTy)) + tryUnifyWithClass(subTy, superTy, /*reversed*/ false); // Unification of nonclasses with classes is almost, but not quite symmetrical. // The order in which we perform this test is significant in the case that both types are classes. - else if (get(subTy)) - tryUnifyWithClass(superTy, subTy, /*reversed*/ true); + else if (log.getMutable(subTy)) + tryUnifyWithClass(subTy, superTy, /*reversed*/ true); + + else if (log.get(superTy)) + tryUnifyTypeWithNegation(subTy, superTy); + + else if (log.get(subTy)) + tryUnifyNegationWithType(subTy, superTy); + + else if (FFlag::LuauUninhabitedSubAnything2 && !normalizer->isInhabited(subTy)) + {} else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); + + if (cacheEnabled) + cacheResult(subTy, superTy, errorCount); log.popSeen(superTy, subTy); } +void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, TypeId superTy) +{ + // A | B <: T if and only if A <: T and B <: T + bool failed = false; + std::optional unificationTooComplex; + std::optional firstFailedOption; + + std::vector logs; + + for (TypeId type : subUnion->options) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, superTy); + + if (FFlag::DebugLuauDeferredConstraintResolution) + logs.push_back(std::move(innerState.log)); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) + { + // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' + if (!firstFailedOption && !isNil(type)) + firstFailedOption = {innerState.errors.front()}; + + failed = true; + } + } + + if (FFlag::DebugLuauDeferredConstraintResolution) + log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); + else + { + // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. + auto tryBind = [this, subTy](TypeId superOption) { + superOption = log.follow(superOption); + + // just skip if the superOption is not free-ish. + auto ttv = log.getMutable(superOption); + if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) + return; + + // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype + // test is successful. + if (auto subUnion = get(subTy)) + { + if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) + return; + } + + // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. + // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. + if (log.haveSeen(subTy, superOption)) + { + // TODO: would it be nice for TxnLog::replace to do this? + if (log.is(superOption)) + log.bindTable(superOption, subTy); + else + log.replace(superOption, *subTy); + } + }; + + if (auto superUnion = log.getMutable(superTy)) + { + for (TypeId ty : superUnion) + tryBind(ty); + } + else + tryBind(superTy); + } + + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (failed) + { + if (firstFailedOption) + reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption, mismatchContext()}); + else + reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); + } +} + +void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall) +{ + // T <: A | B if T <: A or T <: B + bool found = false; + std::optional unificationTooComplex; + + size_t failedOptionCount = 0; + std::optional failedOption; + + bool foundHeuristic = false; + size_t startIndex = 0; + + if (const std::string* subName = getName(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + const std::string* optionName = getName(uv->options[i]); + if (optionName && *optionName == *subName) + { + foundHeuristic = true; + startIndex = i; + break; + } + } + } + + if (auto subMatchTag = getTableMatchTag(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + auto optionMatchTag = getTableMatchTag(uv->options[i]); + if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) + { + foundHeuristic = true; + startIndex = i; + break; + } + } + } + + if (!foundHeuristic && cacheEnabled) + { + auto& cache = sharedState.cachedUnify; + + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[i]; + + if (cache.contains({subTy, type})) + { + startIndex = i; + break; + } + } + } + + std::vector logs; + + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[(i + startIndex) % uv->options.size()]; + Unifier innerState = makeChildUnifier(); + innerState.normalize = false; + innerState.tryUnify_(subTy, type, isFunctionCall); + + if (innerState.errors.empty()) + { + found = true; + if (FFlag::DebugLuauDeferredConstraintResolution) + logs.push_back(std::move(innerState.log)); + else + { + log.concat(std::move(innerState.log)); + break; + } + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + { + unificationTooComplex = e; + } + else if (!isNil(type)) + { + failedOptionCount++; + + if (!failedOption) + failedOption = {innerState.errors.front()}; + } + } + + if (FFlag::DebugLuauDeferredConstraintResolution) + log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); + + if (unificationTooComplex) + { + reportError(*unificationTooComplex); + } + else if (!found && normalize) + { + // It is possible that T <: A | B even though T normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + reportError(location, UnificationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + else + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + } + else if (!found) + { + if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + reportError( + location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption, mismatchContext()}); + else + reportError(location, TypeMismatch{superTy, subTy, "none of the union options are compatible", mismatchContext()}); + } +} + +void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv) +{ + std::optional unificationTooComplex; + std::optional firstFailedOption; + + std::vector logs; + + // T <: A & B if and only if T <: A and T <: B + for (TypeId type : uv->parts) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) + { + if (!firstFailedOption) + firstFailedOption = {innerState.errors.front()}; + } + + if (FFlag::DebugLuauDeferredConstraintResolution) + logs.push_back(std::move(innerState.log)); + else + log.concat(std::move(innerState.log)); + } + + if (FFlag::DebugLuauDeferredConstraintResolution) + log.concat(combineLogsIntoIntersection(std::move(logs))); + + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (firstFailedOption) + reportError(location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption, mismatchContext()}); +} + +void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) +{ + // A & B <: T if A <: T or B <: T + bool found = false; + std::optional unificationTooComplex; + + size_t startIndex = 0; + + if (cacheEnabled) + { + auto& cache = sharedState.cachedUnify; + + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[i]; + + if (cache.contains({type, superTy})) + { + startIndex = i; + break; + } + } + } + + std::vector logs; + + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; + Unifier innerState = makeChildUnifier(); + innerState.normalize = false; + innerState.tryUnify_(type, superTy, isFunctionCall); + + if (innerState.errors.empty()) + { + found = true; + if (FFlag::DebugLuauDeferredConstraintResolution) + logs.push_back(std::move(innerState.log)); + else + { + log.concat(std::move(innerState.log)); + break; + } + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + { + unificationTooComplex = e; + } + } + + if (FFlag::DebugLuauDeferredConstraintResolution) + log.concat(combineLogsIntoIntersection(std::move(logs))); + + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (!found && normalize) + { + // It is possible that A & B <: T even though A normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (subNorm && superNorm) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); + else + reportError(location, UnificationTooComplex{}); + } + else if (!found) + { + reportError(location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible", mismatchContext()}); + } +} + +void Unifier::tryUnifyNormalizedTypes( + TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error) +{ + LUAU_ASSERT(FFlag::LuauSubtypeNormalizer); + + if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) + return; + else if (get(subNorm.tops)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + + if (get(subNorm.errors)) + if (!get(superNorm.errors)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + + if (get(subNorm.booleans)) + { + if (!get(superNorm.booleans)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + } + else if (const SingletonTypeVar* stv = get(subNorm.booleans)) + { + if (!get(superNorm.booleans) && stv != get(superNorm.booleans)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + } + + if (get(subNorm.nils)) + if (!get(superNorm.nils)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + + if (get(subNorm.numbers)) + if (!get(superNorm.numbers)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + + if (!isSubtype(subNorm.strings, superNorm.strings)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + + if (get(subNorm.threads)) + if (!get(superNorm.errors)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + + for (TypeId subClass : subNorm.classes) + { + bool found = false; + const ClassTypeVar* subCtv = get(subClass); + for (TypeId superClass : superNorm.classes) + { + const ClassTypeVar* superCtv = get(superClass); + if (isSubclass(subCtv, superCtv)) + { + found = true; + break; + } + } + if (!found) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + } + + for (TypeId subTable : subNorm.tables) + { + bool found = false; + for (TypeId superTable : superNorm.tables) + { + Unifier innerState = makeChildUnifier(); + if (get(superTable)) + innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); + else if (get(subTable)) + innerState.tryUnifyWithMetatable(superTable, subTable, /* reversed */ true); + else + innerState.tryUnifyTables(subTable, superTable); + if (innerState.errors.empty()) + { + found = true; + log.concat(std::move(innerState.log)); + break; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + return reportError(*e); + } + if (!found) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + } + + if (!subNorm.functions.isNever()) + { + if (superNorm.functions.isNever()) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + for (TypeId superFun : *superNorm.functions.parts) + { + Unifier innerState = makeChildUnifier(); + const FunctionTypeVar* superFtv = get(superFun); + if (!superFtv) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); + innerState.tryUnify_(tgt, superFtv->retTypes); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else if (auto e = hasUnificationTooComplex(innerState.errors)) + return reportError(*e); + else + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + } + } + + for (auto& [tyvar, subIntersect] : subNorm.tyvars) + { + auto found = superNorm.tyvars.find(tyvar); + if (found == superNorm.tyvars.end()) + tryUnifyNormalizedTypes(subTy, superTy, *subIntersect, superNorm, reason, error); + else + tryUnifyNormalizedTypes(subTy, superTy, *subIntersect, *found->second, reason, error); + if (!errors.empty()) + return; + } +} + +TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args) +{ + if (overloads.isNever()) + { + reportError(location, CannotCallNonFunction{function}); + return singletonTypes->errorRecoveryTypePack(); + } + + std::optional result; + const FunctionTypeVar* firstFun = nullptr; + for (TypeId overload : *overloads.parts) + { + if (const FunctionTypeVar* ftv = get(overload)) + { + // TODO: instantiate generics? + if (ftv->generics.empty() && ftv->genericPacks.empty()) + { + if (!firstFun) + firstFun = ftv; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(args, ftv->argTypes); + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); + if (result) + { + if (FFlag::LuauOverloadedFunctionSubtypingPerf) + { + innerState.log.clear(); + innerState.tryUnify_(*result, ftv->retTypes); + } + if (FFlag::LuauOverloadedFunctionSubtypingPerf && innerState.errors.empty()) + log.concat(std::move(innerState.log)); + // Annoyingly, since we don't support intersection of generic type packs, + // the intersection may fail. We rather arbitrarily use the first matching overload + // in that case. + else if (std::optional intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes)) + result = intersect; + } + else + result = ftv->retTypes; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + { + reportError(*e); + return singletonTypes->errorRecoveryTypePack(args); + } + } + } + } + + if (result) + return *result; + else if (firstFun) + { + // TODO: better error reporting? + // The logic for error reporting overload resolution + // is currently over in TypeInfer.cpp, should we move it? + reportError(location, GenericError{"No matching overload."}); + return singletonTypes->errorRecoveryTypePack(firstFun->retTypes); + } + else + { + reportError(location, CannotCallNonFunction{function}); + return singletonTypes->errorRecoveryTypePack(); + } +} + +bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) +{ + bool* superTyInfo = sharedState.skipCacheForType.find(superTy); + + if (superTyInfo && *superTyInfo) + return false; + + bool* subTyInfo = sharedState.skipCacheForType.find(subTy); + + if (subTyInfo && *subTyInfo) + return false; + + auto skipCacheFor = [this](TypeId ty) { + SkipCacheForType visitor{sharedState.skipCacheForType, types}; + visitor.traverse(ty); + + sharedState.skipCacheForType[ty] = visitor.result; + + return visitor.result; + }; + + if (!superTyInfo && skipCacheFor(superTy)) + return false; + + if (!subTyInfo && skipCacheFor(subTy)) + return false; + + return true; +} + +void Unifier::cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount) +{ + if (errors.size() == prevErrorCount) + { + if (canCacheResult(subTy, superTy)) + sharedState.cachedUnify.insert({subTy, superTy}); + } + else if (errors.size() == prevErrorCount + 1) + { + if (canCacheResult(subTy, superTy)) + sharedState.cachedUnifyError[{subTy, superTy}] = errors.back().data; + } +} + struct WeirdIter { TypePackId packId; - const TypePack* pack; + TxnLog& log; + TypePack* pack; size_t index; bool growing; TypeLevel level; + Scope* scope = nullptr; - WeirdIter(TypePackId packId) + WeirdIter(TypePackId packId, TxnLog& log) : packId(packId) - , pack(get(packId)) + , log(log) + , pack(log.getMutable(packId)) , index(0) , growing(false) { while (pack && pack->head.empty() && pack->tail) { packId = *pack->tail; - pack = get(packId); + pack = log.getMutable(packId); } } WeirdIter(const WeirdIter&) = default; - const TypeId& operator*() + TypeId& operator*() { LUAU_ASSERT(good()); return pack->head[index]; @@ -397,8 +1189,8 @@ struct WeirdIter if (pack->tail) { - packId = follow(*pack->tail); - pack = get(packId); + packId = log.follow(*pack->tail); + pack = log.getMutable(packId); index = 0; } @@ -407,73 +1199,95 @@ struct WeirdIter bool canGrow() const { - return nullptr != get(packId); + return nullptr != log.getMutable(packId); } void grow(TypePackId newTail) { LUAU_ASSERT(canGrow()); - level = get(packId)->level; - *asMutable(packId) = Unifiable::Bound(newTail); + LUAU_ASSERT(log.getMutable(newTail)); + + level = log.getMutable(packId)->level; + scope = log.getMutable(packId)->scope; + log.replace(packId, BoundTypePack(newTail)); packId = newTail; - pack = get(newTail); + pack = log.getMutable(newTail); index = 0; growing = true; } + + void pushType(TypeId ty) + { + LUAU_ASSERT(pack); + PendingTypePack* pendingPack = log.queue(packId); + if (TypePack* pending = getMutable(pendingPack)) + { + pending->head.push_back(ty); + // We've potentially just replaced the TypePack* that we need to look + // in. We need to replace pack. + pack = pending; + } + else + { + LUAU_ASSERT(!"Pending state for this pack was not a TypePack"); + } + } }; -ErrorVec Unifier::canUnify(TypeId superTy, TypeId subTy) +ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) { Unifier s = makeChildUnifier(); - s.tryUnify_(superTy, subTy); - s.log.rollback(); + s.tryUnify_(subTy, superTy); + return s.errors; } -ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall) +ErrorVec Unifier::canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall) { Unifier s = makeChildUnifier(); - s.tryUnify_(superTy, subTy, isFunctionCall); - s.log.rollback(); + s.tryUnify_(subTy, superTy, isFunctionCall); + return s.errors; } -void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) +void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { - counters->iterationCount = 0; - return tryUnify_(superTp, subTp, isFunctionCall); + sharedState.counters.iterationCount = 0; + + tryUnify_(subTp, superTp, isFunctionCall); } /* * This is quite tricky: we are walking two rope-like structures and unifying corresponding elements. * If one is longer than the other, but the short end is free, we grow it to the required length. */ -void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) +void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); - ++counters->iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + ++sharedState.counters.iterationCount; + + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) { - errors.push_back(TypeError{location, UnificationTooComplex{}}); + reportError(location, UnificationTooComplex{}); return; } - superTp = follow(superTp); - subTp = follow(subTp); + superTp = log.follow(superTp); + subTp = log.follow(subTp); - while (auto r = get(subTp)) + while (auto tp = log.getMutable(subTp)) { - if (r->head.empty() && r->tail) - subTp = follow(*r->tail); + if (tp->head.empty() && tp->tail) + subTp = log.follow(*tp->tail); else break; } - while (auto l = get(superTp)) + while (auto tp = log.getMutable(superTp)) { - if (l->head.empty() && l->tail) - superTp = follow(*l->tail); + if (tp->head.empty() && tp->tail) + superTp = log.follow(*tp->tail); else break; } @@ -481,57 +1295,50 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal if (superTp == subTp) return; - if (get(superTp)) - { - occursCheck(superTp, subTp); + if (log.haveSeen(superTp, subTp)) + return; - if (!get(superTp)) + if (log.getMutable(superTp)) + { + if (!occursCheck(superTp, subTp)) { - log(superTp); - *asMutable(superTp) = Unifiable::Bound(subTp); + Widen widen{types, singletonTypes}; + log.replace(superTp, Unifiable::Bound(widen(subTp))); } } - - else if (get(subTp)) + else if (log.getMutable(subTp)) { - occursCheck(subTp, superTp); - - if (!get(subTp)) + if (!occursCheck(subTp, superTp)) { - log(subTp); - *asMutable(subTp) = Unifiable::Bound(superTp); + log.replace(subTp, Unifiable::Bound(superTp)); } } - - else if (get(superTp)) - tryUnifyWithAny(superTp, subTp); - - else if (get(subTp)) + else if (log.getMutable(superTp)) tryUnifyWithAny(subTp, superTp); - - else if (get(superTp)) - tryUnifyVariadics(superTp, subTp, false); - else if (get(subTp)) - tryUnifyVariadics(subTp, superTp, true); - - else if (get(superTp) && get(subTp)) + else if (log.getMutable(subTp)) + tryUnifyWithAny(superTp, subTp); + else if (log.getMutable(superTp)) + tryUnifyVariadics(subTp, superTp, false); + else if (log.getMutable(subTp)) + tryUnifyVariadics(superTp, subTp, true); + else if (log.getMutable(superTp) && log.getMutable(subTp)) { - auto l = get(superTp); - auto r = get(subTp); + auto superTpv = log.getMutable(superTp); + auto subTpv = log.getMutable(subTp); // If the size of two heads does not match, but both packs have free tail // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = flatten(superTp); - auto [subTypes, subTail] = flatten(subTp); + auto [superTypes, superTail] = flatten(superTp, log); + auto [subTypes, subTail] = flatten(subTp, log); - bool noInfiniteGrowth = - (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); + bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable(*superTail)) && + (subTail && log.getMutable(*subTail)); - auto superIter = WeirdIter{superTp}; - auto subIter = WeirdIter{subTp}; + auto superIter = WeirdIter(superTp, log); + auto subIter = WeirdIter(subTp, log); - auto mkFreshType = [this](TypeLevel level) { - return types->freshType(level); + auto mkFreshType = [this](Scope* scope, TypeLevel level) { + return types->freshType(scope, level); }; const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); @@ -546,14 +1353,22 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal ++loopCount; if (superIter.good() && subIter.growing) - asMutable(subIter.pack)->head.push_back(mkFreshType(subIter.level)); + { + subIter.pushType(mkFreshType(subIter.scope, subIter.level)); + } if (subIter.good() && superIter.growing) - asMutable(superIter.pack)->head.push_back(mkFreshType(superIter.level)); + { + superIter.pushType(mkFreshType(superIter.scope, superIter.level)); + } if (superIter.good() && subIter.good()) { - tryUnify_(*superIter, *subIter); + tryUnify_(*subIter, *superIter); + + if (!errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; + superIter.advance(); subIter.advance(); continue; @@ -562,30 +1377,49 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - const bool lFreeTail = l->tail && get(FFlag::LuauAddMissingFollow ? follow(*l->tail) : *l->tail) != nullptr; - const bool rFreeTail = r->tail && get(FFlag::LuauAddMissingFollow ? follow(*r->tail) : *r->tail) != nullptr; - if (lFreeTail && rFreeTail) - tryUnify_(*l->tail, *r->tail); + if (!FFlag::LuauTxnLogTypePackIterator && subTpv->tail && superTpv->tail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + break; + } + + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; + if (FFlag::LuauTxnLogTypePackIterator && lFreeTail && rFreeTail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + } else if (lFreeTail) - tryUnify_(*l->tail, emptyTp); + { + tryUnify_(emptyTp, *superTpv->tail); + } else if (rFreeTail) - tryUnify_(*r->tail, emptyTp); + { + tryUnify_(emptyTp, *subTpv->tail); + } + else if (FFlag::LuauTxnLogTypePackIterator && subTpv->tail && superTpv->tail) + { + if (log.getMutable(superIter.packId)) + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + else if (log.getMutable(subIter.packId)) + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + else + tryUnify_(*subTpv->tail, *superTpv->tail); + } break; } // If both tails are free, bind one to the other and call it a day if (superIter.canGrow() && subIter.canGrow()) - return tryUnify_(*superIter.pack->tail, *subIter.pack->tail); + return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); // If just one side is free on its tail, grow it to fit the other side. // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. if (superIter.canGrow()) superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - else if (subIter.canGrow()) subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - else { // A union type including nil marks an optional argument @@ -600,22 +1434,15 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal continue; } - // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && get(FFlag::LuauAddMissingFollow ? follow(*superIter) : *superIter)) + if (log.getMutable(superIter.packId)) { - superIter.advance(); - continue; - } - - if (get(superIter.packId)) - { - tryUnifyVariadics(superIter.packId, subIter.packId, false, int(subIter.index)); + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); return; } - if (get(subIter.packId)) + if (log.getMutable(subIter.packId)) { - tryUnifyVariadics(subIter.packId, superIter.packId, true, int(superIter.index)); + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); return; } @@ -626,23 +1453,23 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal } // This is a bit weird because we don't actually know expected vs actual. We just know - // subtype vs supertype. If we are checking a return value, we swap these to produce - // the expected error message. + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. size_t expectedSize = size(superTp); size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result || ctx == CountMismatch::Return) + if (ctx == CountMismatch::FunctionResult || ctx == CountMismatch::ExprListResult) std::swap(expectedSize, actualSize); - errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + reportError(location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}); while (superIter.good()) { - tryUnify_(singletonTypes.errorType, *superIter); + tryUnify_(*superIter, singletonTypes->errorRecoveryType()); superIter.advance(); } while (subIter.good()) { - tryUnify_(singletonTypes.errorType, *subIter); + tryUnify_(*subIter, singletonTypes->errorRecoveryType()); subIter.advance(); } @@ -653,46 +1480,103 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + if (FFlag::LuauReportTypeMismatchForTypePackUnificationFailure) + reportError(location, TypePackMismatch{subTp, superTp}); + else + reportError(location, GenericError{"Failed to unify type packs"}); } } -void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy) +void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) { - const PrimitiveTypeVar* lp = get(superTy); - const PrimitiveTypeVar* rp = get(subTy); - if (!lp || !rp) + const PrimitiveTypeVar* superPrim = get(superTy); + const PrimitiveTypeVar* subPrim = get(subTy); + if (!superPrim || !subPrim) ice("passed non primitive types to unifyPrimitives"); - if (lp->type != rp->type) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + if (superPrim->type != subPrim->type) + reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); } -void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall) +void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) { - FunctionTypeVar* lf = getMutable(superTy); - FunctionTypeVar* rf = getMutable(subTy); - if (!lf || !rf) + const PrimitiveTypeVar* superPrim = get(superTy); + const SingletonTypeVar* superSingleton = get(superTy); + const SingletonTypeVar* subSingleton = get(subTy); + + if ((!superPrim && !superSingleton) || !subSingleton) + ice("passed non singleton/primitive types to unifySingletons"); + + if (superSingleton && *superSingleton == *subSingleton) + return; + + if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get(subSingleton) && variance == Covariant) + return; + + if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) + return; + + reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); +} + +void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) +{ + FunctionTypeVar* superFunction = log.getMutable(superTy); + FunctionTypeVar* subFunction = log.getMutable(subTy); + + if (!superFunction || !subFunction) ice("passed non-function types to unifyFunction"); - size_t numGenerics = lf->generics.size(); - if (FFlag::LuauGenericFunctions && numGenerics != rf->generics.size()) + size_t numGenerics = superFunction->generics.size(); + size_t numGenericPacks = superFunction->genericPacks.size(); + + bool shouldInstantiate = (numGenerics == 0 && subFunction->generics.size() > 0) || (numGenericPacks == 0 && subFunction->genericPacks.size() > 0); + + // TODO: This is unsound when the context is invariant, but the annotation burden without allowing it and without + // read-only properties is too high for lua-apps. Read-only properties _should_ resolve their issue by allowing + // generic methods in tables to be marked read-only. + if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate) { - numGenerics = std::min(lf->generics.size(), rf->generics.size()); - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + Instantiation instantiation{&log, types, scope->level, scope}; + + std::optional instantiated = instantiation.substitute(subTy); + if (instantiated.has_value()) + { + subFunction = log.getMutable(*instantiated); + + if (!subFunction) + ice("instantiation made a function type into a non-function type in unifyFunction"); + + numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); + numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); + } + else + { + reportError(location, UnificationTooComplex{}); + } + } + else if (numGenerics != subFunction->generics.size()) + { + numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); + + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type parameters", mismatchContext()}); } - size_t numGenericPacks = lf->genericPacks.size(); - if (FFlag::LuauGenericFunctions && numGenericPacks != rf->genericPacks.size()) + if (numGenericPacks != subFunction->genericPacks.size()) { - numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); + + reportError(location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters", mismatchContext()}); } - if (FFlag::LuauGenericFunctions) + for (size_t i = 0; i < numGenerics; i++) { - for (size_t i = 0; i < numGenerics; i++) - log.pushSeen(lf->generics[i], rf->generics[i]); + log.pushSeen(superFunction->generics[i], subFunction->generics[i]); + } + + for (size_t i = 0; i < numGenericPacks; i++) + { + log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); } CountMismatch::Context context = ctx; @@ -701,40 +1585,60 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal { Unifier innerState = makeChildUnifier(); - ctx = CountMismatch::Arg; - innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + innerState.ctx = CountMismatch::Arg; + innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); - ctx = CountMismatch::Result; - innerState.tryUnify_(lf->retType, rf->retType); + bool reported = !innerState.errors.empty(); - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + if (auto e = hasUnificationTooComplex(innerState.errors)) + reportError(*e); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + reportError(location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front(), mismatchContext()}); + else if (!innerState.errors.empty()) + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); + + innerState.ctx = CountMismatch::FunctionResult; + innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); + + if (!reported) + { + if (auto e = hasUnificationTooComplex(innerState.errors)) + reportError(*e); + else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) + reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front(), mismatchContext()}); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + reportError(location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front(), mismatchContext()}); + else if (!innerState.errors.empty()) + reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); + } log.concat(std::move(innerState.log)); } else { ctx = CountMismatch::Arg; - tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); - ctx = CountMismatch::Result; - tryUnify_(lf->retType, rf->retType); + ctx = CountMismatch::FunctionResult; + tryUnify_(subFunction->retTypes, superFunction->retTypes); } - if (lf->definition && !rf->definition && (!FFlag::LuauDontMutatePersistentFunctions || !subTy->persistent)) - { - rf->definition = lf->definition; - } - else if (!lf->definition && rf->definition && (!FFlag::LuauDontMutatePersistentFunctions || !superTy->persistent)) - { - lf->definition = rf->definition; - } + // Updating the log may have invalidated the function pointers + superFunction = log.getMutable(superTy); + subFunction = log.getMutable(subTy); ctx = context; - if (FFlag::LuauGenericFunctions) + for (int i = int(numGenericPacks) - 1; 0 <= i; i--) { - for (int i = int(numGenerics) - 1; 0 <= i; i--) - log.popSeen(lf->generics[i], rf->generics[i]); + log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + } + + for (int i = int(numGenerics) - 1; 0 <= i; i--) + { + log.popSeen(superFunction->generics[i], subFunction->generics[i]); } } @@ -760,338 +1664,485 @@ struct Resetter } // namespace -void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { - std::unique_ptr resetter; + TableTypeVar* superTable = log.getMutable(superTy); + TableTypeVar* subTable = log.getMutable(subTy); + TableTypeVar* instantiatedSubTable = subTable; - resetter.reset(new Resetter{&variance}); - variance = Invariant; - - TableTypeVar* lt = getMutable(left); - TableTypeVar* rt = getMutable(right); - if (!lt || !rt) + if (!superTable || !subTable) ice("passed non-table types to unifyTables"); - if (lt->state == TableState::Sealed && rt->state == TableState::Sealed) - return tryUnifySealedTables(left, right, isIntersection); - else if ((lt->state == TableState::Sealed && rt->state == TableState::Unsealed) || - (lt->state == TableState::Unsealed && rt->state == TableState::Sealed)) - return tryUnifySealedTables(left, right, isIntersection); - else if ((lt->state == TableState::Sealed && rt->state == TableState::Generic) || - (lt->state == TableState::Generic && rt->state == TableState::Sealed)) - errors.push_back(TypeError{location, TypeMismatch{left, right}}); - else if ((lt->state == TableState::Free) != (rt->state == TableState::Free)) // one table is free and the other is not - { - TypeId freeTypeId = rt->state == TableState::Free ? right : left; - TypeId otherTypeId = rt->state == TableState::Free ? left : right; + std::vector missingProperties; + std::vector extraProperties; - return tryUnifyFreeTable(freeTypeId, otherTypeId); - } - else if (lt->state == TableState::Free && rt->state == TableState::Free) + if (FFlag::LuauInstantiateInSubtyping) { - tryUnifyFreeTable(left, right); - - // avoid creating a cycle when the types are already pointing at each other - if (follow(left) != follow(right)) + if (variance == Covariant && subTable->state == TableState::Generic && superTable->state != TableState::Generic) { - log(lt); - lt->boundTo = right; - } - return; - } - else if (lt->state != TableState::Sealed && rt->state != TableState::Sealed) - { - // All free tables are checked in one of the branches above - LUAU_ASSERT(lt->state != TableState::Free); - LUAU_ASSERT(rt->state != TableState::Free); + Instantiation instantiation{&log, types, subTable->level, scope}; - // Tables must have exactly the same props and their types must all unify - // I honestly have no idea if this is remotely close to reasonable. - for (const auto& [name, prop] : lt->props) - { - const auto& r = rt->props.find(name); - if (r == rt->props.end()) - errors.push_back(TypeError{location, UnknownProperty{right, name}}); + std::optional instantiated = instantiation.substitute(subTy); + if (instantiated.has_value()) + { + subTable = log.getMutable(*instantiated); + instantiatedSubTable = subTable; + + if (!subTable) + ice("instantiation made a table type into a non-table type in tryUnifyTables"); + } else - tryUnify_(prop.type, r->second.type); + { + reportError(location, UnificationTooComplex{}); + } + } + } + + // Optimization: First test that the property sets are compatible without doing any recursive unification + if (!subTable->indexer && subTable->state != TableState::Free) + { + for (const auto& [propName, superProp] : superTable->props) + { + auto subIter = subTable->props.find(propName); + + if (subIter == subTable->props.end() && subTable->state == TableState::Unsealed && !isOptional(superProp.type)) + missingProperties.push_back(propName); } - if (lt->indexer && rt->indexer) - tryUnify(*lt->indexer, *rt->indexer); - else if (lt->indexer) + if (!missingProperties.empty()) + { + reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)}); + return; + } + } + + // And vice versa if we're invariant + if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && superTable->state != TableState::Free) + { + for (const auto& [propName, subProp] : subTable->props) + { + auto superIter = superTable->props.find(propName); + + if (superIter == superTable->props.end()) + extraProperties.push_back(propName); + } + + if (!extraProperties.empty()) + { + reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}); + return; + } + } + + // Width subtyping: any property in the supertype must be in the subtype, + // and the types must agree. + for (const auto& [name, prop] : superTable->props) + { + const auto& r = subTable->props.find(name); + if (r != subTable->props.end()) + { + // TODO: read-only properties don't need invariance + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(r->second.type, prop.type); + + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } + else if (subTable->indexer && maybeString(subTable->indexer->indexType)) + { + // TODO: read-only indexers don't need invariance + // TODO: really we should only allow this if prop.type is optional. + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(subTable->indexer->indexResultType, prop.type); + + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } + else if (subTable->state == TableState::Unsealed && isOptional(prop.type)) + // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` + // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. + // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) + { + } + else if (subTable->state == TableState::Free) + { + PendingType* pendingSub = log.queue(subTy); + TableTypeVar* ttv = getMutable(pendingSub); + LUAU_ASSERT(ttv); + ttv->props[name] = prop; + subTable = ttv; + } + else + missingProperties.push_back(name); + + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(superTy) : superTy; + TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(subTy) : subTy; + + if (FFlag::LuauScalarShapeUnifyToMtOwner2) + { + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || subTy != subTyNew) && errors.empty()) + return tryUnify(subTy, superTy, false, isIntersection); + } + + // Otherwise, restart only the table unification + TableTypeVar* newSuperTable = log.getMutable(superTyNew); + TableTypeVar* newSubTable = log.getMutable(subTyNew); + + if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) + { + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; + } + } + + for (const auto& [name, prop] : subTable->props) + { + if (superTable->props.count(name)) + { + // If both lt and rt contain the property, then + // we're done since we already unified them above + } + else if (superTable->indexer && maybeString(superTable->indexer->indexType)) + { + // TODO: read-only indexers don't need invariance + // TODO: really we should only allow this if prop.type is optional. + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(superTable->indexer->indexResultType, prop.type); + + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } + else if (superTable->state == TableState::Unsealed) + { + // TODO: this case is unsound when variance is Invariant, but without it lua-apps fails to typecheck. + // TODO: file a JIRA + // TODO: hopefully readonly/writeonly properties will fix this. + Property clone = prop; + clone.type = deeplyOptional(clone.type); + + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = clone; + superTable = pendingSuperTtv; + } + else if (variance == Covariant) + { + } + else if (superTable->state == TableState::Free) + { + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = prop; + superTable = pendingSuperTtv; + } + else + extraProperties.push_back(name); + + TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(superTy) : superTy; + TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(subTy) : subTy; + + if (FFlag::LuauScalarShapeUnifyToMtOwner2) + { + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || subTy != subTyNew) && errors.empty()) + return tryUnify(subTy, superTy, false, isIntersection); + } + + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTyNew); + TableTypeVar* newSubTable = log.getMutable(subTyNew); + if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) + { + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; + } + } + + // Unify indexers + if (superTable->indexer && subTable->indexer) + { + // TODO: read-only indexers don't need invariance + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + + innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); + + bool reported = !innerState.errors.empty(); + + checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); + + innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); + + if (!reported) + checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); + + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } + else if (superTable->indexer) + { + if (subTable->state == TableState::Unsealed || subTable->state == TableState::Free) { // passing/assigning a table without an indexer to something that has one // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. - if (rt->state == TableState::Unsealed) - rt->indexer = lt->indexer; - else - errors.push_back(TypeError{location, CannotExtendTable{right, CannotExtendTable::Indexer}}); + // TODO: we only need to do this if the supertype's indexer is read/write + // since that can add indexed elements. + log.changeIndexer(subTy, superTable->indexer); } } - else if (lt->state == TableState::Sealed) + else if (subTable->indexer && variance == Invariant) { - // lt is sealed and so it must be possible for rt to have precisely the same shape - // Verify that this is the case, then bind rt to lt. - ice("unsealed tables are not working yet", location); + // Symmetric if we are invariant + if (superTable->state == TableState::Unsealed || superTable->state == TableState::Free) + { + log.changeIndexer(superTy, subTable->indexer); + } + } + + // Changing the indexer can invalidate the table pointers. + if (FFlag::LuauScalarShapeUnifyToMtOwner2) + { + superTable = log.getMutable(log.follow(superTy)); + subTable = log.getMutable(log.follow(subTy)); + + if (!superTable || !subTable) + return; } - else if (rt->state == TableState::Sealed) - return tryUnifyTables(right, left, isIntersection); else - ice("tryUnifyTables"); -} - -void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) -{ - TableTypeVar* freeTable = getMutable(freeTypeId); - TableTypeVar* otherTable = getMutable(otherTypeId); - if (!freeTable || !otherTable) - ice("passed non-table types to tryUnifyFreeTable"); - - // Any properties in freeTable must unify with those in otherTable. - // Then bind freeTable to otherTable. - for (const auto& [freeName, freeProp] : freeTable->props) { - if (auto otherProp = findTablePropertyRespectingMeta(otherTypeId, freeName)) - { - tryUnify_(*otherProp, freeProp.type); - - /* - * TypeVars are commonly cyclic, so it is entirely possible - * for unifying a property of a table to change the table itself! - * We need to check for this and start over if we notice this occurring. - * - * I believe this is guaranteed to terminate eventually because this will - * only happen when a free table is bound to another table. - */ - if (!get(freeTypeId) || !get(otherTypeId)) - return tryUnify_(freeTypeId, otherTypeId); - - if (freeTable->boundTo) - return tryUnify_(freeTypeId, otherTypeId); - } - else - { - // If the other table is also free, then we are learning that it has more - // properties than we previously thought. Else, it is an error. - if (otherTable->state == TableState::Free) - otherTable->props.insert({freeName, freeProp}); - else - errors.push_back(TypeError{location, UnknownProperty{otherTypeId, freeName}}); - } + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); } - if (freeTable->indexer && otherTable->indexer) + if (!missingProperties.empty()) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify(*freeTable->indexer, *otherTable->indexer); - - checkChildUnifierTypeMismatch(innerState.errors, freeTypeId, otherTypeId); - - log.concat(std::move(innerState.log)); - } - else if (otherTable->state == TableState::Free && freeTable->indexer) - freeTable->indexer = otherTable->indexer; - - if (!freeTable->boundTo && otherTable->state != TableState::Free) - { - if (FFlag::LuauLogTableTypeVarBoundTo) - log(freeTable); - else - log(freeTypeId); - freeTable->boundTo = otherTypeId; - } -} - -void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection) -{ - TableTypeVar* lt = getMutable(left); - TableTypeVar* rt = getMutable(right); - if (!lt || !rt) - ice("passed non-table types to unifySealedTables"); - - Unifier innerState = makeChildUnifier(); - - std::vector missingPropertiesInSuper; - bool isUnnamedTable = rt->name == std::nullopt && rt->syntheticName == std::nullopt; - bool errorReported = false; - - // Optimization: First test that the property sets are compatible without doing any recursive unification - if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer) - { - for (const auto& [propName, superProp] : lt->props) - { - auto subIter = rt->props.find(propName); - if (subIter == rt->props.end() && !isOptional(superProp.type)) - missingPropertiesInSuper.push_back(propName); - } - - if (!missingPropertiesInSuper.empty()) - { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingPropertiesInSuper)}}); - return; - } - } - - // Tables must have exactly the same props and their types must all unify - for (const auto& it : lt->props) - { - const auto& r = rt->props.find(it.first); - if (r == rt->props.end()) - { - if (FFlag::LuauSealedTableUnifyOptionalFix) - { - if (isOptional(it.second.type)) - continue; - } - else - { - if (get(it.second.type)) - { - const UnionTypeVar* possiblyOptional = get(it.second.type); - const std::vector& options = possiblyOptional->options; - if (options.end() != std::find_if(options.begin(), options.end(), isNil)) - continue; - } - } - - missingPropertiesInSuper.push_back(it.first); - - innerState.errors.push_back(TypeError{location, TypeMismatch{left, right}}); - } - else - { - if (isUnnamedTable && r->second.location) - { - size_t oldErrorSize = innerState.errors.size(); - Location old = innerState.location; - innerState.location = *r->second.location; - innerState.tryUnify_(it.second.type, r->second.type); - innerState.location = old; - - if (oldErrorSize != innerState.errors.size() && !errorReported) - { - errorReported = true; - errors.push_back(innerState.errors.back()); - } - } - else - { - innerState.tryUnify_(it.second.type, r->second.type); - } - } - } - - if (lt->indexer || rt->indexer) - { - if (lt->indexer && rt->indexer) - innerState.tryUnify(*lt->indexer, *rt->indexer); - else if (rt->state == TableState::Unsealed) - { - if (lt->indexer && !rt->indexer) - rt->indexer = lt->indexer; - } - else if (lt->state == TableState::Unsealed) - { - if (rt->indexer && !lt->indexer) - lt->indexer = rt->indexer; - } - else if (lt->indexer) - { - innerState.tryUnify_(lt->indexer->indexType, singletonTypes.stringType); - // We already try to unify properties in both tables. - // Skip those and just look for the ones remaining and see if they fit into the indexer. - for (const auto& [name, type] : rt->props) - { - const auto& it = lt->props.find(name); - if (it == lt->props.end()) - innerState.tryUnify_(lt->indexer->indexResultType, type.type); - } - } - else - innerState.errors.push_back(TypeError{location, TypeMismatch{left, right}}); - } - - log.concat(std::move(innerState.log)); - - if (errorReported) - return; - - if (!missingPropertiesInSuper.empty()) - { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingPropertiesInSuper)}}); + reportError(location, MissingProperties{superTy, subTy, std::move(missingProperties)}); return; } - // If the superTy/left is an immediate part of an intersection type, do not do extra-property check. - // Otherwise, we would falsely generate an extra-property-error for 's' in this code: - // local a: {n: number} & {s: string} = {n=1, s=""} - // When checking agaist the table '{n: number}'. - if (!isIntersection && lt->state != TableState::Unsealed && !lt->indexer) + if (!extraProperties.empty()) { - // Check for extra properties in the subTy - std::vector extraPropertiesInSub; + reportError(location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}); + return; + } - for (const auto& it : rt->props) + /* + * TypeVars are commonly cyclic, so it is entirely possible + * for unifying a property of a table to change the table itself! + * We need to check for this and start over if we notice this occurring. + * + * I believe this is guaranteed to terminate eventually because this will + * only happen when a free table is bound to another table. + */ + if (superTable->boundTo || subTable->boundTo) + return tryUnify_(subTy, superTy); + + if (superTable->state == TableState::Free) + { + log.bindTable(superTy, subTy); + } + else if (subTable->state == TableState::Free) + { + log.bindTable(subTy, superTy); + } +} + +void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) +{ + LUAU_ASSERT(FFlag::LuauScalarShapeSubtyping); + + TypeId osubTy = subTy; + TypeId osuperTy = superTy; + + if (FFlag::LuauUninhabitedSubAnything2 && !normalizer->isInhabited(subTy)) + return; + + if (reversed) + std::swap(subTy, superTy); + + TableTypeVar* superTable = log.getMutable(superTy); + + if (!superTable || superTable->state != TableState::Free) + return reportError(location, TypeMismatch{osuperTy, osubTy, mismatchContext()}); + + auto fail = [&](std::optional e) { + std::string reason = "The former's metatable does not satisfy the requirements."; + if (e) + reportError(location, TypeMismatch{osuperTy, osubTy, reason, *e, mismatchContext()}); + else + reportError(location, TypeMismatch{osuperTy, osubTy, reason, mismatchContext()}); + }; + + // Given t1 where t1 = { lower: (t1) -> (a, b...) } + // It should be the case that `string <: t1` iff `(subtype's metatable).__index <: t1` + if (auto metatable = getMetatable(subTy, singletonTypes)) + { + auto mttv = log.get(*metatable); + if (!mttv) + fail(std::nullopt); + + if (auto it = mttv->props.find("__index"); it != mttv->props.end()) { - const auto& r = lt->props.find(it.first); - if (r == lt->props.end()) + TypeId ty = it->second.type; + Unifier child = makeChildUnifier(); + child.tryUnify_(ty, superTy); + + if (FFlag::LuauScalarShapeUnifyToMtOwner2) { - if (FFlag::LuauSealedTableUnifyOptionalFix) - { - if (isOptional(it.second.type)) - continue; - } - else - { - if (get(it.second.type)) - { - const UnionTypeVar* possiblyOptional = get(it.second.type); - const std::vector& options = possiblyOptional->options; - if (options.end() != std::find_if(options.begin(), options.end(), isNil)) - continue; - } - } + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed + // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check + TypeId newSuperTy = child.log.follow(superTy); - extraPropertiesInSub.push_back(it.first); + if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) + { + log.replace(superTy, BoundTypeVar{subTy}); + return; + } + } + + if (auto e = hasUnificationTooComplex(child.errors)) + reportError(*e); + else if (!child.errors.empty()) + fail(child.errors.front()); + + log.concat(std::move(child.log)); + + if (FFlag::LuauScalarShapeUnifyToMtOwner2) + { + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype + if (child.errors.empty()) + log.replace(superTy, BoundTypeVar{subTy}); } - } - if (!extraPropertiesInSub.empty()) - { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraPropertiesInSub), MissingProperties::Extra}}); return; } + else + { + return fail(std::nullopt); + } } - checkChildUnifierTypeMismatch(innerState.errors, left, right); + reportError(location, TypeMismatch{osuperTy, osubTy, mismatchContext()}); + return; } -void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed) +TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) { - const MetatableTypeVar* lhs = get(metatable); - if (!lhs) + ty = follow(ty); + if (isOptional(ty)) + return ty; + else if (const TableTypeVar* ttv = get(ty)) + { + TypeId& result = seen[ty]; + if (result) + return result; + result = types->addType(*ttv); + TableTypeVar* resultTtv = getMutable(result); + for (auto& [name, prop] : resultTtv->props) + prop.type = deeplyOptional(prop.type, seen); + return types->addType(UnionTypeVar{{singletonTypes->nilType, result}}); + } + else + return types->addType(UnionTypeVar{{singletonTypes->nilType, ty}}); +} + +void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) +{ + const MetatableTypeVar* superMetatable = get(superTy); + if (!superMetatable) ice("tryUnifyMetatable invoked with non-metatable TypeVar"); - TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other}}; + TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, mismatchContext()}}; - if (const MetatableTypeVar* rhs = get(other)) + if (const MetatableTypeVar* subMetatable = log.getMutable(subTy)) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(lhs->table, rhs->table); - innerState.tryUnify_(lhs->metatable, rhs->metatable); + innerState.tryUnify_(subMetatable->table, superMetatable->table); + innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable); - checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other); + if (auto e = hasUnificationTooComplex(innerState.errors)) + reportError(*e); + else if (!innerState.errors.empty()) + reportError( + location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}); log.concat(std::move(innerState.log)); } - else if (TableTypeVar* rhs = getMutable(other)) + else if (TableTypeVar* subTable = log.getMutable(subTy)) { - switch (rhs->state) + switch (subTable->state) { case TableState::Free: { - tryUnify_(lhs->table, other); - rhs->boundTo = metatable; + if (FFlag::DebugLuauDeferredConstraintResolution) + { + Unifier innerState = makeChildUnifier(); + bool missingProperty = false; + + for (const auto& [propName, prop] : subTable->props) + { + if (std::optional mtPropTy = findTablePropertyRespectingMeta(superTy, propName)) + { + innerState.tryUnify(prop.type, *mtPropTy); + } + else + { + reportError(mismatchError); + missingProperty = true; + break; + } + } + + if (const TableTypeVar* superTable = log.get(log.follow(superMetatable->table))) + { + // TODO: Unify indexers. + } + + if (auto e = hasUnificationTooComplex(innerState.errors)) + reportError(*e); + else if (!innerState.errors.empty()) + reportError(TypeError{location, + TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}}); + else if (!missingProperty) + { + log.concat(std::move(innerState.log)); + log.bindTable(subTy, superTy); + } + } + else + { + tryUnify_(subTy, superMetatable->table); + log.bindTable(subTy, superTy); + } break; } @@ -1099,29 +2150,29 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse case TableState::Sealed: case TableState::Unsealed: case TableState::Generic: - errors.push_back(mismatchError); + reportError(mismatchError); } } - else if (get(other) || get(other)) + else if (log.getMutable(subTy) || log.getMutable(subTy)) { } else { - errors.push_back(mismatchError); + reportError(mismatchError); } } // Class unification is almost, but not quite symmetrical. We use the 'reversed' boolean to indicate which scenario we are evaluating. -void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) +void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) { if (reversed) std::swap(superTy, subTy); auto fail = [&]() { if (!reversed) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); else - errors.push_back(TypeError{location, TypeMismatch{subTy, superTy}}); + reportError(location, TypeMismatch{subTy, superTy, mismatchContext()}); }; const ClassTypeVar* superClass = get(superTy); @@ -1143,7 +2194,7 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) } ice("Illegal variance setting!"); } - else if (TableTypeVar* table = getMutable(subTy)) + else if (TableTypeVar* subTable = getMutable(subTy)) { /** * A free table is something whose shape we do not exactly know yet. @@ -1155,131 +2206,144 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) * * Tables that are not free are known to be actual tables. */ - if (table->state != TableState::Free) + if (subTable->state != TableState::Free) return fail(); bool ok = true; - for (const auto& [propName, prop] : table->props) + for (const auto& [propName, prop] : subTable->props) { const Property* classProp = lookupClassProp(superClass, propName); if (!classProp) { ok = false; - errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - tryUnify_(prop.type, singletonTypes.errorType); + reportError(location, UnknownProperty{superTy, propName}); } else - tryUnify_(prop.type, classProp->type); + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(classProp->type, prop.type); + + checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); + + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); + } + else + { + ok = false; + } + } } - if (table->indexer) + if (subTable->indexer) { ok = false; std::string msg = "Class " + superClass->name + " does not have an indexer"; - errors.push_back(TypeError{location, GenericError{msg}}); + reportError(location, GenericError{msg}); } if (!ok) return; - log(table); - table->boundTo = superTy; + log.bindTable(subTy, superTy); } else return fail(); } -void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer) +void Unifier::tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy) { - tryUnify_(superIndexer.indexType, subIndexer.indexType); - tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); + const NegationTypeVar* ntv = get(superTy); + if (!ntv) + ice("tryUnifyTypeWithNegation superTy must be a negation type"); + + const NormalizedType* subNorm = normalizer->normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + return reportError(location, UnificationTooComplex{}); + + // T & queue, std::unordered_set& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) +void Unifier::tryUnifyNegationWithType(TypeId subTy, TypeId superTy) +{ + const NegationTypeVar* ntv = get(subTy); + if (!ntv) + ice("tryUnifyNegationWithType subTy must be a negation type"); + + // TODO: ~T & queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { while (true) { - if (FFlag::LuauAddMissingFollow) - a = follow(a); + a = state.log.follow(a); - if (seenTypePacks.count(a)) + if (seenTypePacks.find(a)) break; seenTypePacks.insert(a); - if (FFlag::LuauAddMissingFollow) + if (state.log.getMutable(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + state.log.replace(a, Unifiable::Bound{anyTypePack}); } - else + else if (auto tp = state.log.getMutable(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - - if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; } } } -void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool reversed, int subOffset) +void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool reversed, int subOffset) { - const VariadicTypePack* lv = get(superTp); - if (!lv) + const VariadicTypePack* superVariadic = log.getMutable(superTp); + + if (!superVariadic) ice("passed non-variadic pack to tryUnifyVariadics"); - if (const VariadicTypePack* rv = get(subTp)) - tryUnify_(reversed ? rv->ty : lv->ty, reversed ? lv->ty : rv->ty); - else if (get(subTp)) + if (const VariadicTypePack* subVariadic = FFlag::LuauTxnLogTypePackIterator ? log.get(subTp) : get(subTp)) { - TypePackIterator rIter = begin(subTp); - TypePackIterator rEnd = end(subTp); + tryUnify_(reversed ? superVariadic->ty : subVariadic->ty, reversed ? subVariadic->ty : superVariadic->ty); + } + else if (FFlag::LuauTxnLogTypePackIterator ? log.get(subTp) : get(subTp)) + { + TypePackIterator subIter = begin(subTp, &log); + TypePackIterator subEnd = end(subTp); - std::advance(rIter, subOffset); + std::advance(subIter, subOffset); - while (rIter != rEnd) + while (subIter != subEnd) { - tryUnify_(reversed ? *rIter : lv->ty, reversed ? lv->ty : *rIter); - ++rIter; + tryUnify_(reversed ? superVariadic->ty : *subIter, reversed ? *subIter : superVariadic->ty); + ++subIter; } - if (std::optional maybeTail = rIter.tail()) + if (std::optional maybeTail = subIter.tail()) { TypePackId tail = follow(*maybeTail); if (get(tail)) { - log(tail); - *asMutable(tail) = BoundTypePack{superTp}; + log.replace(tail, BoundTypePack(superTp)); } else if (const VariadicTypePack* vtp = get(tail)) { - tryUnify_(lv->ty, vtp->ty); + tryUnify_(vtp->ty, superVariadic->ty); } else if (get(tail)) { - errors.push_back(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); + reportError(location, GenericError{"Cannot unify variadic and generic packs"}); } else if (get(tail)) { @@ -1293,34 +2357,38 @@ void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool rever } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify variadic packs"}}); + reportError(location, GenericError{"Failed to unify variadic packs"}); } } -static void tryUnifyWithAny( - std::vector& queue, Unifier& state, std::unordered_set& seenTypePacks, TypeId anyType, TypePackId anyTypePack) +static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, + const TypeArena* typeArena, TypeId anyType, TypePackId anyTypePack) { - std::unordered_set seen; - while (!queue.empty()) { - TypeId ty = follow(queue.back()); + TypeId ty = state.log.follow(queue.back()); queue.pop_back(); - if (seen.count(ty)) + + // Types from other modules don't have free types + if (ty->owningArena != typeArena) continue; + + if (seen.find(ty)) + continue; + seen.insert(ty); - if (get(ty)) + if (state.log.getMutable(ty)) { - state.log(ty); - *asMutable(ty) = BoundTypeVar{anyType}; + // TODO: Only bind if the anyType isn't any, unknown, or error (?) + state.log.replace(ty, BoundTypeVar{anyType}); } - else if (auto fun = get(ty)) + else if (auto fun = state.log.getMutable(ty)) { queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retTypes, anyTypePack); } - else if (auto table = get(ty)) + else if (auto table = state.log.getMutable(ty)) { for (const auto& [_name, prop] : table->props) queue.push_back(prop.type); @@ -1331,18 +2399,18 @@ static void tryUnifyWithAny( queue.push_back(table->indexer->indexResultType); } } - else if (auto mt = get(ty)) + else if (auto mt = state.log.getMutable(ty)) { queue.push_back(mt->table); queue.push_back(mt->metatable); } - else if (get(ty)) + else if (state.log.getMutable(ty)) { // ClassTypeVars never contain free typevars. } - else if (auto union_ = get(ty)) + else if (auto union_ = state.log.getMutable(ty)) queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = get(ty)) + else if (auto intersection = state.log.getMutable(ty)) queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); else { @@ -1350,205 +2418,202 @@ static void tryUnifyWithAny( } } -void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) +void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) { - LUAU_ASSERT(get(any) || get(any)); + LUAU_ASSERT(get(anyTy) || get(anyTy) || get(anyTy) || get(anyTy)); - const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); + // These types are not visited in general loop below + if (get(subTy) || get(subTy) || get(subTy)) + return; - const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + TypePackId anyTp; + if (FFlag::LuauUnknownAndNeverType) + anyTp = types->addTypePack(TypePackVar{VariadicTypePack{anyTy}}); + else + { + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes->anyType}}); + anyTp = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + } - std::unordered_set seenTypePacks; - std::vector queue = {ty}; + std::vector queue = {subTy}; - Luau::tryUnifyWithAny(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); + + Luau::tryUnifyWithAny( + queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, FFlag::LuauUnknownAndNeverType ? anyTy : singletonTypes->anyType, anyTp); } -void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) +void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) { - LUAU_ASSERT(get(any)); + LUAU_ASSERT(get(anyTp)); - const TypeId anyTy = singletonTypes.errorType; + const TypeId anyTy = singletonTypes->errorRecoveryType(); - std::unordered_set seenTypePacks; std::vector queue; - queueTypePack(queue, seenTypePacks, *this, ty, any); + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, seenTypePacks, anyTy, any); + queueTypePack(queue, sharedState.tempSeenTp, *this, subTy, anyTp); + + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, anyTy, anyTp); } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) { - return Luau::findTablePropertyRespectingMeta(errors, globalScope, lhsType, name, location); + return Luau::findTablePropertyRespectingMeta(singletonTypes, errors, lhsType, name, location); } -std::optional Unifier::findMetatableEntry(TypeId type, std::string entry) +TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) { - type = follow(type); - - if (!FFlag::LuauStringMetatable) - { - if (const PrimitiveTypeVar* primType = get(type)) - { - if (primType->type != PrimitiveTypeVar::String || "__index" != entry) - return std::nullopt; - - auto found = globalScope->bindings.find(AstName{"string"}); - if (found == globalScope->bindings.end()) - return std::nullopt; - else - return found->second.typeId; - } - } - - std::optional metatable = getMetatable(type); - if (!metatable) - return std::nullopt; - - TypeId unwrapped = follow(*metatable); - - if (get(unwrapped)) - return singletonTypes.anyType; - - const TableTypeVar* mtt = getTableType(unwrapped); - if (!mtt) - { - errors.push_back(TypeError{location, GenericError{"Metatable was not a table."}}); - return std::nullopt; - } - - auto it = mtt->props.find(entry); - if (it != mtt->props.end()) - return it->second.type; - else - return std::nullopt; + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + TxnLog result; + for (TxnLog& log : logs) + result.concatAsIntersections(std::move(log), NotNull{types}); + return result; } -void Unifier::occursCheck(TypeId needle, TypeId haystack) +TxnLog Unifier::combineLogsIntoUnion(std::vector logs) { - std::unordered_set seen; - return occursCheck(seen, needle, haystack); + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + TxnLog result; + for (TxnLog& log : logs) + result.concatAsUnion(std::move(log), NotNull{types}); + return result; } -void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack) +bool Unifier::occursCheck(TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + sharedState.tempSeenTy.clear(); - needle = follow(needle); - haystack = follow(haystack); + return occursCheck(sharedState.tempSeenTy, needle, haystack); +} - if (seen.end() != seen.find(haystack)) - return; +bool Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) +{ + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + + bool occurrence = false; + + auto check = [&](TypeId tv) { + if (occursCheck(seen, needle, tv)) + occurrence = true; + }; + + needle = log.follow(needle); + haystack = log.follow(haystack); + + if (seen.find(haystack)) + return false; seen.insert(haystack); - if (get(needle)) - return; + if (log.getMutable(needle)) + return false; - if (!get(needle)) + if (!log.getMutable(needle)) ice("Expected needle to be free"); if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); - log(needle); - *asMutable(needle) = ErrorTypeVar{}; - return; + reportError(location, OccursCheckFailed{}); + log.replace(needle, *singletonTypes->errorRecoveryType()); + + return true; } - auto check = [&](TypeId tv) { - occursCheck(seen, needle, tv); - }; - - if (get(haystack)) - return; - else if (auto a = get(haystack)) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (TypeId ty : a->argTypes) - check(ty); - - for (TypeId ty : a->retType) - check(ty); - } - } - else if (auto a = get(haystack)) + if (log.getMutable(haystack)) + return false; + else if (auto a = log.getMutable(haystack)) { for (TypeId ty : a->options) check(ty); } - else if (auto a = get(haystack)) + else if (auto a = log.getMutable(haystack)) { for (TypeId ty : a->parts) check(ty); } + + return occurrence; } -void Unifier::occursCheck(TypePackId needle, TypePackId haystack) +bool Unifier::occursCheck(TypePackId needle, TypePackId haystack) { - std::unordered_set seen; - return occursCheck(seen, needle, haystack); + sharedState.tempSeenTp.clear(); + + return occursCheck(sharedState.tempSeenTp, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack) +bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) { - needle = follow(needle); - haystack = follow(haystack); + needle = log.follow(needle); + haystack = log.follow(haystack); - if (seen.find(haystack) != seen.end()) - return; + if (seen.find(haystack)) + return false; seen.insert(haystack); - if (get(needle)) - return; + if (log.getMutable(needle)) + return false; - if (!get(needle)) + if (!log.getMutable(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); - while (!get(haystack)) + while (!log.getMutable(haystack)) { if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); - log(needle); - *asMutable(needle) = ErrorTypeVar{}; - return; + reportError(location, OccursCheckFailed{}); + log.replace(needle, *singletonTypes->errorRecoveryTypePack()); + + return true; } - if (auto a = get(haystack)) + if (auto a = get(haystack); a && a->tail) { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (const auto& ty : a->head) - { - if (auto f = get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) - { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); - } - } - } - - if (a->tail) - { - haystack = follow(*a->tail); - continue; - } + haystack = log.follow(*a->tail); + continue; } + break; } + + return false; } Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters}; + Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; + u.normalize = normalize; + u.useScopes = useScopes; + return u; } +// A utility function that appends the given error to the unifier's error log. +// This allows setting a breakpoint wherever the unifier reports an error. +// +// Note: report error accepts its arguments by value intentionally to reduce the stack usage of functions which call `reportError`. +void Unifier::reportError(Location location, TypeErrorData data) +{ + errors.emplace_back(std::move(location), std::move(data)); +} + +// A utility function that appends the given error to the unifier's error log. +// This allows setting a breakpoint wherever the unifier reports an error. +// +// Note: to conserve stack space in calling functions it is generally preferred to call `Unifier::reportError(Location location, TypeErrorData data)` +// instead of this method. +void Unifier::reportError(TypeError err) +{ + errors.push_back(std::move(err)); +} + + bool Unifier::isNonstrictMode() const { return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck); @@ -1557,19 +2622,28 @@ bool Unifier::isNonstrictMode() const void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType) { if (auto e = hasUnificationTooComplex(innerErrors)) - errors.push_back(*e); + reportError(*e); else if (!innerErrors.empty()) - errors.push_back(TypeError{location, TypeMismatch{wantedType, givenType}}); + reportError(location, TypeMismatch{wantedType, givenType, mismatchContext()}); +} + +void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) +{ + if (auto e = hasUnificationTooComplex(innerErrors)) + reportError(*e); + else if (!innerErrors.empty()) + reportError(TypeError{location, + TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front(), mismatchContext()}}); } void Unifier::ice(const std::string& message, const Location& location) { - iceHandler->ice(message, location); + sharedState.iceHandler->ice(message, location); } void Unifier::ice(const std::string& message) { - iceHandler->ice(message); + sharedState.iceHandler->ice(message); } } // namespace Luau diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index df38cfec..aa87d9e8 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -5,6 +5,7 @@ #include #include +#include #include @@ -134,6 +135,10 @@ public: { return visit((class AstExpr*)node); } + virtual bool visit(class AstExprInterpString* node) + { + return visit((class AstExpr*)node); + } virtual bool visit(class AstExprError* node) { return visit((class AstExpr*)node); @@ -255,6 +260,14 @@ public: { return visit((class AstType*)node); } + virtual bool visit(class AstTypeSingletonBool* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeSingletonString* node) + { + return visit((class AstType*)node); + } virtual bool visit(class AstTypeError* node) { return visit((class AstType*)node); @@ -264,6 +277,10 @@ public: { return false; } + virtual bool visit(class AstTypePackExplicit* node) + { + return visit((class AstTypePack*)node); + } virtual bool visit(class AstTypePackVariadic* node) { return visit((class AstTypePack*)node); @@ -301,7 +318,7 @@ template struct AstArray { T* data; - std::size_t size; + size_t size; const T* begin() const { @@ -322,6 +339,20 @@ struct AstTypeList using AstArgumentName = std::pair; // TODO: remove and replace when we get a common struct for this pair instead of AstName +struct AstGenericType +{ + AstName name; + Location location; + AstType* defaultValue = nullptr; +}; + +struct AstGenericTypePack +{ + AstName name; + Location location; + AstTypePack* defaultValue = nullptr; +}; + extern int gAstRttiIndex; template @@ -448,16 +479,26 @@ public: bool value; }; +enum class ConstantNumberParseResult +{ + Ok, + Malformed, + BinOverflow, + HexOverflow, + DoublePrefix, +}; + class AstExprConstantNumber : public AstExpr { public: LUAU_RTTI(AstExprConstantNumber) - AstExprConstantNumber(const Location& location, double value); + AstExprConstantNumber(const Location& location, double value, ConstantNumberParseResult parseResult = ConstantNumberParseResult::Ok); void visit(AstVisitor* visitor) override; double value; + ConstantNumberParseResult parseResult; }; class AstExprConstantString : public AstExpr @@ -557,19 +598,18 @@ class AstExprFunction : public AstExpr public: LUAU_RTTI(AstExprFunction) - AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, - const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, - std::optional returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, - std::optional argLocation = std::nullopt); + AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, + bool hasEnd = false, const std::optional& argLocation = std::nullopt); void visit(AstVisitor* visitor) override; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstLocal* self; AstArray args; - bool hasReturnAnnotation; - AstTypeList returnAnnotation; + std::optional returnAnnotation; bool vararg = false; Location varargLocation; AstTypePack* varargAnnotation; @@ -697,6 +737,22 @@ public: AstExpr* falseExpr; }; +class AstExprInterpString : public AstExpr +{ +public: + LUAU_RTTI(AstExprInterpString) + + AstExprInterpString(const Location& location, const AstArray>& strings, const AstArray& expressions); + + void visit(AstVisitor* visitor) override; + + /// An interpolated string such as `foo{bar}baz` is represented as + /// an array of strings for "foo" and "bar", and an array of expressions for "baz". + /// `strings` will always have one more element than `expressions`. + AstArray> strings; + AstArray expressions; +}; + class AstStatBlock : public AstStat { public: @@ -714,7 +770,7 @@ class AstStatIf : public AstStat public: LUAU_RTTI(AstStatIf) - AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, bool hasThen, const Location& thenLocation, + AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, const std::optional& thenLocation, const std::optional& elseLocation, bool hasEnd); void visit(AstVisitor* visitor) override; @@ -723,12 +779,10 @@ public: AstStatBlock* thenbody; AstStat* elsebody; - bool hasThen = false; - Location thenLocation; + std::optional thenLocation; // Active for 'elseif' as well - bool hasElse = false; - Location elseLocation; + std::optional elseLocation; bool hasEnd = false; }; @@ -823,8 +877,7 @@ public: AstArray vars; AstArray values; - bool hasEqualsSign = false; - Location equalsSignLocation; + std::optional equalsSignLocation; }; class AstStatFor : public AstStat @@ -930,12 +983,14 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported); + AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported); void visit(AstVisitor* visitor) override; AstName name; - AstArray generics; + AstArray generics; + AstArray genericPacks; AstType* type; bool exported; }; @@ -958,14 +1013,15 @@ class AstStatDeclareFunction : public AstStat public: LUAU_RTTI(AstStatDeclareFunction) - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes); + AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, + const AstTypeList& retTypes); void visit(AstVisitor* visitor) override; AstName name; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList params; AstArray paramNames; AstTypeList retTypes; @@ -1007,19 +1063,27 @@ public: } }; +// Don't have Luau::Variant available, it's a bit of an overhead, but a plain struct is nice to use +struct AstTypeOrPack +{ + AstType* type = nullptr; + AstTypePack* typePack = nullptr; +}; + class AstTypeReference : public AstType { public: LUAU_RTTI(AstTypeReference) - AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics = {}); + AstTypeReference(const Location& location, std::optional prefix, AstName name, bool hasParameterList = false, + const AstArray& parameters = {}); void visit(AstVisitor* visitor) override; - bool hasPrefix; - AstName prefix; + bool hasParameterList; + std::optional prefix; AstName name; - AstArray generics; + AstArray parameters; }; struct AstTableProp @@ -1054,13 +1118,13 @@ class AstTypeFunction : public AstType public: LUAU_RTTI(AstTypeFunction) - AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, - const AstArray>& argNames, const AstTypeList& returnTypes); + AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes); void visit(AstVisitor* visitor) override; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList argTypes; AstArray> argNames; AstTypeList returnTypes; @@ -1143,6 +1207,30 @@ public: unsigned messageIndex; }; +class AstTypeSingletonBool : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonBool) + + AstTypeSingletonBool(const Location& location, bool value); + + void visit(AstVisitor* visitor) override; + + bool value; +}; + +class AstTypeSingletonString : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonString) + + AstTypeSingletonString(const Location& location, const AstArray& value); + + void visit(AstVisitor* visitor) override; + + const AstArray value; +}; + class AstTypePack : public AstNode { public: @@ -1152,6 +1240,18 @@ public: } }; +class AstTypePackExplicit : public AstTypePack +{ +public: + LUAU_RTTI(AstTypePackExplicit) + + AstTypePackExplicit(const Location& location, AstTypeList typeList); + + void visit(AstVisitor* visitor) override; + + AstTypeList typeList; +}; + class AstTypePackVariadic : public AstTypePack { public: @@ -1177,6 +1277,16 @@ public: }; AstName getIdentifier(AstExpr*); +Location getLocation(const AstTypeList& typeList); + +template // AstNode, AstExpr, AstLocal, etc +Location getLocation(AstArray array) +{ + if (0 == array.size) + return {}; + + return Location{array.data[0]->location.begin, array.data[array.size - 1]->location.end}; +} #undef LUAU_RTTI @@ -1191,7 +1301,8 @@ struct hash size_t operator()(const Luau::AstName& value) const { // note: since operator== uses pointer identity, hashing function uses it as well - return value.value ? std::hash()(value.value) : 0; + // the hasher is the same as DenseHashPointer (DenseHash.h) + return (uintptr_t(value.value) >> 4) ^ (uintptr_t(value.value) >> 9); } }; diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index 02924e88..b222feb0 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -5,17 +5,12 @@ #include #include -#include #include #include namespace Luau { -// Internal implementation of DenseHashSet and DenseHashMap -namespace detail -{ - struct DenseHashPointer { size_t operator()(const void* key) const @@ -24,6 +19,10 @@ struct DenseHashPointer } }; +// Internal implementation of DenseHashSet and DenseHashMap +namespace detail +{ + template using DenseHashDefault = std::conditional_t, DenseHashPointer, std::hash>; @@ -32,32 +31,128 @@ class DenseHashTable { public: class const_iterator; + class iterator; DenseHashTable(const Key& empty_key, size_t buckets = 0) - : count(0) + : data(nullptr) + , capacity(0) + , count(0) , empty_key(empty_key) { + // validate that equality operator is at least somewhat functional + LUAU_ASSERT(eq(empty_key, empty_key)); // buckets has to be power-of-two or zero LUAU_ASSERT((buckets & (buckets - 1)) == 0); - // don't move this to initializer list! this works around an MSVC codegen issue on AMD CPUs: - // https://developercommunity.visualstudio.com/t/stdvector-constructor-from-size-t-is-25-times-slow/1546547 if (buckets) - data.resize(buckets, ItemInterface::create(empty_key)); + { + data = static_cast(::operator new(sizeof(Item) * buckets)); + capacity = buckets; + + ItemInterface::fill(data, buckets, empty_key); + } + } + + ~DenseHashTable() + { + if (data) + destroy(); + } + + DenseHashTable(const DenseHashTable& other) + : data(nullptr) + , capacity(0) + , count(other.count) + , empty_key(other.empty_key) + { + if (other.capacity) + { + data = static_cast(::operator new(sizeof(Item) * other.capacity)); + + for (size_t i = 0; i < other.capacity; ++i) + { + new (&data[i]) Item(other.data[i]); + capacity = i + 1; // if Item copy throws, capacity will note the number of initialized objects for destroy() to clean up + } + } + } + + DenseHashTable(DenseHashTable&& other) + : data(other.data) + , capacity(other.capacity) + , count(other.count) + , empty_key(other.empty_key) + { + other.data = nullptr; + other.capacity = 0; + other.count = 0; + } + + DenseHashTable& operator=(DenseHashTable&& other) + { + if (this != &other) + { + if (data) + destroy(); + + data = other.data; + capacity = other.capacity; + count = other.count; + empty_key = other.empty_key; + + other.data = nullptr; + other.capacity = 0; + other.count = 0; + } + + return *this; + } + + DenseHashTable& operator=(const DenseHashTable& other) + { + if (this != &other) + { + DenseHashTable copy(other); + *this = std::move(copy); + } + + return *this; } void clear() { - data.clear(); + if (count == 0) + return; + + if (capacity > 32) + { + destroy(); + } + else + { + ItemInterface::destroy(data, capacity); + ItemInterface::fill(data, capacity, empty_key); + } + count = 0; } + void destroy() + { + ItemInterface::destroy(data, capacity); + + ::operator delete(data); + data = nullptr; + + capacity = 0; + } + Item* insert_unsafe(const Key& key) { // It is invalid to insert empty_key into the table since it acts as a "entry does not exist" marker LUAU_ASSERT(!eq(key, empty_key)); - size_t hashmod = data.size() - 1; + size_t hashmod = capacity - 1; size_t bucket = hasher(key) & hashmod; for (size_t probe = 0; probe <= hashmod; ++probe) @@ -89,12 +184,12 @@ public: const Item* find(const Key& key) const { - if (data.empty()) + if (count == 0) return 0; if (eq(key, empty_key)) return 0; - size_t hashmod = data.size() - 1; + size_t hashmod = capacity - 1; size_t bucket = hasher(key) & hashmod; for (size_t probe = 0; probe <= hashmod; ++probe) @@ -120,32 +215,30 @@ public: void rehash() { - size_t newsize = data.empty() ? 16 : data.size() * 2; - - if (data.empty() && data.capacity() >= newsize) - { - LUAU_ASSERT(count == 0); - data.resize(newsize, ItemInterface::create(empty_key)); - return; - } + size_t newsize = capacity == 0 ? 16 : capacity * 2; DenseHashTable newtable(empty_key, newsize); - for (size_t i = 0; i < data.size(); ++i) + for (size_t i = 0; i < capacity; ++i) { const Key& key = ItemInterface::getKey(data[i]); if (!eq(key, empty_key)) - *newtable.insert_unsafe(key) = data[i]; + { + Item* item = newtable.insert_unsafe(key); + *item = std::move(data[i]); + } } LUAU_ASSERT(count == newtable.count); - data.swap(newtable.data); + + std::swap(data, newtable.data); + std::swap(capacity, newtable.capacity); } void rehash_if_full() { - if (count >= data.size() * 3 / 4) + if (count >= capacity * 3 / 4) { rehash(); } @@ -155,7 +248,7 @@ public: { size_t start = 0; - while (start < data.size() && eq(ItemInterface::getKey(data[start]), empty_key)) + while (start < capacity && eq(ItemInterface::getKey(data[start]), empty_key)) start++; return const_iterator(this, start); @@ -163,7 +256,22 @@ public: const_iterator end() const { - return const_iterator(this, data.size()); + return const_iterator(this, capacity); + } + + iterator begin() + { + size_t start = 0; + + while (start < capacity && eq(ItemInterface::getKey(data[start]), empty_key)) + start++; + + return iterator(this, start); + } + + iterator end() + { + return iterator(this, capacity); } size_t size() const @@ -208,7 +316,7 @@ public: const_iterator& operator++() { - size_t size = set->data.size(); + size_t size = set->capacity; do { @@ -230,8 +338,68 @@ public: size_t index; }; + class iterator + { + public: + iterator() + : set(0) + , index(0) + { + } + + iterator(DenseHashTable* set, size_t index) + : set(set) + , index(index) + { + } + + MutableItem& operator*() const + { + return *reinterpret_cast(&set->data[index]); + } + + MutableItem* operator->() const + { + return reinterpret_cast(&set->data[index]); + } + + bool operator==(const iterator& other) const + { + return set == other.set && index == other.index; + } + + bool operator!=(const iterator& other) const + { + return set != other.set || index != other.index; + } + + iterator& operator++() + { + size_t size = set->capacity; + + do + { + index++; + } while (index < size && set->eq(ItemInterface::getKey(set->data[index]), set->empty_key)); + + return *this; + } + + iterator operator++(int) + { + iterator res = *this; + ++*this; + return res; + } + + private: + DenseHashTable* set; + size_t index; + }; + private: - std::vector data; + Item* data; + size_t capacity; size_t count; Key empty_key; Hash hasher; @@ -251,9 +419,16 @@ struct ItemInterfaceSet item = key; } - static Key create(const Key& key) + static void fill(Key* data, size_t count, const Key& key) { - return key; + for (size_t i = 0; i < count; ++i) + new (&data[i]) Key(key); + } + + static void destroy(Key* data, size_t count) + { + for (size_t i = 0; i < count; ++i) + data[i].~Key(); } }; @@ -270,9 +445,22 @@ struct ItemInterfaceMap item.first = key; } - static std::pair create(const Key& key) + static void fill(std::pair* data, size_t count, const Key& key) { - return std::pair(key, Value()); + for (size_t i = 0; i < count; ++i) + { + new (&data[i].first) Key(key); + new (&data[i].second) Value(); + } + } + + static void destroy(std::pair* data, size_t count) + { + for (size_t i = 0; i < count; ++i) + { + data[i].first.~Key(); + data[i].second.~Value(); + } } }; @@ -287,6 +475,7 @@ class DenseHashSet public: typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::iterator iterator; DenseHashSet(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) @@ -333,6 +522,16 @@ public: { return impl.end(); } + + iterator begin() + { + return impl.begin(); + } + + iterator end() + { + return impl.end(); + } }; // This is a faster alternative of unordered_map, but it does not implement the same interface (i.e. it does not support erasing and has @@ -345,6 +544,7 @@ class DenseHashMap public: typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::iterator iterator; DenseHashMap(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) @@ -398,10 +598,21 @@ public: { return impl.begin(); } + const_iterator end() const { return impl.end(); } + + iterator begin() + { + return impl.begin(); + } + + iterator end() + { + return impl.end(); + } }; } // namespace Luau diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index 460ef056..929402b3 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -6,6 +6,8 @@ #include "Luau/DenseHash.h" #include "Luau/Common.h" +#include + namespace Luau { @@ -61,6 +63,12 @@ struct Lexeme SkinnyArrow, DoubleColon, + InterpStringBegin, + InterpStringMid, + InterpStringEnd, + // An interpolated string with no expressions (like `x`) + InterpStringSimple, + AddAssign, SubAssign, MulAssign, @@ -80,6 +88,8 @@ struct Lexeme BrokenString, BrokenComment, BrokenUnicode, + BrokenInterpDoubleBrace, + Error, Reserved_BEGIN, @@ -173,7 +183,7 @@ public: } const Lexeme& next(); - const Lexeme& next(bool skipComments); + const Lexeme& next(bool skipComments, bool updatePrevLocation); void nextline(); Lexeme lookahead(); @@ -208,6 +218,11 @@ private: Lexeme readLongString(const Position& start, int sep, Lexeme::Type ok, Lexeme::Type broken); Lexeme readQuotedString(); + Lexeme readInterpolatedStringBegin(); + Lexeme readInterpolatedStringSection(Position start, Lexeme::Type formatType, Lexeme::Type endType); + + void readBackslashInString(); + std::pair readName(); Lexeme readNumber(const Position& start, unsigned int startOffset); @@ -231,6 +246,19 @@ private: bool skipComments; bool readNames; + + enum class BraceType + { + InterpolatedString, + Normal + }; + + std::vector braceStack; }; +inline bool isSpace(char ch) +{ + return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || ch == '\v' || ch == '\f'; +} + } // namespace Luau diff --git a/Ast/include/Luau/Location.h b/Ast/include/Luau/Location.h index d3c0a462..dbe36bec 100644 --- a/Ast/include/Luau/Location.h +++ b/Ast/include/Luau/Location.h @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include - namespace Luau { @@ -10,100 +8,36 @@ struct Position { unsigned int line, column; - Position(unsigned int line, unsigned int column) - : line(line) - , column(column) - { - } + Position(unsigned int line, unsigned int column); - bool operator==(const Position& rhs) const - { - return this->column == rhs.column && this->line == rhs.line; - } - bool operator!=(const Position& rhs) const - { - return !(*this == rhs); - } + bool operator==(const Position& rhs) const; + bool operator!=(const Position& rhs) const; + bool operator<(const Position& rhs) const; + bool operator>(const Position& rhs) const; + bool operator<=(const Position& rhs) const; + bool operator>=(const Position& rhs) const; - bool operator<(const Position& rhs) const - { - if (line == rhs.line) - return column < rhs.column; - else - return line < rhs.line; - } - - bool operator>(const Position& rhs) const - { - if (line == rhs.line) - return column > rhs.column; - else - return line > rhs.line; - } - - bool operator<=(const Position& rhs) const - { - return *this == rhs || *this < rhs; - } - - bool operator>=(const Position& rhs) const - { - return *this == rhs || *this > rhs; - } + void shift(const Position& start, const Position& oldEnd, const Position& newEnd); }; struct Location { Position begin, end; - Location() - : begin(0, 0) - , end(0, 0) - { - } + Location(); + Location(const Position& begin, const Position& end); + Location(const Position& begin, unsigned int length); + Location(const Location& begin, const Location& end); - Location(const Position& begin, const Position& end) - : begin(begin) - , end(end) - { - } + bool operator==(const Location& rhs) const; + bool operator!=(const Location& rhs) const; - Location(const Position& begin, unsigned int length) - : begin(begin) - , end(begin.line, begin.column + length) - { - } - - Location(const Location& begin, const Location& end) - : begin(begin.begin) - , end(end.end) - { - } - - bool operator==(const Location& rhs) const - { - return this->begin == rhs.begin && this->end == rhs.end; - } - bool operator!=(const Location& rhs) const - { - return !(*this == rhs); - } - - bool encloses(const Location& l) const - { - return begin <= l.begin && end >= l.end; - } - bool contains(const Position& p) const - { - return begin <= p && p < end; - } - bool containsClosed(const Position& p) const - { - return begin <= p && p <= end; - } + bool encloses(const Location& l) const; + bool overlaps(const Location& l) const; + bool contains(const Position& p) const; + bool containsClosed(const Position& p) const; + void extend(const Location& other); + void shift(const Position& start, const Position& oldEnd, const Position& newEnd); }; -std::string toString(const Position& position); -std::string toString(const Location& location); - } // namespace Luau diff --git a/Ast/include/Luau/ParseResult.h b/Ast/include/Luau/ParseResult.h new file mode 100644 index 00000000..9c0a9527 --- /dev/null +++ b/Ast/include/Luau/ParseResult.h @@ -0,0 +1,71 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Location.h" +#include "Luau/Lexer.h" +#include "Luau/StringUtils.h" + +namespace Luau +{ + +class AstStatBlock; + +class ParseError : public std::exception +{ +public: + ParseError(const Location& location, const std::string& message); + + virtual const char* what() const throw(); + + const Location& getLocation() const; + const std::string& getMessage() const; + + static LUAU_NORETURN void raise(const Location& location, const char* format, ...) LUAU_PRINTF_ATTR(2, 3); + +private: + Location location; + std::string message; +}; + +class ParseErrors : public std::exception +{ +public: + ParseErrors(std::vector errors); + + virtual const char* what() const throw(); + + const std::vector& getErrors() const; + +private: + std::vector errors; + std::string message; +}; + +struct HotComment +{ + bool header; + Location location; + std::string content; +}; + +struct Comment +{ + Lexeme::Type type; // Comment, BlockComment, or BrokenComment + Location location; +}; + +struct ParseResult +{ + AstStatBlock* root; + size_t lines = 0; + + std::vector hotcomments; + std::vector errors; + + std::vector commentLocations; +}; + +static constexpr const char* kParseNameError = "%error-id%"; + +} // namespace Luau diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index e6ebd503..8b7eb73c 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -4,47 +4,18 @@ #include "Luau/Ast.h" #include "Luau/Lexer.h" #include "Luau/ParseOptions.h" +#include "Luau/ParseResult.h" #include "Luau/StringUtils.h" #include "Luau/DenseHash.h" #include "Luau/Common.h" #include #include +#include namespace Luau { -class ParseError : public std::exception -{ -public: - ParseError(const Location& location, const std::string& message); - - virtual const char* what() const throw(); - - const Location& getLocation() const; - const std::string& getMessage() const; - - static LUAU_NORETURN void raise(const Location& location, const char* format, ...) LUAU_PRINTF_ATTR(2, 3); - -private: - Location location; - std::string message; -}; - -class ParseErrors : public std::exception -{ -public: - ParseErrors(std::vector errors); - - virtual const char* what() const throw(); - - const std::vector& getErrors() const; - -private: - std::vector errors; - std::string message; -}; - template class TempVector { @@ -80,34 +51,17 @@ private: size_t size_; }; -struct Comment -{ - Lexeme::Type type; // Comment, BlockComment, or BrokenComment - Location location; -}; - -struct ParseResult -{ - AstStatBlock* root; - std::vector hotcomments; - std::vector errors; - - std::vector commentLocations; -}; - class Parser { public: static ParseResult parse( const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options = ParseOptions()); - static constexpr const char* errorName = "%error-id%"; - private: struct Name; struct Binding; - Parser(const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator); + Parser(const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator, const ParseOptions& options); bool blockFollow(const Lexeme& l); @@ -156,8 +110,10 @@ private: // for namelist in explist do block end | AstStat* parseFor(); - // function funcname funcbody | // funcname ::= Name {`.' Name} [`:' Name] + AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); + + // function funcname funcbody AstStat* parseFunctionStat(); // local function Name funcbody | @@ -182,10 +138,12 @@ private: // var [`+=' | `-=' | `*=' | `/=' | `%=' | `^=' | `..='] exp AstStat* parseCompoundAssignment(AstExpr* initial, AstExprBinary::Op op); - // funcbody ::= `(' [parlist] `)' block end - // parlist ::= namelist [`,' `...'] | `...' + std::pair> prepareFunctionArguments(const Location& start, bool hasself, const TempVector& args); + + // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` TypeAnnotation] + // funcbody ::= funcbodyhead block end std::pair parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, std::optional localName); + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); // explist ::= {exp `,'} exp void parseExprList(TempVector& result); @@ -195,7 +153,7 @@ private: // bindinglist ::= (binding | `...') {`,' bindinglist} // Returns the location of the vararg ..., or std::nullopt if the function is not vararg. - std::pair, AstTypePack*> parseBindingList(TempVector& result, bool allowDot3 = false); + std::tuple parseBindingList(TempVector& result, bool allowDot3 = false); AstType* parseOptionalTypeAnnotation(); @@ -218,13 +176,14 @@ private: AstTableIndexer* parseTableIndexerAnnotation(); - AstType* parseFunctionTypeAnnotation(); - AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); + AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); AstType* parseTableTypeAnnotation(); - AstType* parseSimpleTypeAnnotation(); + AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack); + AstTypeOrPack parseTypeOrPackAnnotation(); AstType* parseTypeAnnotation(TempVector& parts, const Location& begin); AstType* parseTypeAnnotation(); @@ -263,7 +222,7 @@ private: AstExpr* parseSimpleExpr(); // args ::= `(' [explist] `)' | tableconstructor | String - AstExpr* parseFunctionArgs(AstExpr* func, bool self, const Location& selfLocation); + AstExpr* parseFunctionArgs(AstExpr* func, bool self); // tableconstructor ::= `{' [fieldlist] `}' // fieldlist ::= field {fieldsep field} [fieldsep] @@ -274,19 +233,23 @@ private: // TODO: Add grammar rules here? AstExpr* parseIfElseExpr(); + // stringinterp ::= exp { exp} + AstExpr* parseInterpString(); + // Name std::optional parseNameOpt(const char* context = nullptr); Name parseName(const char* context = nullptr); Name parseIndexName(const char* context, const Position& previous); // `<' namelist `>' - std::pair, AstArray> parseGenericTypeList(); - std::pair, AstArray> parseGenericTypeListIfFFlagParseGenericFunctions(); + std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); // `<' typeAnnotation[, ...] `>' - AstArray parseTypeParams(); + AstArray parseTypeParams(); + std::optional> parseCharArray(); AstExpr* parseString(); + AstExpr* parseNumber(); AstLocal* pushLocal(const Binding& binding); @@ -299,11 +262,24 @@ private: bool expectAndConsume(Lexeme::Type type, const char* context = nullptr); void expectAndConsumeFail(Lexeme::Type type, const char* context); - bool expectMatchAndConsume(char value, const Lexeme& begin, bool searchForMissing = false); - void expectMatchAndConsumeFail(Lexeme::Type type, const Lexeme& begin, const char* extra = nullptr); + struct MatchLexeme + { + MatchLexeme(const Lexeme& l) + : type(l.type) + , position(l.location.begin) + { + } - bool expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin); - void expectMatchEndAndConsumeFail(Lexeme::Type type, const Lexeme& begin); + Lexeme::Type type; + Position position; + }; + + bool expectMatchAndConsume(char value, const MatchLexeme& begin, bool searchForMissing = false); + void expectMatchAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin, const char* extra = nullptr); + bool expectMatchAndConsumeRecover(char value, const MatchLexeme& begin, bool searchForMissing); + + bool expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begin); + void expectMatchEndAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin); template AstArray copy(const T* data, std::size_t size); @@ -326,10 +302,19 @@ private: AstStatError* reportStatError(const Location& location, const AstArray& expressions, const AstArray& statements, const char* format, ...) LUAU_PRINTF_ATTR(5, 6); AstExprError* reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); - AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) - LUAU_PRINTF_ATTR(5, 6); + AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) + LUAU_PRINTF_ATTR(4, 5); + // `parseErrorLocation` is associated with the parser error + // `astErrorLocation` is associated with the AstTypeError created + // It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely + // define the location (possibly of zero size) where a type annotation is expected. + AstTypeError* reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) + LUAU_PRINTF_ATTR(4, 5); - const Lexeme& nextLexeme(); + AstExpr* reportFunctionArgsError(AstExpr* func, bool self); + void reportAmbiguousCallError(); + + void nextLexeme(); struct Function { @@ -385,6 +370,9 @@ private: Allocator& allocator; std::vector commentLocations; + std::vector hotcomments; + + bool hotcommentHeader = true; unsigned int recursionCounter; @@ -393,7 +381,7 @@ private: AstName nameError; AstName nameNil; - Lexeme endMismatchSuspect; + MatchLexeme endMismatchSuspect; std::vector functionStack; @@ -405,6 +393,7 @@ private: std::vector matchRecoveryStopOnToken; std::vector scratchStat; + std::vector> scratchString; std::vector scratchExpr; std::vector scratchExprAux; std::vector scratchName; @@ -413,9 +402,12 @@ private: std::vector scratchLocal; std::vector scratchTableTypeProps; std::vector scratchAnnotation; + std::vector scratchTypeOrPackAnnotation; std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; + std::vector scratchGenericTypes; + std::vector scratchGenericTypePacks; std::vector> scratchOptArgName; std::string scratchData; }; diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h index 4f7673fa..6345fde4 100644 --- a/Ast/include/Luau/StringUtils.h +++ b/Ast/include/Luau/StringUtils.h @@ -1,17 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Common.h" + #include #include #include -#if defined(__GNUC__) -#define LUAU_PRINTF_ATTR(fmt, arg) __attribute__((format(printf, fmt, arg))) -#else -#define LUAU_PRINTF_ATTR(fmt, arg) -#endif - namespace Luau { @@ -19,6 +15,7 @@ std::string format(const char* fmt, ...) LUAU_PRINTF_ATTR(1, 2); std::string vformat(const char* fmt, va_list args); void formatAppend(std::string& str, const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); +void vformatAppend(std::string& ret, const char* fmt, va_list args); std::string join(const std::vector& segments, std::string_view delimiter); std::string join(const std::vector& segments, std::string_view delimiter); @@ -34,4 +31,6 @@ bool equalsLower(std::string_view lhs, std::string_view rhs); size_t hashRange(const char* data, size_t size); +std::string escape(std::string_view s, bool escapeForInterpString = false); +bool isIdentifier(std::string_view s); } // namespace Luau diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h new file mode 100644 index 00000000..be282827 --- /dev/null +++ b/Ast/include/Luau/TimeTrace.h @@ -0,0 +1,230 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include + +#include + +LUAU_FASTFLAG(DebugLuauTimeTracing) + +namespace Luau +{ +namespace TimeTrace +{ +double getClock(); +uint32_t getClockMicroseconds(); +} // namespace TimeTrace +} // namespace Luau + +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ +struct Token +{ + const char* name; + const char* category; +}; + +enum class EventType : uint8_t +{ + Enter, + Leave, + + ArgName, + ArgValue, +}; + +struct Event +{ + EventType type; + uint16_t token; + + union + { + uint32_t microsec; // 1 hour trace limit + uint32_t dataPos; + } data; +}; + +struct GlobalContext; +struct ThreadContext; + +GlobalContext& getGlobalContext(); + +uint16_t createToken(GlobalContext& context, const char* name, const char* category); +uint32_t createThread(GlobalContext& context, ThreadContext* threadContext); +void releaseThread(GlobalContext& context, ThreadContext* threadContext); +void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector& events, const std::vector& data); + +struct ThreadContext +{ + ThreadContext() + : globalContext(getGlobalContext()) + { + threadId = createThread(globalContext, this); + } + + ~ThreadContext() + { + if (!events.empty()) + flushEvents(); + + releaseThread(globalContext, this); + } + + void flushEvents() + { + static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace"); + + events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}}); + + TimeTrace::flushEvents(globalContext, threadId, events, data); + + events.clear(); + data.clear(); + + events.push_back({EventType::Leave, 0, {getClockMicroseconds()}}); + } + + void eventEnter(uint16_t token) + { + eventEnter(token, getClockMicroseconds()); + } + + void eventEnter(uint16_t token, uint32_t microsec) + { + events.push_back({EventType::Enter, token, {microsec}}); + } + + void eventLeave() + { + eventLeave(getClockMicroseconds()); + } + + void eventLeave(uint32_t microsec) + { + events.push_back({EventType::Leave, 0, {microsec}}); + + if (events.size() > kEventFlushLimit) + flushEvents(); + } + + void eventArgument(const char* name, const char* value) + { + uint32_t pos = uint32_t(data.size()); + data.insert(data.end(), name, name + strlen(name) + 1); + events.push_back({EventType::ArgName, 0, {pos}}); + + pos = uint32_t(data.size()); + data.insert(data.end(), value, value + strlen(value) + 1); + events.push_back({EventType::ArgValue, 0, {pos}}); + } + + GlobalContext& globalContext; + uint32_t threadId; + std::vector events; + std::vector data; + + static constexpr size_t kEventFlushLimit = 8192; +}; + +ThreadContext& getThreadContext(); + +struct Scope +{ + explicit Scope(uint16_t token) + : context(getThreadContext()) + { + if (!FFlag::DebugLuauTimeTracing) + return; + + context.eventEnter(token); + } + + ~Scope() + { + if (!FFlag::DebugLuauTimeTracing) + return; + + context.eventLeave(); + } + + ThreadContext& context; +}; + +struct OptionalTailScope +{ + explicit OptionalTailScope(uint16_t token, uint32_t threshold) + : context(getThreadContext()) + , token(token) + , threshold(threshold) + { + if (!FFlag::DebugLuauTimeTracing) + return; + + pos = uint32_t(context.events.size()); + microsec = getClockMicroseconds(); + } + + ~OptionalTailScope() + { + if (!FFlag::DebugLuauTimeTracing) + return; + + if (pos == context.events.size()) + { + uint32_t curr = getClockMicroseconds(); + + if (curr - microsec > threshold) + { + context.eventEnter(token, microsec); + context.eventLeave(curr); + } + } + } + + ThreadContext& context; + uint16_t token; + uint32_t threshold; + uint32_t microsec; + uint32_t pos; +}; + +LUAU_NOINLINE uint16_t createScopeData(const char* name, const char* category); + +} // namespace TimeTrace +} // namespace Luau + +// Regular scope +#define LUAU_TIMETRACE_SCOPE(name, category) \ + static uint16_t lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::Scope lttScope(lttScopeStatic) + +// A scope without nested scopes that may be skipped if the time it took is less than the threshold +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ + static uint16_t lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail, microsec) + +// Extra key/value data can be added to regular scopes +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + if (FFlag::DebugLuauTimeTracing) \ + lttScope.context.eventArgument(name, value); \ + } while (false) + +#else + +#define LUAU_TIMETRACE_SCOPE(name, category) +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + } while (false) + +#endif diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index fff1537d..cbed8bae 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -50,9 +50,10 @@ void AstExprConstantBool::visit(AstVisitor* visitor) visitor->visit(this); } -AstExprConstantNumber::AstExprConstantNumber(const Location& location, double value) +AstExprConstantNumber::AstExprConstantNumber(const Location& location, double value, ConstantNumberParseResult parseResult) : AstExpr(ClassIndex(), location) , value(value) + , parseResult(parseResult) { } @@ -158,18 +159,18 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } } -AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, - const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, - std::optional returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, std::optional argLocation) +AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, + const std::optional& argLocation) : AstExpr(ClassIndex(), location) , generics(generics) , genericPacks(genericPacks) , self(self) , args(args) - , hasReturnAnnotation(returnAnnotation.has_value()) - , returnAnnotation() - , vararg(vararg.has_value()) - , varargLocation(vararg.value_or(Location())) + , returnAnnotation(returnAnnotation) + , vararg(vararg) + , varargLocation(varargLocation) , varargAnnotation(varargAnnotation) , body(body) , functionDepth(functionDepth) @@ -177,8 +178,6 @@ AstExprFunction::AstExprFunction(const Location& location, const AstArrayreturnAnnotation = *returnAnnotation; } void AstExprFunction::visit(AstVisitor* visitor) @@ -194,8 +193,8 @@ void AstExprFunction::visit(AstVisitor* visitor) if (varargAnnotation) varargAnnotation->visit(visitor); - if (hasReturnAnnotation) - visitTypeList(visitor, returnAnnotation); + if (returnAnnotation) + visitTypeList(visitor, *returnAnnotation); body->visit(visitor); } @@ -350,6 +349,22 @@ AstExprError::AstExprError(const Location& location, const AstArray& e { } +AstExprInterpString::AstExprInterpString(const Location& location, const AstArray>& strings, const AstArray& expressions) + : AstExpr(ClassIndex(), location) + , strings(strings) + , expressions(expressions) +{ +} + +void AstExprInterpString::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstExpr* expr : expressions) + expr->visit(visitor); + } +} + void AstExprError::visit(AstVisitor* visitor) { if (visitor->visit(this)) @@ -374,21 +389,16 @@ void AstStatBlock::visit(AstVisitor* visitor) } } -AstStatIf::AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, bool hasThen, - const Location& thenLocation, const std::optional& elseLocation, bool hasEnd) +AstStatIf::AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, + const std::optional& thenLocation, const std::optional& elseLocation, bool hasEnd) : AstStat(ClassIndex(), location) , condition(condition) , thenbody(thenbody) , elsebody(elsebody) - , hasThen(hasThen) , thenLocation(thenLocation) + , elseLocation(elseLocation) , hasEnd(hasEnd) { - if (bool(elseLocation)) - { - hasElse = true; - this->elseLocation = *elseLocation; - } } void AstStatIf::visit(AstVisitor* visitor) @@ -491,12 +501,8 @@ AstStatLocal::AstStatLocal( : AstStat(ClassIndex(), location) , vars(vars) , values(values) + , equalsSignLocation(equalsSignLocation) { - if (bool(equalsSignLocation)) - { - hasEqualsSign = true; - this->equalsSignLocation = *equalsSignLocation; - } } void AstStatLocal::visit(AstVisitor* visitor) @@ -641,10 +647,12 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported) : AstStat(ClassIndex(), location) , name(name) , generics(generics) + , genericPacks(genericPacks) , type(type) , exported(exported) { @@ -653,7 +661,21 @@ AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name void AstStatTypeAlias::visit(AstVisitor* visitor) { if (visitor->visit(this)) + { + for (const AstGenericType& el : generics) + { + if (el.defaultValue) + el.defaultValue->visit(visitor); + } + + for (const AstGenericTypePack& el : genericPacks) + { + if (el.defaultValue) + el.defaultValue->visit(visitor); + } + type->visit(visitor); + } } AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) @@ -669,8 +691,9 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor) type->visit(visitor); } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes) +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, + const AstTypeList& retTypes) : AstStat(ClassIndex(), location) , name(name) , generics(generics) @@ -729,12 +752,13 @@ void AstStatError::visit(AstVisitor* visitor) } } -AstTypeReference::AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics) +AstTypeReference::AstTypeReference( + const Location& location, std::optional prefix, AstName name, bool hasParameterList, const AstArray& parameters) : AstType(ClassIndex(), location) - , hasPrefix(bool(prefix)) - , prefix(prefix ? *prefix : AstName()) + , hasParameterList(hasParameterList) + , prefix(prefix) , name(name) - , generics(generics) + , parameters(parameters) { } @@ -742,8 +766,13 @@ void AstTypeReference::visit(AstVisitor* visitor) { if (visitor->visit(this)) { - for (AstType* generic : generics) - generic->visit(visitor); + for (const AstTypeOrPack& param : parameters) + { + if (param.type) + param.type->visit(visitor); + else + param.typePack->visit(visitor); + } } } @@ -769,7 +798,7 @@ void AstTypeTable::visit(AstVisitor* visitor) } } -AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, +AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) : AstType(ClassIndex(), location) , generics(generics) @@ -832,6 +861,28 @@ void AstTypeIntersection::visit(AstVisitor* visitor) } } +AstTypeSingletonBool::AstTypeSingletonBool(const Location& location, bool value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonBool::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstTypeSingletonString::AstTypeSingletonString(const Location& location, const AstArray& value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonString::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + AstTypeError::AstTypeError(const Location& location, const AstArray& types, bool isMissing, unsigned messageIndex) : AstType(ClassIndex(), location) , types(types) @@ -849,6 +900,24 @@ void AstTypeError::visit(AstVisitor* visitor) } } +AstTypePackExplicit::AstTypePackExplicit(const Location& location, AstTypeList typeList) + : AstTypePack(ClassIndex(), location) + , typeList(typeList) +{ +} + +void AstTypePackExplicit::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstType* type : typeList.types) + type->visit(visitor); + + if (typeList.tailType) + typeList.tailType->visit(visitor); + } +} + AstTypePackVariadic::AstTypePackVariadic(const Location& location, AstType* variadicType) : AstTypePack(ClassIndex(), location) , variadicType(variadicType) @@ -883,4 +952,17 @@ AstName getIdentifier(AstExpr* node) return AstName(); } +Location getLocation(const AstTypeList& typeList) +{ + Location result; + if (typeList.types.size) + { + result = Location{typeList.types.data[0]->location, typeList.types.data[typeList.types.size - 1]->location}; + } + if (typeList.tailType) + result.end = typeList.tailType->location.end; + + return result; +} + } // namespace Luau diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index a7aa24ca..66436acd 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,6 +6,8 @@ #include +LUAU_FASTFLAG(LuauInterpolatedStringBaseSupport) + namespace Luau { @@ -89,7 +91,8 @@ Lexeme::Lexeme(const Location& location, Type type, const char* data, size_t siz , length(unsigned(size)) , data(data) { - LUAU_ASSERT(type == RawString || type == QuotedString || type == Number || type == Comment || type == BlockComment); + LUAU_ASSERT(type == RawString || type == QuotedString || type == InterpStringBegin || type == InterpStringMid || type == InterpStringEnd || + type == InterpStringSimple || type == BrokenInterpDoubleBrace || type == Number || type == Comment || type == BlockComment); } Lexeme::Lexeme(const Location& location, Type type, const char* name) @@ -101,11 +104,6 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name) LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); } -static bool isComment(const Lexeme& lexeme) -{ - return lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment; -} - static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"}; @@ -165,6 +163,18 @@ std::string Lexeme::toString() const case QuotedString: return data ? format("\"%.*s\"", length, data) : "string"; + case InterpStringBegin: + return data ? format("`%.*s{", length, data) : "the beginning of an interpolated string"; + + case InterpStringMid: + return data ? format("}%.*s{", length, data) : "the middle of an interpolated string"; + + case InterpStringEnd: + return data ? format("}%.*s`", length, data) : "the end of an interpolated string"; + + case InterpStringSimple: + return data ? format("`%.*s`", length, data) : "interpolated string"; + case Number: return data ? format("'%.*s'", length, data) : "number"; @@ -180,6 +190,9 @@ std::string Lexeme::toString() const case BrokenComment: return "unfinished comment"; + case BrokenInterpDoubleBrace: + return "'{{', which is invalid (did you mean '\\{'?)"; + case BrokenUnicode: if (codepoint) { @@ -282,11 +295,6 @@ AstName AstNameTable::get(const char* name) const return getWithType(name, strlen(name)).first; } -inline bool isSpace(char ch) -{ - return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || ch == '\v' || ch == '\f'; -} - inline bool isAlpha(char ch) { // use or trick to convert to lower case and unsigned comparison to do range check @@ -357,10 +365,10 @@ void Lexer::setReadNames(bool read) const Lexeme& Lexer::next() { - return next(this->skipComments); + return next(this->skipComments, true); } -const Lexeme& Lexer::next(bool skipComments) +const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation) { // in skipComments mode we reject valid comments do @@ -369,10 +377,12 @@ const Lexeme& Lexer::next(bool skipComments) while (isSpace(peekch())) consume(); - prevLocation = lexeme.location; + if (updatePrevLocation) + prevLocation = lexeme.location; lexeme = readNext(); - } while (skipComments && isComment(lexeme)); + updatePrevLocation = false; + } while (skipComments && (lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment)); return lexeme; } @@ -523,6 +533,32 @@ Lexeme Lexer::readLongString(const Position& start, int sep, Lexeme::Type ok, Le return Lexeme(Location(start, position()), broken); } +void Lexer::readBackslashInString() +{ + LUAU_ASSERT(peekch() == '\\'); + consume(); + switch (peekch()) + { + case '\r': + consume(); + if (peekch() == '\n') + consume(); + break; + + case 0: + break; + + case 'z': + consume(); + while (isSpace(peekch())) + consume(); + break; + + default: + consume(); + } +} + Lexeme Lexer::readQuotedString() { Position start = position(); @@ -543,27 +579,7 @@ Lexeme Lexer::readQuotedString() return Lexeme(Location(start, position()), Lexeme::BrokenString); case '\\': - consume(); - switch (peekch()) - { - case '\r': - consume(); - if (peekch() == '\n') - consume(); - break; - - case 0: - break; - - case 'z': - consume(); - while (isSpace(peekch())) - consume(); - break; - - default: - consume(); - } + readBackslashInString(); break; default: @@ -576,6 +592,70 @@ Lexeme Lexer::readQuotedString() return Lexeme(Location(start, position()), Lexeme::QuotedString, &buffer[startOffset], offset - startOffset - 1); } +Lexeme Lexer::readInterpolatedStringBegin() +{ + LUAU_ASSERT(peekch() == '`'); + + Position start = position(); + consume(); + + return readInterpolatedStringSection(start, Lexeme::InterpStringBegin, Lexeme::InterpStringSimple); +} + +Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatType, Lexeme::Type endType) +{ + unsigned int startOffset = offset; + + while (peekch() != '`') + { + switch (peekch()) + { + case 0: + case '\r': + case '\n': + return Lexeme(Location(start, position()), Lexeme::BrokenString); + + case '\\': + // Allow for \u{}, which would otherwise be consumed by looking for { + if (peekch(1) == 'u' && peekch(2) == '{') + { + consume(); // backslash + consume(); // u + consume(); // { + break; + } + + readBackslashInString(); + break; + + case '{': + { + braceStack.push_back(BraceType::InterpolatedString); + + if (peekch(1) == '{') + { + Lexeme brokenDoubleBrace = + Lexeme(Location(start, position()), Lexeme::BrokenInterpDoubleBrace, &buffer[startOffset], offset - startOffset); + consume(); + consume(); + return brokenDoubleBrace; + } + + consume(); + Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset - 1); + return lexemeOutput; + } + + default: + consume(); + } + } + + consume(); + + return Lexeme(Location(start, position()), endType, &buffer[startOffset], offset - startOffset - 1); +} + Lexeme Lexer::readNumber(const Position& start, unsigned int startOffset) { LUAU_ASSERT(isDigit(peekch())); @@ -668,6 +748,36 @@ Lexeme Lexer::readNext() } } + case '{': + { + consume(); + + if (!braceStack.empty()) + braceStack.push_back(BraceType::Normal); + + return Lexeme(Location(start, 1), '{'); + } + + case '}': + { + consume(); + + if (braceStack.empty()) + { + return Lexeme(Location(start, 1), '}'); + } + + const BraceType braceStackTop = braceStack.back(); + braceStack.pop_back(); + + if (braceStackTop != BraceType::InterpolatedString) + { + return Lexeme(Location(start, 1), '}'); + } + + return readInterpolatedStringSection(position(), Lexeme::InterpStringMid, Lexeme::InterpStringEnd); + } + case '=': { consume(); @@ -724,6 +834,15 @@ Lexeme Lexer::readNext() case '\'': return readQuotedString(); + case '`': + if (FFlag::LuauInterpolatedStringBaseSupport) + return readInterpolatedStringBegin(); + else + { + consume(); + return Lexeme(Location(start, 1), '`'); + } + case '.': consume(); @@ -825,8 +944,6 @@ Lexeme Lexer::readNext() case '(': case ')': - case '{': - case '}': case ']': case ';': case ',': diff --git a/Ast/src/Location.cpp b/Ast/src/Location.cpp index d7a899ed..67c2dd4b 100644 --- a/Ast/src/Location.cpp +++ b/Ast/src/Location.cpp @@ -4,14 +4,128 @@ namespace Luau { -std::string toString(const Position& position) +Position::Position(unsigned int line, unsigned int column) + : line(line) + , column(column) { - return "{ line = " + std::to_string(position.line) + ", col = " + std::to_string(position.column) + " }"; } -std::string toString(const Location& location) +bool Position::operator==(const Position& rhs) const { - return "Location { " + toString(location.begin) + ", " + toString(location.end) + " }"; + return this->column == rhs.column && this->line == rhs.line; +} + +bool Position::operator!=(const Position& rhs) const +{ + return !(*this == rhs); +} + +bool Position::operator<(const Position& rhs) const +{ + if (line == rhs.line) + return column < rhs.column; + else + return line < rhs.line; +} + +bool Position::operator>(const Position& rhs) const +{ + if (line == rhs.line) + return column > rhs.column; + else + return line > rhs.line; +} + +bool Position::operator<=(const Position& rhs) const +{ + return *this == rhs || *this < rhs; +} + +bool Position::operator>=(const Position& rhs) const +{ + return *this == rhs || *this > rhs; +} + +void Position::shift(const Position& start, const Position& oldEnd, const Position& newEnd) +{ + if (*this >= start) + { + if (this->line > oldEnd.line) + this->line += (newEnd.line - oldEnd.line); + else + { + this->line = newEnd.line; + this->column += (newEnd.column - oldEnd.column); + } + } +} + +Location::Location() + : begin(0, 0) + , end(0, 0) +{ +} + +Location::Location(const Position& begin, const Position& end) + : begin(begin) + , end(end) +{ +} + +Location::Location(const Position& begin, unsigned int length) + : begin(begin) + , end(begin.line, begin.column + length) +{ +} + +Location::Location(const Location& begin, const Location& end) + : begin(begin.begin) + , end(end.end) +{ +} + +bool Location::operator==(const Location& rhs) const +{ + return this->begin == rhs.begin && this->end == rhs.end; +} + +bool Location::operator!=(const Location& rhs) const +{ + return !(*this == rhs); +} + +bool Location::encloses(const Location& l) const +{ + return begin <= l.begin && end >= l.end; +} + +bool Location::overlaps(const Location& l) const +{ + return (begin <= l.begin && end >= l.begin) || (begin <= l.end && end >= l.end) || (begin >= l.begin && end <= l.end); +} + +bool Location::contains(const Position& p) const +{ + return begin <= p && p < end; +} + +bool Location::containsClosed(const Position& p) const +{ + return begin <= p && p <= end; +} + +void Location::extend(const Location& other) +{ + if (other.begin < begin) + begin = other.begin; + if (other.end > end) + end = other.end; +} + +void Location::shift(const Position& start, const Position& oldEnd, const Position& newEnd) +{ + begin.shift(start, oldEnd, newEnd); + end.shift(start, oldEnd, newEnd); } } // namespace Luau diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 6672efe8..5cd5f743 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1,32 +1,40 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/TimeTrace.h" + #include +#include +#include + // Warning: If you are introducing new syntax, ensure that it is behind a separate // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsParserFix, false) -LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false) -LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) -LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) + +LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false) + +bool lua_telemetry_parsed_named_non_function_type = false; + +LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) + +LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) + +LUAU_FASTFLAGVARIABLE(LuauParserErrorsOnMissingDefaultTypePackArgument, false) + +bool lua_telemetry_parsed_out_of_range_bin_integer = false; +bool lua_telemetry_parsed_out_of_range_hex_integer = false; +bool lua_telemetry_parsed_double_prefix_hex_integer = false; + +#define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?" namespace Luau { -inline bool isSpace(char ch) -{ - return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || ch == '\v' || ch == '\f'; -} - -static bool isComment(const Lexeme& lexeme) -{ - return lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment; -} - ParseError::ParseError(const Location& location, const std::string& message) : location(location) , message(message) @@ -148,74 +156,64 @@ static bool shouldParseTypePackAnnotation(Lexer& lexer) ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options) { - Parser p(buffer, bufferSize, names, allocator); + LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser"); + + Parser p(buffer, bufferSize, names, allocator, options); try { - std::vector hotcomments; - - while (isComment(p.lexer.current()) || (FFlag::LuauCaptureBrokenCommentSpans && p.lexer.current().type == Lexeme::BrokenComment)) - { - const char* text = p.lexer.current().data; - unsigned int length = p.lexer.current().length; - - if (length && text[0] == '!') - { - unsigned int end = length; - while (end > 0 && isSpace(text[end - 1])) - --end; - - hotcomments.push_back(std::string(text + 1, text + end)); - } - - const Lexeme::Type type = p.lexer.current().type; - const Location loc = p.lexer.current().location; - - p.lexer.next(); - - if (options.captureComments) - p.commentLocations.push_back(Comment{type, loc}); - } - - p.lexer.setSkipComments(true); - - p.options = options; - AstStatBlock* root = p.parseChunk(); + size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n'); - return ParseResult{root, hotcomments, p.parseErrors, std::move(p.commentLocations)}; + return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; } catch (ParseError& err) { // when catching a fatal error, append it to the list of non-fatal errors and return p.parseErrors.push_back(err); - return ParseResult{nullptr, {}, p.parseErrors}; + return ParseResult{nullptr, 0, {}, p.parseErrors}; } } -Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator) - : lexer(buffer, bufferSize, names) +Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, const ParseOptions& options) + : options(options) + , lexer(buffer, bufferSize, names) , allocator(allocator) , recursionCounter(0) - , endMismatchSuspect(Location(), Lexeme::Eof) + , endMismatchSuspect(Lexeme(Location(), Lexeme::Eof)) , localMap(AstName()) { Function top; top.vararg = true; + functionStack.reserve(8); functionStack.push_back(top); nameSelf = names.addStatic("self"); nameNumber = names.addStatic("number"); - nameError = names.addStatic(errorName); + nameError = names.addStatic(kParseNameError); nameNil = names.getOrAdd("nil"); // nil is a reserved keyword matchRecoveryStopOnToken.assign(Lexeme::Type::Reserved_END, 0); matchRecoveryStopOnToken[Lexeme::Type::Eof] = 1; - // read first lexeme + // required for lookahead() to work across a comment boundary and for nextLexeme() to work when captureComments is false + lexer.setSkipComments(true); + + // read first lexeme (any hot comments get .header = true) + LUAU_ASSERT(hotcommentHeader); nextLexeme(); + + // all hot comments parsed after the first non-comment lexeme are special in that they don't affect type checking / linting mode + hotcommentHeader = false; + + // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays + localStack.reserve(16); + scratchStat.reserve(16); + scratchExpr.reserve(16); + scratchLocal.reserve(16); + scratchBinding.reserve(16); } bool Parser::blockFollow(const Lexeme& l) @@ -380,7 +378,9 @@ AstStat* Parser::parseIf() AstExpr* cond = parseExpr(); Lexeme matchThen = lexer.current(); - bool hasThen = expectAndConsume(Lexeme::ReservedThen, "if statement"); + std::optional thenLocation; + if (expectAndConsume(Lexeme::ReservedThen, "if statement")) + thenLocation = matchThen.location; AstStatBlock* thenbody = parseBlock(); @@ -391,23 +391,13 @@ AstStat* Parser::parseIf() if (lexer.current().type == Lexeme::ReservedElseif) { - if (FFlag::LuauIfStatementRecursionGuard) - { - unsigned int recursionCounterOld = recursionCounter; - incrementRecursionCounter("elseif"); - elseLocation = lexer.current().location; - elsebody = parseIf(); - end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - recursionCounter = recursionCounterOld; - } - else - { - elseLocation = lexer.current().location; - elsebody = parseIf(); - end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - } + unsigned int recursionCounterOld = recursionCounter; + incrementRecursionCounter("elseif"); + elseLocation = lexer.current().location; + elsebody = parseIf(); + end = elsebody->location; + hasEnd = elsebody->as()->hasEnd; + recursionCounter = recursionCounterOld; } else { @@ -428,7 +418,7 @@ AstStat* Parser::parseIf() hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchThenElse); } - return allocator.alloc(Location(start, end), cond, thenbody, elsebody, hasThen, matchThen.location, elseLocation, hasEnd); + return allocator.alloc(Location(start, end), cond, thenbody, elsebody, thenLocation, elseLocation, hasEnd); } // while exp do block end @@ -617,16 +607,11 @@ AstStat* Parser::parseFor() } } -// function funcname funcbody | // funcname ::= Name {`.' Name} [`:' Name] -AstStat* Parser::parseFunctionStat() +AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debugname) { - Location start = lexer.current().location; - - Lexeme matchFunction = lexer.current(); - nextLexeme(); - - AstName debugname = (lexer.current().type == Lexeme::Name) ? AstName(lexer.current().name) : AstName(); + if (lexer.current().type == Lexeme::Name) + debugname = AstName(lexer.current().name); // parse funcname into a chain of indexing operators AstExpr* expr = parseNameExpr("function name"); @@ -652,8 +637,6 @@ AstStat* Parser::parseFunctionStat() recursionCounter = recursionCounterOld; // finish with : - bool hasself = false; - if (lexer.current().type == ':') { Position opPosition = lexer.current().location.begin; @@ -669,9 +652,24 @@ AstStat* Parser::parseFunctionStat() hasself = true; } + return expr; +} + +// function funcname funcbody +AstStat* Parser::parseFunctionStat() +{ + Location start = lexer.current().location; + + Lexeme matchFunction = lexer.current(); + nextLexeme(); + + bool hasself = false; + AstName debugname; + AstExpr* expr = parseFunctionName(start, hasself, debugname); + matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, {}).first; + AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; @@ -700,7 +698,7 @@ AstStat* Parser::parseLocal() matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; - auto [body, var] = parseFunctionBody(false, matchFunction, name.name, name); + auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name); matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; @@ -763,20 +761,19 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { // note: `type` token is already parsed for us, so we just need to parse the rest - auto name = parseNameOpt("type name"); + std::optional name = parseNameOpt("type name"); // Use error name if the name is missing if (!name) name = Name(nameError, lexer.current().location); - // TODO: support generic type pack parameters in type aliases CLI-39907 - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ true); expectAndConsume('=', "type alias"); AstType* type = parseTypeAnnotation(); - return allocator.alloc(Location(start, type->location), name->name, generics, type, exported); + return allocator.alloc(Location(start, type->location), name->name, generics, genericPacks, type, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() @@ -786,22 +783,23 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() Name fnName = parseName("function name"); // TODO: generic method declarations CLI-39909 - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; generics.size = 0; generics.data = nullptr; genericPacks.size = 0; genericPacks.data = nullptr; - Lexeme matchParen = lexer.current(); + MatchLexeme matchParen = lexer.current(); expectAndConsume('(', "function parameter list start"); TempVector args(scratchBinding); - std::optional vararg = std::nullopt; + bool vararg = false; + Location varargLocation; AstTypePack* varargAnnotation = nullptr; if (lexer.current().type != ')') - std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3 */ true); + std::tie(vararg, varargLocation, varargAnnotation) = parseBindingList(args, /* allowDot3 */ true); expectMatchAndConsume(')', matchParen); @@ -813,9 +811,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) { - return AstDeclaredClassProp{fnName.name, - reportTypeAnnotationError(Location(start, end), {}, /*isMissing*/ false, "'self' must be present as the unannotated first parameter"), - true}; + return AstDeclaredClassProp{ + fnName.name, reportTypeAnnotationError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; } // Skip the first index. @@ -826,8 +823,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args[i].annotation) vars.push_back(args[i].annotation); else - vars.push_back(reportTypeAnnotationError( - Location(start, end), {}, /*isMissing*/ false, "All declaration parameters aside from 'self' must be annotated")); + vars.push_back(reportTypeAnnotationError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); } if (vararg && !varargAnnotation) @@ -847,19 +843,20 @@ AstStat* Parser::parseDeclaration(const Location& start) nextLexeme(); Name globalName = parseName("global function name"); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); - Lexeme matchParen = lexer.current(); + MatchLexeme matchParen = lexer.current(); expectAndConsume('(', "global function declaration"); TempVector args(scratchBinding); - std::optional vararg; + bool vararg = false; + Location varargLocation; AstTypePack* varargAnnotation = nullptr; if (lexer.current().type != ')') - std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); + std::tie(vararg, varargLocation, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); expectMatchAndConsume(')', matchParen); @@ -906,6 +903,25 @@ AstStat* Parser::parseDeclaration(const Location& start) { props.push_back(parseDeclaredClassMethod()); } + else if (lexer.current().type == '[') + { + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + + std::optional> chars = parseCharArray(); + + expectMatchAndConsume(']', begin); + expectAndConsume(':', "property type annotation"); + AstType* type = parseTypeAnnotation(); + + // TODO: since AstName conains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false}); + else + report(begin.location, "String literal contains malformed escape sequence"); + } else { Name propName = parseName("property name"); @@ -920,7 +936,7 @@ AstStat* Parser::parseDeclaration(const Location& start) return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props)); } - else if (auto globalName = parseNameOpt("global variable name")) + else if (std::optional globalName = parseNameOpt("global variable name")) { expectAndConsume(':', "global variable declaration"); @@ -951,7 +967,7 @@ AstStat* Parser::parseAssignment(AstExpr* initial) { nextLexeme(); - AstExpr* expr = parsePrimaryExpr(/* asStatement= */ false); + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); if (!isExprLValue(expr)) expr = reportExprError(expr->location, copy({expr}), "Assigned expression must be a variable or a field"); @@ -982,29 +998,47 @@ AstStat* Parser::parseCompoundAssignment(AstExpr* initial, AstExprBinary::Op op) return allocator.alloc(Location(initial->location, value->location), op, initial, value); } +std::pair> Parser::prepareFunctionArguments(const Location& start, bool hasself, const TempVector& args) +{ + AstLocal* self = nullptr; + + if (hasself) + self = pushLocal(Binding(Name(nameSelf, start), nullptr)); + + TempVector vars(scratchLocal); + + for (size_t i = 0; i < args.size(); ++i) + vars.push_back(pushLocal(args[i])); + + return {self, copy(vars)}; +} + // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // parlist ::= bindinglist [`,' `...'] | `...' std::pair Parser::parseFunctionBody( - bool hasself, const Lexeme& matchFunction, const AstName& debugname, std::optional localName) + bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName) { Location start = matchFunction.location; - auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); - Lexeme matchParen = lexer.current(); + MatchLexeme matchParen = lexer.current(); expectAndConsume('(', "function"); TempVector args(scratchBinding); - std::optional vararg; + bool vararg = false; + Location varargLocation; AstTypePack* varargAnnotation = nullptr; if (lexer.current().type != ')') - std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); + std::tie(vararg, varargLocation, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); + + std::optional argLocation; + + if (matchParen.type == Lexeme::Type('(') && lexer.current().type == Lexeme::Type(')')) + argLocation = Location(matchParen.position, lexer.current().location.end); - std::optional argLocation = matchParen.type == Lexeme::Type('(') && lexer.current().type == Lexeme::Type(')') - ? std::make_optional(Location(matchParen.location.begin, lexer.current().location.end)) - : std::nullopt; expectMatchAndConsume(')', matchParen, true); std::optional typelist = parseOptionalReturnTypeAnnotation(); @@ -1017,19 +1051,11 @@ std::pair Parser::parseFunctionBody( unsigned int localsBegin = saveLocals(); Function fun; - fun.vararg = vararg.has_value(); + fun.vararg = vararg; - functionStack.push_back(fun); + functionStack.emplace_back(fun); - AstLocal* self = nullptr; - - if (hasself) - self = pushLocal(Binding(Name(nameSelf, start), nullptr)); - - TempVector vars(scratchLocal); - - for (size_t i = 0; i < args.size(); ++i) - vars.push_back(pushLocal(args[i])); + auto [self, vars] = prepareFunctionArguments(start, hasself, args); AstStatBlock* body = parseBlock(); @@ -1041,8 +1067,8 @@ std::pair Parser::parseFunctionBody( bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); - return {allocator.alloc(Location(start, end), generics, genericPacks, self, copy(vars), vararg, body, functionStack.size(), - debugname, typelist, varargAnnotation, hasEnd, argLocation), + return {allocator.alloc(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, + functionStack.size(), debugname, typelist, varargAnnotation, hasEnd, argLocation), funLocal}; } @@ -1055,13 +1081,19 @@ void Parser::parseExprList(TempVector& result) { nextLexeme(); + if (lexer.current().type == ')') + { + report(lexer.current().location, "Expected expression after ',' but got ')' instead"); + break; + } + result.push_back(parseExpr()); } } Parser::Binding Parser::parseBinding() { - auto name = parseNameOpt("variable name"); + std::optional name = parseNameOpt("variable name"); // Use placeholder if the name is missing if (!name) @@ -1073,7 +1105,7 @@ Parser::Binding Parser::parseBinding() } // bindinglist ::= (binding | `...') [`,' bindinglist] -std::pair, AstTypePack*> Parser::parseBindingList(TempVector& result, bool allowDot3) +std::tuple Parser::parseBindingList(TempVector& result, bool allowDot3) { while (true) { @@ -1089,7 +1121,7 @@ std::pair, AstTypePack*> Parser::parseBindingList(TempVe tailAnnotation = parseVariadicArgumentAnnotation(); } - return {varargLocation, tailAnnotation}; + return {true, varargLocation, tailAnnotation}; } result.push_back(parseBinding()); @@ -1099,7 +1131,7 @@ std::pair, AstTypePack*> Parser::parseBindingList(TempVe nextLexeme(); } - return {std::nullopt, nullptr}; + return {false, Location(), nullptr}; } AstType* Parser::parseOptionalTypeAnnotation() @@ -1128,7 +1160,7 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector& result, TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() { - if (options.allowTypeAnnotations && lexer.current().type == ':') + if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) { + if (lexer.current().type == Lexeme::SkinnyArrow) + report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); + nextLexeme(); unsigned int oldRecursionCount = recursionCounter; @@ -1226,8 +1268,8 @@ std::pair Parser::parseReturnTypeAnnotation() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstArray generics{nullptr, 0}; - AstArray genericPacks{nullptr, 0}; + AstArray generics{nullptr, 0}; + AstArray genericPacks{nullptr, 0}; AstArray types = copy(result); AstArray> names = copy(resultNames); @@ -1267,12 +1309,31 @@ AstType* Parser::parseTableTypeAnnotation() Location start = lexer.current().location; - Lexeme matchBrace = lexer.current(); + MatchLexeme matchBrace = lexer.current(); expectAndConsume('{', "table type"); while (lexer.current().type != '}') { - if (lexer.current().type == '[') + if (lexer.current().type == '[' && (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) + { + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + std::optional> chars = parseCharArray(); + + expectMatchAndConsume(']', begin); + expectAndConsume(':', "table field"); + + AstType* type = parseTypeAnnotation(); + + // TODO: since AstName conains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back({AstName(chars->data), begin.location, type}); + else + report(begin.location, "String literal contains malformed escape sequence"); + } + else if (lexer.current().type == '[') { if (indexer) { @@ -1300,7 +1361,7 @@ AstType* Parser::parseTableTypeAnnotation() } else { - auto name = parseNameOpt("table field"); + std::optional name = parseNameOpt("table field"); if (!name) break; @@ -1333,23 +1394,19 @@ AstType* Parser::parseTableTypeAnnotation() // ReturnType ::= TypeAnnotation | `(' TypeList `)' // FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstType* Parser::parseFunctionTypeAnnotation() +AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); - bool monomorphic = !(FFlag::LuauParseGenericFunctions && lexer.current().type == '<'); - - auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions(); + bool forceFunctionType = lexer.current().type == '<'; Lexeme begin = lexer.current(); - if (FFlag::LuauGenericFunctionsParserFix) - expectAndConsume('(', "function parameters"); - else - { - LUAU_ASSERT(begin.type == '('); - nextLexeme(); // ( - } + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); + + Lexeme parameterStart = lexer.current(); + + expectAndConsume('(', "function parameters"); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; @@ -1360,28 +1417,55 @@ AstType* Parser::parseFunctionTypeAnnotation() if (lexer.current().type != ')') varargAnnotation = parseTypeList(params, names); - expectMatchAndConsume(')', begin, true); + expectMatchAndConsume(')', parameterStart, true); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--; - // Not a function at all. Just a parenthesized type. - if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) - return params[0]; - AstArray paramTypes = copy(params); + + if (FFlag::LuauFixNamedFunctionParse && !names.empty()) + forceFunctionType = true; + + bool returnTypeIntroducer = lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':'; + + // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element + if (params.size() == 1 && !varargAnnotation && !forceFunctionType && !returnTypeIntroducer) + { + if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) + lua_telemetry_parsed_named_non_function_type = true; + + if (allowPack) + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; + else + return {params[0], {}}; + } + + if (!forceFunctionType && !returnTypeIntroducer && allowPack) + { + if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) + lua_telemetry_parsed_named_non_function_type = true; + + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; + } + AstArray> paramNames = copy(names); - return parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation); + return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, +AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation) { incrementRecursionCounter("type annotation"); + if (lexer.current().type == ':') + { + report(lexer.current().location, "Return types in function type annotations are written after '->' instead of ':'"); + lexer.next(); + } // Users occasionally write '()' as the 'unit' type when they actually want to use 'nil', here we'll try to give a more specific error - if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0) + else if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0) { report(Location(begin.location, lexer.previousLocation()), "Expected '->' after '()' when parsing function type; did you mean 'nil'?"); @@ -1392,7 +1476,7 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray(Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList); @@ -1421,7 +1505,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isUnion = true; } else if (c == '?') @@ -1434,9 +1518,14 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isIntersection = true; } + else if (c == Lexeme::Dot3) + { + report(lexer.current().location, "Unexpected '...' after type annotation"); + nextLexeme(); + } else break; } @@ -1446,7 +1535,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (isUnion && isIntersection) { - return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), /*isMissing*/ false, + return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); } @@ -1462,6 +1551,30 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location ParseError::raise(begin, "Composite type was not an intersection or union."); } +AstTypeOrPack Parser::parseTypeOrPackAnnotation() +{ + unsigned int oldRecursionCount = recursionCounter; + incrementRecursionCounter("type annotation"); + + Location begin = lexer.current().location; + + TempVector parts(scratchAnnotation); + + auto [type, typePack] = parseSimpleTypeAnnotation(/* allowPack= */ true); + + if (typePack) + { + LUAU_ASSERT(!type); + return {{}, typePack}; + } + + parts.push_back(type); + + recursionCounter = oldRecursionCount; + + return {parseTypeAnnotation(parts, begin), {}}; +} + AstType* Parser::parseTypeAnnotation() { unsigned int oldRecursionCount = recursionCounter; @@ -1470,7 +1583,7 @@ AstType* Parser::parseTypeAnnotation() Location begin = lexer.current().location; TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); recursionCounter = oldRecursionCount; @@ -1479,16 +1592,47 @@ AstType* Parser::parseTypeAnnotation() // typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' // | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstType* Parser::parseSimpleTypeAnnotation() +AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); - Location begin = lexer.current().location; + Location start = lexer.current().location; if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); - return allocator.alloc(begin, std::nullopt, nameNil); + return {allocator.alloc(start, std::nullopt, nameNil), {}}; + } + else if (lexer.current().type == Lexeme::ReservedTrue) + { + nextLexeme(); + return {allocator.alloc(start, true)}; + } + else if (lexer.current().type == Lexeme::ReservedFalse) + { + nextLexeme(); + return {allocator.alloc(start, false)}; + } + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) + { + if (std::optional> value = parseCharArray()) + { + AstArray svalue = *value; + return {allocator.alloc(start, svalue)}; + } + else + return {reportTypeAnnotationError(start, {}, "String literal contains malformed escape sequence")}; + } + else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) + { + parseInterpString(); + + return {reportTypeAnnotationError(start, {}, "Interpolated string literals cannot be used as types")}; + } + else if (lexer.current().type == Lexeme::BrokenString) + { + nextLexeme(); + return {reportTypeAnnotationError(start, {}, "Malformed string")}; } else if (lexer.current().type == Lexeme::Name) { @@ -1503,6 +1647,11 @@ AstType* Parser::parseSimpleTypeAnnotation() prefix = name.name; name = parseIndexName("field name", pointPosition); } + else if (lexer.current().type == Lexeme::Dot3) + { + report(lexer.current().location, "Unexpected '...' after type name; type pack is not allowed in this context"); + nextLexeme(); + } else if (name.name == "typeof") { Lexeme typeofBegin = lexer.current(); @@ -1514,31 +1663,48 @@ AstType* Parser::parseSimpleTypeAnnotation() expectMatchAndConsume(')', typeofBegin); - return allocator.alloc(Location(begin, end), expr); + return {allocator.alloc(Location(start, end), expr), {}}; } - AstArray generics = parseTypeParams(); + bool hasParameters = false; + AstArray parameters{}; + + if (lexer.current().type == '<') + { + hasParameters = true; + parameters = parseTypeParams(); + } Location end = lexer.previousLocation(); - return allocator.alloc(Location(begin, end), prefix, name.name, generics); + return {allocator.alloc(Location(start, end), prefix, name.name, hasParameters, parameters), {}}; } else if (lexer.current().type == '{') { - return parseTableTypeAnnotation(); + return {parseTableTypeAnnotation(), {}}; } - else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<')) + else if (lexer.current().type == '(' || lexer.current().type == '<') { - return parseFunctionTypeAnnotation(); + return parseFunctionTypeAnnotation(allowPack); + } + else if (lexer.current().type == Lexeme::ReservedFunction) + { + nextLexeme(); + + return {reportTypeAnnotationError(start, {}, + "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " + "...any'"), + {}}; } else { - Location location = lexer.current().location; - - // For a missing type annoation, capture 'space' between last token and the next one - location = Location(lexer.previousLocation().end, lexer.current().location.begin); - - return reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()); + // For a missing type annotation, capture 'space' between last token and the next one + Location astErrorlocation(lexer.previousLocation().end, start.begin); + // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). + // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. + Location parseErrorLocation(lexer.previousLocation().end, start.end); + return { + reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } @@ -1782,7 +1948,7 @@ AstExpr* Parser::parseExpr(unsigned int limit) // NAME AstExpr* Parser::parseNameExpr(const char* context) { - auto name = parseNameOpt(context); + std::optional name = parseNameOpt(context); if (!name) return allocator.alloc(lexer.current().location, copy({}), unsigned(parseErrors.size() - 1)); @@ -1804,14 +1970,14 @@ AstExpr* Parser::parsePrefixExpr() { if (lexer.current().type == '(') { - Location start = lexer.current().location; + Position start = lexer.current().location.begin; - Lexeme matchParen = lexer.current(); + MatchLexeme matchParen = lexer.current(); nextLexeme(); AstExpr* expr = parseExpr(); - Location end = lexer.current().location; + Position end = lexer.current().location.end; if (lexer.current().type != ')') { @@ -1819,7 +1985,7 @@ AstExpr* Parser::parsePrefixExpr() expectMatchAndConsumeFail(static_cast(')'), matchParen, suggestion); - end = lexer.previousLocation(); + end = lexer.previousLocation().end; } else { @@ -1837,7 +2003,7 @@ AstExpr* Parser::parsePrefixExpr() // primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs } AstExpr* Parser::parsePrimaryExpr(bool asStatement) { - Location start = lexer.current().location; + Position start = lexer.current().location.begin; AstExpr* expr = parsePrefixExpr(); @@ -1852,16 +2018,16 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) Name index = parseIndexName(nullptr, opPosition); - expr = allocator.alloc(Location(start, index.location), expr, index.name, index.location, opPosition, '.'); + expr = allocator.alloc(Location(start, index.location.end), expr, index.name, index.location, opPosition, '.'); } else if (lexer.current().type == '[') { - Lexeme matchBracket = lexer.current(); + MatchLexeme matchBracket = lexer.current(); nextLexeme(); AstExpr* index = parseExpr(); - Location end = lexer.current().location; + Position end = lexer.current().location.end; expectMatchAndConsume(']', matchBracket); @@ -1873,27 +2039,24 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) nextLexeme(); Name index = parseIndexName("method name", opPosition); - AstExpr* func = allocator.alloc(Location(start, index.location), expr, index.name, index.location, opPosition, ':'); + AstExpr* func = allocator.alloc(Location(start, index.location.end), expr, index.name, index.location, opPosition, ':'); - expr = parseFunctionArgs(func, true, index.location); + expr = parseFunctionArgs(func, true); } else if (lexer.current().type == '(') { // This error is handled inside 'parseFunctionArgs' as well, but for better error recovery we need to break out the current loop here if (!asStatement && expr->location.end.line != lexer.current().location.begin.line) { - report(lexer.current().location, - "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " - "new statement; use ';' to separate statements"); - + reportAmbiguousCallError(); break; } - expr = parseFunctionArgs(expr, false, Location()); + expr = parseFunctionArgs(expr, false); } else if (lexer.current().type == '{' || lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { - expr = parseFunctionArgs(expr, false, Location()); + expr = parseFunctionArgs(expr, false); } else { @@ -1925,34 +2088,68 @@ AstExpr* Parser::parseAssertionExpr() return expr; } -static bool parseNumber(double& result, const char* data) +static ConstantNumberParseResult parseInteger(double& result, const char* data, int base) +{ + LUAU_ASSERT(base == 2 || base == 16); + + char* end = nullptr; + unsigned long long value = strtoull(data, &end, base); + + if (*end != 0) + return ConstantNumberParseResult::Malformed; + + result = double(value); + + if (value == ULLONG_MAX && errno == ERANGE) + { + // 'errno' might have been set before we called 'strtoull', but we don't want the overhead of resetting a TLS variable on each call + // so we only reset it when we get a result that might be an out-of-range error and parse again to make sure + errno = 0; + value = strtoull(data, &end, base); + + if (errno == ERANGE) + { + if (DFFlag::LuaReportParseIntegerIssues) + { + if (base == 2) + lua_telemetry_parsed_out_of_range_bin_integer = true; + else + lua_telemetry_parsed_out_of_range_hex_integer = true; + } + + return base == 2 ? ConstantNumberParseResult::BinOverflow : ConstantNumberParseResult::HexOverflow; + } + } + + return ConstantNumberParseResult::Ok; +} + +static ConstantNumberParseResult parseDouble(double& result, const char* data) { // binary literal if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) - { - char* end = nullptr; - unsigned long long value = strtoull(data + 2, &end, 2); + return parseInteger(result, data + 2, 2); - result = double(value); - return *end == 0; - } // hexadecimal literal - else if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2]) + if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2]) { - char* end = nullptr; - unsigned long long value = strtoull(data + 2, &end, 16); + if (!FFlag::LuauErrorDoubleHexPrefix && data[2] == '0' && (data[3] == 'x' || data[3] == 'X')) + { + if (DFFlag::LuaReportParseIntegerIssues) + lua_telemetry_parsed_double_prefix_hex_integer = true; - result = double(value); - return *end == 0; - } - else - { - char* end = nullptr; - double value = strtod(data, &end); + ConstantNumberParseResult parseResult = parseInteger(result, data + 2, 16); + return parseResult == ConstantNumberParseResult::Malformed ? parseResult : ConstantNumberParseResult::DoublePrefix; + } - result = value; - return *end == 0; + return parseInteger(result, data, 16); // pass in '0x' prefix, it's handled by 'strtoull' } + + char* end = nullptr; + double value = strtod(data, &end); + + result = value; + return *end == 0 ? ConstantNumberParseResult::Ok : ConstantNumberParseResult::Malformed; } // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp @@ -1983,41 +2180,31 @@ AstExpr* Parser::parseSimpleExpr() Lexeme matchFunction = lexer.current(); nextLexeme(); - return parseFunctionBody(false, matchFunction, AstName(), {}).first; + return parseFunctionBody(false, matchFunction, AstName(), nullptr).first; } else if (lexer.current().type == Lexeme::Number) { - scratchData.assign(lexer.current().data, lexer.current().length); - - // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al - if (scratchData.find('_') != std::string::npos) - { - scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end()); - } - - double value = 0; - if (parseNumber(value, scratchData.c_str())) - { - nextLexeme(); - - return allocator.alloc(start, value); - } - else - { - nextLexeme(); - - return reportExprError(start, {}, "Malformed number"); - } + return parseNumber(); } - else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || + (FFlag::LuauInterpolatedStringBaseSupport && lexer.current().type == Lexeme::InterpStringSimple)) { return parseString(); } + else if (FFlag::LuauInterpolatedStringBaseSupport && lexer.current().type == Lexeme::InterpStringBegin) + { + return parseInterpString(); + } else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); return reportExprError(start, {}, "Malformed string"); } + else if (lexer.current().type == Lexeme::BrokenInterpDoubleBrace) + { + nextLexeme(); + return reportExprError(start, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + } else if (lexer.current().type == Lexeme::Dot3) { if (functionStack.back().vararg) @@ -2037,7 +2224,7 @@ AstExpr* Parser::parseSimpleExpr() { return parseTableConstructor(); } - else if (FFlag::LuauIfElseExpressionBaseSupport && lexer.current().type == Lexeme::ReservedIf) + else if (lexer.current().type == Lexeme::ReservedIf) { return parseIfElseExpr(); } @@ -2048,18 +2235,15 @@ AstExpr* Parser::parseSimpleExpr() } // args ::= `(' [explist] `)' | tableconstructor | String -AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self, const Location& selfLocation) +AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self) { if (lexer.current().type == '(') { Position argStart = lexer.current().location.end; if (func->location.end.line != lexer.current().location.begin.line) - { - report(lexer.current().location, "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " - "new statement; use ';' to separate statements"); - } + reportAmbiguousCallError(); - Lexeme matchParen = lexer.current(); + MatchLexeme matchParen = lexer.current(); nextLexeme(); TempVector args(scratchExpr); @@ -2091,18 +2275,29 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self, const Location& sel } else { - if (self && lexer.current().location.begin.line != func->location.end.line) - { - return reportExprError(func->location, copy({func}), "Expected function call arguments after '('"); - } - else - { - return reportExprError(Location(func->location.begin, lexer.current().location.begin), copy({func}), - "Expected '(', '{' or when parsing function call, got %s", lexer.current().toString().c_str()); - } + return reportFunctionArgsError(func, self); } } +LUAU_NOINLINE AstExpr* Parser::reportFunctionArgsError(AstExpr* func, bool self) +{ + if (self && lexer.current().location.begin.line != func->location.end.line) + { + return reportExprError(func->location, copy({func}), "Expected function call arguments after '('"); + } + else + { + return reportExprError(Location(func->location.begin, lexer.current().location.begin), copy({func}), + "Expected '(', '{' or when parsing function call, got %s", lexer.current().toString().c_str()); + } +} + +LUAU_NOINLINE void Parser::reportAmbiguousCallError() +{ + report(lexer.current().location, "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " + "new statement; use ';' to separate statements"); +} + // tableconstructor ::= `{' [fieldlist] `}' // fieldlist ::= field {fieldsep field} [fieldsep] // field ::= `[' exp `]' `=' exp | Name `=' exp | exp @@ -2113,14 +2308,17 @@ AstExpr* Parser::parseTableConstructor() Location start = lexer.current().location; - Lexeme matchBrace = lexer.current(); + MatchLexeme matchBrace = lexer.current(); expectAndConsume('{', "table literal"); + unsigned lastElementIndent = 0; while (lexer.current().type != '}') { + lastElementIndent = lexer.current().location.begin.column; + if (lexer.current().type == '[') { - Lexeme matchLocationBracket = lexer.current(); + MatchLexeme matchLocationBracket = lexer.current(); nextLexeme(); AstExpr* key = parseExpr(); @@ -2146,6 +2344,9 @@ AstExpr* Parser::parseTableConstructor() AstExpr* key = allocator.alloc(name.location, nameString); AstExpr* value = parseExpr(); + if (AstExprFunction* func = value->as()) + func->debugname = name.name; + items.push_back({AstExprTable::Item::Record, key, value}); } else @@ -2159,10 +2360,13 @@ AstExpr* Parser::parseTableConstructor() { nextLexeme(); } - else + else if ((lexer.current().type == '[' || lexer.current().type == Lexeme::Name) && lexer.current().location.begin.column == lastElementIndent) { - if (lexer.current().type != '}') - break; + report(lexer.current().location, "Expected ',' after table constructor element"); + } + else if (lexer.current().type != '}') + { + break; } } @@ -2226,7 +2430,7 @@ std::optional Parser::parseNameOpt(const char* context) Parser::Name Parser::parseName(const char* context) { - if (auto name = parseNameOpt(context)) + if (std::optional name = parseNameOpt(context)) return *name; Location location = lexer.current().location; @@ -2237,7 +2441,7 @@ Parser::Name Parser::parseName(const char* context) Parser::Name Parser::parseIndexName(const char* context, const Position& previous) { - if (auto name = parseNameOpt(context)) + if (std::optional name = parseNameOpt(context)) return *name; // If we have a reserved keyword next at the same line, assume it's an incomplete name @@ -2257,23 +2461,10 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou return Name(nameError, location); } -std::pair, AstArray> Parser::parseGenericTypeListIfFFlagParseGenericFunctions() +std::pair, AstArray> Parser::parseGenericTypeList(bool withDefaultValues) { - if (FFlag::LuauParseGenericFunctions) - return Parser::parseGenericTypeList(); - AstArray generics; - AstArray genericPacks; - generics.size = 0; - generics.data = nullptr; - genericPacks.size = 0; - genericPacks.data = nullptr; - return std::pair(generics, genericPacks); -} - -std::pair, AstArray> Parser::parseGenericTypeList() -{ - TempVector names{scratchName}; - TempVector namePacks{scratchPackName}; + TempVector names{scratchGenericTypes}; + TempVector namePacks{scratchGenericTypePacks}; if (lexer.current().type == '<') { @@ -2281,25 +2472,91 @@ std::pair, AstArray> Parser::parseGenericTypeList() nextLexeme(); bool seenPack = false; + bool seenDefault = false; + while (true) { + Location nameLocation = lexer.current().location; AstName name = parseName().name; - if (FFlag::LuauParseGenericFunctions && lexer.current().type == Lexeme::Dot3) + if (lexer.current().type == Lexeme::Dot3 || seenPack) { seenPack = true; - nextLexeme(); - namePacks.push_back(name); + + if (lexer.current().type != Lexeme::Dot3) + report(lexer.current().location, "Generic types come before generic type packs"); + else + nextLexeme(); + + if (withDefaultValues && lexer.current().type == '=') + { + seenDefault = true; + nextLexeme(); + + Lexeme packBegin = lexer.current(); + + if (shouldParseTypePackAnnotation(lexer)) + { + AstTypePack* typePack = parseTypePackAnnotation(); + + namePacks.push_back({name, nameLocation, typePack}); + } + else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (type) + report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); + + namePacks.push_back({name, nameLocation, typePack}); + } + else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument) + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (type) + report(type->location, "Expected type pack after '=', got type"); + + namePacks.push_back({name, nameLocation, typePack}); + } + } + else + { + if (seenDefault) + report(lexer.current().location, "Expected default type pack after type pack name"); + + namePacks.push_back({name, nameLocation, nullptr}); + } } else { - if (seenPack) - report(lexer.current().location, "Generic types come before generic type packs"); + if (withDefaultValues && lexer.current().type == '=') + { + seenDefault = true; + nextLexeme(); - names.push_back(name); + AstType* defaultType = parseTypeAnnotation(); + + names.push_back({name, nameLocation, defaultType}); + } + else + { + if (seenDefault) + report(lexer.current().location, "Expected default type after type name"); + + names.push_back({name, nameLocation, nullptr}); + } } if (lexer.current().type == ',') + { nextLexeme(); + + if (lexer.current().type == '>') + { + report(lexer.current().location, "Expected type after ',' but got '>' instead"); + break; + } + } else break; } @@ -2307,14 +2564,14 @@ std::pair, AstArray> Parser::parseGenericTypeList() expectMatchAndConsume('>', begin); } - AstArray generics = copy(names); - AstArray genericPacks = copy(namePacks); + AstArray generics = copy(names); + AstArray genericPacks = copy(namePacks); return {generics, genericPacks}; } -AstArray Parser::parseTypeParams() +AstArray Parser::parseTypeParams() { - TempVector result{scratchAnnotation}; + TempVector parameters{scratchTypeOrPackAnnotation}; if (lexer.current().type == '<') { @@ -2323,7 +2580,30 @@ AstArray Parser::parseTypeParams() while (true) { - result.push_back(parseTypeAnnotation()); + if (shouldParseTypePackAnnotation(lexer)) + { + AstTypePack* typePack = parseTypePackAnnotation(); + + parameters.push_back({{}, typePack}); + } + else if (lexer.current().type == '(') + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (typePack) + parameters.push_back({{}, typePack}); + else + parameters.push_back({type, {}}); + } + else if (lexer.current().type == '>' && parameters.empty()) + { + break; + } + else + { + parameters.push_back({parseTypeAnnotation(), {}}); + } + if (lexer.current().type == ',') nextLexeme(); else @@ -2333,24 +2613,22 @@ AstArray Parser::parseTypeParams() expectMatchAndConsume('>', begin); } - return copy(result); + return copy(parameters); } -AstExpr* Parser::parseString() +std::optional> Parser::parseCharArray() { - LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString); + LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || + lexer.current().type == Lexeme::InterpStringSimple); scratchData.assign(lexer.current().data, lexer.current().length); - if (lexer.current().type == Lexeme::QuotedString) + if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { if (!Lexer::fixupQuotedString(scratchData)) { - Location location = lexer.current().location; - nextLexeme(); - - return reportExprError(location, {}, "String literal contains malformed escape sequence"); + return std::nullopt; } } else @@ -2358,12 +2636,127 @@ AstExpr* Parser::parseString() Lexer::fixupMultilineString(scratchData); } - Location start = lexer.current().location; AstArray value = copy(scratchData); + nextLexeme(); + return value; +} +AstExpr* Parser::parseString() +{ + Location location = lexer.current().location; + if (std::optional> value = parseCharArray()) + return allocator.alloc(location, *value); + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); +} + +AstExpr* Parser::parseInterpString() +{ + TempVector> strings(scratchString); + TempVector expressions(scratchExpr); + + Location startLocation = lexer.current().location; + Location endLocation; + + do + { + Lexeme currentLexeme = lexer.current(); + LUAU_ASSERT(currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid || + currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple); + + endLocation = currentLexeme.location; + + Location startOfBrace = Location(endLocation.end, 1); + + scratchData.assign(currentLexeme.data, currentLexeme.length); + + if (!Lexer::fixupQuotedString(scratchData)) + { + nextLexeme(); + return reportExprError(Location{startLocation, endLocation}, {}, "Interpolated string literal contains malformed escape sequence"); + } + + AstArray chars = copy(scratchData); + + nextLexeme(); + + strings.push_back(chars); + + if (currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple) + { + break; + } + + bool errorWhileChecking = false; + + switch (lexer.current().type) + { + case Lexeme::InterpStringMid: + case Lexeme::InterpStringEnd: + { + errorWhileChecking = true; + nextLexeme(); + expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string, expected expression inside '{}'")); + break; + } + case Lexeme::BrokenString: + { + errorWhileChecking = true; + nextLexeme(); + expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '`'?")); + break; + } + default: + expressions.push_back(parseExpr()); + } + + if (errorWhileChecking) + { + break; + } + + switch (lexer.current().type) + { + case Lexeme::InterpStringBegin: + case Lexeme::InterpStringMid: + case Lexeme::InterpStringEnd: + break; + case Lexeme::BrokenInterpDoubleBrace: + nextLexeme(); + return reportExprError(endLocation, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + case Lexeme::BrokenString: + nextLexeme(); + return reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '}'?"); + default: + return reportExprError(endLocation, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); + } + } while (true); + + AstArray> stringsArray = copy(strings); + AstArray expressionsArray = copy(expressions); + return allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); +} + +AstExpr* Parser::parseNumber() +{ + Location start = lexer.current().location; + + scratchData.assign(lexer.current().data, lexer.current().length); + + // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al + if (scratchData.find('_') != std::string::npos) + { + scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end()); + } + + double value = 0; + ConstantNumberParseResult result = parseDouble(value, scratchData.c_str()); nextLexeme(); - return allocator.alloc(start, value); + if (result == ConstantNumberParseResult::Malformed) + return reportExprError(start, {}, "Malformed number"); + + return allocator.alloc(start, value, result); } AstLocal* Parser::pushLocal(const Binding& binding) @@ -2437,7 +2830,7 @@ LUAU_NOINLINE void Parser::expectAndConsumeFail(Lexeme::Type type, const char* c report(lexer.current().location, "Expected %s, got %s", typeString.c_str(), currLexemeString.c_str()); } -bool Parser::expectMatchAndConsume(char value, const Lexeme& begin, bool searchForMissing) +bool Parser::expectMatchAndConsume(char value, const MatchLexeme& begin, bool searchForMissing) { Lexeme::Type type = static_cast(static_cast(value)); @@ -2445,42 +2838,7 @@ bool Parser::expectMatchAndConsume(char value, const Lexeme& begin, bool searchF { expectMatchAndConsumeFail(type, begin); - if (searchForMissing) - { - // previous location is taken because 'current' lexeme is already the next token - unsigned currentLine = lexer.previousLocation().end.line; - - // search to the end of the line for expected token - // we will also stop if we hit a token that can be handled by parsing function above the current one - Lexeme::Type lexemeType = lexer.current().type; - - while (currentLine == lexer.current().location.begin.line && lexemeType != type && matchRecoveryStopOnToken[lexemeType] == 0) - { - nextLexeme(); - lexemeType = lexer.current().type; - } - - if (lexemeType == type) - { - nextLexeme(); - - return true; - } - } - else - { - // check if this is an extra token and the expected token is next - if (lexer.lookahead().type == type) - { - // skip invalid and consume expected - nextLexeme(); - nextLexeme(); - - return true; - } - } - - return false; + return expectMatchAndConsumeRecover(value, begin, searchForMissing); } else { @@ -2490,21 +2848,64 @@ bool Parser::expectMatchAndConsume(char value, const Lexeme& begin, bool searchF } } -// LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is -// cold -LUAU_NOINLINE void Parser::expectMatchAndConsumeFail(Lexeme::Type type, const Lexeme& begin, const char* extra) +LUAU_NOINLINE bool Parser::expectMatchAndConsumeRecover(char value, const MatchLexeme& begin, bool searchForMissing) { - std::string typeString = Lexeme(Location(Position(0, 0), 0), type).toString(); + Lexeme::Type type = static_cast(static_cast(value)); - if (lexer.current().location.begin.line == begin.location.begin.line) - report(lexer.current().location, "Expected %s (to close %s at column %d), got %s%s", typeString.c_str(), begin.toString().c_str(), - begin.location.begin.column + 1, lexer.current().toString().c_str(), extra ? extra : ""); + if (searchForMissing) + { + // previous location is taken because 'current' lexeme is already the next token + unsigned currentLine = lexer.previousLocation().end.line; + + // search to the end of the line for expected token + // we will also stop if we hit a token that can be handled by parsing function above the current one + Lexeme::Type lexemeType = lexer.current().type; + + while (currentLine == lexer.current().location.begin.line && lexemeType != type && matchRecoveryStopOnToken[lexemeType] == 0) + { + nextLexeme(); + lexemeType = lexer.current().type; + } + + if (lexemeType == type) + { + nextLexeme(); + + return true; + } + } else - report(lexer.current().location, "Expected %s (to close %s at line %d), got %s%s", typeString.c_str(), begin.toString().c_str(), - begin.location.begin.line + 1, lexer.current().toString().c_str(), extra ? extra : ""); + { + // check if this is an extra token and the expected token is next + if (lexer.lookahead().type == type) + { + // skip invalid and consume expected + nextLexeme(); + nextLexeme(); + + return true; + } + } + + return false; } -bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin) +// LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is +// cold +LUAU_NOINLINE void Parser::expectMatchAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin, const char* extra) +{ + std::string typeString = Lexeme(Location(Position(0, 0), 0), type).toString(); + std::string matchString = Lexeme(Location(Position(0, 0), 0), begin.type).toString(); + + if (lexer.current().location.begin.line == begin.position.line) + report(lexer.current().location, "Expected %s (to close %s at column %d), got %s%s", typeString.c_str(), matchString.c_str(), + begin.position.column + 1, lexer.current().toString().c_str(), extra ? extra : ""); + else + report(lexer.current().location, "Expected %s (to close %s at line %d), got %s%s", typeString.c_str(), matchString.c_str(), + begin.position.line + 1, lexer.current().toString().c_str(), extra ? extra : ""); +} + +bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begin) { if (lexer.current().type != type) { @@ -2526,9 +2927,8 @@ bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin) { // If the token matches on a different line and a different column, it suggests misleading indentation // This can be used to pinpoint the problem location for a possible future *actual* mismatch - if (lexer.current().location.begin.line != begin.location.begin.line && - lexer.current().location.begin.column != begin.location.begin.column && - endMismatchSuspect.location.begin.line < begin.location.begin.line) // Only replace the previous suspect with more recent suspects + if (lexer.current().location.begin.line != begin.position.line && lexer.current().location.begin.column != begin.position.column && + endMismatchSuspect.position.line < begin.position.line) // Only replace the previous suspect with more recent suspects { endMismatchSuspect = begin; } @@ -2541,12 +2941,12 @@ bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin) // LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is // cold -LUAU_NOINLINE void Parser::expectMatchEndAndConsumeFail(Lexeme::Type type, const Lexeme& begin) +LUAU_NOINLINE void Parser::expectMatchEndAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin) { - if (endMismatchSuspect.type != Lexeme::Eof && endMismatchSuspect.location.begin.line > begin.location.begin.line) + if (endMismatchSuspect.type != Lexeme::Eof && endMismatchSuspect.position.line > begin.position.line) { - std::string suggestion = - format("; did you forget to close %s at line %d?", endMismatchSuspect.toString().c_str(), endMismatchSuspect.location.begin.line + 1); + std::string matchString = Lexeme(Location(Position(0, 0), 0), endMismatchSuspect.type).toString(); + std::string suggestion = format("; did you forget to close %s at line %d?", matchString.c_str(), endMismatchSuspect.position.line + 1); expectMatchAndConsumeFail(type, begin, suggestion.c_str()); } @@ -2659,35 +3059,56 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray(location, expressions, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) +AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) { va_list args; va_start(args, format); report(location, format, args); va_end(args); - return allocator.alloc(location, types, isMissing, unsigned(parseErrors.size() - 1)); + return allocator.alloc(location, types, false, unsigned(parseErrors.size() - 1)); } -const Lexeme& Parser::nextLexeme() +AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) { - if (options.captureComments) + va_list args; + va_start(args, format); + report(parseErrorLocation, format, args); + va_end(args); + + return allocator.alloc(astErrorLocation, AstArray{}, true, unsigned(parseErrors.size() - 1)); +} + +void Parser::nextLexeme() +{ + Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type; + + while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) { - while (true) + const Lexeme& lexeme = lexer.current(); + + if (options.captureComments) + commentLocations.push_back(Comment{lexeme.type, lexeme.location}); + + // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. + // The parser will turn this into a proper syntax error. + if (lexeme.type == Lexeme::BrokenComment) + return; + + // Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling + if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') { - const Lexeme& lexeme = lexer.next(/*skipComments*/ false); - // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. - // The parser will turn this into a proper syntax error. - if (FFlag::LuauCaptureBrokenCommentSpans && lexeme.type == Lexeme::BrokenComment) - commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - if (isComment(lexeme)) - commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - else - return lexeme; + const char* text = lexeme.data; + + unsigned int end = lexeme.length; + while (end > 0 && isSpace(text[end - 1])) + --end; + + hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); } + + type = lexer.next(/* skipComments= */ false, /* updatePrevLocation= */ false).type; } - else - return lexer.next(); } } // namespace Luau diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 24b2283a..11e0076a 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -11,7 +11,7 @@ namespace Luau { -static void vformatAppend(std::string& ret, const char* fmt, va_list args) +void vformatAppend(std::string& ret, const char* fmt, va_list args) { va_list argscopy; va_copy(argscopy, args); @@ -225,4 +225,68 @@ size_t hashRange(const char* data, size_t size) return hash; } +bool isIdentifier(std::string_view s) +{ + return (s.find_first_not_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_") == std::string::npos); +} + +std::string escape(std::string_view s, bool escapeForInterpString) +{ + std::string r; + r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting + + for (uint8_t c : s) + { + if (c >= ' ' && c != '\\' && c != '\'' && c != '\"' && c != '`' && c != '{') + r += c; + else + { + r += '\\'; + + if (escapeForInterpString && (c == '`' || c == '{')) + { + r += c; + continue; + } + + switch (c) + { + case '\a': + r += 'a'; + break; + case '\b': + r += 'b'; + break; + case '\f': + r += 'f'; + break; + case '\n': + r += 'n'; + break; + case '\r': + r += 'r'; + break; + case '\t': + r += 't'; + break; + case '\v': + r += 'v'; + break; + case '\'': + r += '\''; + break; + case '\"': + r += '\"'; + break; + case '\\': + r += '\\'; + break; + default: + Luau::formatAppend(r, "%03u", c); + } + } + } + + return r; +} } // namespace Luau diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp new file mode 100644 index 00000000..e3807683 --- /dev/null +++ b/Ast/src/TimeTrace.cpp @@ -0,0 +1,269 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TimeTrace.h" + +#include "Luau/StringUtils.h" + +#include +#include + +#include + +#ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#endif + +#ifdef __APPLE__ +#include +#include +#endif + +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false) +namespace Luau +{ +namespace TimeTrace +{ +static double getClockPeriod() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceFrequency(&result); + return 1.0 / double(result.QuadPart); +#elif defined(__APPLE__) + mach_timebase_info_data_t result = {}; + mach_timebase_info(&result); + return double(result.numer) / double(result.denom) * 1e-9; +#elif defined(__linux__) + return 1e-9; +#else + return 1.0 / double(CLOCKS_PER_SEC); +#endif +} + +static double getClockTimestamp() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceCounter(&result); + return double(result.QuadPart); +#elif defined(__APPLE__) + return double(mach_absolute_time()); +#elif defined(__linux__) + timespec now; + clock_gettime(CLOCK_MONOTONIC, &now); + return now.tv_sec * 1e9 + now.tv_nsec; +#else + return double(clock()); +#endif +} + +double getClock() +{ + static double period = getClockPeriod(); + static double start = getClockTimestamp(); + + return (getClockTimestamp() - start) * period; +} + +uint32_t getClockMicroseconds() +{ + static double period = getClockPeriod() * 1e6; + static double start = getClockTimestamp(); + + return uint32_t((getClockTimestamp() - start) * period); +} +} // namespace TimeTrace +} // namespace Luau + +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ +struct GlobalContext +{ + GlobalContext() = default; + ~GlobalContext() + { + // Ideally we would want all ThreadContext destructors to run + // But in VS, not all thread_local object instances are destroyed + for (ThreadContext* context : threads) + { + if (!context->events.empty()) + context->flushEvents(); + } + + if (traceFile) + fclose(traceFile); + } + + std::mutex mutex; + std::vector threads; + uint32_t nextThreadId = 0; + std::vector tokens; + FILE* traceFile = nullptr; +}; + +GlobalContext& getGlobalContext() +{ + static GlobalContext context; + return context; +} + +uint16_t createToken(GlobalContext& context, const char* name, const char* category) +{ + std::scoped_lock lock(context.mutex); + + LUAU_ASSERT(context.tokens.size() < 64 * 1024); + + context.tokens.push_back({name, category}); + return uint16_t(context.tokens.size() - 1); +} + +uint32_t createThread(GlobalContext& context, ThreadContext* threadContext) +{ + std::scoped_lock lock(context.mutex); + + context.threads.push_back(threadContext); + + return ++context.nextThreadId; +} + +void releaseThread(GlobalContext& context, ThreadContext* threadContext) +{ + std::scoped_lock lock(context.mutex); + + if (auto it = std::find(context.threads.begin(), context.threads.end(), threadContext); it != context.threads.end()) + context.threads.erase(it); +} + +void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector& events, const std::vector& data) +{ + std::scoped_lock lock(context.mutex); + + if (!context.traceFile) + { + context.traceFile = fopen("trace.json", "w"); + + if (!context.traceFile) + return; + + fprintf(context.traceFile, "[\n"); + } + + std::string temp; + const unsigned tempReserve = 64 * 1024; + temp.reserve(tempReserve); + + const char* rawData = data.data(); + + // Formatting state + bool unfinishedEnter = false; + bool unfinishedArgs = false; + + for (const Event& ev : events) + { + switch (ev.type) + { + case EventType::Enter: + { + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + Token& token = context.tokens[ev.token]; + + formatAppend(temp, R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)", token.name, token.category, + ev.data.microsec, threadId); + unfinishedEnter = true; + } + break; + case EventType::Leave: + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + formatAppend(temp, + R"({"ph": "E", "ts": %u, "pid": 0, "tid": %u},)" + "\n", + ev.data.microsec, threadId); + break; + case EventType::ArgName: + LUAU_ASSERT(unfinishedEnter); + + if (!unfinishedArgs) + { + formatAppend(temp, R"(, "args": { "%s": )", rawData + ev.data.dataPos); + unfinishedArgs = true; + } + else + { + formatAppend(temp, R"(, "%s": )", rawData + ev.data.dataPos); + } + break; + case EventType::ArgValue: + LUAU_ASSERT(unfinishedArgs); + formatAppend(temp, R"("%s")", rawData + ev.data.dataPos); + break; + } + + // Don't want to hit the string capacity and reallocate + if (temp.size() > tempReserve - 1024) + { + fwrite(temp.data(), 1, temp.size(), context.traceFile); + temp.clear(); + } + } + + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + fwrite(temp.data(), 1, temp.size(), context.traceFile); + fflush(context.traceFile); +} + +ThreadContext& getThreadContext() +{ + thread_local ThreadContext context; + return context; +} + +uint16_t createScopeData(const char* name, const char* category) +{ + return createToken(Luau::TimeTrace::getGlobalContext(), name, category); +} +} // namespace TimeTrace +} // namespace Luau + +#endif diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 920502b8..6257e2f3 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -7,39 +7,56 @@ #include "Luau/Transpiler.h" #include "FileUtils.h" +#include "Flags.h" + +#ifdef CALLGRIND +#include +#endif + +LUAU_FASTFLAG(DebugLuauTimeTracing) enum class ReportFormat { Default, - Luacheck + Luacheck, + Gnu, }; -static void report(ReportFormat format, const char* name, const Luau::Location& location, const char* type, const char* message) +static void report(ReportFormat format, const char* name, const Luau::Location& loc, const char* type, const char* message) { switch (format) { case ReportFormat::Default: - fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, type, message); break; case ReportFormat::Luacheck: { // Note: luacheck's end column is inclusive but our end column is exclusive // In addition, luacheck doesn't support multi-line messages, so if the error is multiline we'll fake end column as 100 and hope for the best - int columnEnd = (location.begin.line == location.end.line) ? location.end.column : 100; + int columnEnd = (loc.begin.line == loc.end.line) ? loc.end.column : 100; - fprintf(stdout, "%s:%d:%d-%d: (W0) %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, columnEnd, type, message); + // Use stdout to match luacheck behavior + fprintf(stdout, "%s:%d:%d-%d: (W0) %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, columnEnd, type, message); break; } + + case ReportFormat::Gnu: + // Note: GNU end column is inclusive but our end column is exclusive + fprintf(stderr, "%s:%d.%d-%d.%d: %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, loc.end.line + 1, loc.end.column, type, message); + break; } } -static void reportError(ReportFormat format, const char* name, const Luau::TypeError& error) +static void reportError(const Luau::Frontend& frontend, ReportFormat format, const Luau::TypeError& error) { + std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(error.moduleName); + if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) - report(format, name, error.location, "SyntaxError", syntaxError->message.c_str()); + report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); else - report(format, name, error.location, "TypeError", Luau::toString(error).c_str()); + report(format, humanReadableName.c_str(), error.location, "TypeError", + Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); } static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) @@ -49,7 +66,10 @@ static void reportWarning(ReportFormat format, const char* name, const Luau::Lin static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate) { - Luau::CheckResult cr = frontend.check(name); + Luau::CheckResult cr; + + if (frontend.isDirty(name)) + cr = frontend.check(name); if (!frontend.getSourceModule(name)) { @@ -58,14 +78,15 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat } for (auto& error : cr.errors) - reportError(format, name, error); + reportError(frontend, format, error); Luau::LintResult lr = frontend.lint(name); + std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); for (auto& error : lr.errors) - reportWarning(format, name, error); + reportWarning(format, humanReadableName.c_str(), error); for (auto& warning : lr.warnings) - reportWarning(format, name, warning); + reportWarning(format, humanReadableName.c_str(), warning); if (annotate) { @@ -92,9 +113,12 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); + printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); + printf(" --mode=strict: default to strict mode when typechecking\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); } -static int assertionHandler(const char* expr, const char* file, int line) +static int assertionHandler(const char* expr, const char* file, int line, const char* function) { printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); return 1; @@ -104,36 +128,49 @@ struct CliFileResolver : Luau::FileResolver { std::optional readSource(const Luau::ModuleName& name) override { - std::optional source = readFile(name); + Luau::SourceCode::Type sourceType; + std::optional source = std::nullopt; + + // If the module name is "-", then read source from stdin + if (name == "-") + { + source = readStdin(); + sourceType = Luau::SourceCode::Script; + } + else + { + source = readFile(name); + sourceType = Luau::SourceCode::Module; + } + if (!source) return std::nullopt; - return Luau::SourceCode{*source, Luau::SourceCode::Module}; + return Luau::SourceCode{*source, sourceType}; } - bool moduleExists(const Luau::ModuleName& name) const override + std::optional resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* node) override { - return !!readFile(name); - } + if (Luau::AstExprConstantString* expr = node->as()) + { + Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".luau"; + if (!readFile(name)) + { + // fall back to .lua if a module with .luau doesn't exist + name = std::string(expr->value.data, expr->value.size) + ".lua"; + } + + return {{name}}; + } - std::optional fromAstFragment(Luau::AstExpr* expr) const override - { return std::nullopt; } - Luau::ModuleName concat(const Luau::ModuleName& lhs, std::string_view rhs) const override + std::string getHumanReadableModuleName(const Luau::ModuleName& name) const override { - return lhs + "/" + std::string(rhs); - } - - std::optional getParentModuleName(const Luau::ModuleName& name) const override - { - return std::nullopt; - } - - std::optional getEnvironmentForModule(const Luau::ModuleName& name) const override - { - return std::nullopt; + if (name == "-") + return "stdin"; + return name; } }; @@ -144,9 +181,9 @@ struct CliConfigResolver : Luau::ConfigResolver mutable std::unordered_map configCache; mutable std::vector> configErrors; - CliConfigResolver() + CliConfigResolver(Luau::Mode mode) { - defaultConfig.mode = Luau::Mode::Nonstrict; + defaultConfig.mode = mode; } const Luau::Config& getConfig(const Luau::ModuleName& name) const override @@ -184,9 +221,7 @@ int main(int argc, char** argv) { Luau::assertHandler() = assertionHandler; - for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) - if (strncmp(flag->name, "Luau", 4) == 0) - flag->value = true; + setLuauFlagsDefault(); if (argc >= 2 && strcmp(argv[1], "--help") == 0) { @@ -195,6 +230,7 @@ int main(int argc, char** argv) } ReportFormat format = ReportFormat::Default; + Luau::Mode mode = Luau::Mode::Nonstrict; bool annotate = false; for (int i = 1; i < argc; ++i) @@ -204,39 +240,46 @@ int main(int argc, char** argv) if (strcmp(argv[i], "--formatter=plain") == 0) format = ReportFormat::Luacheck; + else if (strcmp(argv[i], "--formatter=gnu") == 0) + format = ReportFormat::Gnu; + else if (strcmp(argv[i], "--mode=strict") == 0) + mode = Luau::Mode::Strict; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; + else if (strcmp(argv[i], "--timetrace") == 0) + FFlag::DebugLuauTimeTracing.value = true; + else if (strncmp(argv[i], "--fflags=", 9) == 0) + setLuauFlags(argv[i] + 9); } +#if !defined(LUAU_ENABLE_TIME_TRACE) + if (FFlag::DebugLuauTimeTracing) + { + fprintf(stderr, "To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; + } +#endif + Luau::FrontendOptions frontendOptions; frontendOptions.retainFullTypeGraphs = annotate; CliFileResolver fileResolver; - CliConfigResolver configResolver; + CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); - Luau::registerBuiltinTypes(frontend.typeChecker); + Luau::registerBuiltinGlobals(frontend.typeChecker); Luau::freeze(frontend.typeChecker.globalTypes); +#ifdef CALLGRIND + CALLGRIND_ZERO_STATS; +#endif + + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !analyzeFile(frontend, name.c_str(), format, annotate); - }); - } - else - { - failed += !analyzeFile(frontend, argv[i], format, annotate); - } - } + for (const std::string& path : files) + failed += !analyzeFile(frontend, path.c_str(), format, annotate); if (!configResolver.configErrors.empty()) { @@ -246,7 +289,8 @@ int main(int argc, char** argv) fprintf(stderr, "%s: %s\n", pair.first.c_str(), pair.second.c_str()); } - return (format == ReportFormat::Luacheck) ? 0 : failed; + if (format == ReportFormat::Luacheck) + return 0; + else + return failed ? 1 : 0; } - - diff --git a/CLI/Ast.cpp b/CLI/Ast.cpp new file mode 100644 index 00000000..99c58393 --- /dev/null +++ b/CLI/Ast.cpp @@ -0,0 +1,86 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include + +#include "Luau/Common.h" +#include "Luau/Ast.h" +#include "Luau/AstJsonEncoder.h" +#include "Luau/Parser.h" +#include "Luau/ParseOptions.h" +#include "Luau/ToString.h" + +#include "FileUtils.h" + +static void displayHelp(const char* argv0) +{ + printf("Usage: %s [file]\n", argv0); +} + +static int assertionHandler(const char* expr, const char* file, int line, const char* function) +{ + printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + return 1; +} + +int main(int argc, char** argv) +{ + Luau::assertHandler() = assertionHandler; + + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + if (argc >= 2 && strcmp(argv[1], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + else if (argc < 2) + { + displayHelp(argv[0]); + return 1; + } + + const char* name = argv[1]; + std::optional maybeSource = std::nullopt; + if (strcmp(name, "-") == 0) + { + maybeSource = readStdin(); + } + else + { + maybeSource = readFile(name); + } + + if (!maybeSource) + { + fprintf(stderr, "Couldn't read source %s\n", name); + return 1; + } + + std::string source = *maybeSource; + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + + Luau::ParseOptions options; + options.captureComments = true; + options.supportContinueStatement = true; + options.allowTypeAnnotations = true; + options.allowDeclarationSyntax = true; + + Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); + + if (parseResult.errors.size() > 0) + { + fprintf(stderr, "Parse errors were encountered:\n"); + for (const Luau::ParseError& error : parseResult.errors) + { + fprintf(stderr, " %s - %s\n", toString(error.getLocation()).c_str(), error.getMessage().c_str()); + } + fprintf(stderr, "\n"); + } + + printf("%s", Luau::toJson(parseResult.root, parseResult.commentLocations).c_str()); + + return parseResult.errors.size() > 0 ? 1 : 0; +} diff --git a/CLI/Coverage.cpp b/CLI/Coverage.cpp new file mode 100644 index 00000000..a509ab89 --- /dev/null +++ b/CLI/Coverage.cpp @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Coverage.h" + +#include "lua.h" + +#include +#include + +struct Coverage +{ + lua_State* L = nullptr; + std::vector functions; +} gCoverage; + +void coverageInit(lua_State* L) +{ + gCoverage.L = lua_mainthread(L); +} + +bool coverageActive() +{ + return gCoverage.L != nullptr; +} + +void coverageTrack(lua_State* L, int funcindex) +{ + int ref = lua_ref(L, funcindex); + gCoverage.functions.push_back(ref); +} + +static void coverageCallback(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) +{ + FILE* f = static_cast(context); + + std::string name; + + if (depth == 0) + name = "
"; + else if (function) + name = std::string(function) + ":" + std::to_string(linedefined); + else + name = ":" + std::to_string(linedefined); + + fprintf(f, "FN:%d,%s\n", linedefined, name.c_str()); + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + { + fprintf(f, "FNDA:%d,%s\n", hits[i], name.c_str()); + break; + } + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + fprintf(f, "DA:%d,%d\n", int(i), hits[i]); +} + +void coverageDump(const char* path) +{ + lua_State* L = gCoverage.L; + + FILE* f = fopen(path, "w"); + if (!f) + { + fprintf(stderr, "Error opening coverage %s\n", path); + return; + } + + fprintf(f, "TN:\n"); + + for (int fref : gCoverage.functions) + { + lua_getref(L, fref); + + lua_Debug ar = {}; + lua_getinfo(L, -1, "s", &ar); + + fprintf(f, "SF:%s\n", ar.short_src); + lua_getcoverage(L, -1, f, coverageCallback); + fprintf(f, "end_of_record\n"); + + lua_pop(L, 1); + } + + fclose(f); + + printf("Coverage dump written to %s (%d functions)\n", path, int(gCoverage.functions.size())); +} diff --git a/CLI/Coverage.h b/CLI/Coverage.h new file mode 100644 index 00000000..74be4e5c --- /dev/null +++ b/CLI/Coverage.h @@ -0,0 +1,10 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +struct lua_State; + +void coverageInit(lua_State* L); +bool coverageActive(); + +void coverageTrack(lua_State* L, int funcindex); +void coverageDump(const char* path); diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index 0702b74f..39a14ec7 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -4,8 +4,13 @@ #include "Luau/Common.h" #ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN -#include +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include #else #include #include @@ -67,6 +72,25 @@ std::optional readFile(const std::string& name) if (read != size_t(length)) return std::nullopt; + // Skip first line if it's a shebang + if (length > 2 && result[0] == '#' && result[1] == '!') + result.erase(0, result.find('\n')); + + return result; +} + +std::optional readStdin() +{ + std::string result; + char buffer[4096] = {}; + + while (fgets(buffer, sizeof(buffer), stdin) != nullptr) + result.append(buffer); + + // If eof was not reached for stdin, then a read error occurred + if (!feof(stdin)) + return std::nullopt; + return result; } @@ -142,6 +166,7 @@ static bool traverseDirectoryRec(const std::string& path, const std::function getParentPath(const std::string& path) return std::nullopt; #endif - std::string::size_type slash = path.find_last_of("\\/", path.size() - 1); + size_t slash = path.find_last_of("\\/", path.size() - 1); if (slash == 0) return "/"; @@ -222,3 +250,42 @@ std::optional getParentPath(const std::string& path) return ""; } + +static std::string getExtension(const std::string& path) +{ + size_t dot = path.find_last_of(".\\/"); + + if (dot == std::string::npos || path[dot] != '.') + return ""; + + return path.substr(dot); +} + +std::vector getSourceFiles(int argc, char** argv) +{ + std::vector files; + + for (int i = 1; i < argc; ++i) + { + // Treat '-' as a special file whose source is read from stdin + // All other arguments that start with '-' are skipped + if (argv[i][0] == '-' && argv[i][1] != '\0') + continue; + + if (isDirectory(argv[i])) + { + traverseDirectory(argv[i], [&](const std::string& name) { + std::string ext = getExtension(name); + + if (ext == ".lua" || ext == ".luau") + files.push_back(name); + }); + } + else + { + files.push_back(argv[i]); + } + } + + return files; +} diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index f7fbe8af..97471cdc 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -4,11 +4,15 @@ #include #include #include +#include std::optional readFile(const std::string& name); +std::optional readStdin(); bool isDirectory(const std::string& path); bool traverseDirectory(const std::string& path, const std::function& callback); std::string joinPaths(const std::string& lhs, const std::string& rhs); std::optional getParentPath(const std::string& path); + +std::vector getSourceFiles(int argc, char** argv); diff --git a/CLI/Flags.cpp b/CLI/Flags.cpp new file mode 100644 index 00000000..4e261171 --- /dev/null +++ b/CLI/Flags.cpp @@ -0,0 +1,75 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" +#include "Luau/ExperimentalFlags.h" + +#include + +#include +#include + +static void setLuauFlag(std::string_view name, bool state) +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + { + if (name == flag->name) + { + flag->value = state; + return; + } + } + + fprintf(stderr, "Warning: unrecognized flag '%.*s'.\n", int(name.length()), name.data()); +} + +static void setLuauFlags(bool state) +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = state; +} + +void setLuauFlagsDefault() +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0 && !Luau::isFlagExperimental(flag->name)) + flag->value = true; +} + +void setLuauFlags(const char* list) +{ + std::string_view rest = list; + + while (!rest.empty()) + { + size_t ending = rest.find(","); + std::string_view element = rest.substr(0, ending); + + if (size_t separator = element.find('='); separator != std::string_view::npos) + { + std::string_view key = element.substr(0, separator); + std::string_view value = element.substr(separator + 1); + + if (value == "true" || value == "True") + setLuauFlag(key, true); + else if (value == "false" || value == "False") + setLuauFlag(key, false); + else + fprintf(stderr, "Warning: unrecognized value '%.*s' for flag '%.*s'.\n", int(value.length()), value.data(), int(key.length()), + key.data()); + } + else + { + if (element == "true" || element == "True") + setLuauFlags(true); + else if (element == "false" || element == "False") + setLuauFlags(false); + else + setLuauFlag(element, true); + } + + if (ending != std::string_view::npos) + rest.remove_prefix(ending + 1); + else + break; + } +} diff --git a/CLI/Flags.h b/CLI/Flags.h new file mode 100644 index 00000000..8dfb0a29 --- /dev/null +++ b/CLI/Flags.h @@ -0,0 +1,5 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +void setLuauFlagsDefault(); +void setLuauFlags(const char* list); diff --git a/CLI/Profiler.cpp b/CLI/Profiler.cpp index c6d15a7f..d3ad4e99 100644 --- a/CLI/Profiler.cpp +++ b/CLI/Profiler.cpp @@ -82,11 +82,13 @@ static void profilerLoop() if (now - last >= 1.0 / double(gProfiler.frequency)) { - gProfiler.ticks += uint64_t((now - last) * 1e6); + int64_t ticks = int64_t((now - last) * 1e6); + + gProfiler.ticks += ticks; gProfiler.samples++; gProfiler.callbacks->interrupt = profilerTrigger; - last = now; + last += ticks * 1e-6; } else { @@ -110,12 +112,12 @@ void profilerStop() gProfiler.thread.join(); } -void profilerDump(const char* name) +void profilerDump(const char* path) { - FILE* f = fopen(name, "wb"); + FILE* f = fopen(path, "wb"); if (!f) { - fprintf(stderr, "Error opening profile %s\n", name); + fprintf(stderr, "Error opening profile %s\n", path); return; } @@ -129,7 +131,7 @@ void profilerDump(const char* name) fclose(f); - printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", name, double(total) / 1e6, + printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", path, double(total) / 1e6, static_cast(gProfiler.samples.load()), static_cast(gProfiler.data.size())); uint64_t totalgc = 0; diff --git a/CLI/Profiler.h b/CLI/Profiler.h index 0a407e47..67b1acfd 100644 --- a/CLI/Profiler.h +++ b/CLI/Profiler.h @@ -5,4 +5,4 @@ struct lua_State; void profilerStart(lua_State* L, int frequency); void profilerStop(); -void profilerDump(const char* name); \ No newline at end of file +void profilerDump(const char* path); diff --git a/CLI/Reduce.cpp b/CLI/Reduce.cpp new file mode 100644 index 00000000..d24c9874 --- /dev/null +++ b/CLI/Reduce.cpp @@ -0,0 +1,511 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Ast.h" +#include "Luau/Common.h" +#include "Luau/Parser.h" +#include "Luau/Transpiler.h" + +#include "FileUtils.h" + +#include +#include +#include +#include +#include + +#define VERBOSE 0 // 1 - print out commandline invocations. 2 - print out stdout + +#ifdef _WIN32 + +const auto popen = &_popen; +const auto pclose = &_pclose; + +#endif + +using namespace Luau; + +enum class TestResult +{ + BugFound, // We encountered the bug we are trying to isolate + NoBug, // We did not encounter the bug we are trying to isolate +}; + +struct Enqueuer : public AstVisitor +{ + std::queue* queue; + + explicit Enqueuer(std::queue* queue) + : queue(queue) + { + LUAU_ASSERT(queue); + } + + bool visit(AstStatBlock* block) override + { + queue->push(block); + return false; + } +}; + +struct Reducer +{ + Allocator allocator; + AstNameTable nameTable{allocator}; + ParseOptions parseOptions; + + ParseResult parseResult; + AstStatBlock* root; + + std::string tempScriptName; + + std::string appName; + std::vector appArgs; + std::string_view searchText; + + Reducer() + { + parseOptions.captureComments = true; + } + + std::string readLine(FILE* f) + { + std::string line = ""; + char buffer[256]; + while (fgets(buffer, sizeof(buffer), f)) + { + auto len = strlen(buffer); + line += std::string(buffer, len); + if (buffer[len - 1] == '\n') + break; + } + + return line; + } + + void writeTempScript(bool minify = false) + { + std::string source = transpileWithTypes(*root); + + if (minify) + { + size_t pos = 0; + do + { + pos = source.find("\n\n", pos); + if (pos == std::string::npos) + break; + + source.erase(pos, 1); + } while (true); + } + + FILE* f = fopen(tempScriptName.c_str(), "w"); + if (!f) + { + printf("Unable to open temp script to %s\n", tempScriptName.c_str()); + exit(2); + } + + for (const HotComment& comment : parseResult.hotcomments) + fprintf(f, "--!%s\n", comment.content.c_str()); + + auto written = fwrite(source.data(), 1, source.size(), f); + if (written != source.size()) + { + printf("??? %zu %zu\n", written, source.size()); + printf("Unable to write to temp script %s\n", tempScriptName.c_str()); + exit(3); + } + + fclose(f); + } + + int step = 0; + + std::string escape(const std::string& s) + { + std::string result; + result.reserve(s.size() + 20); // guess + result += '"'; + for (char c : s) + { + if (c == '"') + result += '\\'; + result += c; + } + result += '"'; + + return result; + } + + TestResult run() + { + writeTempScript(); + + std::string command = appName + " " + escape(tempScriptName); + for (const auto& arg : appArgs) + command += " " + escape(arg); + +#if VERBOSE >= 1 + printf("running %s\n", command.c_str()); +#endif + + TestResult result = TestResult::NoBug; + + ++step; + printf("Step %4d...\n", step); + + FILE* p = popen(command.c_str(), "r"); + + while (!feof(p)) + { + std::string s = readLine(p); +#if VERBOSE >= 2 + printf("%s", s.c_str()); +#endif + if (std::string::npos != s.find(searchText)) + { + result = TestResult::BugFound; + break; + } + } + + pclose(p); + + return result; + } + + std::vector getNestedStats(AstStat* stat) + { + std::vector result; + + auto append = [&](AstStatBlock* block) { + if (block) + result.insert(result.end(), block->body.data, block->body.data + block->body.size); + }; + + if (auto block = stat->as()) + append(block); + else if (auto ifs = stat->as()) + { + append(ifs->thenbody); + if (ifs->elsebody) + { + if (AstStatBlock* elseBlock = ifs->elsebody->as()) + append(elseBlock); + else if (AstStatIf* elseIf = ifs->elsebody->as()) + { + auto innerStats = getNestedStats(elseIf); + result.insert(end(result), begin(innerStats), end(innerStats)); + } + else + { + printf("AstStatIf's else clause can have more statement types than I thought\n"); + LUAU_ASSERT(0); + } + } + } + else if (auto w = stat->as()) + append(w->body); + else if (auto r = stat->as()) + append(r->body); + else if (auto f = stat->as()) + append(f->body); + else if (auto f = stat->as()) + append(f->body); + else if (auto f = stat->as()) + append(f->func->body); + else if (auto f = stat->as()) + append(f->func->body); + + return result; + } + + // Move new body data into allocator-managed storage so that it's safe to keep around longterm. + AstStat** reallocateStatements(const std::vector& statements) + { + AstStat** newData = static_cast(allocator.allocate(sizeof(AstStat*) * statements.size())); + std::copy(statements.data(), statements.data() + statements.size(), newData); + + return newData; + } + + // Semiopen interval + using Span = std::pair; + + // Generates 'chunks' semiopen spans of equal-ish size to span the indeces running from 0 to 'size' + // Also inverses. + std::vector> generateSpans(size_t size, size_t chunks) + { + if (size <= 1) + return {}; + + LUAU_ASSERT(chunks > 0); + size_t chunkLength = std::max(1, size / chunks); + + std::vector> result; + + auto append = [&result](Span a, Span b) { + if (a.first == a.second && b.first == b.second) + return; + else + result.emplace_back(a, b); + }; + + size_t i = 0; + while (i < size) + { + size_t end = std::min(i + chunkLength, size); + append(Span{0, i}, Span{end, size}); + + i = end; + } + + i = 0; + while (i < size) + { + size_t end = std::min(i + chunkLength, size); + append(Span{i, end}, Span{size, size}); + + i = end; + } + + return result; + } + + // Returns the statements of block within span1 and span2 + // Also has the hokey restriction that span1 must come before span2 + std::vector prunedSpan(AstStatBlock* block, Span span1, Span span2) + { + std::vector result; + + for (size_t i = span1.first; i < span1.second; ++i) + result.push_back(block->body.data[i]); + + for (size_t i = span2.first; i < span2.second; ++i) + result.push_back(block->body.data[i]); + + return result; + } + + // returns true if anything was culled plus the chunk count + std::pair deleteChildStatements(AstStatBlock* block, size_t chunkCount) + { + if (block->body.size == 0) + return {false, chunkCount}; + + do + { + auto permutations = generateSpans(block->body.size, chunkCount); + for (const auto& [span1, span2] : permutations) + { + auto tempStatements = prunedSpan(block, span1, span2); + AstArray backupBody{tempStatements.data(), tempStatements.size()}; + + std::swap(block->body, backupBody); + TestResult result = run(); + if (result == TestResult::BugFound) + { + // The bug still reproduces without the statements we've culled. Commit. + block->body.data = reallocateStatements(tempStatements); + return {true, std::max(2, chunkCount - 1)}; + } + else + { + // The statements we've culled are critical for the reproduction of the bug. + // TODO try promoting its contents into this scope + std::swap(block->body, backupBody); + } + } + + chunkCount *= 2; + } while (chunkCount <= block->body.size); + + return {false, block->body.size}; + } + + bool deleteChildStatements(AstStatBlock* b) + { + bool result = false; + + size_t chunkCount = 2; + while (true) + { + auto [workDone, newChunkCount] = deleteChildStatements(b, chunkCount); + if (workDone) + { + result = true; + chunkCount = newChunkCount; + continue; + } + else + break; + } + + return result; + } + + bool tryPromotingChildStatements(AstStatBlock* b, size_t index) + { + std::vector tempStats(b->body.data, b->body.data + b->body.size); + AstStat* removed = tempStats.at(index); + tempStats.erase(begin(tempStats) + index); + + std::vector nestedStats = getNestedStats(removed); + tempStats.insert(begin(tempStats) + index, begin(nestedStats), end(nestedStats)); + + AstArray tempArray{tempStats.data(), tempStats.size()}; + std::swap(b->body, tempArray); + + TestResult result = run(); + + if (result == TestResult::BugFound) + { + b->body.data = reallocateStatements(tempStats); + return true; + } + else + { + std::swap(b->body, tempArray); + return false; + } + } + + // We live with some weirdness because I'm kind of lazy: If a statement's + // contents are promoted, we try promoting those prometed statements right + // away. I don't think it matters: If we can delete a statement and still + // exhibit the bug, we should do so. The order isn't so important. + bool tryPromotingChildStatements(AstStatBlock* b) + { + size_t i = 0; + while (i < b->body.size) + { + bool promoted = tryPromotingChildStatements(b, i); + if (!promoted) + ++i; + } + + return false; + } + + void walk(AstStatBlock* block) + { + std::queue queue; + Enqueuer enqueuer{&queue}; + + queue.push(block); + + while (!queue.empty()) + { + AstStatBlock* b = queue.front(); + queue.pop(); + + bool result = false; + do + { + result = deleteChildStatements(b); + + /* Try other reductions here before we walk into child statements + * Other reductions to try someday: + * + * Promoting a statement's children to the enclosing block. + * Deleting type annotations + * Deleting parts of type annotations + * Replacing subexpressions with ({} :: any) + * Inlining type aliases + * Inlining constants + * Inlining functions + */ + result |= tryPromotingChildStatements(b); + } while (result); + + for (AstStat* stat : b->body) + stat->visit(&enqueuer); + } + } + + void run(const std::string scriptName, const std::string appName, const std::vector& appArgs, std::string_view source, + std::string_view searchText) + { + tempScriptName = scriptName; + if (tempScriptName.substr(tempScriptName.size() - 4) == ".lua") + { + tempScriptName.erase(tempScriptName.size() - 4); + tempScriptName += "-reduced.lua"; + } + else + { + this->tempScriptName = scriptName + "-reduced"; + } + +#if 0 + // Handy debugging trick: VS Code will update its view of the file in realtime as it is edited. + std::string wheee = "code " + tempScriptName; + system(wheee.c_str()); +#endif + + printf("Temp script: %s\n", tempScriptName.c_str()); + + this->appName = appName; + this->appArgs = appArgs; + this->searchText = searchText; + + parseResult = Parser::parse(source.data(), source.size(), nameTable, allocator, parseOptions); + if (!parseResult.errors.empty()) + { + printf("Parse errors\n"); + exit(1); + } + + root = parseResult.root; + + const TestResult initialResult = run(); + if (initialResult == TestResult::NoBug) + { + printf("Could not find failure string in the unmodified script! Check your commandline arguments\n"); + exit(2); + } + + walk(root); + + writeTempScript(/* minify */ true); + + printf("Done! Check %s\n", tempScriptName.c_str()); + } +}; + +[[noreturn]] void help(const std::vector& args) +{ + printf("Syntax: %s script application \"search text\" [arguments]\n", args[0].data()); + exit(1); +} + +int main(int argc, char** argv) +{ + const std::vector args(argv, argv + argc); + + if (args.size() < 4) + help(args); + + for (int i = 1; i < args.size(); ++i) + { + if (args[i] == "--help") + help(args); + } + + const std::string scriptName = argv[1]; + const std::string appName = argv[2]; + const std::string searchText = argv[3]; + const std::vector appArgs(begin(args) + 4, end(args)); + + std::optional source = readFile(scriptName); + + if (!source) + { + printf("Could not read source %s\n", argv[1]); + exit(1); + } + + Reducer reducer; + reducer.run(scriptName, appName, appArgs, *source, searchText); +} diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 6baa21ea..e567725e 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -1,18 +1,108 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Repl.h" + #include "lua.h" #include "lualib.h" +#include "Luau/CodeGen.h" #include "Luau/Compiler.h" #include "Luau/BytecodeBuilder.h" #include "Luau/Parser.h" +#include "Coverage.h" #include "FileUtils.h" +#include "Flags.h" #include "Profiler.h" -#include "linenoise.hpp" +#include "isocline.h" #include +#ifdef _WIN32 +#include +#include + +#define WIN32_LEAN_AND_MEAN +#include +#endif + +#ifdef CALLGRIND +#include +#endif + +#include +#include + +LUAU_FASTFLAG(DebugLuauTimeTracing) + +enum class CliMode +{ + Unknown, + Repl, + Compile, + RunSourceFiles +}; + +enum class CompileFormat +{ + Text, + Binary, + Remarks, + Codegen, + CodegenVerbose, + CodegenNull, + Null +}; + +constexpr int MaxTraversalLimit = 50; + +static bool codegen = false; + +// Ctrl-C handling +static void sigintCallback(lua_State* L, int gc) +{ + if (gc >= 0) + return; + + lua_callbacks(L)->interrupt = NULL; + + lua_rawcheckstack(L, 1); // reserve space for error string + luaL_error(L, "Execution interrupted"); +} + +static lua_State* replState = NULL; + +#ifdef _WIN32 +BOOL WINAPI sigintHandler(DWORD signal) +{ + if (signal == CTRL_C_EVENT && replState) + lua_callbacks(replState)->interrupt = &sigintCallback; + return TRUE; +} +#else +static void sigintHandler(int signum) +{ + if (signum == SIGINT && replState) + lua_callbacks(replState)->interrupt = &sigintCallback; +} +#endif + +struct GlobalOptions +{ + int optimizationLevel = 1; + int debugLevel = 1; +} globalOptions; + +static Luau::CompileOptions copts() +{ + Luau::CompileOptions result = {}; + result.optimizationLevel = globalOptions.optimizationLevel; + result.debugLevel = globalOptions.debugLevel; + result.coverageLevel = coverageActive() ? 2 : 0; + + return result; +} + static int lua_loadstring(lua_State* L) { size_t l = 0; @@ -21,13 +111,13 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); - std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + std::string bytecode = Luau::compile(std::string(s, l), copts()); + if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0) return 1; lua_pushnil(L); - lua_insert(L, -2); /* put before error message */ - return 2; /* return nil plus error message */ + lua_insert(L, -2); // put before error message + return 2; // return nil plus error message } static int finishrequire(lua_State* L) @@ -48,14 +138,23 @@ static int lua_require(lua_State* L) // return the module from the cache lua_getfield(L, -1, name.c_str()); if (!lua_isnil(L, -1)) + { + // L stack: _MODULES result return finishrequire(L); + } + lua_pop(L, 1); - std::optional source = readFile(name + ".lua"); + std::optional source = readFile(name + ".luau"); if (!source) - luaL_argerrorL(L, 1, ("error loading " + name).c_str()); + { + source = readFile(name + ".lua"); // try .lua if .luau doesn't exist + if (!source) + luaL_argerrorL(L, 1, ("error loading " + name).c_str()); // if neither .luau nor .lua exist, we have an error + } // module needs to run in a new thread, isolated from the rest + // note: we create ML on main thread so that it doesn't inherit environment of L lua_State* GL = lua_mainthread(L); lua_State* ML = lua_newthread(GL); lua_xmove(GL, L, 1); @@ -64,9 +163,15 @@ static int lua_require(lua_State* L) luaL_sandboxthread(ML); // now we can compile & run module on the new thread - std::string bytecode = Luau::compile(*source); - if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + std::string bytecode = Luau::compile(*source, copts()); + if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (codegen) + Luau::CodeGen::compile(ML, -1); + + if (coverageActive()) + coverageTrack(ML, -1); + int status = lua_resume(ML, L, 0); if (status == 0) @@ -86,11 +191,12 @@ static int lua_require(lua_State* L) } } - // there's now a return value on top of ML; stack of L is MODULES thread + // there's now a return value on top of ML; L stack: _MODULES ML lua_xmove(ML, L, 1); lua_pushvalue(L, -1); lua_setfield(L, -4, name.c_str()); + // L stack: _MODULES ML result return finishrequire(L); } @@ -114,14 +220,50 @@ static int lua_collectgarbage(lua_State* L) luaL_error(L, "collectgarbage must be called with 'count' or 'collect'"); } -static void setupState(lua_State* L) +#ifdef CALLGRIND +static int lua_callgrind(lua_State* L) { + const char* option = luaL_checkstring(L, 1); + + if (strcmp(option, "running") == 0) + { + int r = RUNNING_ON_VALGRIND; + lua_pushboolean(L, r); + return 1; + } + + if (strcmp(option, "zero") == 0) + { + CALLGRIND_ZERO_STATS; + return 0; + } + + if (strcmp(option, "dump") == 0) + { + const char* name = luaL_checkstring(L, 2); + + CALLGRIND_DUMP_STATS_AT(name); + return 0; + } + + luaL_error(L, "callgrind must be called with one of 'running', 'zero', 'dump'"); +} +#endif + +void setupState(lua_State* L) +{ + if (codegen) + Luau::CodeGen::create(L); + luaL_openlibs(L); static const luaL_Reg funcs[] = { {"loadstring", lua_loadstring}, {"require", lua_require}, {"collectgarbage", lua_collectgarbage}, +#ifdef CALLGRIND + {"callgrind", lua_callgrind}, +#endif {NULL, NULL}, }; @@ -132,11 +274,11 @@ static void setupState(lua_State* L) luaL_sandbox(L); } -static std::string runCode(lua_State* L, const std::string& source) +std::string runCode(lua_State* L, const std::string& source) { - std::string bytecode = Luau::compile(source); + std::string bytecode = Luau::compile(source, copts()); - if (luau_load(L, "=stdin", bytecode.data(), bytecode.size()) != 0) + if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) { size_t len; const char* msg = lua_tolstring(L, -1, &len); @@ -147,6 +289,9 @@ static std::string runCode(lua_State* L, const std::string& source) return error; } + if (codegen) + Luau::CodeGen::compile(L, -1); + lua_State* T = lua_newthread(L); lua_pushvalue(L, -2); @@ -162,14 +307,30 @@ static std::string runCode(lua_State* L, const std::string& source) if (n) { luaL_checkstack(T, LUA_MINSTACK, "too many results to print"); - lua_getglobal(T, "print"); + lua_getglobal(T, "_PRETTYPRINT"); + // If _PRETTYPRINT is nil, then use the standard print function instead + if (lua_isnil(T, -1)) + { + lua_pop(T, 1); + lua_getglobal(T, "print"); + } lua_insert(T, 1); lua_pcall(T, n, 0, 0); } } else { - std::string error = (status == LUA_YIELD) ? "thread yielded unexpectedly" : lua_tostring(T, -1); + std::string error; + + if (status == LUA_YIELD) + { + error = "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(T, -1)) + { + error = str; + } + error += "\nstack backtrace:\n"; error += lua_debugtrace(T); @@ -180,95 +341,225 @@ static std::string runCode(lua_State* L, const std::string& source) return std::string(); } -static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) +// Replaces the top of the lua stack with the metatable __index for the value +// if it exists. Returns true iff __index exists. +static bool tryReplaceTopWithIndex(lua_State* L) { - std::string_view lookup = editBuffer + start; - - for (;;) + if (luaL_getmetafield(L, -1, "__index")) { - size_t dot = lookup.find('.'); - std::string_view prefix = lookup.substr(0, dot); + // Remove the table leaving __index on the top of stack + lua_remove(L, -2); + return true; + } + return false; +} - if (dot == std::string_view::npos) + +// This function is similar to lua_gettable, but it avoids calling any +// lua callback functions (e.g. __index) which might modify the Lua VM state. +static void safeGetTable(lua_State* L, int tableIndex) +{ + lua_pushvalue(L, tableIndex); // Duplicate the table + + // The loop invariant is that the table to search is at -1 + // and the key is at -2. + for (int loopCount = 0;; loopCount++) + { + lua_pushvalue(L, -2); // Duplicate the key + lua_rawget(L, -2); // Try to find the key + if (!lua_isnil(L, -1) || loopCount >= MaxTraversalLimit) { - // table, key - lua_pushnil(L); + // Either the key has been found, and/or we have reached the max traversal limit + break; + } + else + { + lua_pop(L, 1); // Pop the nil result + if (!luaL_getmetafield(L, -1, "__index")) + { + lua_pushnil(L); + break; + } + else if (lua_istable(L, -1)) + { + // Replace the current table being searched with __index table + lua_replace(L, -2); + } + else + { + lua_pop(L, 1); // Pop the value + lua_pushnil(L); + break; + } + } + } - while (lua_next(L, -2) != 0) + lua_remove(L, -2); // Remove the table + lua_remove(L, -2); // Remove the original key +} + +// completePartialMatches finds keys that match the specified 'prefix' +// Note: the table/object to be searched must be on the top of the Lua stack +static void completePartialMatches(lua_State* L, bool completeOnlyFunctions, const std::string& editBuffer, std::string_view prefix, + const AddCompletionCallback& addCompletionCallback) +{ + for (int i = 0; i < MaxTraversalLimit && lua_istable(L, -1); i++) + { + // table, key + lua_pushnil(L); + + // Loop over all the keys in the current table + while (lua_next(L, -2) != 0) + { + if (lua_type(L, -2) == LUA_TSTRING) { // table, key, value std::string_view key = lua_tostring(L, -2); + int valueType = lua_type(L, -1); - if (!key.empty() && Luau::startsWith(key, prefix)) - completions.push_back(editBuffer + std::string(key.substr(prefix.size()))); + // If the last separator was a ':' (i.e. a method call) then only functions should be completed. + bool requiredValueType = (!completeOnlyFunctions || valueType == LUA_TFUNCTION); - lua_pop(L, 1); + if (!key.empty() && requiredValueType && Luau::startsWith(key, prefix)) + { + std::string completedComponent(key.substr(prefix.size())); + std::string completion(editBuffer + completedComponent); + if (valueType == LUA_TFUNCTION) + { + // Add an opening paren for function calls by default. + completion += "("; + } + addCompletionCallback(completion, std::string(key)); + } } + lua_pop(L, 1); + } + // Replace the current table being searched with an __index table if one exists + if (!tryReplaceTopWithIndex(L)) + { + break; + } + } +} + +static void completeIndexer(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback) +{ + std::string_view lookup = editBuffer; + bool completeOnlyFunctions = false; + + // Push the global variable table to begin the search + lua_pushvalue(L, LUA_GLOBALSINDEX); + + for (;;) + { + size_t sep = lookup.find_first_of(".:"); + std::string_view prefix = lookup.substr(0, sep); + + if (sep == std::string_view::npos) + { + completePartialMatches(L, completeOnlyFunctions, editBuffer, prefix, addCompletionCallback); break; } else { // find the key in the table lua_pushlstring(L, prefix.data(), prefix.size()); - lua_rawget(L, -2); + safeGetTable(L, -2); lua_remove(L, -2); - if (lua_isnil(L, -1)) + if (lua_istable(L, -1) || tryReplaceTopWithIndex(L)) + { + completeOnlyFunctions = lookup[sep] == ':'; + lookup.remove_prefix(sep + 1); + } + else + { + // Unable to search for keys, so stop searching break; - - lookup.remove_prefix(dot + 1); + } } } lua_pop(L, 1); } -static void completeRepl(lua_State* L, const char* editBuffer, std::vector& completions) +void getCompletions(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback) { - size_t start = strlen(editBuffer); - while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.')) - start--; - - // look the value up in current global table first - lua_pushvalue(L, LUA_GLOBALSINDEX); - completeIndexer(L, editBuffer, start, completions); - - // and in actual global table after that - lua_getglobal(L, "_G"); - completeIndexer(L, editBuffer, start, completions); + completeIndexer(L, editBuffer, addCompletionCallback); } -static void runRepl() +static void icGetCompletions(ic_completion_env_t* cenv, const char* editBuffer) { - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); + auto* L = reinterpret_cast(ic_completion_arg(cenv)); - setupState(L); - - luaL_sandboxthread(L); - - linenoise::SetCompletionCallback([L](const char* editBuffer, std::vector& completions) { - completeRepl(L, editBuffer, completions); + getCompletions(L, std::string(editBuffer), [cenv](const std::string& completion, const std::string& display) { + ic_add_completion_ex(cenv, completion.data(), display.data(), nullptr); }); +} + +static bool isMethodOrFunctionChar(const char* s, long len) +{ + char c = *s; + return len == 1 && (isalnum(c) || c == '.' || c == ':' || c == '_'); +} + +static void completeRepl(ic_completion_env_t* cenv, const char* editBuffer) +{ + ic_complete_word(cenv, editBuffer, icGetCompletions, isMethodOrFunctionChar); +} + +static void loadHistory(const char* name) +{ + std::string path; + + if (const char* home = getenv("HOME")) + { + path = joinPaths(home, name); + } + else if (const char* userProfile = getenv("USERPROFILE")) + { + path = joinPaths(userProfile, name); + } + + if (!path.empty()) + ic_set_history(path.c_str(), -1 /* default entries (= 200) */); +} + +static void runReplImpl(lua_State* L) +{ + ic_set_default_completer(completeRepl, L); + + // Reset the locale to C + setlocale(LC_ALL, "C"); + + // Make brace matching easier to see + ic_style_def("ic-bracematch", "teal"); + + // Prevent auto insertion of braces + ic_enable_brace_insertion(false); + + // Loads history from the given file; isocline automatically saves the history on process exit + loadHistory(".luau_history"); std::string buffer; for (;;) { - bool quit = false; - std::string line = linenoise::Readline(buffer.empty() ? "> " : ">> ", quit); - if (quit) + const char* prompt = buffer.empty() ? "" : ">"; + std::unique_ptr line(ic_readline(prompt), free); + if (!line) break; - if (buffer.empty() && runCode(L, std::string("return ") + line) == std::string()) + if (buffer.empty() && runCode(L, std::string("return ") + line.get()) == std::string()) { - linenoise::AddHistory(line.c_str()); + ic_history_add(line.get()); continue; } - buffer += line; - buffer += " "; // linenoise doesn't work very well with multiline history entries + if (!buffer.empty()) + buffer += "\n"; + buffer += line.get(); std::string error = runCode(L, buffer); @@ -282,12 +573,32 @@ static void runRepl() fprintf(stdout, "%s\n", error.c_str()); } - linenoise::AddHistory(buffer.c_str()); + ic_history_add(buffer.c_str()); buffer.clear(); } } -static bool runFile(const char* name, lua_State* GL) +static void runRepl() +{ + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + setupState(L); + + // setup Ctrl+C handling + replState = L; +#ifdef _WIN32 + SetConsoleCtrlHandler(sigintHandler, TRUE); +#else + signal(SIGINT, sigintHandler); +#endif + + luaL_sandboxthread(L); + runReplImpl(L); +} + +// `repl` is used it indicate if a repl should be started after executing the file. +static bool runFile(const char* name, lua_State* GL, bool repl) { std::optional source = readFile(name); if (!source) @@ -304,11 +615,17 @@ static bool runFile(const char* name, lua_State* GL) std::string chunkname = "=" + std::string(name); - std::string bytecode = Luau::compile(*source); + std::string bytecode = Luau::compile(*source, copts()); int status = 0; - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (codegen) + Luau::CodeGen::compile(L, -1); + + if (coverageActive()) + coverageTrack(L, -1); + status = lua_resume(L, NULL, 0); } else @@ -316,19 +633,31 @@ static bool runFile(const char* name, lua_State* GL) status = LUA_ERRSYNTAX; } - if (status == 0) + if (status != 0) { - return true; - } - else - { - std::string error = (status == LUA_YIELD) ? "thread yielded unexpectedly" : lua_tostring(L, -1); + std::string error; + + if (status == LUA_YIELD) + { + error = "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(L, -1)) + { + error = str; + } + error += "\nstacktrace:\n"; error += lua_debugtrace(L); fprintf(stderr, "%s", error.c_str()); - return false; } + + if (repl) + { + runReplImpl(L); + } + lua_pop(GL, 1); + return status == 0; } static void report(const char* name, const Luau::Location& location, const char* type, const char* message) @@ -346,7 +675,33 @@ static void reportError(const char* name, const Luau::CompileError& error) report(name, error.getLocation(), "CompileError", error.what()); } -static bool compileFile(const char* name) +static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options) +{ + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) + return Luau::CodeGen::getAssembly(L, -1, options); + + fprintf(stderr, "Error loading bytecode %s\n", name); + return ""; +} + +static void annotateInstruction(void* context, std::string& text, int fid, int instpos) +{ + Luau::BytecodeBuilder& bcb = *(Luau::BytecodeBuilder*)context; + + bcb.annotateInstruction(text, fid, instpos); +} + +struct CompileStats +{ + size_t lines; + size_t bytecode; + size_t codegen; +}; + +static bool compileFile(const char* name, CompileFormat format, CompileStats& stats) { std::optional source = readFile(name); if (!source) @@ -355,15 +710,65 @@ static bool compileFile(const char* name) return false; } + // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) + // This function is much more complicated because it supports many output human-readable formats through internal interfaces + try { Luau::BytecodeBuilder bcb; - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); - bcb.setDumpSource(*source); + Luau::CodeGen::AssemblyOptions options = {format == CompileFormat::CodegenNull, format == CompileFormat::Codegen, annotateInstruction, &bcb}; - Luau::compileOrThrow(bcb, *source); + if (format == CompileFormat::Text) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } + else if (format == CompileFormat::Remarks) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } + else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenVerbose) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(*source); + } - printf("%s", bcb.dumpEverything().c_str()); + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); + + if (!result.errors.empty()) + throw Luau::ParseErrors(result.errors); + + stats.lines += result.lines; + + Luau::compileOrThrow(bcb, result, names, copts()); + stats.bytecode += bcb.getBytecode().size(); + + switch (format) + { + case CompileFormat::Text: + printf("%s", bcb.dumpEverything().c_str()); + break; + case CompileFormat::Remarks: + printf("%s", bcb.dumpSourceRemarks().c_str()); + break; + case CompileFormat::Binary: + fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); + break; + case CompileFormat::Codegen: + case CompileFormat::CodegenVerbose: + printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); + break; + case CompileFormat::CodegenNull: + stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); + break; + case CompileFormat::Null: + break; + } return true; } @@ -388,98 +793,220 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available modes:\n"); printf(" omitted: compile and run input files one by one\n"); - printf(" --compile: compile input files and output resulting bytecode\n"); + printf(" --compile[=format]: compile input files and output resulting bytecode/assembly (binary, text, remarks, codegen)\n"); printf("\n"); printf("Available options:\n"); + printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); + printf(" -h, --help: Display this usage message.\n"); + printf(" -i, --interactive: Run an interactive REPL after executing the last script specified.\n"); + printf(" -O: compile with optimization level n (default 1, n should be between 0 and 2).\n"); + printf(" -g: compile with debug level n (default 1, n should be between 0 and 2).\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); + printf(" --codegen: execute code using native code generation\n"); } -static int assertionHandler(const char* expr, const char* file, int line) +static int assertionHandler(const char* expr, const char* file, int line, const char* function) { printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); return 1; } -int main(int argc, char** argv) +int replMain(int argc, char** argv) { Luau::assertHandler() = assertionHandler; - for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) - if (strncmp(flag->name, "Luau", 4) == 0) - flag->value = true; + setLuauFlagsDefault(); - if (argc == 1) + CliMode mode = CliMode::Unknown; + CompileFormat compileFormat{}; + int profile = 0; + bool coverage = false; + bool interactive = false; + + // Set the mode if the user has explicitly specified one. + int argStart = 1; + if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) + { + argStart++; + mode = CliMode::Compile; + if (strcmp(argv[1], "--compile") == 0) + { + compileFormat = CompileFormat::Text; + } + else if (strcmp(argv[1], "--compile=binary") == 0) + { + compileFormat = CompileFormat::Binary; + } + else if (strcmp(argv[1], "--compile=text") == 0) + { + compileFormat = CompileFormat::Text; + } + else if (strcmp(argv[1], "--compile=remarks") == 0) + { + compileFormat = CompileFormat::Remarks; + } + else if (strcmp(argv[1], "--compile=codegen") == 0) + { + compileFormat = CompileFormat::Codegen; + } + else if (strcmp(argv[1], "--compile=codegenverbose") == 0) + { + compileFormat = CompileFormat::CodegenVerbose; + } + else if (strcmp(argv[1], "--compile=codegennull") == 0) + { + compileFormat = CompileFormat::CodegenNull; + } + else if (strcmp(argv[1], "--compile=null") == 0) + { + compileFormat = CompileFormat::Null; + } + else + { + fprintf(stderr, "Error: Unrecognized value for '--compile' specified.\n"); + return 1; + } + } + + for (int i = argStart; i < argc; i++) + { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + else if (strcmp(argv[i], "-i") == 0 || strcmp(argv[i], "--interactive") == 0) + { + interactive = true; + } + else if (strncmp(argv[i], "-O", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Optimization level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.optimizationLevel = level; + } + else if (strncmp(argv[i], "-g", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Debug level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.debugLevel = level; + } + else if (strcmp(argv[i], "--profile") == 0) + { + profile = 10000; // default to 10 KHz + } + else if (strncmp(argv[i], "--profile=", 10) == 0) + { + profile = atoi(argv[i] + 10); + } + else if (strcmp(argv[i], "--codegen") == 0) + { + codegen = true; + } + else if (strcmp(argv[i], "--coverage") == 0) + { + coverage = true; + } + else if (strcmp(argv[i], "--timetrace") == 0) + { + FFlag::DebugLuauTimeTracing.value = true; + } + else if (strncmp(argv[i], "--fflags=", 9) == 0) + { + setLuauFlags(argv[i] + 9); + } + else if (argv[i][0] == '-') + { + fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); + displayHelp(argv[0]); + return 1; + } + } + +#if !defined(LUAU_ENABLE_TIME_TRACE) + if (FFlag::DebugLuauTimeTracing) + { + fprintf(stderr, "To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; + } +#endif + +#if !LUA_CUSTOM_EXECUTION + if (codegen) + { + fprintf(stderr, "To run with --codegen, Luau has to be built with LUA_CUSTOM_EXECUTION enabled\n"); + return 1; + } +#endif + + const std::vector files = getSourceFiles(argc, argv); + if (mode == CliMode::Unknown) + { + mode = files.empty() ? CliMode::Repl : CliMode::RunSourceFiles; + } + + if (mode != CliMode::Compile && codegen && !Luau::CodeGen::isSupported()) + { + fprintf(stderr, "Cannot enable --codegen, native code generation is not supported in current configuration\n"); + return 1; + } + + switch (mode) + { + case CliMode::Compile: + { +#ifdef _WIN32 + if (compileFormat == CompileFormat::Binary) + _setmode(_fileno(stdout), _O_BINARY); +#endif + + CompileStats stats = {}; + int failed = 0; + + for (const std::string& path : files) + failed += !compileFile(path.c_str(), compileFormat, stats); + + if (compileFormat == CompileFormat::Null) + printf("Compiled %d KLOC into %d KB bytecode\n", int(stats.lines / 1000), int(stats.bytecode / 1024)); + else if (compileFormat == CompileFormat::CodegenNull) + printf("Compiled %d KLOC into %d KB bytecode => %d KB native code\n", int(stats.lines / 1000), int(stats.bytecode / 1024), + int(stats.codegen / 1024)); + + return failed ? 1 : 0; + } + case CliMode::Repl: { runRepl(); return 0; } - - if (argc >= 2 && strcmp(argv[1], "--help") == 0) - { - displayHelp(argv[0]); - return 0; - } - - if (argc >= 2 && strcmp(argv[1], "--compile") == 0) - { - int failed = 0; - - for (int i = 2; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !compileFile(name.c_str()); - }); - } - else - { - failed += !compileFile(argv[i]); - } - } - - return failed; - } - + case CliMode::RunSourceFiles: { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); setupState(L); - int profile = 0; - - for (int i = 1; i < argc; ++i) - if (strcmp(argv[i], "--profile") == 0) - profile = 10000; // default to 10 KHz - else if (strncmp(argv[i], "--profile=", 10) == 0) - profile = atoi(argv[i] + 10); - if (profile) profilerStart(L, profile); + if (coverage) + coverageInit(L); + int failed = 0; - for (int i = 1; i < argc; ++i) + for (size_t i = 0; i < files.size(); ++i) { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !runFile(name.c_str(), L); - }); - } - else - { - failed += !runFile(argv[i], L); - } + bool isLastFile = i == files.size() - 1; + failed += !runFile(files[i].c_str(), L, interactive && isLastFile); } if (profile) @@ -488,8 +1015,14 @@ int main(int argc, char** argv) profilerDump("profile.out"); } - return failed; + if (coverage) + coverageDump("coverage.out"); + + return failed ? 1 : 0; + } + case CliMode::Unknown: + default: + LUAU_ASSERT(!"Unhandled cli mode."); + return 1; } } - - diff --git a/CLI/Repl.h b/CLI/Repl.h new file mode 100644 index 00000000..cd54b7e0 --- /dev/null +++ b/CLI/Repl.h @@ -0,0 +1,17 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "lua.h" + +#include +#include + +using AddCompletionCallback = std::function; + +// Note: These are internal functions which are being exposed in a header +// so they can be included by unit tests. +void setupState(lua_State* L); +std::string runCode(lua_State* L, const std::string& source); +void getCompletions(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback); + +int replMain(int argc, char** argv); diff --git a/CLI/ReplEntry.cpp b/CLI/ReplEntry.cpp new file mode 100644 index 00000000..8543e3f7 --- /dev/null +++ b/CLI/ReplEntry.cpp @@ -0,0 +1,7 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Repl.h" + +int main(int argc, char** argv) +{ + return replMain(argc, argv); +} diff --git a/CLI/Web.cpp b/CLI/Web.cpp new file mode 100644 index 00000000..416a79f2 --- /dev/null +++ b/CLI/Web.cpp @@ -0,0 +1,114 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" +#include "luacode.h" + +#include "Luau/Common.h" + +#include + +#include + +static void setupState(lua_State* L) +{ + luaL_openlibs(L); + + luaL_sandbox(L); +} + +static std::string runCode(lua_State* L, const std::string& source) +{ + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.length(), nullptr, &bytecodeSize); + int result = luau_load(L, "=stdin", bytecode, bytecodeSize, 0); + free(bytecode); + + if (result != 0) + { + size_t len; + const char* msg = lua_tolstring(L, -1, &len); + + std::string error(msg, len); + lua_pop(L, 1); + + return error; + } + + lua_State* T = lua_newthread(L); + + lua_pushvalue(L, -2); + lua_remove(L, -3); + lua_xmove(L, T, 1); + + int status = lua_resume(T, NULL, 0); + + if (status == 0) + { + int n = lua_gettop(T); + + if (n) + { + luaL_checkstack(T, LUA_MINSTACK, "too many results to print"); + lua_getglobal(T, "print"); + lua_insert(T, 1); + lua_pcall(T, n, 0, 0); + } + + lua_pop(L, 1); // pop T + return std::string(); + } + else + { + std::string error; + + lua_Debug ar; + if (lua_getinfo(L, 0, "sln", &ar)) + { + error += ar.short_src; + error += ':'; + error += std::to_string(ar.currentline); + error += ": "; + } + + if (status == LUA_YIELD) + { + error += "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(T, -1)) + { + error += str; + } + + error += "\nstack backtrace:\n"; + error += lua_debugtrace(T); + + lua_pop(L, 1); // pop T + return error; + } +} + +extern "C" const char* executeScript(const char* source) +{ + // setup flags + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + // create new state + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // setup state + setupState(L); + + // sandbox thread + luaL_sandboxthread(L); + + // static string for caching result (prevents dangling ptr on function exit) + static std::string result; + + // run code + collect error + result = runCode(L, source); + + return result.empty() ? NULL : result.c_str(); +} diff --git a/CMakeLists.txt b/CMakeLists.txt index d6598f2a..05d701ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,33 +5,63 @@ if(EXT_PLATFORM_STRING) endif() cmake_minimum_required(VERSION 3.0) -project(Luau LANGUAGES CXX) option(LUAU_BUILD_CLI "Build CLI" ON) option(LUAU_BUILD_TESTS "Build tests" ON) +option(LUAU_BUILD_WEB "Build Web module" OFF) +option(LUAU_WERROR "Warnings as errors" OFF) +option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) +option(LUAU_EXTERN_C "Use extern C for all APIs" OFF) +option(LUAU_NATIVE "Enable support for native code generation" OFF) +if(LUAU_STATIC_CRT) + cmake_minimum_required(VERSION 3.15) + cmake_policy(SET CMP0091 NEW) + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") +endif() + +project(Luau LANGUAGES CXX C) +add_library(Luau.Common INTERFACE) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) add_library(Luau.Analysis STATIC) +add_library(Luau.CodeGen STATIC) add_library(Luau.VM STATIC) +add_library(isocline STATIC) if(LUAU_BUILD_CLI) add_executable(Luau.Repl.CLI) add_executable(Luau.Analyze.CLI) + add_executable(Luau.Ast.CLI) + add_executable(Luau.Reduce.CLI) # This also adds target `name` on Linux/macOS and `name.exe` on Windows set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) + set_target_properties(Luau.Ast.CLI PROPERTIES OUTPUT_NAME luau-ast) + set_target_properties(Luau.Reduce.CLI PROPERTIES OUTPUT_NAME luau-reduce) endif() if(LUAU_BUILD_TESTS) add_executable(Luau.UnitTest) add_executable(Luau.Conformance) + add_executable(Luau.CLI.Test) endif() + +if(LUAU_BUILD_WEB) + add_executable(Luau.Web) +endif() + +# Proxy target to make it possible to depend on private VM headers +add_library(Luau.VM.Internals INTERFACE) + include(Sources.cmake) +target_include_directories(Luau.Common INTERFACE Common/include) + target_compile_features(Luau.Ast PUBLIC cxx_std_17) target_include_directories(Luau.Ast PUBLIC Ast/include) +target_link_libraries(Luau.Ast PUBLIC Luau.Common) target_compile_features(Luau.Compiler PUBLIC cxx_std_17) target_include_directories(Luau.Compiler PUBLIC Compiler/include) @@ -41,48 +71,162 @@ target_compile_features(Luau.Analysis PUBLIC cxx_std_17) target_include_directories(Luau.Analysis PUBLIC Analysis/include) target_link_libraries(Luau.Analysis PUBLIC Luau.Ast) +target_compile_features(Luau.CodeGen PRIVATE cxx_std_17) +target_include_directories(Luau.CodeGen PUBLIC CodeGen/include) +target_link_libraries(Luau.CodeGen PRIVATE Luau.VM Luau.VM.Internals) # Code generation needs VM internals +target_link_libraries(Luau.CodeGen PUBLIC Luau.Common) + target_compile_features(Luau.VM PRIVATE cxx_std_11) target_include_directories(Luau.VM PUBLIC VM/include) +target_link_libraries(Luau.VM PUBLIC Luau.Common) + +target_include_directories(isocline PUBLIC extern/isocline/include) + +target_include_directories(Luau.VM.Internals INTERFACE VM/src) set(LUAU_OPTIONS) if(MSVC) list(APPEND LUAU_OPTIONS /D_CRT_SECURE_NO_WARNINGS) # We need to use the portable CRT functions. - list(APPEND LUAU_OPTIONS /WX) # Warnings are errors list(APPEND LUAU_OPTIONS /MP) # Distribute single project compilation across multiple cores else() list(APPEND LUAU_OPTIONS -Wall) # All warnings - list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors +endif() - if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - list(APPEND LUAU_OPTIONS -Wno-unused) # GCC considers variables declared/checked in if() as unused +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + # Some gcc versions treat var in `if (type var = val)` as unused + # Some gcc versions treat variables used in constexpr if blocks as unused + list(APPEND LUAU_OPTIONS -Wno-unused) +endif() + +# Enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere +if(LUAU_WERROR) + if(MSVC) + list(APPEND LUAU_OPTIONS /WX) # Warnings are errors + else() + list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors endif() endif() +if(LUAU_BUILD_WEB) + # add -fexceptions for emscripten to allow exceptions to be caught in C++ + list(APPEND LUAU_OPTIONS -fexceptions) +endif() + +set(ISOCLINE_OPTIONS) + +if (NOT MSVC) + list(APPEND ISOCLINE_OPTIONS -Wno-unused-function) +endif() + target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) +target_compile_options(Luau.CodeGen PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) +target_compile_options(isocline PRIVATE ${LUAU_OPTIONS} ${ISOCLINE_OPTIONS}) + +if(LUAU_EXTERN_C) + # enable extern "C" for VM (lua.h, lualib.h) and Compiler (luacode.h) to make Luau friendlier to use from non-C++ languages + # note that we enable LUA_USE_LONGJMP=1 as well; otherwise functions like luaL_error will throw C++ exceptions, which can't be done from extern "C" functions + target_compile_definitions(Luau.VM PUBLIC LUA_USE_LONGJMP=1) + target_compile_definitions(Luau.VM PUBLIC LUA_API=extern\"C\") + target_compile_definitions(Luau.Compiler PUBLIC LUACODE_API=extern\"C\") +endif() + +if(LUAU_NATIVE) + target_compile_definitions(Luau.VM PUBLIC LUA_CUSTOM_EXECUTION=1) +endif() + +if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) + # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: + # https://developercommunity.visualstudio.com/t/performance-regression-on-a-complex-interpreter-lo/1631863 + set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) +endif() + +if (NOT MSVC) + # disable support for math_errno which allows compilers to lower sqrt() into a single CPU instruction + target_compile_options(Luau.VM PRIVATE -fno-math-errno) +endif() + +if(MSVC AND LUAU_BUILD_CLI) + # the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger + set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) + set_target_properties(Luau.Repl.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) +endif() + +# embed .natvis inside the library debug information +if(MSVC) + target_link_options(Luau.Ast INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Ast.natvis) + target_link_options(Luau.Analysis INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Analysis.natvis) + target_link_options(Luau.CodeGen INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/CodeGen.natvis) + target_link_options(Luau.VM INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/VM.natvis) +endif() + +# make .natvis visible inside the solution +if(MSVC_IDE) + target_sources(Luau.Ast PRIVATE tools/natvis/Ast.natvis) + target_sources(Luau.Analysis PRIVATE tools/natvis/Analysis.natvis) + target_sources(Luau.CodeGen PRIVATE tools/natvis/CodeGen.natvis) + target_sources(Luau.VM PRIVATE tools/natvis/VM.natvis) +endif() if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) + target_compile_options(Luau.Ast.CLI PRIVATE ${LUAU_OPTIONS}) - target_include_directories(Luau.Repl.CLI PRIVATE extern) - target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) + target_include_directories(Luau.Repl.CLI PRIVATE extern extern/isocline/include) + + target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.CodeGen Luau.VM isocline) if(UNIX) - target_link_libraries(Luau.Repl.CLI PRIVATE pthread) + find_library(LIBPTHREAD pthread) + if (LIBPTHREAD) + target_link_libraries(Luau.Repl.CLI PRIVATE pthread) + endif() endif() target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) + + target_link_libraries(Luau.Ast.CLI PRIVATE Luau.Ast Luau.Analysis) + + target_compile_features(Luau.Reduce.CLI PRIVATE cxx_std_17) + target_include_directories(Luau.Reduce.CLI PUBLIC Reduce/include) + target_link_libraries(Luau.Reduce.CLI PRIVATE Luau.Common Luau.Ast Luau.Analysis) endif() if(LUAU_BUILD_TESTS) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) + target_compile_definitions(Luau.UnitTest PRIVATE DOCTEST_CONFIG_DOUBLE_STRINGIFY) target_include_directories(Luau.UnitTest PRIVATE extern) - target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler) + target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen) target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Conformance PRIVATE extern) - target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.VM) + target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen Luau.VM) + + target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS}) + target_include_directories(Luau.CLI.Test PRIVATE extern CLI) + target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.CodeGen Luau.VM isocline) + if(UNIX) + find_library(LIBPTHREAD pthread) + if (LIBPTHREAD) + target_link_libraries(Luau.CLI.Test PRIVATE pthread) + endif() + endif() + +endif() + +if(LUAU_BUILD_WEB) + target_compile_options(Luau.Web PRIVATE ${LUAU_OPTIONS}) + target_link_libraries(Luau.Web PRIVATE Luau.Compiler Luau.VM) + + # declare exported functions to emscripten + target_link_options(Luau.Web PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap']) + + # add -fexceptions for emscripten to allow exceptions to be caught in C++ + target_link_options(Luau.Web PRIVATE -fexceptions) + + # the output is a single .js file with an embedded wasm blob + target_link_options(Luau.Web PRIVATE -sSINGLE_FILE=1) endif() diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b48066ee..460e4b43 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -26,10 +26,9 @@ If you're thinking of adding a new feature to the language, library, analysis to Luau team has internal priorities and a roadmap that may or may not align with specific features, so before starting to work on a feature please submit an issue describing the missing feature that you'd like to add. For features that result in observable change of language syntax or semantics, you'd need to [create an RFC](https://github.com/Roblox/luau/blob/master/rfcs/README.md) to make sure that the feature is needed and well designed. -Similarly to the above, please create an issue first so that we can see if we should go through with an RFC process. Finally, please note that Luau tries to carry a minimal feature set. All features must be evaluated not just for the benefits that they provide, but also for the downsides/costs in terms of language simplicity, maintainability, cross-feature interaction etc. -As such, feature requests may not be accepted, or may get to an RFC stage and get rejected there - don't expect Luau to gain a feature just because another programming language has it. +As such, feature requests may not be accepted even if a comprehensive RFC is written - don't expect Luau to gain a feature just because another programming language has it. ## Code style @@ -47,6 +46,12 @@ You can run the tests yourself using `make test` or using `cmake` to build `Luau When making code changes please try to make sure they are covered by an existing test or add a new test accordingly. +## Performance + +One of the central feature of Luau is performance; our runtime in particular is heavily optimized for high performance and low memory consumption, and code is generally carefully tuned to result in close to optimal assembly for x64 and AArch64 architectures. The analysis code is not optimized to the same level of detail, but performance is still very important to make sure that we can support interactive IDE features. + +As such, it's important to make sure that the changes, including bug fixes, improve or at least do not regress performance. For VM this can be validated by running `bench.py` script from `bench` folder on two binaries built in Release mode, before and after the changes, although note that our benchmark coverage is not complete and in some cases additional performance testing will be necessary to determine if the change can be merged. + ## Feature flags For large bug fixes or features that apply to the Luau components and not just the CLI tools, we may ask that you introduce a feature flag to gate your changes. The feature flags use `LUAU_FASTFLAG` macro family defined in `Luau/Common.h` and allow us to ensure that the change can be safely shipped, enabled, and rolled back on the Roblox platform when the change makes it into our production codebase. The tests run the code with flags in their default state and enabled state as well to ensure correctness. diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h new file mode 100644 index 00000000..53efd3c3 --- /dev/null +++ b/CodeGen/include/Luau/AddressA64.h @@ -0,0 +1,53 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterA64.h" + +namespace Luau +{ +namespace CodeGen +{ + +enum class AddressKindA64 : uint8_t +{ + imm, // reg + imm + reg, // reg + reg + + // TODO: + // reg + reg << shift + // reg + sext(reg) << shift + // reg + uext(reg) << shift +}; + +struct AddressA64 +{ + AddressA64(RegisterA64 base, int off = 0) + : kind(AddressKindA64::imm) + , base(base) + , offset(xzr) + , data(off) + { + LUAU_ASSERT(base.kind == KindA64::x || base == sp); + LUAU_ASSERT(off >= -256 && off < 4096); + } + + AddressA64(RegisterA64 base, RegisterA64 offset) + : kind(AddressKindA64::reg) + , base(base) + , offset(offset) + , data(0) + { + LUAU_ASSERT(base.kind == KindA64::x); + LUAU_ASSERT(offset.kind == KindA64::x); + } + + AddressKindA64 kind; + RegisterA64 base; + RegisterA64 offset; + int data; +}; + +using mem = AddressA64; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h new file mode 100644 index 00000000..9e12168a --- /dev/null +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -0,0 +1,161 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterA64.h" +#include "Luau/AddressA64.h" +#include "Luau/ConditionA64.h" +#include "Luau/Label.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +class AssemblyBuilderA64 +{ +public: + explicit AssemblyBuilderA64(bool logText); + ~AssemblyBuilderA64(); + + // Moves + void mov(RegisterA64 dst, RegisterA64 src); + void mov(RegisterA64 dst, uint16_t src, int shift = 0); + void movk(RegisterA64 dst, uint16_t src, int shift = 0); + + // Arithmetics + void add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void add(RegisterA64 dst, RegisterA64 src1, int src2); + void sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + void sub(RegisterA64 dst, RegisterA64 src1, int src2); + void neg(RegisterA64 dst, RegisterA64 src); + + // Comparisons + // Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm + void cmp(RegisterA64 src1, RegisterA64 src2); + void cmp(RegisterA64 src1, int src2); + + // Bitwise + // Note: shifted-register support and bitfield operations are omitted for simplicity + // TODO: support immediate arguments (they have odd encoding and forbid many values) + void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void mvn(RegisterA64 dst, RegisterA64 src); + + // Shifts + void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void ror(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void clz(RegisterA64 dst, RegisterA64 src); + void rbit(RegisterA64 dst, RegisterA64 src); + + // Load + // Note: paired loads are currently omitted for simplicity + void ldr(RegisterA64 dst, AddressA64 src); + void ldrb(RegisterA64 dst, AddressA64 src); + void ldrh(RegisterA64 dst, AddressA64 src); + void ldrsb(RegisterA64 dst, AddressA64 src); + void ldrsh(RegisterA64 dst, AddressA64 src); + void ldrsw(RegisterA64 dst, AddressA64 src); + + // Store + void str(RegisterA64 src, AddressA64 dst); + void strb(RegisterA64 src, AddressA64 dst); + void strh(RegisterA64 src, AddressA64 dst); + + // Control flow + // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks + void b(Label& label); + void b(ConditionA64 cond, Label& label); + void cbz(RegisterA64 src, Label& label); + void cbnz(RegisterA64 src, Label& label); + void br(RegisterA64 src); + void blr(RegisterA64 src); + void ret(); + + // Address of embedded data + void adr(RegisterA64 dst, const void* ptr, size_t size); + void adr(RegisterA64 dst, uint64_t value); + void adr(RegisterA64 dst, double value); + + // Run final checks + bool finalize(); + + // Places a label at current location and returns it + Label setLabel(); + + // Assigns label position to the current location + void setLabel(Label& label); + + void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); + + uint32_t getCodeSize() const; + + // Resulting data and code that need to be copied over one after the other + // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' + std::vector data; + std::vector code; + + std::string text; + + const bool logText = false; + +private: + // Instruction archetypes + void place0(const char* name, uint32_t word); + void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0); + void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2 = 0); + void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2); + void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); + void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); + void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); + void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size); + void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); + void placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond); + void placeBR(const char* name, RegisterA64 src, uint32_t op); + void placeADR(const char* name, RegisterA64 src, uint8_t op); + + void place(uint32_t word); + + void patchLabel(Label& label); + void patchImm19(uint32_t location, int value); + + void commit(); + LUAU_NOINLINE void extend(); + + // Data + size_t allocateData(size_t size, size_t align); + + // Logging of assembly in text form + LUAU_NOINLINE void log(const char* opcode); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, RegisterA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 src); + LUAU_NOINLINE void log(const char* opcode, Label label); + LUAU_NOINLINE void log(Label label); + LUAU_NOINLINE void log(RegisterA64 reg); + LUAU_NOINLINE void log(AddressA64 addr); + + uint32_t nextLabel = 1; + std::vector(a) -> a +``` + +which would also be a parse error by the time we reach `->`, that is the production of the above is semantically equivalent to `(f < a) > (a)` which would compare whether the value of `f` is less than the value of `a`, then whether the result of that value is greater than `a`. + +## Alternatives + +Only allow these syntax when used inside parentheses e.g. `(F T)` or `(F(T))`. This makes it inconsistent with the existing `typeof(...)` type annotation, and changing that over is also breaking change. + +Support backtracking in the parser, so if `: MyType(t or u):m()` is invalid syntax, revert and parse `MyType` as a type, and `(t or u):m()` as an expression statement. Even so, this option is terrible for: + 1. parsing performance (backtracking means losing progress on invalid input), + 2. user experience (why was this annotation parsed as `X(...)` instead of `X` followed by a statement `(...)`), + 3. has false positives (`foo(bar)(baz)` may be parsed as `foo(bar)` as the type annotation and `(baz)` is the remainder to parse) + +## Drawbacks + +To be able to expose some kind of type-level operations using `F` syntax, means one of the following must be chosen: + 1. introduce the concept of "magic type functions" into type inference, or + 2. introduce them into the prelude as `export type F = ...` (where `...` is to be read as "we haven't decided") diff --git a/rfcs/function-bit32-countlz-countrz.md b/rfcs/function-bit32-countlz-countrz.md new file mode 100644 index 00000000..b4ccb197 --- /dev/null +++ b/rfcs/function-bit32-countlz-countrz.md @@ -0,0 +1,52 @@ +# bit32.countlz/countrz + +**Status**: Implemented + +## Summary + +Add bit32.countlz (count left zeroes) and bit32.countrz (count right zeroes) to accelerate bit scanning + +## Motivation + +All CPUs have instructions to determine the position of first/last set bit in an integer. These instructions have a variety of uses, the popular ones being: + +- Fast implementation of integer logarithm (essentially allowing to compute `floor(log2(value))` quickly) +- Scanning set bits in an integer, which allows efficient traversal of compact representation of bitmaps +- Allocating bits out of a bitmap quickly + +Today it's possible to approximate `countlz` using `floor` and `log` but this approximation is relatively slow; approximating `countrz` is difficult without iterating through each bit. + +## Design + +`bit32` library will gain two new functions, `countlz` and `countrz`: + +``` +function bit32.countlz(n: number): number +function bit32.countrz(n: number): number +``` + +`countlz` takes an integer number (converting the input number to a 32-bit unsigned integer as all other `bit32` functions do), and returns the number of consecutive left-most zero bits - that is, the number of most significant zero bits in a 32-bit number until the first 1. The result is in `[0, 32]` range. + +For example, when the input number is `0`, it's `32`. When the input number is `2^k`, the result is `31-k`. + +`countrz` takes an integer number (converting the input number to a 32-bit unsigned integer as all other `bit32` functions do), and returns the number of consecutive right-most zero bits - that is, +the number of least significant zero bits in a 32-bit number until the first 1. The result is in `[0, 32]` range. + +For example, when the input number is `0`, it's `32`. When the input number is `2^k`, the result is `k`. + +> Non-normative: a proof of concept implementation shows that a polyfill for `countlz` takes ~34 ns per loop iteration when computing `countlz` for an increasing number sequence, whereas +> a builtin implementation takes ~4 ns. + +## Drawbacks + +None known. + +## Alternatives + +These functions can be alternatively specified as "find the position of the most/least significant bit set" (e.g. "ffs"/"fls" for "find first set"/"find last set"). This formulation +can be more immediately useful since the bit position is usually more important than the number of bits. However, the bit position is undefined when the input number is zero, +returning a sentinel such as -1 seems non-idiomatic, and returning `nil` seems awkward for calling code. Counting functions don't have this problem. + +An early version of this proposal suggested `clz`/`ctz` (leading/trailing) as names; however, using a full verb is more consistent with other operations like shift/rotate, and left/right may be easier to understand intuitively compared to leading/trailing. left/right are used by C++20. + +Of the two functions, `countlz` is vastly more useful than `countrz`; we could implement just `countlz`, but having both is nice for symmetry. diff --git a/rfcs/function-coroutine-close.md b/rfcs/function-coroutine-close.md new file mode 100644 index 00000000..b9ffbf6f --- /dev/null +++ b/rfcs/function-coroutine-close.md @@ -0,0 +1,36 @@ +# coroutine.close + +**Status**: Implemented + +## Summary + +Add `coroutine.close` function from Lua 5.4 that takes a suspended coroutine and makes it "dead" (non-runnable). + +## Motivation + +When implementing various higher level objects on top of coroutines, such as promises, it can be useful to cancel the coroutine execution externally - when the caller is not +interested in getting the results anymore, execution can be aborted. Since coroutines don't provide a way to do that externally, this requires the framework to implement +cancellation on top of coroutines by keeping extra status/token and checking that token in all places where the coroutine is resumed. + +Since coroutine execution can be aborted with an error at any point, coroutines already implement support for "dead" status. If it were possible to externally transition a coroutine +to that status, it would be easier to implement cancellable promises on top of coroutines. + +## Design + +We implement Lua 5.4 behavior exactly with the exception of to-be-closed variables that we don't support. Quoting Lua 5.4 manual: + +> coroutine.close (co) +> Closes coroutine co, that is, puts the coroutine in a dead state. The given coroutine must be dead or suspended. In case of error (either the original error that stopped the coroutine or errors in closing methods), returns false plus the error object; otherwise returns true. + +The `co` argument must be a coroutine object (of type `thread`). + +After closing the coroutine, it gets transitioned to dead state which means that `coroutine.status` will return `"dead"` and attempts to resume the coroutine will fail. In addition, the coroutine stack (which can be accessed via `debug.traceback` or `debug.info`) will become empty. Calling `coroutine.close` on a closed coroutine will return `true` - after closing, the coroutine transitions into a "dead" state with no error information. + +## Drawbacks + +None known, as this function doesn't introduce any existing states to coroutines, and is similar to running the coroutine to completion/error. + +## Alternatives + +Lua's name for this function is likely in part motivated by to-be-closed variables that we don't support. As such, a more appropriate name could be `coroutine.cancel` which also +aligns with use cases better. However, since the semantics is otherwise the same, using the same name as Lua 5.4 reduces library fragmentation. diff --git a/rfcs/function-debug-info.md b/rfcs/function-debug-info.md index 1362dc91..5f486db4 100644 --- a/rfcs/function-debug-info.md +++ b/rfcs/function-debug-info.md @@ -14,7 +14,7 @@ Today Luau provides only one method to get the callstack, `debug.traceback`. Thi There are a few cases where this can be inconvenient: -- Sometimes it is useful to pass the resulting call stack to some system expecting a structured input, e.g. for crash aggreggation +- Sometimes it is useful to pass the resulting call stack to some system expecting a structured input, e.g. for crash aggregation - Sometimes it is useful to use the information about the caller for logging or filtering purposes; in these cases using just the script name can be useful, and getting script name out of the traceback is slow and imprecise Additionally, in some cases instead of getting the information (such as script or function name) from the callstack, it can be useful to get it from a function object for diagnostic purposes. For example, maybe you want to call a callback and if it doesn't return expected results, display a user-friendly error message that contains the function name & script location - these aren't possible today at all. diff --git a/rfcs/function-table-clone.md b/rfcs/function-table-clone.md new file mode 100644 index 00000000..8cb97984 --- /dev/null +++ b/rfcs/function-table-clone.md @@ -0,0 +1,66 @@ +# table.clone + +**Status**: Implemented + +## Summary + +Add `table.clone` function that, given a table, produces a copy of that table with the same keys/values/metatable. + +## Motivation + +There are multiple cases today when cloning tables is a useful operation. + +- When working with tables as data containers, some algorithms may require modifying the table that can't be done in place for some reason. +- When working with tables as objects, it can be useful to obtain an identical copy of the object for further modification, preserving the metatable. +- When working with immutable data structures, any modification needs to clone some parts of the data structure to produce a new version of the object. + +While it's possible to implement this function in user code today, it's impossible to implement it with maximum efficiency; furthermore, cloning is a reasonably fundamental +operation so from the ergonomics perspective it can be expected to be provided by the standard library. + +## Design + +`table.clone(t)` takes a table, `t`, and returns a new table that: + +- has the same metatable +- has the same keys and values +- is not frozen, even if `t` was + +The copy is shallow: implementing a deep recursive copy automatically is challenging (for similar reasons why we decided to avoid this in `table.freeze`), and often only certain keys need to be cloned recursively which can be done after the initial clone. + +The table can be modified after cloning; as such, functions that compute a slightly modified copy of the table can be easily built on top of `table.clone`. + +`table.clone(t)` is functionally equivalent to the following code, but it's more ergonomic (on the account of being built-in) and significantly faster: + +```lua +assert(type(t) == "table") +local nt = {} +for k,v in pairs(t) do + nt[k] = v +end +if type(getmetatable(t)) == "table" then + setmetatable(nt, getmetatable(t)) +end +``` + +The reason why `table.clone` can be dramatically more efficient is that it can directly copy the internal structure, preserving capacity and exact key order, and is thus +limited purely by memory bandwidth. In comparison, the code above can't predict the table size ahead of time, has to recreate the internal table structure one key at a time, +and bears the interpreter overhead (which can be avoided for numeric keys with `table.move` but that doesn't work for the general case of dictionaries). + +Out of the abundance of caution, `table.clone` will fail to clone the table if it has a protected metatable. This is motivated by the fact that you can't do this today, so +there are no new potential vectors to escape various sandboxes. Superficially it seems like it's probably reasonable to allow cloning tables with protected metatables, but +there may be cases where code manufactures tables with unique protected metatables expecting 1-1 relationship and cloning would break that, so for now this RFC proposes a more +conservative route. We are likely to relax this restriction in the future. + +## Drawbacks + +Adding a new function to `table` library theoretically increases complexity. In practice though, we already effectively implement `table.clone` internally for some VM optimizations, so exposing this to the users bears no cost. + +Assigning a type to this function is a little difficult if we want to enforce the "argument must be a table" constraint. It's likely that we'll need to type this as `table.clone(T): T` for the time being, which is less precise. + +## Alternatives + +We can implement something similar to `Object.assign` from JavaScript instead, that simultaneously assigns extra keys. However, this won't be fundamentally more efficient than +assigning the keys afterwards, and can be implemented in user space. Additionally, we can later extend `clone` with an extra argument if we so choose, so this proposal is the +minimal viable one. + +We can immediately remove the rule wrt protected metatables, as it's not clear that it's actually problematic to be able to clone tables with protected metatables. diff --git a/rfcs/function-table-freeze.md b/rfcs/function-table-freeze.md index a99da64f..ca819882 100644 --- a/rfcs/function-table-freeze.md +++ b/rfcs/function-table-freeze.md @@ -2,6 +2,8 @@ > Note: this RFC was adapted from an internal proposal that predates RFC process +**Status**: Implemented + ## Summary Add `table.freeze` which allows to make a table read-only in a shallow way. @@ -42,7 +44,7 @@ To limit the use of `table.freeze` to cases when table contents can be freely ma Exposing the internal "readonly" feature may have an impact on interoperability between scripts - for example, it becomes possible to freeze some tables that scripts may be expecting to have write access to from other scripts. Since we don't provide a way to unfreeze tables and freezing a table with a locked metatable fails, in theory the impact should not be any worse than allowing to change a metatable, but the full extents are unclear. -There may be existing code in the VM that allows changing frozen tables in ways that are benign to the current sanboxing code, but expose a "gap" in the implementation that becomes significant with this feature; thus we would need to audit all table writes when implementing this. +There may be existing code in the VM that allows changing frozen tables in ways that are benign to the current sandboxing code, but expose a "gap" in the implementation that becomes significant with this feature; thus we would need to audit all table writes when implementing this. ## Alternatives diff --git a/rfcs/generalized-iteration.md b/rfcs/generalized-iteration.md new file mode 100644 index 00000000..c28156ff --- /dev/null +++ b/rfcs/generalized-iteration.md @@ -0,0 +1,126 @@ +# Generalized iteration + +**Status**: Implemented + +## Summary + +Introduce support for iterating over tables without using `pairs`/`ipairs` as well as a generic customization point for iteration via `__iter` metamethod. + +## Motivation + +Today there are many different ways to iterate through various containers that are syntactically incompatible. + +To iterate over arrays, you need to use `ipairs`: `for i, v in ipairs(t) do`. The traversal goes over a sequence `1..k` of numeric keys until `t[k] == nil`, preserving order. + +To iterate over dictionaries, you need to use `pairs`: `for k, v in pairs(t) do`. The traversal goes over all keys, numeric and otherwise, but doesn't guarantee an order; when iterating over arrays this may happen to work but is not guaranteed to work, as it depends on how keys are distributed between array and hash portion. + +To iterate over custom objects, whether they are represented as tables (user-specified) or userdata (host-specified), you need to expose special iteration methods, for example `for k, v in obj:Iterator() do`. + +All of these rely on the standard Lua iteration protocol, but it's impossible to trigger them in a generic fashion. Additionally, you *must* use one of `pairs`/`ipairs`/`next` to iterate over tables, which is easy to forget - a naive `for k, v in tab do` doesn't work and produces a hard-to-understand error `attempt to call a table value`. + +This proposal solves all of these by providing a way to implement uniform iteration with self-iterating objects by allowing to iterate over objects and tables directly via convenient `for k, v in obj do` syntax, and specifies the default iteration behavior for tables, thus mostly rendering `pairs`/`ipairs` obsolete - making Luau easier to use and teach. + +## Design + +In Lua, `for vars in iter do` has the following semantics (otherwise known as the iteration protocol): `iter` is expanded into three variables, `gen`, `state` and `index` (using `nil` if `iter` evaluates to fewer than 3 results); after this the loop is converted to the following pseudocode: + +```lua +while true do + vars... = gen(state, index) + index = vars... -- copy the first variable into the index + if index == nil then break end + + -- loop body goes here +end +``` + +This is a general mechanism that can support iteration through many containers, especially if `gen` is allowed to mutate state. Importantly, the *first* returned variable (which is exposed to the user) is used to continue the process on the next iteration - this can be limiting because it may require `gen` or `state` to carry extra internal iteration data for efficiency. To work around this for table iteration to avoid repeated calls to `next`, Luau compiler produces a special instruction sequence that recognizes `pairs`/`ipairs` iterators and stores the iteration index separately. + +Thus, today the loop `for k, v in tab do` effectively executes `k, v = tab()` on the first iteration, which is why it yields `attempt to call a table value`. If the object defines `__call` metamethod then it can act as a self-iterating method, but this is not idiomatic, not efficient and not pure/clean. + +This proposal comes in two parts: general support for `__iter` metamethod and default implementation for tables without one. With both of these in place, there's going to be a single, idiomatic, general and performant way to iterate through the object of any type: + +```lua +for k, v in obj do +... +end +``` + +### __iter + +To support self-iterating objects, we modify the iteration protocol as follows: instead of simply expanding the result of expression `iter` into three variables (`gen`, `state` and `index`), we check if the first result has an `__iter` metamethod (which can be the case if it's a table, userdata or another composite object (e.g. a record in the future). If it does, the metamethod is called with `gen` as the first argument, and the returned three values replace `gen`/`state`/`index`. This happens *before* the loop: + +```lua +local genmt = rawgetmetatable(gen) -- pseudo code for getmetatable that bypasses __metatable +local iterf = genmt and rawget(genmt, "__iter") +if iterf then + gen, state, index = iterf(gen) +end +``` + +This check is comparatively trivial: usually `gen` is a function, and functions don't have metatables; as such we can simply check the type of `gen` and if it's a table/userdata, we can check if it has a metamethod `__iter`. Due to tag-method cache, this check is also very cheap if the metamethod is absent. + +This allows objects to provide a custom function that guides the iteration. Since the function is called once, it is easy to reuse other functions in the implementation, for example here's a node object that exposes iteration through its children: + +```lua +local Node = {} +Node.__index = Node + +function Node.new(children) + return setmetatable({ children = children }, Node) +end + +function Node:__iter() + return next, self.children +end +``` + +Luau compiler already emits a bytecode instruction, FORGPREP*, to perform initial loop setup - this is where we can evaluate `__iter` as well. + +Naturally, this means that if the table has `__iter` metamethod and you need to iterate through the table fields instead of using the provided metamethod, you can't rely on the general iteration scheme and need to use `pairs`. This is similar to other parts of the language, like `t[k]` vs `rawget(t, 'k')`, where the default behavior is overrideable but a library function can help peek behind the curtain. + +### Default table iteration + +If the argument is a table and it does not implement `__iter` metamethod, we treat this as an attempt to iterate through the table using the builtin iteration order. + +> Note: we also check if the table implements `__call`; if it does, we fall back to the default handling. We may be able to remove this check in the future, but we will need this initially to preserve backwards compatibility with custom table-driven iterator objects that implement `__call`. In either case, we will be able to collect detailed analytics about the use of `__call` in iteration, and if neither is present we can emit a specialized error message such as `object X is not iteratable`. + +To have a single, unified, iteration scheme over tables regardless of whether they are arrays or dictionaries, we establish the following semantics: + +- First, the traversal goes over numeric keys in range `1..k` up until reaching the first `k` such that `t[k] == nil` +- Then, the traversal goes over the remaining keys (with non-nil values), numeric and otherwise, in unspecified order. + +For arrays with gaps, this iterates until the first gap in order, and the remaining order is not specified. + +> Note: This behavior is similar to what `pairs` happens to provide today, but `pairs` doesn't give any guarantees, and it doesn't always provide this behavior in practice. + +To ensure that this traversal is performant, the actual implementation of the traversal involves going over the array part (in index order) and then over the hash part (in hash order). For that implementation to satisfy the criteria above, we need to make two additional changes to table insertion/rehash: + +- When inserting key `k` in the table when `k == t->sizearray + 1`, we force the table to rehash (resize its array portion). Today this is only performed if the hash portion is full, as such sometimes numeric keys can end up in the hash part. +- When rehashing the table, we ensure that the hash part doesn't contain the key `newsizearray + 1`. This requires checking if the table has this key, which may require an additional hash lookup but we only need to do this in rare cases based on the analysis of power-of-two key buckets that we already collect during rehash. + +These changes guarantee that the order observed via standard traversal with `next`/`pairs` matches the guarantee above, which is nice because it means we can minimize the complexity cost of this change by reusing the traversal code, including VM optimizations. They also mean that the array boundary (aka `#t`) can *always* be computed from just the array portion, which simplifies the table length computation and may slightly speed it up. + +## Drawbacks + +This makes `for` desugaring and implementation a little more complicated; it's not a large complexity factor in Luau because we already have special handling for `for` loops in the VM, but it's something to keep in mind. + +While the proposed iteration scheme should be a superset to both `pairs` and `ipairs` for tables, for arrays `ipairs` may in some cases be faster because it stops at the first `nil`, whereas the proposed new scheme (like `pairs`) needs to iterate through the rest of the table's array storage. This may be fixable in the future, if we replace our cached table length (`aboundary`) with Lua 5.4's `alimit`, which maintains the invariant that all values after `alimit` in the array are `nil`. This would make default table iteration maximally performant as well as help us accelerate GC in some cases, but will require extra checks during table assignments which is a cost we may not be willing to pay. Thus it is theoretically possible that we will end up with `ipairs` being a slightly faster equivalent for array iteration forever. + +The resulting iteration behavior, while powerful, increases the divergence between Luau and Lua, making more programs that are written for Luau not runnable in Lua. Luau language in general does not consider this type of compatibility essential, but this is noted for posterity. + +The changes in insertion behavior that facilitate single iteration order may have a small cost; that said, they are currently understood to belong to paths that are already slow and the added cost is minimal. + +The extra semantics will make inferring the types of the variables in a for loop more difficult - if we know the type of the expression that is being iterated through it probably is not a problem though. + +## Alternatives + +Other major designs have been considered. + +A minor variation of the proposal involves having `__iter` be called on every iteration instead of at loop startup, effectively having `__iter` work as an alternative to `__call`. The issue with this variant is that while it's a little simpler to specify and implement, it restricts the options when implementing custom iteratable objects, because it would be difficult for iteratable objects to store custom iteration state elsewhere since `__iter` method would effectively need to be pure, as it can't modify the object itself as more than one concurrent iteration needs to be supported. + +A major variation of the proposal involves instead supporting `__pairs` from Lua 5.2. The issue with this variant is that it still requires the use of a library method, `pairs`, to work, which doesn't make the language simpler as far as table iteration, which is the 95% case, is concerned. Additionally, with some rare exceptions metamethods today extend the *language* behavior, not the *library* behavior, and extending extra library functions with metamethods does not seem true to the core of the language. Finally, this only works if the user uses `pairs` to iterate and doesn't work with `ipairs`/`next`. + +Another variation involves using a new pseudo-keyword, `foreach`, instead of overloading existing `for`, and only using the new `__iter` semantics there. This can more cleanly separate behavior, requiring the object to have an `__iter` metamethod (or be a table) in `foreach` - which also avoids having to deal with `__call` - but it also requires teaching the users a new keyword which fragments the iteration space a little bit more. Compared to that, the main proposal doesn't introduce new divergent syntax, and merely tweaks existing behavior to be more general, thus making an existing construct easier to use. + +Finally, the author also considered and rejected extending the iteration protocol as part of this change. One problem with the current protocol is that the iterator requires an allocation (per loop execution) to keep extra state that isn't exposed to the user. The builtin iterators like `pairs`/`ipairs` work around this by feeding the user-visible index back to the search function, but that's not always practical. That said, having a different iteration protocol in effect only when `__iter` is used makes the language more complicated for unclear efficiency gains, thus this design doesn't suggest a new core protocol to favor simplicity. diff --git a/rfcs/generic-function-subtyping.md b/rfcs/generic-function-subtyping.md new file mode 100644 index 00000000..b9c0c430 --- /dev/null +++ b/rfcs/generic-function-subtyping.md @@ -0,0 +1,211 @@ +# Expanded Subtyping for Generic Function Types + +## Summary + +Extend the subtyping relation for function types to relate generic function +types with compatible instantiated function types. + +## Motivation + +As Luau does not have an explicit syntax for instantiation, there are a number +of places where the typechecker will automatically perform instantiation with +the goal of permitting more programs. These instances of instantiation are +ad-hoc and strategic, but useful in practice for permitting programs such as: + +```lua +function id(x: T): T + return x +end + +local idNum : (number) -> number +idNum = id -- ok +``` + +However, they have also been a source of some typechecking bugs because of how +they actually make a determination as to whether the instantation should happen, +and they currently open up some potential soundness holes when instantiating +functions in table types since properties of tables are mutable and thus need to +be invariant (which the automatic-instantiation potentially masks). + +## Design + +The goal then is to rework subtyping to support the relationship we want in the +first place: allowing polymorphic functions to be used where instantiated +functions are expected. In particular, this means adding instantiation itself to +the subtyping relation. Formally, that'd look something like: + +``` +instantiate((T1) -> T2) = (T1') -> T2' +(T1') -> T2' <: (T3) -> T4 +-------------------------------------------- +(T1) -> T2 <: (T3) -> T4) +``` + +Or informally, we'd say that a generic function type is a subtype of another +function type if we can instantiate it and show that instantiated function type +to be a subtype of the original function type. Implementation-wise, this loose +formal rule suggests a strategy of when we'll want to apply instantiation. +Namely, whenever the subtype and supertype are both functions with the potential +subtype having some generic parameters and the supertype having none. So, if we +look once again at our simple example from motivation, we can walk through how +we expect it to type check: + +```lua +function id(x: T): T + return x +end + +local idNum : (number) -> number +idNum = id -- ok +``` + +First, `id` is given the type `(T) -> T` and `idNum` is given the type +`(number) -> number`. When we actually perform the assignment, we must show that +the type of the right-hand side is compatible with the type of the left-hand +side according to subtyping. That is, we'll ask if `(T) -> T` is a subtype of +`(number) -> number` which matches the rule to apply instantiation since the +would-be subtype has a generic parameter while the would-be supertype has no +generic parameters. This contrasts with the current implementation which, before +asking the subtyping question, checks if the type of the right-hand side +contains any generics at any point and if the type of the left-hand side cannot +_possibly_ contain generics and instantiates the right-hand side if so. + +Adding instantiation to subtyping does pose some additional questions still +about when exactly to instantiate. Namely, we need to consider cases like +function application. We can see why by looking at some examples: + +```lua +function rank2(f: (a) -> a): (number) -> number + return f +end +``` + +In this case, we expect to allow the instantiation of `f` from `(a) -> a` to +`(number) -> number`. After all, we can consider other cases like where the body +instead applies `f` to some particular value, e.g. `f(42)`, and we'd want the +instantiation to be allowed there. However, this means we'd potentially run into +issues if we allowed call sites to `rank2` to pass in non-polymorphic functions. +A naive approach to implementing this proposal would do exactly that because we +currently treat contravariant subtyping positions (i.e. for the arguments of +functions) as being the same as our normal (i.e. covariant) subtyping relation +but with the arguments reversed. So, to type check an application like +`rank2(function(str: string) return str + "s" end)` (where the function argument +is of type `(string) -> string`), we would ask if `(a) -> a` is a subtype of +`(string) -> string`. This is precisely the question we asked in the original +example, but in the contravariant context, this is actually unsound since +`rank2` would then function as a general coercion from, e.g., +`(string) -> string` to `(number) -> number`. + +This sort of behavior does come up in other languages that mix polymorphism and +subtyping. If we consider the same example in F#, we can compare its behavior: + +```fsharp +let ranktwo (f : 'a -> 'a) : int -> int = f +let pluralize (s : string) : string = s + "s" +let x = ranktwo pluralize +``` + +For this example, F# produces one warning and one error. The warning is applied +to the function definition of `ranktwo` itself (coded `FS0064`), and says "This +construct causes code to be less generic than indicated by the type annotations. +The type variable 'a has been constrained to be type 'int'." This warning +highlights the actual difference between our example in Luau and the F# +translation. In F#, `'a` is really a free type variable, rather than a generic +type parameter of the function `ranktwo`, as such, this code actually +constrains the type of `ranktwo` to be `(int -> int) -> (int -> int)`. As such, +the application on line 3 errors because our `(string -> string)` function is +simply not compatible with that type. With higher-rank polymorphic function +parameters, it doesn't make sense to warn on their instantiation (as illustrated +by the example of actually applying `f` to some particular data in the +definition of `rank2`), but it's still just as problematic if we were to accept +instantiated functions at polymorphic types. Thus, it's important that we +actually ensure that we only instantiate in covariant contexts. So, we must +ensure that subtyping only instantiates in covariant contexts. + +It may also be helpful to consider an example of rank-1 polymorphism to +understand the full scope of the behavior. So, we can look at what happens if we +simply move the type parameter out in our working example: + +```lua +function rank1(f: (a) -> a): (number) -> number + return f +end +``` + +In this case, we expect an error to occur because the type of `f` depends on +what we instantiate `rank1` with. If we allowed this, it would naturally be +unsound because we could again provide a `(string) -> string` argument (by +instantiating `a` with `string`). This reinforces the idea that the presence of +the generic type parameter is likely to be a good option for determining +instantiation (at least when compared to the presence of free type variables). + +## Drawbacks + +One of the aims of this proposal is to provide a clear and predictable mental +model of when instantiation will take place in Luau. The author feels this +proposal is step forward compared to the existing ad-hoc usage of instantiation +in the typechecker, but it's possible that programmers are already comfortable +with the mental model they have built for the existing implementation. +Hopefully, this is mitigated by the fact that the new setup should allow all of +the _sound_ uses of instantiation permitted by the existing system. Notably, +however, programmers may be surprised by the added restriction when it comes to +properties in tables. In particular, we can consider a small variation of our +original example with identity functions: + +```lua +function id(x: T): T + return x +end + +local poly : { id : (a) -> a } = { id = id } + +local mono : { id : (number) -> number } +mono = poly -- error! +mono.id = id -- also an error! +``` + +In this case, the fact that we're dealing with a _property_ of a table type +means that we're in a context that needs to be invariant (i.e. not allow +subtyping) to avoid unsoundness caused by interactions between mutable +references and polymorphism (see things like the [value +restriction in OCaml][value-restriction] to understand why). In most cases, we +believe programmers will be using functions in tables as an implementation of +methods for objects, so we don't anticipate that they'll actually _want_ to do +the unsound thing here. The accepted RFC for [read-only +properties][read-only-props] gives us a technically-precise solution since +read-only properties would be free to be typechecked as a covariant context +(since they disallow mutation), and thus if the property `id` was marked +read-only, we'd be able to do both of the assignments in the above example. + +## Alternatives + +The main alternatives would likely be keeping the existing solution (and +likely having to tactically fix future bugs where instantiation either happens +too much or not enough), or removing automatic instantiation altogether in favor +of manual instantiation syntax. The former solution (changing nothing) is cheap +now (both in terms of runtime performance and also development cost), but the +existing implementation involves extra walks of both types to make a decision +about whether or not to perform instantiation. To minimize the performance +impact, the functions that perform these questions (`isGeneric` and +`maybeGeneric`) actually do not perform a full walk, and instead try to +strategically look at only enough to make the decision. We already found and +fixed one bug that was caused by these functions being too imprecise against +their spec, but fleshing them out entirely could potentially be a noticeable +performance regression since the decision to potentially instantiate is one that +comes up often. + +Removing automatic instantiation altogether, by contrast, will definitely be +"correct" in that we'll never instantiate in the wrong spot and programmers will +always have the ability to instantiate, but it would be a marked regression on +developer experience since it would increase the annotation burden considerably +and generally runs counter to the overall design strategy of Luau (which focuses +heavily on type inference). It would also require us to actually pick a syntax +for manual instantiation (which we are still open to do in the future if we +maintain an automatic instantiation solution) which is frought with parser +ambiguity issues or requires the introduction of a sigil like Rust's turbofish +for instantiation. Discussion of that syntax is present in the [generic +functions][generic-functions] RFC. + +[value-restriction]: https://stackoverflow.com/questions/22507448/the-value-restriction#22507665 +[read-only-props]: https://github.com/Roblox/luau/blob/master/rfcs/property-readonly.md +[generic-functions]: https://github.com/Roblox/luau/blob/master/rfcs/generic-functions.md diff --git a/rfcs/generic-functions.md b/rfcs/generic-functions.md index 1e9add96..3ac1bbba 100644 --- a/rfcs/generic-functions.md +++ b/rfcs/generic-functions.md @@ -1,5 +1,7 @@ # Generic functions +**Status**: Implemented + ## Summary Extend the syntax and semantics of functions to support explicit generic functions, which can bind type parameters as well as data parameters. @@ -30,6 +32,12 @@ Functions may also take generic type pack arguments for varargs, for instance: function compose(... : a...) -> (a...) return ... end ``` +Generic type and type pack parameters can also be used in function types, for instance: + +```lua +local id: (a)->a = function(x) return x end +``` + This change is *not* only syntax, as explicit type parameters need to be part of the semantics of types. For example, we can define a generic identity function ```lua @@ -84,6 +92,34 @@ Currently, Luau does not have explicit type binders, so `f` and `g` have the sam We propose supporting type parameters which can be instantiated with any type (jargon: Rank-N Types) but not type functions (jargon: Higher Kinded Types) or types with constraints (jargon: F-bounded polymorphism). +## Turbofish + +Note that this RFC proposes a syntax for adding generic parameters to functions, but it does *not* propose syntax for adding generic arguments to function call site. For example, for `id` function you *can* write: + +```lua + -- generic type gets inferred as a number in all these cases +local x = id(4) +local x = id(y) :: number +local x: number = id(y) +``` + +but you can *not* write `id(y)`. + +This syntax is difficult to parse as it's ambiguous wrt grammar for comparison, and disambiguating it requires being able to parse types in expression context which makes parsing slow and complicated. It's also worth noting that today there are programs with this syntax that are grammatically correct (eg `id('4')` parses as "compare variable `id` to variable `string`, and compare the result to string '4'"). The specific example with a single argument will always fail at runtime because booleans can't be compared with relational operators, but multi-argument cases such as `print(foo(4))` can execute without errors in certain cases. + +Note that in many cases the types can be inferred, whether through function arguments (`id(4)`) or through expected return type (`id(y) :: number`). It's also often possible to cast the function object to a given type, even though that can be unwieldy (`(id :: (number)->number)(y)`). Some languages don't have a way to specify the types at call site either, Swift being a prominent example. Thus it's not a given we need this feature in Luau. + +If we ever want to implement this though, we can use a solution inspired by Rust's turbofish and require an extra token before `<`. Rust uses `::<` but that doesn't work in Luau because as part of this RFC, `id::(a)->a` is a valid, if redundant, type ascription, so we need to choose a different prefix. + +The following two variants are grammatically unambiguous in expression context in Luau, and are a better parallel for Rust's turbofish (in Rust, `::` is more similar to Luau's `:` or `.` than `::`, which in Rust is called `as`): + +```lua +foo:() -- require : before <; this is only valid in Luau in variable declaration context, so it's safe to use in expression context +foo.() -- require . before <; this is currently never valid in Luau +``` + +This RFC doesn't propose using either of these options, but notes that either one of these options is possible to specify & implement in the future if we so desire. + ## Drawbacks This is a breaking change, in that examples like the unsound program above will no longer typecheck. diff --git a/rfcs/len-metamethod-rawlen.md b/rfcs/len-metamethod-rawlen.md new file mode 100644 index 00000000..60278dda --- /dev/null +++ b/rfcs/len-metamethod-rawlen.md @@ -0,0 +1,45 @@ +# Support `__len` metamethod for tables and `rawlen` function + +**Status**: Implemented + +## Summary + +`__len` metamethod will be called by `#` operator on tables, matching Lua 5.2 + +## Motivation + +Lua 5.1 invokes `__len` only on userdata objects, whereas Lua 5.2 extends this to tables. In addition to making `__len` metamethod more uniform and making Luau +more compatible with later versions of Lua, this has the important advantage which is that it makes it possible to implement an index based container. + +Before `__iter` and `__len` it was possible to implement a custom container using `__index`/`__newindex`, but to iterate through the container a custom function was +necessary, because Luau didn't support generalized iteration, `__pairs`/`__ipairs` from Lua 5.2, or `#` override. + +With generalized iteration, a custom container can implement its own iteration behavior so as long as code uses `for k,v in obj` iteration style, the container can +be interfaced with the same way as a table. However, when the container uses integer indices, manual iteration via `#` would still not work - which is required for some +more complicated algorithms, or even to simply iterate through the container backwards. + +Supporting `__len` would make it possible to implement a custom integer based container that exposes the same interface as a table does. + +## Design + +`#v` will call `__len` metamethod if the object is a table and the metamethod exists; the result of the metamethod will be returned if it's a number (an error will be raised otherwise). + +`table.` functions that implicitly compute table length, such as `table.getn`, `table.insert`, will continue using the actual table length. This is consistent with the +general policy that Luau doesn't support metamethods in `table.` functions. + +A new function, `rawlen(v)`, will be added to the standard library; given a string or a table, it will return the length of the object without calling any metamethods. +The new function has the previous behavior of `#` operator with the exception of not supporting userdata inputs, as userdata doesn't have an inherent definition of length. + +## Drawbacks + +`#` is an operator that is used frequently and as such an extra metatable check here may impact performance. However, `#` is usually called on tables without metatables, +and even when it is, using the existing metamethod-absence-caching approach we use for many other metamethods a test version of the change to support `__len` shows no +statistically significant difference on existing benchmark suite. This does complicate the `#` computation a little more which may affect JIT as well, but even if the +table doesn't have a metatable the process of computing `#` involves a series of condition checks and as such will likely require slow paths anyway. + +This is technically changing semantics of `#` when called on tables with an existing `__len` metamethod, and as such has a potential to change behavior of an existing valid program. +That said, it's unlikely that any table would have a metatable with `__len` metamethod as outside of userdata it would not anything, and this drawback is not feasible to resolve with any alternate version of the proposal. + +## Alternatives + +Do not implement `__len`. diff --git a/rfcs/lower-bounds-calculation.md b/rfcs/lower-bounds-calculation.md new file mode 100644 index 00000000..a1793884 --- /dev/null +++ b/rfcs/lower-bounds-calculation.md @@ -0,0 +1,217 @@ +# Lower Bounds Calculation + +## Summary + +We propose adapting lower bounds calculation from Pierce's Local Type Inference paper into the Luau type inference algorithm. + +https://www.cis.upenn.edu/~bcpierce/papers/lti-toplas.pdf + +## Motivation + +There are a number of important scenarios that occur where Luau cannot infer a sensible type without annotations. + +Many of these revolve around type variables that occur in contravariant positions. + +### Function Return Types + +A very common thing to write in Luau is a function to try to find something in some data structure. These functions habitually return the relevant datum when it is successfully found, or `nil` in the case that it cannot. For instance: + +```lua +-- A.lua +function find_first_if(vec, f) + for i, e in ipairs(vec) do + if f(e) then + return i + end + end + + return nil +end +``` + +This function has two `return` statements: One returns `number` and the other `nil`. Today, Luau flags this as an error. We ask authors to add a return annotation to make this error go away. + +We would like to automatically infer `find_first_if : ({T}, (T) -> boolean) -> number?`. + +Higher order functions also present a similar problem. + +```lua +-- B.lua +function foo(f) + f(5) + f("string") +end +``` + +There is nothing wrong with the implementation of `foo` here, but Luau fails to typecheck it all the same because `f` is used in an inconsistent way. This too can be worked around by introducing a type annotation for `f`. + +The fact that the return type of `f` is never used confounds things a little, but for now it would be a big improvement if we inferred `f : ((number | string) -> T...) -> ()`. + +## Design + +We introduce a new kind of TypeVar, `ConstrainedTypeVar` to represent a TypeVar whose lower bounds are known. We will never expose syntax for a user to write these types: They only temporarily exist as type inference is being performed. + +When unifying some type with a `ConstrainedTypeVar` we _broaden_ the set of constraints that can be placed upon it. + +It may help to realize that what we have been doing up until now has been _upper bounds calculation_. + +When we `quantify` a function, we will _normalize_ each type and convert each `ConstrainedTypeVar` into a `UnionTypeVar`. + +### Normalization + +When computing lower bounds, we need to have some process by which we reduce types down to a minimal shape and canonicalize them, if only to have a clean way to flush out degenerate unions like `A | A`. Normalization is about reducing union and intersection types to a minimal, canonicalizable shape. + +A normalized union is one where there do not exist two branches on the union where one is a subtype of the other. It is quite straightforward to implement. + +A normalized intersection is a little bit more complicated: + +1. The tables of an intersection are always combined into a single table. Coincident properties are merged into intersections of their own. + * eg `normalize({x: number, y: string} & {y: number, z: number}) == {x: number, y: string & number, z: number}` + * This is recursive. eg `normalize({x: {y: number}} & {x: {y: string}}) == {x: {y: number & string}}` +1. If two functions in the intersection have a subtyping relationship, the normalization results only in the super-type-most function. (more on function subtyping later) + +### Function subtyping relationships + +If we are going to infer intersections of functions, then we need to be very careful about keeping combinatorics under control. We therefore need to be very deliberate about what subtyping rules we have for functions of differing arity. We have some important requirements: + +* We'd like some way to canonicalize intersections of functions, and yet +* optional function arguments are a great feature that we don't want to break + +A very important use case for us is the case where the user is providing a callback to some higher-order function, and that function will be invoked with extra arguments that the original customer doesn't actually care about. For example: + +```lua +-- C.lua +function map_array(arr, f) + local result = {} + for i, e in ipairs(arr) do + table.insert(result, f(e, i, arr)) + end + return result +end + +local example = {1, 2, 3, 4} +local example_result = map_array(example, function(i) return i * 2 end) +``` + +This function mirrors the actual `Array.map` function in JavaScript. It is very frequent for users of this function to provide a lambda that only accepts one argument. It would be annoying for callers to be forced to provide a lambda that accepts two unused arguments. This obviously becomes even worse if the function later changes to provide yet more optional information to the callback. + +This use case is very important for Roblox, as we have many APIs that accept callbacks. Implementors of those callbacks frequently omit arguments that they don't care about. + +Here is an example straight out of the Roblox developer documentation. ([full example here](https://developer.roblox.com/en-us/api-reference/event/BasePart/Touched)) + +```lua +-- D.lua +local part = script.Parent + +local function blink() + -- ... +end + +part.Touched:Connect(blink) +``` + +The `Touched` event actually passes a single argument: the part that touched the `Instance` in question. In this example, it is omitted from the callback handler. + +We therefore want _oversaturation_ of a function to be allowed, but this combines with optional function arguments to create a problem with soundness. Consider the following: + +```lua +-- E.lua +type Callback = (Instance) -> () + +local cb: Callback +function register_callback(c: Callback) + cb = c +end + +function invoke_callback(i: Instance) + cb(i) +end + +--- + +function bad_callback(x: number?) +end + +local obscured: () -> () = bad_callback + +register_callback(obscured) + +function good_callback() +end + +register_callback(good_callback) +``` + +The problem we run into is, if we allow the subtyping rule `(T?) -> () <: () -> ()` and also allow oversaturation of a function, it becomes easy to obscure an argument type and pass the wrong type of value to it. + +Next, consider the following type alias + +```lua +-- F.lua +type OldFunctionType = (any, any) -> any +type NewFunctionType = (any) -> any +type FunctionType = OldFunctionType & NewFunctionType +``` + +If we have a subtyping rule `(T0..TN) <: (T0..TN-1)` to permit the function subtyping relationship `(T0..TN-1) -> R <: (T0..TN) -> R`, then the above type alias normalizes to `(any) -> any`. In order to call the two-argument variation, we would need to permit oversaturation, which runs afoul of the soundness hole from the previous example. + +We need a solution here. + +To resolve this, let's reframe things in simpler terms: + +If there is never a subtyping relationship between packs of different length, then we don't have any soundness issues, but we find ourselves unable to register `good_callback`. + +To resolve _that_, consider that we are in truth being a bit hasty when we say `good_callback : () -> ()`. We can pass any number of arguments to this function safely. We could choose to type `good_callback : () -> () & (any) -> () & (any, any) -> () & ...`. Luau already has syntax for this particular sort of infinite intersection: `good_callback : (any...) -> ()`. + +So, we propose some different inference rules for functions: + +1. The AST fragment `function(arg0..argN) ... end` is typed `(T0..TN, any...) -> R` where `arg0..argN : T0..TN` and `R` is the inferred return type of the function body. Function statements are inferred the same way. +1. Type annotations are unchanged. `() -> ()` is still a nullary function. + +For reference, the subtyping rules for unions and functions are unchanged. We include them here for clarity. + +1. `A <: A | B` +1. `B <: A | B` +1. `A | B <: T` if `A <: T` or `B <: T` +1. `T -> R <: U -> S` if `U <: T` and `R <: S` + +We propose new subtyping rules for type packs: + +1. `(T0..TN) <: (U0..UN)` if, for each `T` and `U`, `T <: U` +1. `(U...)` is the same as `() | (U) | (U, U) | (U, U, U) | ...`, therefore +1. `(T0..TN) <: (U...)` if for each `T`, `T <: U`, therefore +1. `(U...) -> R <: (T0..TN) -> R` if for each `T`, `T <: U` + +The important difference is that we remove all subtyping rules that mention options. Functions of different arities are no longer considered subtypes of one another. Optional function arguments are still allowed, but function as a feature of function calls. + +Under these rules, functions of different arities can never be converted to one another, but actual functions are known to be safe to oversaturate with anything, and so gain a type that says so. + +Under these subtyping rules, snippets `C.lua` and `D.lua`, check the way we want: literal functions are implicitly safe to oversaturate, so it is fine to cast them as the necessary callback function type. + +`E.lua` also typechecks the way we need it to: `(Instance) -> () ()` and so `obscured` cannot receive the value `bad_callback`, which prevents it from being passed to `register_callback`. However, `good_callback : (any...) -> ()` and `(any...) -> () <: (Instance) -> ()` and so it is safe to register `good_callback`. + +Snippet `F.lua` is also fixed with this ruleset: There is no subtyping relationship between `(any) -> ()` and `(any, any) -> ()`, so the intersection is not combined under normalization. + +This works, but itself creates some small problems that we need to resolve: + +First, the `...` symbol still needs to be unavailable for functions that have been given this implicit `...any` type. This is actually taken care of in the Luau parser, so no code change is required. + +Secondly, we do not want to silently allow oversaturation of direct calls to a function if we know that the arguments will be ignored. We need to treat these variadic packs differently when unifying for function calls. + +Thirdly, we don't want to display this variadic in the signature if the author doesn't expect to see it. + +We solve these issues by adding a property `bool VariadicTypePack::hidden` to the implementation and switching on it in the above scenarios. The implementation is relatively straightforward for all 3 cases. + +## Drawbacks + +There is a potential cause for concern that we will be inferring unions of functions in cases where we previously did not. Unions are known to be potential sources of performance issues. One possibility is to allow Luau to be less intelligent and have it "give up" and produce less precise types. This would come at the cost of accuracy and soundness. + +If we allow functions to be oversaturated, we are going to miss out on opportunities to warn the user about legitimate problems with their program. I think we will have to work out some kind of special logic to detect when we are oversaturating a function whose exact definition is known and warn on that. + +Allowing indirect function calls to be oversaturated with `nil` values only should be safe, but a little bit unfortunate. As long as we statically know for certain that `nil` is actually a permissible value for that argument position, it should be safe. + +## Alternatives + +If we are willing to sacrifice soundness, we could adopt success typing and come up with an inference algorithm that produces less precise type information. + +We could also technically choose to do nothing, but this has some unpalatable consequences: Something I would like to do in the near future is to have the inference algorithm assume the same `self` type for all methods of a table. This will make inference of common OO patterns dramatically more intuitive and ergonomic, but inference of polymorphic methods requires some kind of lower bounds calculation to work correctly. diff --git a/rfcs/never-and-unknown-types.md b/rfcs/never-and-unknown-types.md new file mode 100644 index 00000000..5ad216ef --- /dev/null +++ b/rfcs/never-and-unknown-types.md @@ -0,0 +1,146 @@ +# never and unknown types + +**Status**: Implemented + +## Summary + +Add `unknown` and `never` types that are inhabited by everything and nothing respectively. + +## Motivation + +There are lots of cases in local type inference, semantic subtyping, +and type normalization, where it would be useful to have top and +bottom types. Currently, `any` is filling that role, but it has +special "switch off the type system" superpowers. + +Any use of `unknown` must be narrowed by type refinements unless another `unknown` or `any` is expected. For +example a function which can return any value is: + +```lua + function anything() : unknown ... end +``` + +and can be used as: + +```lua + local x = anything() + if type(x) == "number" then + print(x + 1) + end +``` + +The type of this function cannot be given concisely in current +Luau. The nearest equivalent is `any`, but this switches off the type system, for example +if the type of `anything` is `() -> any` then the following code typechecks: + +```lua + local x = anything() + print(x + 1) +``` + +This is fine in nonstrict mode, but strict mode should flag this as an error. + +The `never` type comes up whenever type inference infers incompatible types for a variable, for example + +```lua + function oops(x) + print("hi " .. x) -- constrains x must be a string + print(math.abs(x)) -- constrains x must be a number + end +``` + +The most general type of `x` is `string & number`, so this code gives +a type error, but we still need to provide a type for `oops`. With a +`never` type, we can infer the type `oops : (never) -> ()`. + +or when exhaustive type casing is achieved: + +```lua + function f(x: string | number) + if type(x) == "string" then + -- x : string + elseif type(x) == "number" then + -- x : number + else + -- x : never + end + end +``` + +or even when the type casing is simply nonsensical: + +```lua + function f(x: string | number) + if type(x) == "string" and type(x) == "number" then + -- x : string & number which is never + end + end +``` + +The `never` type is also useful in cases such as tagged unions where +some of the cases are impossible. For example: + +```lua + type Result = { err: false, val: T } | { err: true, err: E } +``` + +For code which we know is successful, we would like to be able to +indicate that the error case is impossible. With a `never` type, we +can do this with `Result`. Similarly, code which cannot succeed +has type `Result`. + +These types can _almost_ be defined in current Luau, but only quite verbosely: + +```lua + type never = number & string + type unknown = nil | number | boolean | string | {} | (...never) -> (...unknown) +``` + +But even for `unknown` it is impossible to include every single data types, e.g. every root class. + +Providing `never` and `unknown` as built-in types makes the code for +type inference simpler, for example we have a way to present a union +type with no options (as `never`). Otherwise we have to contend with ad hoc +corner cases. + +## Design + +Add: + +* a type `never`, inhabited by nothing, and +* a type `unknown`, inhabited by everything. + +And under success types (nonstrict mode), `unknown` is exactly equivalent to `any` because `unknown` +encompasses everything as does `any`. + +The interesting thing is that `() -> (never, string)` is equivalent to `() -> never` because all +values in a pack must be inhabitable in order for the pack itself to also be inhabitable. In fact, +the type `() -> never` is not completely accurate, it should be `() -> (never, ...never)` to avoid +cascading type errors. Ditto for when an expression list `f(), g()` where the resulting type pack is +`(never, string, number)` is still the same as `(never, ...never)`. + +```lua + function f(): never error() end + function g(): string return "" end + + -- no cascading type error where count mismatches, because the expression list f(), g() + -- was made to return (never, ...never) due to the presence of a never type in the pack + local x, y, z = f(), g() + -- x : never + -- y : never + -- z : never +``` + +## Drawbacks + +Another bit of complexity budget spent. + +These types will be visible to creators, so yay bikeshedding! + +Replacing `any` with `unknown` is a breaking change: code in strict mode may now produce errors. + +## Alternatives + +Stick with the current use of `any` for these cases. + +Make `never` and `unknown` type aliases rather than built-ins. diff --git a/rfcs/recursive-type-restriction.md b/rfcs/recursive-type-restriction.md index eb9b2124..6f69d43a 100644 --- a/rfcs/recursive-type-restriction.md +++ b/rfcs/recursive-type-restriction.md @@ -1,5 +1,11 @@ # Recursive type restriction +**Status**: Implemented + +## Summary + +Restrict generic type aliases to only be able to refer to the exact same instantiation of the generic that's being declared. + ## Motivation Luau supports recursive type aliases, but with an important restriction: diff --git a/rfcs/sealed-table-subtyping.md b/rfcs/sealed-table-subtyping.md index 7310795a..73714909 100644 --- a/rfcs/sealed-table-subtyping.md +++ b/rfcs/sealed-table-subtyping.md @@ -1,5 +1,9 @@ # Sealed table subtyping +**Status**: Implemented + +## Summary + In Luau, tables have a state, which can, among others, be "sealed". A sealed table is one that we know the full shape of and cannot have new properties added to it. We would like to introduce subtyping for sealed tables, to allow users to express some subtyping relationships that they currently cannot. ## Motivation diff --git a/rfcs/syntax-continue-statement.md b/rfcs/syntax-continue-statement.md index 4adde09e..94e2009a 100644 --- a/rfcs/syntax-continue-statement.md +++ b/rfcs/syntax-continue-statement.md @@ -56,7 +56,7 @@ end These rules are simple to implement. In any Lua parser there is already a point where you have to disambiguate an identifier that starts an assignment statement (`foo = 5`) from an identifier that starts a function call (`foo(5)`). It's one of the few, if not the only, place in the Lua grammar where single token lookahead is not sufficient to parse Lua, because you could have `foo.bar(5)` or `foo.bar=5` or `foo.bar(5)[6] = 7`. -Because of this, we need to parse the entire left hand side of an assignment statement (primaryexpr in Lua's BNF) and then check if it was a function call; if not, we'd expect it to be an assignment statement. +Because of this, we need to parse the entire left hand side of an assignment statement (primaryexp in Lua's BNF) and then check if it was a function call; if not, we'd expect it to be an assignment statement. Alternatively in this specific case we could parse "continue", parse the next token, and if it's one of the exclusion list above, roll the parser state back and re-parse the non-continue statement. Our lexer currently doesn't support rollbacks but it's also an easy strategy that other implementations might employ for `continue` specifically. diff --git a/rfcs/syntax-default-type-alias-type-parameters.md b/rfcs/syntax-default-type-alias-type-parameters.md index 23bd71a8..443bbac3 100644 --- a/rfcs/syntax-default-type-alias-type-parameters.md +++ b/rfcs/syntax-default-type-alias-type-parameters.md @@ -1,5 +1,7 @@ # Default type alias type parameters +**Status**: Implemented + ## Summary Introduce syntax to provide default type values inside the type alias type parameter list. @@ -48,9 +50,12 @@ type A = ... -- not allowed Default type parameter values are also allowed for type packs: ```lua -type A -- ok, variadic type pack -type C -- ok, type pack with one element -type D -- ok +type A -- ok, variadic type pack +type B -- ok, type pack with no elements +type C -- ok, type pack with one element +type D -- ok, type pack with two elements +type E -- ok, variadic type pack with a different first element +type F -- ok, same type pack as T... ``` --- @@ -68,7 +73,7 @@ If all type parameters have a default type value, it is now possible to referenc type All = ... local a: All -- ok -local b: All<> -- not allowed +local b: All<> -- ok as well ``` If type is exported from a module, default type parameter values will still be available when module is imported. diff --git a/rfcs/syntax-if-expression.md b/rfcs/syntax-if-expression.md index c900409e..76f76cf1 100644 --- a/rfcs/syntax-if-expression.md +++ b/rfcs/syntax-if-expression.md @@ -2,6 +2,8 @@ > Note: this RFC was adapted from an internal proposal that predates RFC process +**Status**: Implemented + ## Summary Introduce a form of ternary conditional using `if cond then value else alternative` syntax. @@ -10,7 +12,7 @@ Introduce a form of ternary conditional using `if cond then value else alternati Luau does not have a first-class ternary operator; when a ternary operator is needed, it is usually emulated with `and/or` expression, such as `cond and value or alternative`. -This expression evaluates to `value` if `cond` and `value` are truthful, and `alternative` otherwise. In particular it means that when `value` is `false` or `nil`, the result of the entire expression is `alternative` even when `cond` is truthful - which doesn't match the expected ternary logic and is a frequent source of subtle errors. +This expression evaluates to `value` if `cond` and `value` are truthy, and `alternative` otherwise. In particular it means that when `value` is `false` or `nil`, the result of the entire expression is `alternative` even when `cond` is truthy - which doesn't match the expected ternary logic and is a frequent source of subtle errors. Instead of `and/or`, `if/else` statement can be used but since that requires a separate mutable variable, this option isn't ergonomic. An immediately invoked function expression is also unergonomic and results in performance issues at runtime. @@ -18,9 +20,9 @@ Instead of `and/or`, `if/else` statement can be used but since that requires a s To solve these problems, we propose introducing a first-class ternary conditional. Instead of `? :` common in C-like languages, we propose an `if-then-else` expression form that is syntactically similar to `if-then-else` statement, but lacks terminating `end`. -Concretely, the `if-then-else` expression must match `if then else `; it can also contain an arbitrary number of `elseif` clauses, like `if then elseif then else `. Note that in either case, `else` is mandatory. +Concretely, the `if-then-else` expression must match `if then else `; it can also contain an arbitrary number of `elseif` clauses, like `if then elseif then else `. Unlike if statements, `else` is mandatory. -The result of the expression is the then-expression when condition is truthful (not `nil` or `false`) and else-expression otherwise. Only one of the two possible resulting expressions is evaluated. +The result of the expression is the then-expression when condition is truthy (not `nil` or `false`) and else-expression otherwise. Only one of the two possible resulting expressions is evaluated. Example: @@ -28,11 +30,27 @@ Example: local x = if FFlagFoo then A else B MyComponent.validateProps = t.strictInterface({ - layoutOrder = t.optional(t.number), - newThing = if FFlagUseNewThing then t.whatever() else nil, + layoutOrder = t.optional(t.number), + newThing = if FFlagUseNewThing then t.whatever() else nil, }) ``` +Note that `else` is mandatory because it's always better to be explicit. If it weren't mandatory, it opens the possiblity that someone might be writing a chain of if-then-else and forgot to add in the final `else` that _doesn't_ return a `nil` value! Enforcing this syntactically ensures the program does not run. Also, with it being mandatory, it solves many cases where parsing the expression is ambiguous due to the infamous [dangling else](https://en.wikipedia.org/wiki/Dangling_else). + +This example will not do what it looks like it's supposed to do! The if expression will _successfully_ parse and be interpreted as to return `h()` if `g()` evaluates to some falsy value, when in actual fact the clear intention is to evaluate `h()` only if `f()` is falsy. + +```lua +if f() then + ... + local foo = if g() then x +else + h() + ... +end +``` + +The only way to solve this had we chose optional `else` branch would be to wrap the if expression in parentheses or to place a semi-colon. + ## Drawbacks Studio's script editor autocomplete currently adds an indented block followed by `end` whenever a line ends that includes a `then` token. This can make use of the if expression unpleasant as developers have to keep fixing the code by removing auto-inserted `end`. We can work around this on the editor side by (short-term) differentiating between whether `if` token is the first on its line, and (long-term) by refactoring completion engine to use infallible parser for the block completer. diff --git a/rfcs/syntax-method-call-on-string-literals.md b/rfcs/syntax-method-call-on-string-literals.md deleted file mode 100644 index 877774d0..00000000 --- a/rfcs/syntax-method-call-on-string-literals.md +++ /dev/null @@ -1,70 +0,0 @@ -# Allow method call on string literals - -> Note: this RFC was adapted from an internal proposal that predates RFC process - -## Summary - -Allow string literals to be indexed on without parentheses or from an identifier. That is, the following snippet will become legal under this proposal: - -```lua -print("Hello, %s!":format("world")) -print("0123456789ABCDEF":sub(i, i)) -``` - -## Motivation - -Experienced Lua developers occasionally run into this paper-cut even after years of working with the language. Programmers in Lua frequently wants to format a user-facing message using a constant string, but the parser will not accept it as legible syntax. - -## Design - -Formally, the proposal is to move the `String` parser from `exp` to `prefixexp`: - -```diff - var ::= Name | prefixexp `[´ exp `]´ | prefixexp `.´ Name -- exp ::= nil | false | true | Number | String | `...´ | function | -+ exp ::= nil | false | true | Number | `...´ | function - | prefixexp | tableconstructor | exp binop exp | unop exp -- prefixexp ::= var | functioncall | `(´ exp `)´ -+ prefixexp ::= String | var | functioncall | `(´ exp `)´ - functioncall ::= prefixexp args | prefixexp `:´ Name args -``` - -By itself, this change introduces an additional ambiguity because of the combination of non-significant whitespace and function calls with string literal as the first argument without the use of parentheses. - -Consider code like this: - -```lua -local foo = bar -("fmt"):format(...) -``` - -The grammar for this sequence suggests that the above is a function call to bar with a single string literal argument, "fmt", and format method is called on the result. This is a consequence of line endings not being significant, but humans don't read the code like this, and are likely to think that here, format is called on the string literal "fmt". - -Because of this, Lua 5.1 produces a syntax error whenever function call arguments start on the next line. Luau has the same error production rule; future versions of Lua remove this restriction but it's not clear that we want to remove this as this does help prevent errors. - -The grammar today also allows calling functions with string literals as their first (and only) argument without the use of parentheses; bar "fmt" is a function call. This is helpful when defining embedded domain-specific languages. - -By itself, this proposal thus would create a similar ambiguity in code like this: - -```lua -local foo = bar -"fmt":format(...) -``` - -While we could extend the line-based error check to include function literal arguments, this is not syntactically backwards compatible and as such may break existing code. A simpler and more conservative solution is to disallow string literal as the leading token of a new statement - there are no cases when this is valid today, so it's safe to limit this. - -Doing so would prohibit code like this: - -```lua -"fmt":format(...) -``` - -However, there are no methods on the string object where code like this would be meaningful. As such, in addition to changing the grammar wrt string literals, we will add an extra ambiguity error whenever a statement starts with a string literal. - -## Drawbacks - -None known. - -## Alternatives - -The infallible parser could be mended in this exact scenario to report a more friendly error message. We decided not to do this because there is more value to gain by simply supporting the main proposal. diff --git a/rfcs/syntax-nil-forgiving-operator.md b/rfcs/syntax-nil-forgiving-operator.md deleted file mode 100644 index 63d5cabb..00000000 --- a/rfcs/syntax-nil-forgiving-operator.md +++ /dev/null @@ -1,90 +0,0 @@ -# nil-forgiving postfix operator ! - -## Summary - -Introduce syntax to suppress typechecking errors for nilable types by ascribing them into a non-nil type. - -## Motivation - -Typechecking might not be able to figure out that a certain expression is a non-nil type, but the user might know additional context of the expression. - -Using `::` ascriptions is the current work-around for this issue, but it's much more verbose and requires providing the full type name when you only want to ascribe T? to T. - -The nil-forgiving operator will also allow chaining to be written in a very terse manner: -```lua -local p = a!.b!.c -``` -instead of -```lua -local ((p :: Part).b :: Folder).c -``` - -Note that nil-forgiving operator is **not** a part of member access operator, it can be used in standalone expressions, indexing and other places: -```lua -local p = f(a!)! -local q = b!['X'] -``` - -Nil-forgiving operator can be found in some programming languages such as C# (null-forgiving or null-suppression operator) and TypeScript (non-null assertion operator). - -## Design - -To implement this, we will change the syntax of the *primaryexp*. - -Before: -``` -primaryexp ::= prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs } -``` -After: -``` -postfixeexp ::= (`.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs) [`!'] -primaryexp ::= prefixexp [`!'] { postfixeexp } -``` - -When we get the `!` token, we will wrap the expression that we have into a new AstExprNonNilAssertion node. - -An error is generated when the type of expression node this is one of the following: -* AstExprConstantBool -* AstExprConstantNumber -* AstExprConstantString -* AstExprFunction -* AstExprTable - -This operator doesn't have any impact on the run-time behavior of the program, it will only affect the type of the expression in the typechecker. - ---- -While parsing an assignment expression starts with a *primaryexp*, it performs a check that it has an l-value based on a fixed set of AstNode types. - -Since using `!` on an l-value has no effect, we don't extend this list with the new node and will generate a specialized parse error for code like: -```lua -p.a! = b -``` - ---- -When operator is used on expression of a union type with a `nil` option, it removes that option from the set. -If only one option remains, the union type is replaced with the type of a single option. - -If the type is `nil`, typechecker will generate a warning that the operator cannot be applied to the expression. - -For any other type, it has no effect and doesn't generate additional warnings. - -The reason for the last rule is to simplify movement of existing code where context in each location is slightly different. - -As an example from Roblox, instance path could dynamically change from being know to exist to be missing when script is changed in edit mode. - -## Drawbacks - -### Unnecessary operator use - -It might be useful to warn about unnecessary uses of this operator when the value cannot be `nil`, but we have no way of enabling this behavior. - -### Bad practice - -The operator might be placed by users to ignore/silence correct warnings and lower the strength of type checking that Luau provides. - -## Alternatives - -Aside from type assertion operator :: it should be possible to place `assert` function calls before the operation. -Type refinement/constraints should handle that statement and avoid warning in the following expressions. - -But `assert` call will introduce runtime overhead without adding extra safety to the case when the type is nil at run time, in both cases an error will be thrown. diff --git a/rfcs/syntax-singleton-types.md b/rfcs/syntax-singleton-types.md index 749006c7..2c1f5442 100644 --- a/rfcs/syntax-singleton-types.md +++ b/rfcs/syntax-singleton-types.md @@ -2,6 +2,8 @@ > Note: this RFC was adapted from an internal proposal that predates RFC process +**Status**: Implemented + ## Summary Introduce a new kind of type variable, called singleton types. They are just like normal types but has the capability to represent a constant runtime value as a type. @@ -48,6 +50,18 @@ type Animals = "Dog" | "Cat" | "Bird" type TrueOrNil = true? ``` +Adding constant strings as type means that it is now legal to write +`{["foo"]:T}` as a table type. This should be parsed as a property, +not an indexer. For example: +```lua + type T = { + ["foo"]: number, + ["$$bar"]: string, + baz: boolean, + } +``` +The table type `T` is a table with three properties and no indexer. + ### Semantics You are allowed to provide a constant value to the generic primitive type. diff --git a/rfcs/syntax-string-interpolation.md b/rfcs/syntax-string-interpolation.md new file mode 100644 index 00000000..2fbb04b0 --- /dev/null +++ b/rfcs/syntax-string-interpolation.md @@ -0,0 +1,169 @@ +# String interpolation + +## Summary + +New string interpolation syntax. + +## Motivation + +The problems with `string.format` are many. + +1. Must be exact about the types and its corresponding value. +2. Using `%d` is the idiomatic default for most people, but this loses precision. + * `%d` casts the number into `long long`, which has a lower max value than `double` and does not support decimals. + * `%f` by default will format to the millionths, e.g. `5.5` is `5.500000`. + * `%g` by default will format up to the hundred thousandths, e.g. `5.5` is `5.5` and `5.5312389` is `5.53123`. It will also convert the number to scientific notation when it encounters a number equal to or greater than 10^6. + * To not lose too much precision, you need to use `%s`, but even so the type checker assumes you actually wanted strings. +3. No support for `boolean`. You must use `%s` **and** call `tostring`. +4. No support for values implementing the `__tostring` metamethod. +5. Using `%` is in itself a dangerous operation within `string.format`. + * `"Your health is %d% so you need to heal up."` causes a runtime error because `% so` is actually parsed as `(%s)o` and now requires a corresponding string. +6. Having to use parentheses around string literals just to call a method of it. + +## Design + +To fix all of those issues, we need to do a few things. + +1. A new string interpolation expression (fixes #5, #6) +2. Extend `string.format` to accept values of arbitrary types (fixes #1, #2, #3, #4) + +Because we care about backward compatibility, we need some new syntax in order to not change the meaning of existing strings. There are a few components of this new expression: + +1. A string chunk (`` `...{ ``, `}...{`, and `` }...` ``) where `...` is a range of 0 to many characters. + * `\` escapes `` ` ``, `{`, and itself `\`. + * The pairs must be on the same line (unless a `\` escapes the newline) but expressions needn't be on the same line. +2. An expression between the braces. This is the value that will be interpolated into the string. + * Restriction: we explicitly reject `{{` as it is considered an attempt to escape and get a single `{` character at runtime. +3. Formatting specification may follow after the expression, delimited by an unambiguous character. + * Restriction: the formatting specification must be constant at parse time. + * In the absence of an explicit formatting specification, the `%*` token will be used. + * For now, we explicitly reject any formatting specification syntax. A future extension may be introduced to extend the syntax with an optional specification. + +To put the above into formal EBNF grammar: + +``` +stringinterp ::= exp { exp} +``` + +Which, in actual Luau code, will look like the following: + +``` +local world = "world" +print(`Hello {world}!`) +--> Hello world! + +local combo = {5, 2, 8, 9} +print(`The lock combinations are: {table.concat(combo, ", ")}`) +--> The lock combinations are: 5, 2, 8, 9 + +local set1 = Set.new({0, 1, 3}) +local set2 = Set.new({0, 5, 4}) +print(`{set1} ∪ {set2} = {Set.union(set1, set2)}`) +--> {0, 1, 3} ∪ {0, 5, 4} = {0, 1, 3, 4, 5} + +print(`Some example escaping the braces \{like so}`) +print(`backslash \ that escapes the space is not a part of the string...`) +print(`backslash \\ will escape the second backslash...`) +print(`Some text that also includes \`...`) +--> Some example escaping the braces {like so} +--> backslash that escapes the space is not a part of the string... +--> backslash \ will escape the second backslash... +--> Some text that also includes `... +``` + +As for how newlines are handled, they are handled the same as other string literals. Any text between the `{}` delimiters are not considered part of the string, hence newlines are OK. The main thing is that one opening pair will scan until either a closing pair is encountered, or an unescaped newline. + +``` +local name = "Luau" + +print(`Welcome to { + name +}!`) +--> Welcome to Luau! + +print(`Welcome to \ +{name}!`) +--> Welcome to +-- Luau! +``` + +We currently *prohibit* using interpolated strings in function calls without parentheses, this is illegal: + +``` +local name = "world" +print`Hello {name}` +``` + +> Note: This restriction is likely temporary while we work through string interpolation DSLs, an ability to pass individual components of interpolated strings to a function. + +The restriction on `{{` exists solely for the people coming from languages e.g. C#, Rust, or Python which uses `{{` to escape and get the character `{` at runtime. We're also rejecting this at parse time too, since the proper way to escape it is `\{`, so: + +```lua +print(`{{1, 2, 3}} = {myCoolSet}`) -- parse error +``` + +If we did not apply this as a parse error, then the above would wind up printing as the following, which is obviously a gotcha we can and should avoid. + +``` +--> table: 0xSOMEADDRESS = {1, 2, 3} +``` + +Since the string interpolation expression is going to be lowered into a `string.format` call, we'll also need to extend `string.format`. The bare minimum to support the lowering is to add a new token whose definition is to perform a `tostring` call. `%*` is currently an invalid token, so this is a backward compatible extension. This RFC shall define `%*` to have the same behavior as if `tostring` was called. + +```lua +print(string.format("%* %*", 1, 2)) +--> 1 2 +``` + +The offset must always be within bound of the numbers of values passed to `string.format`. + +```lua +local function return_one_thing() return "hi" end +local function return_two_nils() return nil, nil end + +print(string.format("%*", return_one_thing())) +--> "hi" + +print(string.format("%*", Set.new({1, 2, 3}))) +--> {1, 2, 3} + +print(string.format("%* %*", return_two_nils())) +--> nil nil + +print(string.format("%* %* %*", return_two_nils())) +--> error: value #3 is missing, got 2 +``` + +It must be said that we are not allowing this style of string literals in type annotations at this time, regardless of zero or many interpolating expressions, so the following two type annotations below are illegal syntax: + +```lua +local foo: `foo` +local bar: `bar{baz}` +``` + +String interpolation syntax will also support escape sequences. Except `\u{...}`, there is no ambiguity with other escape sequences. If `\u{...}` occurs within a string interpolation literal, it takes priority. + +```lua +local foo = `foo\tbar` -- "foo bar" +local bar = `\u{0041} \u{42}` -- "A B" +``` + +## Drawbacks + +If we want to use backticks for other purposes, it may introduce some potential ambiguity. One option to solve that is to only ever produce string interpolation tokens from the context of an expression. This is messy but doable because the parser and the lexer are already implemented to work in tandem. The other option is to pick a different delimiter syntax to keep backticks available for use in the future. + +If we were to naively compile the expression into a `string.format` call, then implementation details would be observable if you write `` `Your health is {hp}% so you need to heal up.` ``. When lowering the expression, we would need to implicitly insert a `%` character anytime one shows up in a string interpolation token. Otherwise attempting to run this will produce a runtime error where the `%s` token is missing its corresponding string value. + +## Alternatives + +Rather than coming up with a new syntax (which doesn't help issue #5 and #6) and extending `string.format` to accept an extra token, we could just make `%s` call `tostring` and be done. However, doing so would cause programs to be more lenient and the type checker would have no way to infer strings from a `string.format` call. To preserve that, we would need a different token anyway. + +Language | Syntax | Conclusion +----------:|:----------------------|:----------- +Python | `f'Hello {name}'` | Rejected because it's ambiguous with function call syntax. +Swift | `"Hello \(name)"` | Rejected because it changes the meaning of existing strings. +Ruby | `"Hello #{name}"` | Rejected because it changes the meaning of existing strings. +JavaScript | `` `Hello ${name}` `` | Viable option as long as we don't intend to use backticks for other purposes. +C# | `$"Hello {name}"` | Viable option and guarantees no ambiguities with future syntax. + +This leaves us with only two syntax that already exists in other programming languages. The current proposal are for backticks, so the only backward compatible alternative are `$""` literals. We don't necessarily need to use `$` symbol here, but if we were to choose a different symbol, `#` cannot be used. I picked backticks because it doesn't require us to add a stack of closing delimiters in the lexer to make sure each nested string interpolation literals are correctly closed with its opening pair. You only have to count them. diff --git a/rfcs/syntax-type-alias-type-packs.md b/rfcs/syntax-type-alias-type-packs.md index aa088299..d5bb6065 100644 --- a/rfcs/syntax-type-alias-type-packs.md +++ b/rfcs/syntax-type-alias-type-packs.md @@ -1,5 +1,7 @@ # Type alias type packs +**Status**: Implemented + ## Summary Provide semantics for referencing type packs inside the body of a type alias declaration diff --git a/rfcs/syntax-type-ascription-bidi.md b/rfcs/syntax-type-ascription-bidi.md index 9d31939e..0831aba5 100644 --- a/rfcs/syntax-type-ascription-bidi.md +++ b/rfcs/syntax-type-ascription-bidi.md @@ -1,5 +1,7 @@ # Relaxing type assertions +**Status**: Implemented + ## Summary The way `::` works today is really strange. The best solution we can come up with is to allow `::` to convert between any two related types. diff --git a/rfcs/unsealed-table-assign-optional-property.md b/rfcs/unsealed-table-assign-optional-property.md index 78f674aa..477399c2 100644 --- a/rfcs/unsealed-table-assign-optional-property.md +++ b/rfcs/unsealed-table-assign-optional-property.md @@ -1,5 +1,9 @@ # Unsealed table assignment creates an optional property +**Status**: Implemented + +## Summary + In Luau, tables have a state, which can, among others, be "unsealed". An unsealed table is one that we are still constructing. Currently assigning a table literal to an unsealed table does not introduce new @@ -53,4 +57,4 @@ and so needs access to an allocator. ## Alternatives Rather than introducing optional properties, we could introduce an indexer. For example we could infer the type of -`t` as `{ u: { [string]: number } }`. \ No newline at end of file +`t` as `{ u: { [string]: number } }`. diff --git a/rfcs/unsealed-table-literals.md b/rfcs/unsealed-table-literals.md new file mode 100644 index 00000000..669b67d4 --- /dev/null +++ b/rfcs/unsealed-table-literals.md @@ -0,0 +1,78 @@ +# Unsealed table literals + +**Status**: Implemented + +## Summary + +Currently the only way to create an unsealed table is as an empty table literal `{}`. +This RFC proposes making all table literals unsealed. + +## Motivation + +Table types can be *sealed* or *unsealed*. These are different in that: + +* Unsealed table types are *precise*: if a table has unsealed type `{ p: number, q: string }` + then it is guaranteed to have only properties `p` and `q`. + +* Sealed tables support *width subtyping*: if a table has sealed type `{ p: number }` + then it is guaranteed to have at least property `p`, so we allow `{ p: number, q: string }` + to be treated as a subtype of `{ p: number }` + +* Unsealed tables can have properties added to them: if `t` has unsealed type + `{ p: number }` then after the assignment `t.q = "hi"`, `t`'s type is updated to be + `{ p: number, q: string }`. + +* Unsealed tables are subtypes of sealed tables. + +Currently the only way to create an unsealed table is using an empty table literal, so +```lua + local t = {} + t.p = 5 + t.q = "hi" +``` +typechecks, but +```lua + local t = { p = 5 } + t.q = "hi" +``` +does not. + +This causes problems in examples, in particular developers +may initialize properties but not methods: +```lua + local t = { p = 5 } + function t.f() return t.p end +``` + +## Design + +The proposed change is straightforward: make all table literals unsealed. + +## Drawbacks + +Making all table literals unsealed is a conservative change, it only removes type errors. + +It does encourage developers to add new properties to tables during initialization, which +may be considered poor style. + +It does mean that some spelling mistakes will not be caught, for example +```lua +local t = {x = 1, y = 2} +if foo then + t.z = 3 -- is z a typo or intentional 2-vs-3 choice? +end +``` + +In particular, we no longer warn about adding properties to array-like tables. +```lua +local a = {1,2,3} +a.p = 5 +``` + +## Alternatives + +We could introduce a new table state for unsealed-but-precise +tables. The trade-off is that that would be more precise, at the cost +of adding user-visible complexity to the type system. + +We could continue to treat array-like tables as sealed. diff --git a/rfcs/unsealed-table-subtyping-strips-optional-properties.md b/rfcs/unsealed-table-subtyping-strips-optional-properties.md new file mode 100644 index 00000000..d99c1f81 --- /dev/null +++ b/rfcs/unsealed-table-subtyping-strips-optional-properties.md @@ -0,0 +1,68 @@ +# Only strip optional properties from unsealed tables during subtyping + +**Status**: Implemented + +## Summary + +Currently subtyping allows optional properties to be stripped from table types during subtyping. +This RFC proposes only allowing that when the subtype is unsealed and the supertype is sealed. + +## Motivation + +Table types can be *sealed* or *unsealed*. These are different in that: + +* Unsealed table types are *precise*: if a table has unsealed type `{ p: number, q: string }` + then it is guaranteed to have only properties `p` and `q`. + +* Sealed tables support *width subtyping*: if a table has sealed type `{ p: number }` + then it is guaranteed to have at least property `p`, so we allow `{ p: number, q: string }` + to be treated as a subtype of `{ p: number }` + +* Unsealed tables can have properties added to them: if `t` has unsealed type + `{ p: number }` then after the assignment `t.q = "hi"`, `t`'s type is updated to be + `{ p: number, q: string }`. + +* Unsealed tables are subtypes of sealed tables. + +Currently we allow subtyping to strip away optional fields +as long as the supertype is sealed. +This is necessary for examples, for instance: +```lua + local t : { p: number, q: string? } = { p = 5, q = "hi" } + t = { p = 7 } +``` +typechecks because `{ p : number }` is a subtype of +`{ p : number, q : string? }`. Unfortunately this is not sound, +since sealed tables support width subtyping: +```lua + local t : { p: number, q: string? } = { p = 5, q = "hi" } + local u : { p: number } = { p = 5, q = false } + t = u +``` + +## Design + +The fix for this source of unsoundness is twofold: + +1. make all table literals unsealed, and +2. only allow stripping optional properties from when the + supertype is sealed and the subtype is unsealed. + +This RFC is for (2). There is a [separate RFC](unsealed-table-literals.md) for (1). + +## Drawbacks + +This introduces new type errors (it has to, since it is fixing a source of +unsoundness). This means that there are now false positives such as: +```lua + local t : { p: number, q: string? } = { p = 5, q = "hi" } + local u : { p: number } = { p = 5, q = "lo" } + t = u +``` +These false positives are so similar to sources of unsoundness +that it is difficult to see how to allow them soundly. + +## Alternatives + +We could just live with unsoundness. + diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp new file mode 100644 index 00000000..d808ac49 --- /dev/null +++ b/tests/AssemblyBuilderA64.test.cpp @@ -0,0 +1,272 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AssemblyBuilderA64.h" +#include "Luau/StringUtils.h" + +#include "doctest.h" + +#include + +using namespace Luau::CodeGen; + +static std::string bytecodeAsArray(const std::vector& bytecode) +{ + std::string result = "{"; + + for (size_t i = 0; i < bytecode.size(); i++) + Luau::formatAppend(result, "%s0x%02x", i == 0 ? "" : ", ", bytecode[i]); + + return result.append("}"); +} + +static std::string bytecodeAsArray(const std::vector& code) +{ + std::string result = "{"; + + for (size_t i = 0; i < code.size(); i++) + Luau::formatAppend(result, "%s0x%08x", i == 0 ? "" : ", ", code[i]); + + return result.append("}"); +} + +class AssemblyBuilderA64Fixture +{ +public: + bool check(void (*f)(AssemblyBuilderA64& build), std::vector code, std::vector data = {}) + { + AssemblyBuilderA64 build(/* logText= */ false); + + f(build); + + build.finalize(); + + if (build.code != code) + { + printf("Expected code: %s\nReceived code: %s\n", bytecodeAsArray(code).c_str(), bytecodeAsArray(build.code).c_str()); + return false; + } + + if (build.data != data) + { + printf("Expected data: %s\nReceived data: %s\n", bytecodeAsArray(data).c_str(), bytecodeAsArray(build.data).c_str()); + return false; + } + + return true; + } +}; + +// armconverter.com can be used to validate instruction sequences +TEST_SUITE_BEGIN("A64Assembly"); + +#define SINGLE_COMPARE(inst, ...) \ + CHECK(check( \ + [](AssemblyBuilderA64& build) { \ + build.inst; \ + }, \ + {__VA_ARGS__})) + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Unary") +{ + SINGLE_COMPARE(neg(x0, x1), 0xCB0103E0); + SINGLE_COMPARE(neg(w0, w1), 0x4B0103E0); + SINGLE_COMPARE(mvn(x0, x1), 0xAA2103E0); + + SINGLE_COMPARE(clz(x0, x1), 0xDAC01020); + SINGLE_COMPARE(clz(w0, w1), 0x5AC01020); + SINGLE_COMPARE(rbit(x0, x1), 0xDAC00020); + SINGLE_COMPARE(rbit(w0, w1), 0x5AC00020); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") +{ + // reg, reg + SINGLE_COMPARE(add(x0, x1, x2), 0x8B020020); + SINGLE_COMPARE(add(w0, w1, w2), 0x0B020020); + SINGLE_COMPARE(add(x0, x1, x2, 7), 0x8B021C20); + SINGLE_COMPARE(sub(x0, x1, x2), 0xCB020020); + SINGLE_COMPARE(and_(x0, x1, x2), 0x8A020020); + SINGLE_COMPARE(orr(x0, x1, x2), 0xAA020020); + SINGLE_COMPARE(eor(x0, x1, x2), 0xCA020020); + SINGLE_COMPARE(lsl(x0, x1, x2), 0x9AC22020); + SINGLE_COMPARE(lsl(w0, w1, w2), 0x1AC22020); + SINGLE_COMPARE(lsr(x0, x1, x2), 0x9AC22420); + SINGLE_COMPARE(asr(x0, x1, x2), 0x9AC22820); + SINGLE_COMPARE(ror(x0, x1, x2), 0x9AC22C20); + SINGLE_COMPARE(cmp(x0, x1), 0xEB01001F); + + // reg, imm + SINGLE_COMPARE(add(x3, x7, 78), 0x910138E3); + SINGLE_COMPARE(add(w3, w7, 78), 0x110138E3); + SINGLE_COMPARE(sub(w3, w7, 78), 0x510138E3); + SINGLE_COMPARE(cmp(w0, 42), 0x7100A81F); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads") +{ + // address forms + SINGLE_COMPARE(ldr(x0, x1), 0xF9400020); + SINGLE_COMPARE(ldr(x0, mem(x1, 8)), 0xF9400420); + SINGLE_COMPARE(ldr(x0, mem(x1, x7)), 0xF8676820); + SINGLE_COMPARE(ldr(x0, mem(x1, -7)), 0xF85F9020); + + // load sizes + SINGLE_COMPARE(ldr(x0, x1), 0xF9400020); + SINGLE_COMPARE(ldr(w0, x1), 0xB9400020); + SINGLE_COMPARE(ldrb(w0, x1), 0x39400020); + SINGLE_COMPARE(ldrh(w0, x1), 0x79400020); + SINGLE_COMPARE(ldrsb(x0, x1), 0x39800020); + SINGLE_COMPARE(ldrsb(w0, x1), 0x39C00020); + SINGLE_COMPARE(ldrsh(x0, x1), 0x79800020); + SINGLE_COMPARE(ldrsh(w0, x1), 0x79C00020); + SINGLE_COMPARE(ldrsw(x0, x1), 0xB9800020); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Stores") +{ + // address forms + SINGLE_COMPARE(str(x0, x1), 0xF9000020); + SINGLE_COMPARE(str(x0, mem(x1, 8)), 0xF9000420); + SINGLE_COMPARE(str(x0, mem(x1, x7)), 0xF8276820); + SINGLE_COMPARE(strh(w0, mem(x1, -7)), 0x781F9020); + + // store sizes + SINGLE_COMPARE(str(x0, x1), 0xF9000020); + SINGLE_COMPARE(str(w0, x1), 0xB9000020); + SINGLE_COMPARE(strb(w0, x1), 0x39000020); + SINGLE_COMPARE(strh(w0, x1), 0x79000020); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Moves") +{ + SINGLE_COMPARE(mov(x0, x1), 0xAA0103E0); + SINGLE_COMPARE(mov(w0, w1), 0x2A0103E0); + SINGLE_COMPARE(mov(x0, 42), 0xD2800540); + SINGLE_COMPARE(mov(w0, 42), 0x52800540); + SINGLE_COMPARE(movk(x0, 42, 16), 0xF2A00540); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "ControlFlow") +{ + // Jump back + CHECK(check( + [](AssemblyBuilderA64& build) { + Label start = build.setLabel(); + build.mov(x0, x1); + build.b(ConditionA64::Equal, start); + }, + {0xAA0103E0, 0x54FFFFE0})); + + // Jump forward + CHECK(check( + [](AssemblyBuilderA64& build) { + Label skip; + build.b(ConditionA64::Equal, skip); + build.mov(x0, x1); + build.setLabel(skip); + }, + {0x54000040, 0xAA0103E0})); + + // Jumps + CHECK(check( + [](AssemblyBuilderA64& build) { + Label skip; + build.b(ConditionA64::Equal, skip); + build.cbz(x0, skip); + build.cbnz(x0, skip); + build.setLabel(skip); + build.b(skip); + }, + {0x54000060, 0xB4000040, 0xB5000020, 0x5400000E})); + + // Basic control flow + SINGLE_COMPARE(br(x0), 0xD61F0000); + SINGLE_COMPARE(blr(x0), 0xD63F0000); + SINGLE_COMPARE(ret(), 0xD65F03C0); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "StackOps") +{ + SINGLE_COMPARE(mov(x0, sp), 0x910003E0); + SINGLE_COMPARE(mov(sp, x0), 0x9100001F); + + SINGLE_COMPARE(add(sp, sp, 4), 0x910013FF); + SINGLE_COMPARE(sub(sp, sp, 4), 0xD10013FF); + + SINGLE_COMPARE(add(x0, sp, 4), 0x910013E0); + SINGLE_COMPARE(sub(sp, x0, 4), 0xD100101F); + + SINGLE_COMPARE(ldr(x0, mem(sp, 8)), 0xF94007E0); + SINGLE_COMPARE(str(x0, mem(sp, 8)), 0xF90007E0); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Constants") +{ + // clang-format off + CHECK(check( + [](AssemblyBuilderA64& build) { + char arr[12] = "hello world"; + build.adr(x0, arr, 12); + build.adr(x0, uint64_t(0x1234567887654321)); + build.adr(x0, 1.0); + }, + { + 0x10ffffa0, 0x10ffff20, 0x10fffec0 + }, + { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, + 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12, + 0x00, 0x00, 0x00, 0x00, // 4b padding to align double + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x0, + })); + // clang-format on +} + +TEST_CASE("LogTest") +{ + AssemblyBuilderA64 build(/* logText= */ true); + + build.add(sp, sp, 4); + build.add(w0, w1, w2); + build.add(x0, x1, x2, 2); + build.add(w7, w8, 5); + build.add(x7, x8, 5); + build.ldr(x7, x8); + build.ldr(x7, mem(x8, 8)); + build.ldr(x7, mem(x8, x9)); + build.mov(x1, x2); + build.movk(x1, 42, 16); + build.cmp(x1, x2); + build.blr(x0); + + Label l; + build.b(ConditionA64::Plus, l); + build.cbz(x7, l); + + build.setLabel(l); + build.ret(); + + build.finalize(); + + std::string expected = R"( + add sp,sp,#4 + add w0,w1,w2 + add x0,x1,x2 LSL #2 + add w7,w8,#5 + add x7,x8,#5 + ldr x7,[x8] + ldr x7,[x8,#8] + ldr x7,[x8,x9] + mov x1,x2 + movk x1,#42 LSL #16 + cmp x1,x2 + blr x0 + b.pl .L1 + cbz x7,.L1 +.L1: + ret +)"; + + CHECK("\n" + build.text == expected); +} + +TEST_SUITE_END(); diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp new file mode 100644 index 00000000..1af03e4f --- /dev/null +++ b/tests/AssemblyBuilderX64.test.cpp @@ -0,0 +1,655 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/StringUtils.h" + +#include "doctest.h" + +#include + +using namespace Luau::CodeGen; + +static std::string bytecodeAsArray(const std::vector& bytecode) +{ + std::string result = "{"; + + for (size_t i = 0; i < bytecode.size(); i++) + Luau::formatAppend(result, "%s0x%02x", i == 0 ? "" : ", ", bytecode[i]); + + return result.append("}"); +} + +class AssemblyBuilderX64Fixture +{ +public: + bool check(void (*f)(AssemblyBuilderX64& build), std::vector code, std::vector data = {}) + { + AssemblyBuilderX64 build(/* logText= */ false); + + f(build); + + build.finalize(); + + if (build.code != code) + { + printf("Expected code: %s\nReceived code: %s\n", bytecodeAsArray(code).c_str(), bytecodeAsArray(build.code).c_str()); + return false; + } + + if (build.data != data) + { + printf("Expected data: %s\nReceived data: %s\n", bytecodeAsArray(data).c_str(), bytecodeAsArray(build.data).c_str()); + return false; + } + + return true; + } +}; + +TEST_SUITE_BEGIN("x64Assembly"); + +#define SINGLE_COMPARE(inst, ...) \ + CHECK(check( \ + [](AssemblyBuilderX64& build) { \ + build.inst; \ + }, \ + {__VA_ARGS__})) + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseBinaryInstructionForms") +{ + // reg, reg + SINGLE_COMPARE(add(rax, rcx), 0x48, 0x03, 0xc1); + SINGLE_COMPARE(add(rsp, r12), 0x49, 0x03, 0xe4); + SINGLE_COMPARE(add(r14, r10), 0x4d, 0x03, 0xf2); + + // reg, imm + SINGLE_COMPARE(add(rax, 0), 0x48, 0x83, 0xc0, 0x00); + SINGLE_COMPARE(add(rax, 0x7f), 0x48, 0x83, 0xc0, 0x7f); + SINGLE_COMPARE(add(rax, 0x80), 0x48, 0x81, 0xc0, 0x80, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r10, 0x7fffffff), 0x49, 0x81, 0xc2, 0xff, 0xff, 0xff, 0x7f); + + // reg, [reg] + SINGLE_COMPARE(add(rax, qword[rax]), 0x48, 0x03, 0x00); + SINGLE_COMPARE(add(rax, qword[rbx]), 0x48, 0x03, 0x03); + SINGLE_COMPARE(add(rax, qword[rsp]), 0x48, 0x03, 0x04, 0x24); + SINGLE_COMPARE(add(rax, qword[rbp]), 0x48, 0x03, 0x45, 0x00); + SINGLE_COMPARE(add(rax, qword[r10]), 0x49, 0x03, 0x02); + SINGLE_COMPARE(add(rax, qword[r12]), 0x49, 0x03, 0x04, 0x24); + SINGLE_COMPARE(add(rax, qword[r13]), 0x49, 0x03, 0x45, 0x00); + + SINGLE_COMPARE(add(r12, qword[rax]), 0x4c, 0x03, 0x20); + SINGLE_COMPARE(add(r12, qword[rbx]), 0x4c, 0x03, 0x23); + SINGLE_COMPARE(add(r12, qword[rsp]), 0x4c, 0x03, 0x24, 0x24); + SINGLE_COMPARE(add(r12, qword[rbp]), 0x4c, 0x03, 0x65, 0x00); + SINGLE_COMPARE(add(r12, qword[r10]), 0x4d, 0x03, 0x22); + SINGLE_COMPARE(add(r12, qword[r12]), 0x4d, 0x03, 0x24, 0x24); + SINGLE_COMPARE(add(r12, qword[r13]), 0x4d, 0x03, 0x65, 0x00); + + // reg, [base+imm8] + SINGLE_COMPARE(add(rax, qword[rax + 0x1b]), 0x48, 0x03, 0x40, 0x1b); + SINGLE_COMPARE(add(rax, qword[rbx + 0x1b]), 0x48, 0x03, 0x43, 0x1b); + SINGLE_COMPARE(add(rax, qword[rsp + 0x1b]), 0x48, 0x03, 0x44, 0x24, 0x1b); + SINGLE_COMPARE(add(rax, qword[rbp + 0x1b]), 0x48, 0x03, 0x45, 0x1b); + SINGLE_COMPARE(add(rax, qword[r10 + 0x1b]), 0x49, 0x03, 0x42, 0x1b); + SINGLE_COMPARE(add(rax, qword[r12 + 0x1b]), 0x49, 0x03, 0x44, 0x24, 0x1b); + SINGLE_COMPARE(add(rax, qword[r13 + 0x1b]), 0x49, 0x03, 0x45, 0x1b); + + SINGLE_COMPARE(add(r12, qword[rax + 0x1b]), 0x4c, 0x03, 0x60, 0x1b); + SINGLE_COMPARE(add(r12, qword[rbx + 0x1b]), 0x4c, 0x03, 0x63, 0x1b); + SINGLE_COMPARE(add(r12, qword[rsp + 0x1b]), 0x4c, 0x03, 0x64, 0x24, 0x1b); + SINGLE_COMPARE(add(r12, qword[rbp + 0x1b]), 0x4c, 0x03, 0x65, 0x1b); + SINGLE_COMPARE(add(r12, qword[r10 + 0x1b]), 0x4d, 0x03, 0x62, 0x1b); + SINGLE_COMPARE(add(r12, qword[r12 + 0x1b]), 0x4d, 0x03, 0x64, 0x24, 0x1b); + SINGLE_COMPARE(add(r12, qword[r13 + 0x1b]), 0x4d, 0x03, 0x65, 0x1b); + + // reg, [base+imm32] + SINGLE_COMPARE(add(rax, qword[rax + 0xabab]), 0x48, 0x03, 0x80, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rbx + 0xabab]), 0x48, 0x03, 0x83, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rsp + 0xabab]), 0x48, 0x03, 0x84, 0x24, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rbp + 0xabab]), 0x48, 0x03, 0x85, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r10 + 0xabab]), 0x49, 0x03, 0x82, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r12 + 0xabab]), 0x49, 0x03, 0x84, 0x24, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r13 + 0xabab]), 0x49, 0x03, 0x85, 0xab, 0xab, 0x00, 0x00); + + SINGLE_COMPARE(add(r12, qword[rax + 0xabab]), 0x4c, 0x03, 0xa0, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rbx + 0xabab]), 0x4c, 0x03, 0xa3, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rsp + 0xabab]), 0x4c, 0x03, 0xa4, 0x24, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rbp + 0xabab]), 0x4c, 0x03, 0xa5, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r10 + 0xabab]), 0x4d, 0x03, 0xa2, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r12 + 0xabab]), 0x4d, 0x03, 0xa4, 0x24, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r13 + 0xabab]), 0x4d, 0x03, 0xa5, 0xab, 0xab, 0x00, 0x00); + + // reg, [index*scale] + SINGLE_COMPARE(add(rax, qword[rax * 2]), 0x48, 0x03, 0x04, 0x45, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rbx * 2]), 0x48, 0x03, 0x04, 0x5d, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rbp * 2]), 0x48, 0x03, 0x04, 0x6d, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r10 * 2]), 0x4a, 0x03, 0x04, 0x55, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r12 * 2]), 0x4a, 0x03, 0x04, 0x65, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[r13 * 2]), 0x4a, 0x03, 0x04, 0x6d, 0x00, 0x00, 0x00, 0x00); + + SINGLE_COMPARE(add(r12, qword[rax * 2]), 0x4c, 0x03, 0x24, 0x45, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rbx * 2]), 0x4c, 0x03, 0x24, 0x5d, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rbp * 2]), 0x4c, 0x03, 0x24, 0x6d, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r10 * 2]), 0x4e, 0x03, 0x24, 0x55, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r12 * 2]), 0x4e, 0x03, 0x24, 0x65, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[r13 * 2]), 0x4e, 0x03, 0x24, 0x6d, 0x00, 0x00, 0x00, 0x00); + + // reg, [base+index*scale+imm] + SINGLE_COMPARE(add(rax, qword[rax + rax * 2]), 0x48, 0x03, 0x04, 0x40); + SINGLE_COMPARE(add(rax, qword[rax + rbx * 2 + 0x1b]), 0x48, 0x03, 0x44, 0x58, 0x1b); + SINGLE_COMPARE(add(rax, qword[rax + rbp * 2]), 0x48, 0x03, 0x04, 0x68); + SINGLE_COMPARE(add(rax, qword[rax + rbp + 0xabab]), 0x48, 0x03, 0x84, 0x28, 0xAB, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rax + r12 + 0x1b]), 0x4a, 0x03, 0x44, 0x20, 0x1b); + SINGLE_COMPARE(add(rax, qword[rax + r12 * 4 + 0xabab]), 0x4a, 0x03, 0x84, 0xa0, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[rax + r13 * 2 + 0x1b]), 0x4a, 0x03, 0x44, 0x68, 0x1b); + SINGLE_COMPARE(add(rax, qword[rax + r13 + 0xabab]), 0x4a, 0x03, 0x84, 0x28, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rax + r12 * 2]), 0x4e, 0x03, 0x24, 0x60); + SINGLE_COMPARE(add(r12, qword[rax + r13 + 0xabab]), 0x4e, 0x03, 0xA4, 0x28, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(r12, qword[rax + rbp * 2 + 0x1b]), 0x4c, 0x03, 0x64, 0x68, 0x1b); + + // reg, [imm32] + SINGLE_COMPARE(add(rax, qword[0]), 0x48, 0x03, 0x04, 0x25, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(add(rax, qword[0xabab]), 0x48, 0x03, 0x04, 0x25, 0xab, 0xab, 0x00, 0x00); + + // [addr], reg + SINGLE_COMPARE(add(qword[rax], rax), 0x48, 0x01, 0x00); + SINGLE_COMPARE(add(qword[rax + rax * 4 + 0xabab], rax), 0x48, 0x01, 0x84, 0x80, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(qword[rbx + rax * 2 + 0x1b], rax), 0x48, 0x01, 0x44, 0x43, 0x1b); + SINGLE_COMPARE(add(qword[rbx + rbp * 2 + 0x1b], rax), 0x48, 0x01, 0x44, 0x6b, 0x1b); + SINGLE_COMPARE(add(qword[rbp + rbp * 4 + 0xabab], rax), 0x48, 0x01, 0x84, 0xad, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(qword[rbp + r12 + 0x1b], rax), 0x4a, 0x01, 0x44, 0x25, 0x1b); + SINGLE_COMPARE(add(qword[r12], rax), 0x49, 0x01, 0x04, 0x24); + SINGLE_COMPARE(add(qword[r13 + rbx + 0xabab], rax), 0x49, 0x01, 0x84, 0x1d, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(qword[rax + r13 * 2 + 0x1b], rsi), 0x4a, 0x01, 0x74, 0x68, 0x1b); + SINGLE_COMPARE(add(qword[rbp + rbx * 2], rsi), 0x48, 0x01, 0x74, 0x5d, 0x00); + SINGLE_COMPARE(add(qword[rsp + r10 * 2 + 0x1b], r10), 0x4e, 0x01, 0x54, 0x54, 0x1b); + + // [addr], imm + SINGLE_COMPARE(add(byte[rax], 2), 0x80, 0x00, 0x02); + SINGLE_COMPARE(add(dword[rax], 2), 0x83, 0x00, 0x02); + SINGLE_COMPARE(add(dword[rax], 0xabcd), 0x81, 0x00, 0xcd, 0xab, 0x00, 0x00); + SINGLE_COMPARE(add(qword[rax], 2), 0x48, 0x83, 0x00, 0x02); + SINGLE_COMPARE(add(qword[rax], 0xabcd), 0x48, 0x81, 0x00, 0xcd, 0xab, 0x00, 0x00); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "BaseUnaryInstructionForms") +{ + SINGLE_COMPARE(div(rcx), 0x48, 0xf7, 0xf1); + SINGLE_COMPARE(idiv(qword[rax]), 0x48, 0xf7, 0x38); + SINGLE_COMPARE(mul(qword[rax + rbx]), 0x48, 0xf7, 0x24, 0x18); + SINGLE_COMPARE(imul(r9), 0x49, 0xf7, 0xe9); + SINGLE_COMPARE(neg(r9), 0x49, 0xf7, 0xd9); + SINGLE_COMPARE(not_(r12), 0x49, 0xf7, 0xd4); + SINGLE_COMPARE(inc(r12), 0x49, 0xff, 0xc4); + SINGLE_COMPARE(dec(ecx), 0xff, 0xc9); + SINGLE_COMPARE(dec(byte[rdx]), 0xfe, 0x0a); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfMov") +{ + SINGLE_COMPARE(mov(rcx, 1), 0x48, 0xb9, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(mov64(rcx, 0x1234567812345678ll), 0x48, 0xb9, 0x78, 0x56, 0x34, 0x12, 0x78, 0x56, 0x34, 0x12); + SINGLE_COMPARE(mov(ecx, 2), 0xb9, 0x02, 0x00, 0x00, 0x00); + SINGLE_COMPARE(mov(cl, 2), 0xb1, 0x02); + SINGLE_COMPARE(mov(rcx, qword[rdi]), 0x48, 0x8b, 0x0f); + SINGLE_COMPARE(mov(dword[rax], 0xabcd), 0xc7, 0x00, 0xcd, 0xab, 0x00, 0x00); + SINGLE_COMPARE(mov(r13, 1), 0x49, 0xbd, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(mov64(r13, 0x1234567812345678ll), 0x49, 0xbd, 0x78, 0x56, 0x34, 0x12, 0x78, 0x56, 0x34, 0x12); + SINGLE_COMPARE(mov(r13d, 2), 0x41, 0xbd, 0x02, 0x00, 0x00, 0x00); + SINGLE_COMPARE(mov(r13, qword[r12]), 0x4d, 0x8b, 0x2c, 0x24); + SINGLE_COMPARE(mov(dword[r13], 0xabcd), 0x41, 0xc7, 0x45, 0x00, 0xcd, 0xab, 0x00, 0x00); + SINGLE_COMPARE(mov(qword[rdx], r9), 0x4c, 0x89, 0x0a); + SINGLE_COMPARE(mov(byte[rsi], 0x3), 0xc6, 0x06, 0x03); + SINGLE_COMPARE(mov(byte[rsi], al), 0x88, 0x06); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfMovExtended") +{ + SINGLE_COMPARE(movsx(eax, byte[rcx]), 0x0f, 0xbe, 0x01); + SINGLE_COMPARE(movsx(r12, byte[r10]), 0x4d, 0x0f, 0xbe, 0x22); + SINGLE_COMPARE(movsx(ebx, word[r11]), 0x41, 0x0f, 0xbf, 0x1b); + SINGLE_COMPARE(movsx(rdx, word[rcx]), 0x48, 0x0f, 0xbf, 0x11); + SINGLE_COMPARE(movzx(eax, byte[rcx]), 0x0f, 0xb6, 0x01); + SINGLE_COMPARE(movzx(r12, byte[r10]), 0x4d, 0x0f, 0xb6, 0x22); + SINGLE_COMPARE(movzx(ebx, word[r11]), 0x41, 0x0f, 0xb7, 0x1b); + SINGLE_COMPARE(movzx(rdx, word[rcx]), 0x48, 0x0f, 0xb7, 0x11); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfTest") +{ + SINGLE_COMPARE(test(al, 8), 0xf6, 0xc0, 0x08); + SINGLE_COMPARE(test(eax, 8), 0xf7, 0xc0, 0x08, 0x00, 0x00, 0x00); + SINGLE_COMPARE(test(rax, 8), 0x48, 0xf7, 0xc0, 0x08, 0x00, 0x00, 0x00); + SINGLE_COMPARE(test(rcx, 0xabab), 0x48, 0xf7, 0xc1, 0xab, 0xab, 0x00, 0x00); + SINGLE_COMPARE(test(rcx, rax), 0x48, 0x85, 0xc8); + SINGLE_COMPARE(test(rax, qword[rcx]), 0x48, 0x85, 0x01); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfShift") +{ + SINGLE_COMPARE(shl(al, 1), 0xd0, 0xe0); + SINGLE_COMPARE(shl(al, cl), 0xd2, 0xe0); + SINGLE_COMPARE(shr(al, 4), 0xc0, 0xe8, 0x04); + SINGLE_COMPARE(shr(eax, 1), 0xd1, 0xe8); + SINGLE_COMPARE(sal(eax, cl), 0xd3, 0xe0); + SINGLE_COMPARE(sal(eax, 4), 0xc1, 0xe0, 0x04); + SINGLE_COMPARE(sar(rax, 4), 0x48, 0xc1, 0xf8, 0x04); + SINGLE_COMPARE(sar(r11, 1), 0x49, 0xd1, 0xfb); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") +{ + SINGLE_COMPARE(lea(rax, addr[rdx + rcx]), 0x48, 0x8d, 0x04, 0x0a); + SINGLE_COMPARE(lea(rax, addr[rdx + rax * 4]), 0x48, 0x8d, 0x04, 0x82); + SINGLE_COMPARE(lea(rax, addr[r13 + r12 * 4 + 4]), 0x4b, 0x8d, 0x44, 0xa5, 0x04); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfAbsoluteJumps") +{ + SINGLE_COMPARE(jmp(rax), 0xff, 0xe0); + SINGLE_COMPARE(jmp(r14), 0x41, 0xff, 0xe6); + SINGLE_COMPARE(jmp(qword[r14 + rdx * 4]), 0x41, 0xff, 0x24, 0x96); + SINGLE_COMPARE(call(rax), 0xff, 0xd0); + SINGLE_COMPARE(call(r14), 0x41, 0xff, 0xd6); + SINGLE_COMPARE(call(qword[r14 + rdx * 4]), 0x41, 0xff, 0x14, 0x96); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfImul") +{ + SINGLE_COMPARE(imul(ecx, esi), 0x0f, 0xaf, 0xce); + SINGLE_COMPARE(imul(r12, rax), 0x4c, 0x0f, 0xaf, 0xe0); + SINGLE_COMPARE(imul(r12, qword[rdx + rdi]), 0x4c, 0x0f, 0xaf, 0x24, 0x3a); + SINGLE_COMPARE(imul(ecx, edx, 8), 0x6b, 0xca, 0x08); + SINGLE_COMPARE(imul(ecx, r9d, 0xabcd), 0x41, 0x69, 0xc9, 0xcd, 0xab, 0x00, 0x00); + SINGLE_COMPARE(imul(r8d, eax, -9), 0x44, 0x6b, 0xc0, 0xf7); + SINGLE_COMPARE(imul(rcx, rdx, 17), 0x48, 0x6b, 0xca, 0x11); + SINGLE_COMPARE(imul(rcx, r12, 0xabcd), 0x49, 0x69, 0xcc, 0xcd, 0xab, 0x00, 0x00); + SINGLE_COMPARE(imul(r12, rax, -13), 0x4c, 0x6b, 0xe0, 0xf3); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "NopForms") +{ + SINGLE_COMPARE(nop(), 0x90); + SINGLE_COMPARE(nop(2), 0x66, 0x90); + SINGLE_COMPARE(nop(3), 0x0f, 0x1f, 0x00); + SINGLE_COMPARE(nop(4), 0x0f, 0x1f, 0x40, 0x00); + SINGLE_COMPARE(nop(5), 0x0f, 0x1f, 0x44, 0x00, 0x00); + SINGLE_COMPARE(nop(6), 0x66, 0x0f, 0x1f, 0x44, 0x00, 0x00); + SINGLE_COMPARE(nop(7), 0x0f, 0x1f, 0x80, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(nop(8), 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(nop(9), 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00); + SINGLE_COMPARE(nop(15), 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x44, 0x00, 0x00); // 9+6 +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentForms") +{ + CHECK(check( + [](AssemblyBuilderX64& build) { + build.ret(); + build.align(8, AlignmentDataX64::Nop); + }, + {0xc3, 0x0f, 0x1f, 0x80, 0x00, 0x00, 0x00, 0x00})); + + CHECK(check( + [](AssemblyBuilderX64& build) { + build.ret(); + build.align(32, AlignmentDataX64::Nop); + }, + {0xc3, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x66, 0x0f, 0x1f, 0x84, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x1f, 0x40, 0x00})); + + CHECK(check( + [](AssemblyBuilderX64& build) { + build.ret(); + build.align(8, AlignmentDataX64::Int3); + }, + {0xc3, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc})); + + CHECK(check( + [](AssemblyBuilderX64& build) { + build.ret(); + build.align(8, AlignmentDataX64::Ud2); + }, + {0xc3, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0xcc})); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AlignmentOverflow") +{ + // Test that alignment correctly resizes the code buffer + { + AssemblyBuilderX64 build(/* logText */ false); + + build.ret(); + build.align(8192, AlignmentDataX64::Nop); + build.finalize(); + } + + { + AssemblyBuilderX64 build(/* logText */ false); + + build.ret(); + build.align(8192, AlignmentDataX64::Int3); + build.finalize(); + } + + { + AssemblyBuilderX64 build(/* logText */ false); + + for (int i = 0; i < 8192; i++) + build.int3(); + build.finalize(); + } + + { + AssemblyBuilderX64 build(/* logText */ false); + + build.ret(); + build.align(8192, AlignmentDataX64::Ud2); + build.finalize(); + } +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") +{ + // Jump back + CHECK(check( + [](AssemblyBuilderX64& build) { + Label start = build.setLabel(); + build.add(rsi, 1); + build.cmp(rsi, rdi); + build.jcc(ConditionX64::Equal, start); + }, + {0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf3, 0xff, 0xff, 0xff})); + + // Jump back, but the label is set before use + CHECK(check( + [](AssemblyBuilderX64& build) { + Label start; + build.add(rsi, 1); + build.setLabel(start); + build.cmp(rsi, rdi); + build.jcc(ConditionX64::Equal, start); + }, + {0x48, 0x83, 0xc6, 0x01, 0x48, 0x3b, 0xf7, 0x0f, 0x84, 0xf7, 0xff, 0xff, 0xff})); + + // Jump forward + CHECK(check( + [](AssemblyBuilderX64& build) { + Label skip; + + build.cmp(rsi, rdi); + build.jcc(ConditionX64::Greater, skip); + build.or_(rdi, 0x3e); + build.setLabel(skip); + }, + {0x48, 0x3b, 0xf7, 0x0f, 0x8f, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xcf, 0x3e})); + + // Regular jump + CHECK(check( + [](AssemblyBuilderX64& build) { + Label skip; + + build.jmp(skip); + build.and_(rdi, 0x3e); + build.setLabel(skip); + }, + {0xe9, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xe7, 0x3e})); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall") +{ + CHECK(check( + [](AssemblyBuilderX64& build) { + Label fnB; + + build.and_(rcx, 0x3e); + build.call(fnB); + build.ret(); + + build.setLabel(fnB); + build.lea(rax, addr[rcx + 0x1f]); + build.ret(); + }, + {0x48, 0x83, 0xe1, 0x3e, 0xe8, 0x01, 0x00, 0x00, 0x00, 0xc3, 0x48, 0x8d, 0x41, 0x1f, 0xc3})); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") +{ + SINGLE_COMPARE(vaddpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x29, 0x58, 0xc6); + SINGLE_COMPARE(vaddpd(xmm8, xmm10, xmmword[r9]), 0xc4, 0x41, 0x29, 0x58, 0x01); + SINGLE_COMPARE(vaddpd(ymm8, ymm10, ymm14), 0xc4, 0x41, 0x2d, 0x58, 0xc6); + SINGLE_COMPARE(vaddpd(ymm8, ymm10, ymmword[r9]), 0xc4, 0x41, 0x2d, 0x58, 0x01); + SINGLE_COMPARE(vaddps(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x28, 0x58, 0xc6); + SINGLE_COMPARE(vaddps(xmm8, xmm10, xmmword[r9]), 0xc4, 0x41, 0x28, 0x58, 0x01); + SINGLE_COMPARE(vaddsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x58, 0xc6); + SINGLE_COMPARE(vaddsd(xmm8, xmm10, qword[r9]), 0xc4, 0x41, 0x2b, 0x58, 0x01); + SINGLE_COMPARE(vaddss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2a, 0x58, 0xc6); + SINGLE_COMPARE(vaddss(xmm8, xmm10, dword[r9]), 0xc4, 0x41, 0x2a, 0x58, 0x01); + + SINGLE_COMPARE(vaddps(xmm1, xmm2, xmm3), 0xc4, 0xe1, 0x68, 0x58, 0xcb); + SINGLE_COMPARE(vaddps(xmm9, xmm12, xmmword[r9 + r14 * 2 + 0x1c]), 0xc4, 0x01, 0x18, 0x58, 0x4c, 0x71, 0x1c); + SINGLE_COMPARE(vaddps(ymm1, ymm2, ymm3), 0xc4, 0xe1, 0x6c, 0x58, 0xcb); + SINGLE_COMPARE(vaddps(ymm9, ymm12, ymmword[r9 + r14 * 2 + 0x1c]), 0xc4, 0x01, 0x1c, 0x58, 0x4c, 0x71, 0x1c); + + // Coverage for other instructions that follow the same pattern + SINGLE_COMPARE(vsubsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5c, 0xc6); + SINGLE_COMPARE(vmulsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x59, 0xc6); + SINGLE_COMPARE(vdivsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5e, 0xc6); + + SINGLE_COMPARE(vxorpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x29, 0x57, 0xc6); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXUnaryMergeInstructionForms") +{ + SINGLE_COMPARE(vsqrtpd(xmm8, xmm10), 0xc4, 0x41, 0x79, 0x51, 0xc2); + SINGLE_COMPARE(vsqrtpd(xmm8, xmmword[r9]), 0xc4, 0x41, 0x79, 0x51, 0x01); + SINGLE_COMPARE(vsqrtpd(ymm8, ymm10), 0xc4, 0x41, 0x7d, 0x51, 0xc2); + SINGLE_COMPARE(vsqrtpd(ymm8, ymmword[r9]), 0xc4, 0x41, 0x7d, 0x51, 0x01); + SINGLE_COMPARE(vsqrtps(xmm8, xmm10), 0xc4, 0x41, 0x78, 0x51, 0xc2); + SINGLE_COMPARE(vsqrtps(xmm8, xmmword[r9]), 0xc4, 0x41, 0x78, 0x51, 0x01); + SINGLE_COMPARE(vsqrtsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x51, 0xc6); + SINGLE_COMPARE(vsqrtsd(xmm8, xmm10, qword[r9]), 0xc4, 0x41, 0x2b, 0x51, 0x01); + SINGLE_COMPARE(vsqrtss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2a, 0x51, 0xc6); + SINGLE_COMPARE(vsqrtss(xmm8, xmm10, dword[r9]), 0xc4, 0x41, 0x2a, 0x51, 0x01); + + // Coverage for other instructions that follow the same pattern + SINGLE_COMPARE(vucomisd(xmm1, xmm4), 0xc4, 0xe1, 0x79, 0x2e, 0xcc); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXMoveInstructionForms") +{ + SINGLE_COMPARE(vmovsd(qword[r9], xmm10), 0xc4, 0x41, 0x7b, 0x11, 0x11); + SINGLE_COMPARE(vmovsd(xmm8, qword[r9]), 0xc4, 0x41, 0x7b, 0x10, 0x01); + SINGLE_COMPARE(vmovsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x10, 0xc6); + SINGLE_COMPARE(vmovss(dword[r9], xmm10), 0xc4, 0x41, 0x7a, 0x11, 0x11); + SINGLE_COMPARE(vmovss(xmm8, dword[r9]), 0xc4, 0x41, 0x7a, 0x10, 0x01); + SINGLE_COMPARE(vmovss(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2a, 0x10, 0xc6); + SINGLE_COMPARE(vmovapd(xmm8, xmmword[r9]), 0xc4, 0x41, 0x79, 0x28, 0x01); + SINGLE_COMPARE(vmovapd(xmmword[r9], xmm10), 0xc4, 0x41, 0x79, 0x29, 0x11); + SINGLE_COMPARE(vmovapd(ymm8, ymmword[r9]), 0xc4, 0x41, 0x7d, 0x28, 0x01); + SINGLE_COMPARE(vmovaps(xmm8, xmmword[r9]), 0xc4, 0x41, 0x78, 0x28, 0x01); + SINGLE_COMPARE(vmovaps(xmmword[r9], xmm10), 0xc4, 0x41, 0x78, 0x29, 0x11); + SINGLE_COMPARE(vmovaps(ymm8, ymmword[r9]), 0xc4, 0x41, 0x7c, 0x28, 0x01); + SINGLE_COMPARE(vmovupd(xmm8, xmmword[r9]), 0xc4, 0x41, 0x79, 0x10, 0x01); + SINGLE_COMPARE(vmovupd(xmmword[r9], xmm10), 0xc4, 0x41, 0x79, 0x11, 0x11); + SINGLE_COMPARE(vmovupd(ymm8, ymmword[r9]), 0xc4, 0x41, 0x7d, 0x10, 0x01); + SINGLE_COMPARE(vmovups(xmm8, xmmword[r9]), 0xc4, 0x41, 0x78, 0x10, 0x01); + SINGLE_COMPARE(vmovups(xmmword[r9], xmm10), 0xc4, 0x41, 0x78, 0x11, 0x11); + SINGLE_COMPARE(vmovups(ymm8, ymmword[r9]), 0xc4, 0x41, 0x7c, 0x10, 0x01); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXConversionInstructionForms") +{ + SINGLE_COMPARE(vcvttsd2si(ecx, xmm0), 0xc4, 0xe1, 0x7b, 0x2c, 0xc8); + SINGLE_COMPARE(vcvttsd2si(r9d, xmmword[rcx + rdx]), 0xc4, 0x61, 0x7b, 0x2c, 0x0c, 0x11); + SINGLE_COMPARE(vcvttsd2si(rdx, xmm0), 0xc4, 0xe1, 0xfb, 0x2c, 0xd0); + SINGLE_COMPARE(vcvttsd2si(r13, xmmword[rcx + rdx]), 0xc4, 0x61, 0xfb, 0x2c, 0x2c, 0x11); + SINGLE_COMPARE(vcvtsi2sd(xmm5, xmm10, ecx), 0xc4, 0xe1, 0x2b, 0x2a, 0xe9); + SINGLE_COMPARE(vcvtsi2sd(xmm6, xmm11, dword[rcx + rdx]), 0xc4, 0xe1, 0x23, 0x2a, 0x34, 0x11); + SINGLE_COMPARE(vcvtsi2sd(xmm5, xmm10, r13), 0xc4, 0xc1, 0xab, 0x2a, 0xed); + SINGLE_COMPARE(vcvtsi2sd(xmm6, xmm11, qword[rcx + rdx]), 0xc4, 0xe1, 0xa3, 0x2a, 0x34, 0x11); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") +{ + SINGLE_COMPARE(vroundsd(xmm7, xmm12, xmm3, RoundingModeX64::RoundToNegativeInfinity), 0xc4, 0xe3, 0x19, 0x0b, 0xfb, 0x09); + SINGLE_COMPARE( + vroundsd(xmm8, xmm13, xmmword[r13 + rdx], RoundingModeX64::RoundToPositiveInfinity), 0xc4, 0x43, 0x11, 0x0b, 0x44, 0x15, 0x00, 0x0a); + SINGLE_COMPARE(vroundsd(xmm9, xmm14, xmmword[rcx + r10], RoundingModeX64::RoundToZero), 0xc4, 0x23, 0x09, 0x0b, 0x0c, 0x11, 0x0b); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions") +{ + SINGLE_COMPARE(int3(), 0xcc); +} + +TEST_CASE("LogTest") +{ + AssemblyBuilderX64 build(/* logText= */ true); + + build.push(r12); + build.align(8); + build.align(8, AlignmentDataX64::Int3); + build.align(8, AlignmentDataX64::Ud2); + + build.add(rax, rdi); + build.add(rcx, 8); + build.sub(dword[rax], 0x1fdc); + build.and_(dword[rcx], 0x37); + build.mov(rdi, qword[rax + rsi * 2]); + build.vaddss(xmm0, xmm0, dword[rax + r14 * 2 + 0x1c]); + + Label start = build.setLabel(); + build.cmp(rsi, rdi); + build.jcc(ConditionX64::Equal, start); + + build.jmp(qword[rdx]); + build.vaddps(ymm9, ymm12, ymmword[rbp + 0xc]); + build.vaddpd(ymm2, ymm7, build.f64(2.5)); + build.neg(qword[rbp + r12 * 2]); + build.mov64(r10, 0x1234567812345678ll); + build.vmovapd(xmmword[rax], xmm11); + build.movzx(eax, byte[rcx]); + build.movsx(rsi, word[r12]); + build.imul(rcx, rdx); + build.imul(rcx, rdx, 8); + build.vroundsd(xmm1, xmm2, xmm3, RoundingModeX64::RoundToNearestEven); + build.add(rdx, qword[rcx - 12]); + build.pop(r12); + build.ret(); + build.int3(); + + build.nop(); + build.nop(2); + build.nop(3); + build.nop(4); + build.nop(5); + build.nop(6); + build.nop(7); + build.nop(8); + build.nop(9); + + build.finalize(); + + std::string expected = R"( + push r12 +; align 8 + nop word ptr[rax+rax] ; 6-byte nop +; align 8 using int3 +; align 8 using ud2 + add rax,rdi + add rcx,8 + sub dword ptr [rax],1FDCh + and dword ptr [rcx],37h + mov rdi,qword ptr [rax+rsi*2] + vaddss xmm0,xmm0,dword ptr [rax+r14*2+01Ch] +.L1: + cmp rsi,rdi + je .L1 + jmp qword ptr [rdx] + vaddps ymm9,ymm12,ymmword ptr [rbp+0Ch] + vaddpd ymm2,ymm7,qword ptr [.start-8] + neg qword ptr [rbp+r12*2] + mov r10,1234567812345678h + vmovapd xmmword ptr [rax],xmm11 + movzx eax,byte ptr [rcx] + movsx rsi,word ptr [r12] + imul rcx,rdx + imul rcx,rdx,8 + vroundsd xmm1,xmm2,xmm3,8 + add rdx,qword ptr [rcx-0Ch] + pop r12 + ret + int3 + nop + xchg ax, ax ; 2-byte nop + nop dword ptr[rax] ; 3-byte nop + nop dword ptr[rax] ; 4-byte nop + nop dword ptr[rax+rax] ; 5-byte nop + nop word ptr[rax+rax] ; 6-byte nop + nop dword ptr[rax] ; 7-byte nop + nop dword ptr[rax+rax] ; 8-byte nop + nop word ptr[rax+rax] ; 9-byte nop +)"; + + CHECK("\n" + build.text == expected); +} + +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "Constants") +{ + // clang-format off + CHECK(check( + [](AssemblyBuilderX64& build) { + build.xor_(rax, rax); + build.add(rax, build.i64(0x1234567887654321)); + build.vmovss(xmm2, build.f32(1.0f)); + build.vmovsd(xmm3, build.f64(1.0)); + build.vmovaps(xmm4, build.f32x4(1.0f, 2.0f, 4.0f, 8.0f)); + char arr[16] = "hello world!123"; + build.vmovupd(xmm5, build.bytes(arr, 16, 8)); + build.ret(); + }, + { + 0x48, 0x33, 0xc0, + 0x48, 0x03, 0x05, 0xee, 0xff, 0xff, 0xff, + 0xc4, 0xe1, 0x7a, 0x10, 0x15, 0xe1, 0xff, 0xff, 0xff, + 0xc4, 0xe1, 0x7b, 0x10, 0x1d, 0xcc, 0xff, 0xff, 0xff, + 0xc4, 0xe1, 0x78, 0x28, 0x25, 0xab, 0xff, 0xff, 0xff, + 0xc4, 0xe1, 0x79, 0x10, 0x2d, 0x92, 0xff, 0xff, 0xff, + 0xc3 + }, + { + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', '1', '2', '3', 0x0, + 0x00, 0x00, 0x80, 0x3f, + 0x00, 0x00, 0x00, 0x40, + 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // padding to align f32x4 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, + 0x00, 0x00, 0x00, 0x00, // padding to align f64 + 0x00, 0x00, 0x80, 0x3f, + 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12, + })); + // clang-format on +} + +TEST_CASE("ConstantStorage") +{ + AssemblyBuilderX64 build(/* logText= */ false); + + for (int i = 0; i <= 3000; i++) + build.vaddss(xmm0, xmm0, build.f32(1.0f)); + + build.finalize(); + + LUAU_ASSERT(build.data.size() == 12004); + + for (int i = 0; i <= 3000; i++) + { + LUAU_ASSERT(build.data[i * 4 + 0] == 0x00); + LUAU_ASSERT(build.data[i * 4 + 1] == 0x00); + LUAU_ASSERT(build.data[i * 4 + 2] == 0x80); + LUAU_ASSERT(build.data[i * 4 + 3] == 0x3f); + } +} + +TEST_SUITE_END(); diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp new file mode 100644 index 00000000..1532b7a8 --- /dev/null +++ b/tests/AstJsonEncoder.test.cpp @@ -0,0 +1,483 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Ast.h" +#include "Luau/AstJsonEncoder.h" +#include "Luau/Parser.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +#include +#include + +using namespace Luau; + +struct JsonEncoderFixture +{ + Allocator allocator; + AstNameTable names{allocator}; + + ParseResult parse(std::string_view src) + { + ParseOptions opts; + opts.allowDeclarationSyntax = true; + return Parser::parse(src.data(), src.size(), names, allocator, opts); + } + + AstStatBlock* expectParse(std::string_view src) + { + ParseResult res = parse(src); + REQUIRE(res.errors.size() == 0); + return res.root; + } + + AstStat* expectParseStatement(std::string_view src) + { + AstStatBlock* root = expectParse(src); + REQUIRE(1 == root->body.size); + return root->body.data[0]; + } + + AstExpr* expectParseExpr(std::string_view src) + { + std::string s = "a = "; + s.append(src); + AstStatBlock* root = expectParse(s); + + AstStatAssign* statAssign = root->body.data[0]->as(); + REQUIRE(statAssign != nullptr); + REQUIRE(statAssign->values.size == 1); + + return statAssign->values.data[0]; + } +}; + +TEST_SUITE_BEGIN("JsonEncoderTests"); + +TEST_CASE("encode_constants") +{ + AstExprConstantNil nil{Location()}; + AstExprConstantBool b{Location(), true}; + AstExprConstantNumber n{Location(), 8.2}; + AstExprConstantNumber bigNum{Location(), 0.1677721600000003}; + AstExprConstantNumber positiveInfinity{Location(), INFINITY}; + AstExprConstantNumber negativeInfinity{Location(), -INFINITY}; + AstExprConstantNumber nan{Location(), NAN}; + + AstArray charString; + charString.data = const_cast("a\x1d\0\\\"b"); + charString.size = 6; + + AstExprConstantString needsEscaping{Location(), charString}; + + CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil)); + CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b)); + CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":8.1999999999999993})", toJson(&n)); + CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":0.16777216000000031})", toJson(&bigNum)); + CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":Infinity})", toJson(&positiveInfinity)); + CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":-Infinity})", toJson(&negativeInfinity)); + CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":NaN})", toJson(&nan)); + CHECK_EQ("{\"type\":\"AstExprConstantString\",\"location\":\"0,0 - 0,0\",\"value\":\"a\\u001d\\u0000\\\\\\\"b\"}", toJson(&needsEscaping)); +} + +TEST_CASE("basic_escaping") +{ + std::string s = "hello \"world\""; + AstArray theString{s.data(), s.size()}; + AstExprConstantString str{Location(), theString}; + + std::string expected = R"({"type":"AstExprConstantString","location":"0,0 - 0,0","value":"hello \"world\""})"; + CHECK_EQ(expected, toJson(&str)); +} + +TEST_CASE("encode_AstStatBlock") +{ + AstLocal astlocal{AstName{"a_local"}, Location(), nullptr, 0, 0, nullptr}; + AstLocal* astlocalarray[] = {&astlocal}; + + AstArray vars{astlocalarray, 1}; + AstArray values{nullptr, 0}; + AstStatLocal local{Location(), vars, values, std::nullopt}; + AstStat* statArray[] = {&local}; + + AstArray bodyArray{statArray, 1}; + + AstStatBlock block{Location(), bodyArray}; + + CHECK_EQ( + (R"({"type":"AstStatBlock","location":"0,0 - 0,0","body":[{"type":"AstStatLocal","location":"0,0 - 0,0","vars":[{"luauType":null,"name":"a_local","type":"AstLocal","location":"0,0 - 0,0"}],"values":[]}]})"), + toJson(&block)); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_tables") +{ + std::string src = R"( + local x: { + foo: number + } = { + foo = 123, + } + )"; + + AstStatBlock* root = expectParse(src); + std::string json = toJson(root); + + CHECK( + json == + R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"luauType":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","type":"AstTableProp","location":"2,12 - 2,15","propType":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":null},"name":"x","type":"AstLocal","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"type":"AstExprTableItem","kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_table_array") +{ + std::string src = R"(type X = {string})"; + + AstStatBlock* root = expectParse(src); + std::string json = toJson(root); + + CHECK( + json == + R"({"type":"AstStatBlock","location":"0,0 - 0,17","body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"type":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","parameters":[]}}},"exported":false}]})"); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_table_indexer") +{ + std::string src = R"(type X = {string})"; + + AstStatBlock* root = expectParse(src); + std::string json = toJson(root); + + CHECK( + json == + R"({"type":"AstStatBlock","location":"0,0 - 0,17","body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"type":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","parameters":[]}}},"exported":false}]})"); +} + +TEST_CASE("encode_AstExprGroup") +{ + AstExprConstantNumber number{Location{}, 5.0}; + AstExprGroup group{Location{}, &number}; + + std::string json = toJson(&group); + + const std::string expected = + R"({"type":"AstExprGroup","location":"0,0 - 0,0","expr":{"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":5}})"; + + CHECK(json == expected); +} + +TEST_CASE("encode_AstExprGlobal") +{ + AstExprGlobal global{Location{}, AstName{"print"}}; + + std::string json = toJson(&global); + std::string expected = R"({"type":"AstExprGlobal","location":"0,0 - 0,0","global":"print"})"; + + CHECK(json == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprIfThen") +{ + AstStat* statement = expectParseStatement("local a = if x then y else z"); + + std::string_view expected = + R"({"type":"AstStatLocal","location":"0,0 - 0,28","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,6 - 0,7"}],"values":[{"type":"AstExprIfElse","location":"0,10 - 0,28","condition":{"type":"AstExprGlobal","location":"0,13 - 0,14","global":"x"},"hasThen":true,"trueExpr":{"type":"AstExprGlobal","location":"0,20 - 0,21","global":"y"},"hasElse":true,"falseExpr":{"type":"AstExprGlobal","location":"0,27 - 0,28","global":"z"}}]})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprInterpString") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + AstStat* statement = expectParseStatement("local a = `var = {x}`"); + + std::string_view expected = + R"({"type":"AstStatLocal","location":"0,0 - 0,21","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,6 - 0,7"}],"values":[{"type":"AstExprInterpString","location":"0,10 - 0,21","strings":["var = ",""],"expressions":[{"type":"AstExprGlobal","location":"0,18 - 0,19","global":"x"}]}]})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE("encode_AstExprLocal") +{ + AstLocal local{AstName{"foo"}, Location{}, nullptr, 0, 0, nullptr}; + AstExprLocal exprLocal{Location{}, &local, false}; + + CHECK(toJson(&exprLocal) == + R"({"type":"AstExprLocal","location":"0,0 - 0,0","local":{"luauType":null,"name":"foo","type":"AstLocal","location":"0,0 - 0,0"}})"); +} + +TEST_CASE("encode_AstExprVarargs") +{ + AstExprVarargs varargs{Location{}}; + + CHECK(toJson(&varargs) == R"({"type":"AstExprVarargs","location":"0,0 - 0,0"})"); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprCall") +{ + AstExpr* expr = expectParseExpr("foo(1, 2, 3)"); + std::string_view expected = + R"({"type":"AstExprCall","location":"0,4 - 0,16","func":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"args":[{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},{"type":"AstExprConstantNumber","location":"0,11 - 0,12","value":2},{"type":"AstExprConstantNumber","location":"0,14 - 0,15","value":3}],"self":false,"argLocation":"0,8 - 0,16"})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprIndexName") +{ + AstExpr* expr = expectParseExpr("foo.bar"); + + std::string_view expected = + R"({"type":"AstExprIndexName","location":"0,4 - 0,11","expr":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"index":"bar","indexLocation":"0,8 - 0,11","op":"."})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprIndexExpr") +{ + AstExpr* expr = expectParseExpr("foo['bar']"); + + std::string_view expected = + R"({"type":"AstExprIndexExpr","location":"0,4 - 0,14","expr":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"index":{"type":"AstExprConstantString","location":"0,8 - 0,13","value":"bar"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprFunction") +{ + AstExpr* expr = expectParseExpr("function (a) return a end"); + + std::string_view expected = + R"({"type":"AstExprFunction","location":"0,4 - 0,29","generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,16 - 0,26","body":[{"type":"AstStatReturn","location":"0,17 - 0,25","list":[{"type":"AstExprLocal","location":"0,24 - 0,25","local":{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}}]}]},"functionDepth":1,"debugname":"","hasEnd":true})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprTable") +{ + AstExpr* expr = expectParseExpr("{true, key=true, [key2]=true}"); + + std::string_view expected = + R"({"type":"AstExprTable","location":"0,4 - 0,33","items":[{"type":"AstExprTableItem","kind":"item","value":{"type":"AstExprConstantBool","location":"0,5 - 0,9","value":true}},{"type":"AstExprTableItem","kind":"record","key":{"type":"AstExprConstantString","location":"0,11 - 0,14","value":"key"},"value":{"type":"AstExprConstantBool","location":"0,15 - 0,19","value":true}},{"type":"AstExprTableItem","kind":"general","key":{"type":"AstExprGlobal","location":"0,22 - 0,26","global":"key2"},"value":{"type":"AstExprConstantBool","location":"0,28 - 0,32","value":true}}]})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprUnary") +{ + AstExpr* expr = expectParseExpr("-b"); + + std::string_view expected = + R"({"type":"AstExprUnary","location":"0,4 - 0,6","op":"Minus","expr":{"type":"AstExprGlobal","location":"0,5 - 0,6","global":"b"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprBinary") +{ + AstExpr* expr = expectParseExpr("b + c"); + + std::string_view expected = + R"({"type":"AstExprBinary","location":"0,4 - 0,9","op":"Add","left":{"type":"AstExprGlobal","location":"0,4 - 0,5","global":"b"},"right":{"type":"AstExprGlobal","location":"0,8 - 0,9","global":"c"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprTypeAssertion") +{ + AstExpr* expr = expectParseExpr("b :: any"); + + std::string_view expected = + R"({"type":"AstExprTypeAssertion","location":"0,4 - 0,12","expr":{"type":"AstExprGlobal","location":"0,4 - 0,5","global":"b"},"annotation":{"type":"AstTypeReference","location":"0,9 - 0,12","name":"any","parameters":[]}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprError") +{ + std::string_view src = "a = "; + ParseResult parseResult = Parser::parse(src.data(), src.size(), names, allocator); + + REQUIRE(1 == parseResult.root->body.size); + + AstStatAssign* statAssign = parseResult.root->body.data[0]->as(); + REQUIRE(statAssign != nullptr); + REQUIRE(1 == statAssign->values.size); + + AstExpr* expr = statAssign->values.data[0]; + + std::string_view expected = R"({"type":"AstExprError","location":"0,4 - 0,4","expressions":[],"messageIndex":0})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatIf") +{ + AstStat* statement = expectParseStatement("if true then else end"); + + std::string_view expected = + R"({"type":"AstStatIf","location":"0,0 - 0,21","condition":{"type":"AstExprConstantBool","location":"0,3 - 0,7","value":true},"thenbody":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"elsebody":{"type":"AstStatBlock","location":"0,17 - 0,18","body":[]},"hasThen":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatWhile") +{ + AstStat* statement = expectParseStatement("while true do end"); + + std::string_view expected = + R"({"type":"AstStatWhile","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatRepeat") +{ + AstStat* statement = expectParseStatement("repeat until true"); + + std::string_view expected = + R"({"type":"AstStatRepeat","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,13 - 0,17","value":true},"body":{"type":"AstStatBlock","location":"0,6 - 0,7","body":[]},"hasUntil":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatBreak") +{ + AstStat* statement = expectParseStatement("while true do break end"); + + std::string_view expected = + R"({"type":"AstStatWhile","location":"0,0 - 0,23","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,20","body":[{"type":"AstStatBreak","location":"0,14 - 0,19"}]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatContinue") +{ + AstStat* statement = expectParseStatement("while true do continue end"); + + std::string_view expected = + R"({"type":"AstStatWhile","location":"0,0 - 0,26","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,23","body":[{"type":"AstStatContinue","location":"0,14 - 0,22"}]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatFor") +{ + AstStat* statement = expectParseStatement("for a=0,1 do end"); + + std::string_view expected = + R"({"type":"AstStatFor","location":"0,0 - 0,16","var":{"luauType":null,"name":"a","type":"AstLocal","location":"0,4 - 0,5"},"from":{"type":"AstExprConstantNumber","location":"0,6 - 0,7","value":0},"to":{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},"body":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatForIn") +{ + AstStat* statement = expectParseStatement("for a in b do end"); + + std::string_view expected = + R"({"type":"AstStatForIn","location":"0,0 - 0,17","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,4 - 0,5"}],"values":[{"type":"AstExprGlobal","location":"0,9 - 0,10","global":"b"}],"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasIn":true,"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatCompoundAssign") +{ + AstStat* statement = expectParseStatement("a += b"); + + std::string_view expected = + R"({"type":"AstStatCompoundAssign","location":"0,0 - 0,6","op":"Add","var":{"type":"AstExprGlobal","location":"0,0 - 0,1","global":"a"},"value":{"type":"AstExprGlobal","location":"0,5 - 0,6","global":"b"}})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatLocalFunction") +{ + AstStat* statement = expectParseStatement("local function a(b) return end"); + + std::string_view expected = + R"({"type":"AstStatLocalFunction","location":"0,0 - 0,30","name":{"luauType":null,"name":"a","type":"AstLocal","location":"0,15 - 0,16"},"func":{"type":"AstExprFunction","location":"0,0 - 0,30","generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"b","type":"AstLocal","location":"0,17 - 0,18"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,19 - 0,27","body":[{"type":"AstStatReturn","location":"0,20 - 0,26","list":[]}]},"functionDepth":1,"debugname":"a","hasEnd":true}})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatTypeAlias") +{ + AstStat* statement = expectParseStatement("type A = B"); + + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,10","name":"A","generics":[],"genericPacks":[],"type":{"type":"AstTypeReference","location":"0,9 - 0,10","name":"B","parameters":[]},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction") +{ + AstStat* statement = expectParseStatement("declare function foo(x: number): string"); + + std::string_view expected = + R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","parameters":[]}]},"retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","parameters":[]}]},"generics":[],"genericPacks":[]})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") +{ + AstStatBlock* root = expectParse(R"( + declare class Foo + prop: number + function method(self, foo: number): string + end + + declare class Bar extends Foo + prop2: string + end + )"); + + REQUIRE(2 == root->body.size); + + std::string_view expected1 = + R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","parameters":[]}},{"name":"method","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","parameters":[]}]}}}]})"; + CHECK(toJson(root->body.data[0]) == expected1); + + std::string_view expected2 = + R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","parameters":[]}}]})"; + CHECK(toJson(root->body.data[1]) == expected2); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_annotation") +{ + AstStat* statement = expectParseStatement("type T = ((number) -> (string | nil)) & ((string) -> ())"); + + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,35","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypeError") +{ + ParseResult parseResult = parse("type T = "); + REQUIRE(1 == parseResult.root->body.size); + + AstStat* statement = parseResult.root->body.data[0]; + + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,9","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeError","location":"0,8 - 0,9","types":[],"messageIndex":0},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypePackExplicit") +{ + AstStatBlock* root = expectParse(R"( + type A = () -> T... + local a: A<(number, string)> + )"); + + CHECK(2 == root->body.size); + + std::string_view expected = + R"({"type":"AstStatLocal","location":"2,8 - 2,36","vars":[{"luauType":{"type":"AstTypeReference","location":"2,17 - 2,36","name":"A","parameters":[{"type":"AstTypePackExplicit","location":"2,19 - 2,20","typeList":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"2,20 - 2,26","name":"number","parameters":[]},{"type":"AstTypeReference","location":"2,28 - 2,34","name":"string","parameters":[]}]}}]},"name":"a","type":"AstLocal","location":"2,14 - 2,15"}],"values":[]})"; + + CHECK(toJson(root->body.data[1]) == expected); +} + +TEST_SUITE_END(); diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index dd49e675..a642334a 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -1,13 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Fixture.h" #include "Luau/AstQuery.h" +#include "AstQueryDsl.h" #include "doctest.h" +#include "Fixture.h" using namespace Luau; -struct DocumentationSymbolFixture : Fixture +struct DocumentationSymbolFixture : BuiltinsFixture { std::optional getDocSymbol(const std::string& source, Position position) { @@ -44,11 +45,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "prop") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") { - ScopedFastFlag sffs[] = { - {"LuauDontMutatePersistentFunctions", true}, - {"LuauPersistDefinitionFileTypes", true}, - }; - loadDefinition(R"( declare function Connect(fn: (string) -> ()) )"); @@ -64,8 +60,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn") { - ScopedFastFlag sffs{"LuauStoreMatchingOverloadFnType", true}; - loadDefinition(R"( declare foo: ((string) -> number) & ((number) -> string) )"); @@ -78,4 +72,214 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn") CHECK_EQ(symbol, "@test/global/foo/overload/(string) -> number"); } +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "class_method") +{ + loadDefinition(R"( + declare class Foo + function bar(self, x: string): number + end + )"); + + std::optional symbol = getDocSymbol(R"( + local x: Foo + x:bar("asdf") + )", + Position(2, 11)); + + CHECK_EQ(symbol, "@test/globaltype/Foo.bar"); +} + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_class_method") +{ + loadDefinition(R"( + declare class Foo + function bar(self, x: string): number + function bar(self, x: number): string + end + )"); + + std::optional symbol = getDocSymbol(R"( + local x: Foo + x:bar("asdf") + )", + Position(2, 11)); + + CHECK_EQ(symbol, "@test/globaltype/Foo.bar/overload/(Foo, string) -> number"); +} + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_function_prop") +{ + loadDefinition(R"( + declare Foo: { + new: (number) -> string + } + )"); + + std::optional symbol = getDocSymbol(R"( + Foo.new("asdf") + )", + Position(1, 13)); + + CHECK_EQ(symbol, "@test/global/Foo.new"); +} + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_overloaded_function_prop") +{ + loadDefinition(R"( + declare Foo: { + new: ((number) -> string) & ((string) -> number) + } + )"); + + std::optional symbol = getDocSymbol(R"( + Foo.new("asdf") + )", + Position(1, 13)); + + CHECK_EQ(symbol, "@test/global/Foo.new/overload/(string) -> number"); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("AstQuery"); + +TEST_CASE_FIXTURE(Fixture, "last_argument_function_call_type") +{ + check(R"( +local function foo() return 2 end +local function bar(a: number) return -a end +bar(foo()) + )"); + + auto oty = findTypeAtPosition(Position(3, 7)); + REQUIRE(oty); + CHECK_EQ("number", toString(*oty)); + + auto expectedOty = findExpectedTypeAtPosition(Position(3, 7)); + REQUIRE(expectedOty); + CHECK_EQ("number", toString(*expectedOty)); +} + +TEST_CASE_FIXTURE(Fixture, "ast_ancestry_at_eof") +{ + check(R"( +if true then + )"); + + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), Position(2, 4)); + REQUIRE_GE(ancestry.size(), 2); + AstStat* parentStat = ancestry[ancestry.size() - 2]->asStat(); + REQUIRE(bool(parentStat)); + REQUIRE(parentStat->is()); +} + +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_at_number_const") +{ + check(R"( +print(3.) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 8)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_in_workspace_dot") +{ + check(R"( +print(workspace.) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 16)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_in_workspace_colon") +{ + check(R"( +print(workspace:) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 16)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_query") +{ + AstStatBlock* block = parse(R"( + if true then + end + )"); + + AstStatIf* if_ = Luau::query(block); + CHECK(if_); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_query_for_2nd_if_stat_which_doesnt_exist") +{ + AstStatBlock* block = parse(R"( + if true then + end + )"); + + AstStatIf* if_ = Luau::query(block); + CHECK(!if_); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_nested_query") +{ + AstStatBlock* block = parse(R"( + if true then + end + )"); + + AstStatIf* if_ = Luau::query(block); + REQUIRE(if_); + AstExprConstantBool* bool_ = Luau::query(if_); + REQUIRE(bool_); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_nested_query_but_first_query_failed") +{ + AstStatBlock* block = parse(R"( + if true then + end + )"); + + AstStatIf* if_ = Luau::query(block); + REQUIRE(!if_); + AstExprConstantBool* bool_ = Luau::query(if_); // ensure it doesn't crash + REQUIRE(!bool_); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_selectively_query_for_a_different_boolean") +{ + AstStatBlock* block = parse(R"( + local x = false and true + local y = true and false + )"); + + AstExprConstantBool* fst = Luau::query(block, {nth(), nth(2)}); + REQUIRE(fst); + REQUIRE(fst->value == true); + + AstExprConstantBool* snd = Luau::query(block, {nth(2), nth(2)}); + REQUIRE(snd); + REQUIRE(snd->value == false); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_selectively_query_for_a_different_boolean_2") +{ + AstStatBlock* block = parse(R"( + local x = false and true + local y = true and false + )"); + + AstExprConstantBool* snd = Luau::query(block, {nth(2), nth()}); + REQUIRE(snd); + REQUIRE(snd->value == true); +} + TEST_SUITE_END(); diff --git a/tests/AstQueryDsl.cpp b/tests/AstQueryDsl.cpp new file mode 100644 index 00000000..fa487cab --- /dev/null +++ b/tests/AstQueryDsl.cpp @@ -0,0 +1,45 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "AstQueryDsl.h" + +namespace Luau +{ + +FindNthOccurenceOf::FindNthOccurenceOf(Nth nth) + : requestedNth(nth) +{ +} + +bool FindNthOccurenceOf::checkIt(AstNode* n) +{ + if (theNode) + return false; + + if (n->classIndex == requestedNth.classIndex) + { + // Human factor: the requestedNth starts from 1 because of the term `nth`. + if (currentOccurrence + 1 != requestedNth.nth) + ++currentOccurrence; + else + theNode = n; + } + + return !theNode; // once found, returns false and stops traversal +} + +bool FindNthOccurenceOf::visit(AstNode* n) +{ + return checkIt(n); +} + +bool FindNthOccurenceOf::visit(AstType* t) +{ + return checkIt(t); +} + +bool FindNthOccurenceOf::visit(AstTypePack* t) +{ + return checkIt(t); +} + +} // namespace Luau diff --git a/tests/AstQueryDsl.h b/tests/AstQueryDsl.h new file mode 100644 index 00000000..1f77b7ce --- /dev/null +++ b/tests/AstQueryDsl.h @@ -0,0 +1,83 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Common.h" + +#include +#include + +namespace Luau +{ + +struct Nth +{ + int classIndex; + int nth; +}; + +template +Nth nth(int nth = 1) +{ + static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); + LUAU_ASSERT(nth > 0); // Did you mean to use `nth(1)`? + + return Nth{T::ClassIndex(), nth}; +} + +struct FindNthOccurenceOf : public AstVisitor +{ + Nth requestedNth; + int currentOccurrence = 0; + AstNode* theNode = nullptr; + + FindNthOccurenceOf(Nth nth); + + bool checkIt(AstNode* n); + + bool visit(AstNode* n) override; + bool visit(AstType* n) override; + bool visit(AstTypePack* n) override; +}; + +/** DSL querying of the AST. + * + * Given an AST, one can query for a particular node directly without having to manually unwrap the tree, for example: + * + * ``` + * if a and b then + * print(a + b) + * end + * + * function f(x, y) + * return x + y + * end + * ``` + * + * There are numerous ways to access the second AstExprBinary. + * 1. Luau::query(block, {nth(), nth()}) + * 2. Luau::query(Luau::query(block)) + * 3. Luau::query(block, {nth(2)}) + */ +template +T* query(AstNode* node, const std::vector& nths = {nth(N)}) +{ + static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); + + // If a nested query call fails to find the node in question, subsequent calls can propagate rather than trying to do more. + // This supports `query(query(...))` + + for (Nth nth : nths) + { + if (!node) + return nullptr; + + FindNthOccurenceOf finder{nth}; + node->visit(&finder); + node = finder.theNode; + } + + return node ? node->as() : nullptr; +} + +} // namespace Luau diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 9cd642cf..123708ca 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1,13 +1,13 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Autocomplete.h" #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" #include "Luau/StringUtils.h" #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -23,19 +23,22 @@ static std::optional nullCallback(std::string tag, std::op return std::nullopt; } -struct ACFixture : Fixture +template +struct ACFixtureImpl : BaseType { + ACFixtureImpl() + : BaseType(true, true) + { + } + AutocompleteResult autocomplete(unsigned row, unsigned column) { - return Luau::autocomplete(frontend, "MainModule", Position{row, column}, nullCallback); + return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); } AutocompleteResult autocomplete(char marker) { - auto i = markerPosition.find(marker); - LUAU_ASSERT(i != markerPosition.end()); - const Position& pos = i->second; - return Luau::autocomplete(frontend, "MainModule", pos, nullCallback); + return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), nullCallback); } CheckResult check(const std::string& source) @@ -45,16 +48,18 @@ struct ACFixture : Fixture filteredSource.reserve(source.size()); Position curPos(0, 0); + char prevChar{}; for (char c : source) { - if (c == '@' && !filteredSource.empty()) + if (prevChar == '@') { - char prevChar = filteredSource.back(); - filteredSource.pop_back(); - curPos.column--; // Adjust column position since we removed a character from the output - LUAU_ASSERT("Illegal marker character" && prevChar >= '0' && prevChar <= '9'); - LUAU_ASSERT("Duplicate marker found" && markerPosition.count(prevChar) == 0); - markerPosition.insert(std::pair{prevChar, curPos}); + LUAU_ASSERT("Illegal marker character" && c >= '0' && c <= '9'); + LUAU_ASSERT("Duplicate marker found" && markerPosition.count(c) == 0); + markerPosition.insert(std::pair{c, curPos}); + } + else if (c == '@') + { + // skip the '@' character } else { @@ -69,54 +74,94 @@ struct ACFixture : Fixture curPos.column++; } } + prevChar = c; } + LUAU_ASSERT("Digit expected after @ symbol" && prevChar != '@'); - return Fixture::check(filteredSource); + return BaseType::check(filteredSource); + } + + LoadDefinitionFileResult loadDefinition(const std::string& source) + { + TypeChecker& typeChecker = this->frontend.typeCheckerForAutocomplete; + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, source, "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); + return result; + } + + const Position& getPosition(char marker) const + { + auto i = markerPosition.find(marker); + LUAU_ASSERT(i != markerPosition.end()); + return i->second; } // Maps a marker character (0-9 inclusive) to a position in the source code. std::map markerPosition; }; +struct ACFixture : ACFixtureImpl +{ + ACFixture() + : ACFixtureImpl() + { + addGlobalBinding(frontend, "table", Binding{typeChecker.anyType}); + addGlobalBinding(frontend, "math", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.typeCheckerForAutocomplete, "table", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.typeCheckerForAutocomplete, "math", Binding{typeChecker.anyType}); + } +}; + +struct ACBuiltinsFixture : ACFixtureImpl +{ +}; + TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") { - check(" "); + check(" @1"); - auto ac = autocomplete(0, 1); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.empty()); CHECK(ac.entryMap.count("table")); CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "local_initializer") { - check("local a = "); + check("local a = @1"); - auto ac = autocomplete(0, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); } TEST_CASE_FIXTURE(ACFixture, "leave_numbers_alone") { - check("local a = 3.1"); + check("local a = 3.@11"); - auto ac = autocomplete(0, 12); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::Unknown); } TEST_CASE_FIXTURE(ACFixture, "user_defined_globals") { - check("local myLocal = 4; "); + check("local myLocal = 4; @1"); - auto ac = autocomplete(0, 19); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); CHECK(ac.entryMap.count("table")); CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "dont_suggest_local_before_its_definition") @@ -124,20 +169,20 @@ TEST_CASE_FIXTURE(ACFixture, "dont_suggest_local_before_its_definition") check(R"( local myLocal = 4 function abc() - local myInnerLocal = 1 - +@1 local myInnerLocal = 1 +@2 end - )"); +@3 )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); CHECK(!ac.entryMap.count("myInnerLocal")); - ac = autocomplete(4, 0); + ac = autocomplete('2'); CHECK(ac.entryMap.count("myLocal")); CHECK(ac.entryMap.count("myInnerLocal")); - ac = autocomplete(6, 0); + ac = autocomplete('3'); CHECK(ac.entryMap.count("myLocal")); CHECK(!ac.entryMap.count("myInnerLocal")); } @@ -146,11 +191,12 @@ TEST_CASE_FIXTURE(ACFixture, "recursive_function") { check(R"( function foo() - end +@1 end )"); - auto ac = autocomplete(2, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("foo")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "nested_recursive_function") @@ -158,11 +204,11 @@ TEST_CASE_FIXTURE(ACFixture, "nested_recursive_function") check(R"( local function outer() local function inner() - end +@1 end end )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("inner")); CHECK(ac.entryMap.count("outer")); } @@ -171,11 +217,11 @@ TEST_CASE_FIXTURE(ACFixture, "user_defined_local_functions_in_own_definition") { check(R"( local function abc() - +@1 end )"); - auto ac = autocomplete(2, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); CHECK(ac.entryMap.count("table")); @@ -183,11 +229,11 @@ TEST_CASE_FIXTURE(ACFixture, "user_defined_local_functions_in_own_definition") check(R"( local abc = function() - +@1 end )"); - ac = autocomplete(2, 0); + ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); // FIXME: This is actually incorrect! CHECK(ac.entryMap.count("table")); @@ -202,9 +248,9 @@ TEST_CASE_FIXTURE(ACFixture, "global_functions_are_not_scoped_lexically") end end - )"); +@1 )"); - auto ac = autocomplete(6, 0); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.empty()); CHECK(ac.entryMap.count("abc")); @@ -220,9 +266,9 @@ TEST_CASE_FIXTURE(ACFixture, "local_functions_fall_out_of_scope") end end - )"); +@1 )"); - auto ac = autocomplete(6, 0); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(!ac.entryMap.count("abc")); @@ -233,40 +279,41 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") check(R"( function abc(test) - end +@1 end )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("test")); } -TEST_CASE_FIXTURE(ACFixture, "get_member_completions") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "get_member_completions") { check(R"( - local a = table. -- Line 1 - -- | Column 23 + local a = table.@1 )"); - auto ac = autocomplete(1, 24); + auto ac = autocomplete('1'); - CHECK_EQ(16, ac.entryMap.size()); + CHECK_EQ(17, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); CHECK(!ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Property); } TEST_CASE_FIXTURE(ACFixture, "nested_member_completions") { check(R"( local tbl = { abc = { def = 1234, egh = false } } - tbl.abc. + tbl.abc. @1 )"); - auto ac = autocomplete(2, 17); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("def")); CHECK(ac.entryMap.count("egh")); + CHECK_EQ(ac.context, AutocompleteContext::Property); } TEST_CASE_FIXTURE(ACFixture, "unsealed_table") @@ -274,12 +321,13 @@ TEST_CASE_FIXTURE(ACFixture, "unsealed_table") check(R"( local tbl = {} tbl.prop = 5 - tbl. + tbl.@1 )"); - auto ac = autocomplete(3, 12); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("prop")); + CHECK_EQ(ac.context, AutocompleteContext::Property); } TEST_CASE_FIXTURE(ACFixture, "unsealed_table_2") @@ -288,12 +336,13 @@ TEST_CASE_FIXTURE(ACFixture, "unsealed_table_2") local tbl = {} local inner = { prop = 5 } tbl.inner = inner - tbl.inner. + tbl.inner. @1 )"); - auto ac = autocomplete(4, 19); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("prop")); + CHECK_EQ(ac.context, AutocompleteContext::Property); } TEST_CASE_FIXTURE(ACFixture, "cyclic_table") @@ -302,11 +351,12 @@ TEST_CASE_FIXTURE(ACFixture, "cyclic_table") local abc = {} local def = { abc = abc } abc.def = def - abc.def. + abc.def. @1 )"); - auto ac = autocomplete(4, 17); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); + CHECK_EQ(ac.context, AutocompleteContext::Property); } TEST_CASE_FIXTURE(ACFixture, "table_union") @@ -315,13 +365,14 @@ TEST_CASE_FIXTURE(ACFixture, "table_union") type t1 = { a1 : string, b2 : number } type t2 = { b2 : string, c3 : string } function func(abc : t1 | t2) - abc. + abc. @1 end )"); - auto ac = autocomplete(4, 18); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("b2")); + CHECK_EQ(ac.context, AutocompleteContext::Property); } TEST_CASE_FIXTURE(ACFixture, "table_intersection") @@ -330,50 +381,53 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") type t1 = { a1 : string, b2 : number } type t2 = { b2 : string, c3 : string } function func(abc : t1 & t2) - abc. + abc. @1 end )"); - auto ac = autocomplete(4, 18); + auto ac = autocomplete('1'); CHECK_EQ(3, ac.entryMap.size()); CHECK(ac.entryMap.count("a1")); CHECK(ac.entryMap.count("b2")); CHECK(ac.entryMap.count("c3")); + CHECK_EQ(ac.context, AutocompleteContext::Property); } -TEST_CASE_FIXTURE(ACFixture, "get_string_completions") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "get_string_completions") { check(R"( - local a = ("foo"): -- Line 1 - -- | Column 26 + local a = ("foo"):@1 )"); - auto ac = autocomplete(1, 26); + auto ac = autocomplete('1'); CHECK_EQ(17, ac.entryMap.size()); + CHECK_EQ(ac.context, AutocompleteContext::Property); } TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_new_statement") { - check(""); + check("@1"); - auto ac = autocomplete(0, 0); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_the_very_start_of_the_script") { - check(R"( + check(R"(@1 function aaa() end )"); - auto ac = autocomplete(0, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") @@ -382,28 +436,30 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") local game = { GetService=function(s) return 'hello' end } function a() - game: + game: @1 end )"); - auto ac = autocomplete(4, 19); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(!ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Property); } -TEST_CASE_FIXTURE(ACFixture, "method_call_inside_if_conditional") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "method_call_inside_if_conditional") { check(R"( - if table: + if table: @1 )"); - auto ac = autocomplete(1, 19); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(ac.entryMap.count("concat")); CHECK(!ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Property); } TEST_CASE_FIXTURE(ACFixture, "statement_between_two_statements") @@ -411,16 +467,18 @@ TEST_CASE_FIXTURE(ACFixture, "statement_between_two_statements") check(R"( function getmyscripts() end - g + g@1 getmyscripts() )"); - auto ac = autocomplete(3, 9); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(ac.entryMap.count("getmyscripts")); + + CHECK_EQ(ac.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") @@ -431,13 +489,14 @@ TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") function B() local A = {two=2} - A + A @1 end )"); - auto ac = autocomplete(6, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("A")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); TypeId t = follow(*ac.entryMap["A"].type); const TableTypeVar* tt = get(t); @@ -448,13 +507,15 @@ TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") TEST_CASE_FIXTURE(ACFixture, "recommend_statement_starting_keywords") { - check(""); - auto ac = autocomplete(0, 0); + check("@1"); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("local")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); - check("local i = "); - auto ac2 = autocomplete(0, 10); + check("local i = @1"); + auto ac2 = autocomplete('1'); CHECK(!ac2.entryMap.count("local")); + CHECK_EQ(ac2.context, AutocompleteContext::Expression); } TEST_CASE_FIXTURE(ACFixture, "do_not_overwrite_context_sensitive_kws") @@ -464,12 +525,13 @@ TEST_CASE_FIXTURE(ACFixture, "do_not_overwrite_context_sensitive_kws") end - )"); +@1 )"); - auto ac = autocomplete(5, 0); + auto ac = autocomplete('1'); AutocompleteEntry entry = ac.entryMap["continue"]; CHECK(entry.kind == AutocompleteEntryKind::Binding); + CHECK_EQ(ac.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_comment") @@ -480,213 +542,232 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_comment") function foo:bar() end --[[ - foo: + foo:@1 ]] )"); - auto ac = autocomplete(6, 16); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); + CHECK_EQ(ac.context, AutocompleteContext::Unknown); } TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_the_end_of_a_comment") { check(R"( - --!strict + --!strict@1 )"); - auto ac = autocomplete(1, 17); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); + CHECK_EQ(ac.context, AutocompleteContext::Unknown); } TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check(R"( - --[[ + --[[ @1 )"); - auto ac = autocomplete(1, 13); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); + CHECK_EQ(ac.context, AutocompleteContext::Unknown); } TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment_at_the_very_end_of_the_file") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; + check("--[[@1"); - check("--[["); - - auto ac = autocomplete(0, 4); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); + CHECK_EQ(ac.context, AutocompleteContext::Unknown); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") { check(R"( - for x = + for x @1= )"); - auto ac1 = autocomplete(1, 14); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 0); CHECK_EQ(ac1.entryMap.count("end"), 0); + CHECK_EQ(ac1.context, AutocompleteContext::Unknown); check(R"( - for x = 1 + for x =@1 1 )"); - auto ac2 = autocomplete(1, 15); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("do"), 0); CHECK_EQ(ac2.entryMap.count("end"), 0); + CHECK_EQ(ac2.context, AutocompleteContext::Unknown); check(R"( - for x = 1, 2 + for x = 1,@1 2 )"); - auto ac3 = autocomplete(1, 18); + auto ac3 = autocomplete('1'); CHECK_EQ(1, ac3.entryMap.size()); CHECK_EQ(ac3.entryMap.count("do"), 1); + CHECK_EQ(ac3.context, AutocompleteContext::Keyword); check(R"( - for x = 1, 2, + for x = 1, @12, )"); - auto ac4 = autocomplete(1, 19); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.count("do"), 0); CHECK_EQ(ac4.entryMap.count("end"), 0); + CHECK_EQ(ac4.context, AutocompleteContext::Expression); check(R"( - for x = 1, 2, 5 + for x = 1, 2, @15 )"); - auto ac5 = autocomplete(1, 22); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.count("do"), 1); CHECK_EQ(ac5.entryMap.count("end"), 0); + CHECK_EQ(ac5.context, AutocompleteContext::Keyword); check(R"( - for x = 1, 2, 5 f + for x = 1, 2, 5 f@1 )"); - auto ac6 = autocomplete(1, 25); + auto ac6 = autocomplete('1'); CHECK_EQ(ac6.entryMap.size(), 1); CHECK_EQ(ac6.entryMap.count("do"), 1); + CHECK_EQ(ac6.context, AutocompleteContext::Keyword); check(R"( - for x = 1, 2, 5 do + for x = 1, 2, 5 do @1 )"); - auto ac7 = autocomplete(1, 32); + auto ac7 = autocomplete('1'); CHECK_EQ(ac7.entryMap.count("end"), 1); + CHECK_EQ(ac7.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_in_middle_keywords") { check(R"( - for + for @1 )"); - auto ac1 = autocomplete(1, 12); + auto ac1 = autocomplete('1'); CHECK_EQ(0, ac1.entryMap.size()); + CHECK_EQ(ac1.context, AutocompleteContext::Unknown); check(R"( - for x + for x@1 @2 )"); - auto ac2 = autocomplete(1, 13); + auto ac2 = autocomplete('1'); CHECK_EQ(0, ac2.entryMap.size()); + CHECK_EQ(ac2.context, AutocompleteContext::Unknown); - auto ac2a = autocomplete(1, 14); + auto ac2a = autocomplete('2'); CHECK_EQ(1, ac2a.entryMap.size()); CHECK_EQ(1, ac2a.entryMap.count("in")); + CHECK_EQ(ac2a.context, AutocompleteContext::Keyword); check(R"( - for x in y + for x in y@1 )"); - auto ac3 = autocomplete(1, 18); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("table"), 1); CHECK_EQ(ac3.entryMap.count("do"), 0); + CHECK_EQ(ac3.context, AutocompleteContext::Expression); check(R"( - for x in y + for x in y @1 )"); - auto ac4 = autocomplete(1, 19); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.size(), 1); CHECK_EQ(ac4.entryMap.count("do"), 1); + CHECK_EQ(ac4.context, AutocompleteContext::Keyword); check(R"( - for x in f f + for x in f f@1 )"); - auto ac5 = autocomplete(1, 20); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.size(), 1); CHECK_EQ(ac5.entryMap.count("do"), 1); + CHECK_EQ(ac5.context, AutocompleteContext::Keyword); check(R"( - for x in y do + for x in y do @1 )"); - auto ac6 = autocomplete(1, 23); + auto ac6 = autocomplete('1'); CHECK_EQ(ac6.entryMap.count("in"), 0); CHECK_EQ(ac6.entryMap.count("table"), 1); CHECK_EQ(ac6.entryMap.count("end"), 1); CHECK_EQ(ac6.entryMap.count("function"), 1); + CHECK_EQ(ac6.context, AutocompleteContext::Statement); check(R"( - for x in y do e + for x in y do e@1 )"); - auto ac7 = autocomplete(1, 23); + auto ac7 = autocomplete('1'); CHECK_EQ(ac7.entryMap.count("in"), 0); CHECK_EQ(ac7.entryMap.count("table"), 1); CHECK_EQ(ac7.entryMap.count("end"), 1); CHECK_EQ(ac7.entryMap.count("function"), 1); + CHECK_EQ(ac7.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") { check(R"( - while + while@1 )"); - auto ac1 = autocomplete(1, 13); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 0); CHECK_EQ(ac1.entryMap.count("end"), 0); + CHECK_EQ(ac1.context, AutocompleteContext::Expression); check(R"( - while true + while true @1 )"); - auto ac2 = autocomplete(1, 19); + auto ac2 = autocomplete('1'); CHECK_EQ(1, ac2.entryMap.size()); CHECK_EQ(ac2.entryMap.count("do"), 1); + CHECK_EQ(ac2.context, AutocompleteContext::Keyword); check(R"( - while true do + while true do @1 )"); - auto ac3 = autocomplete(1, 23); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("end"), 1); + CHECK_EQ(ac3.context, AutocompleteContext::Statement); check(R"( - while true d + while true d@1 )"); - auto ac4 = autocomplete(1, 20); + auto ac4 = autocomplete('1'); CHECK_EQ(1, ac4.entryMap.size()); CHECK_EQ(ac4.entryMap.count("do"), 1); + CHECK_EQ(ac4.context, AutocompleteContext::Keyword); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") { check(R"( - if + if @1 )"); - auto ac1 = autocomplete(1, 13); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("then"), 0); CHECK_EQ(ac1.entryMap.count("function"), 1); // FIXME: This is kind of dumb. It is technically syntactically valid but you can never do anything interesting with this. @@ -694,247 +775,260 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac1.entryMap.count("else"), 0); CHECK_EQ(ac1.entryMap.count("elseif"), 0); CHECK_EQ(ac1.entryMap.count("end"), 0); + CHECK_EQ(ac1.context, AutocompleteContext::Expression); check(R"( - if x + if x @1 )"); - auto ac2 = autocomplete(1, 14); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("then"), 1); CHECK_EQ(ac2.entryMap.count("function"), 0); CHECK_EQ(ac2.entryMap.count("else"), 0); CHECK_EQ(ac2.entryMap.count("elseif"), 0); CHECK_EQ(ac2.entryMap.count("end"), 0); + CHECK_EQ(ac2.context, AutocompleteContext::Keyword); check(R"( - if x t + if x t@1 )"); - auto ac3 = autocomplete(1, 14); + auto ac3 = autocomplete('1'); CHECK_EQ(1, ac3.entryMap.size()); CHECK_EQ(ac3.entryMap.count("then"), 1); + CHECK_EQ(ac3.context, AutocompleteContext::Keyword); check(R"( if x then - +@1 end )"); - auto ac4 = autocomplete(2, 0); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.count("then"), 0); CHECK_EQ(ac4.entryMap.count("else"), 1); CHECK_EQ(ac4.entryMap.count("function"), 1); CHECK_EQ(ac4.entryMap.count("elseif"), 1); CHECK_EQ(ac4.entryMap.count("end"), 0); + CHECK_EQ(ac4.context, AutocompleteContext::Statement); check(R"( if x then - t + t@1 end )"); - auto ac4a = autocomplete(2, 13); + auto ac4a = autocomplete('1'); CHECK_EQ(ac4a.entryMap.count("then"), 0); CHECK_EQ(ac4a.entryMap.count("table"), 1); CHECK_EQ(ac4a.entryMap.count("else"), 1); CHECK_EQ(ac4a.entryMap.count("elseif"), 1); + CHECK_EQ(ac4a.context, AutocompleteContext::Statement); check(R"( if x then - +@1 elseif x then end )"); - auto ac5 = autocomplete(2, 0); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.count("then"), 0); CHECK_EQ(ac5.entryMap.count("function"), 1); CHECK_EQ(ac5.entryMap.count("else"), 0); CHECK_EQ(ac5.entryMap.count("elseif"), 0); CHECK_EQ(ac5.entryMap.count("end"), 0); + CHECK_EQ(ac5.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_in_repeat") { check(R"( - repeat + repeat @1 )"); - auto ac = autocomplete(1, 16); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); CHECK_EQ(ac.entryMap.count("until"), 1); + CHECK_EQ(ac.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_expression") { check(R"( repeat - until + until @1 )"); - auto ac = autocomplete(2, 16); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); + CHECK_EQ(ac.context, AutocompleteContext::Expression); } TEST_CASE_FIXTURE(ACFixture, "local_names") { check(R"( - local ab + local ab@1 )"); - auto ac1 = autocomplete(1, 16); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.size(), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); + CHECK_EQ(ac1.context, AutocompleteContext::Unknown); check(R"( - local ab, cd + local ab, cd@1 )"); - auto ac2 = autocomplete(1, 20); + auto ac2 = autocomplete('1'); CHECK(ac2.entryMap.empty()); + CHECK_EQ(ac2.context, AutocompleteContext::Unknown); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_fn_exprs") { check(R"( - local function f() + local function f() @1 )"); - auto ac = autocomplete(1, 28); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("end"), 1); + CHECK_EQ(ac.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_lambda") { check(R"( - local a = function() local bar = foo en + local a = function() local bar = foo en@1 )"); - auto ac = autocomplete(1, 47); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("end"), 1); + CHECK_EQ(ac.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") { check(R"( repeat - for x + for x @1 )"); - auto ac1 = autocomplete(2, 18); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("in"), 1); CHECK_EQ(ac1.entryMap.count("until"), 0); + CHECK_EQ(ac1.context, AutocompleteContext::Keyword); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_repeat_middle_keyword") { check(R"( - repeat + repeat @1 )"); - auto ac1 = autocomplete(1, 15); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); CHECK_EQ(ac1.entryMap.count("until"), 1); check(R"( - repeat f f + repeat f f@1 )"); - auto ac2 = autocomplete(1, 18); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("function"), 1); CHECK_EQ(ac2.entryMap.count("until"), 1); check(R"( repeat - u + u@1 until )"); - auto ac3 = autocomplete(2, 13); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("until"), 0); } TEST_CASE_FIXTURE(ACFixture, "local_function") { check(R"( - local f + local f@1 )"); - auto ac1 = autocomplete(1, 15); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.size(), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); check(R"( - local f, cd + local f@1, cd )"); - auto ac2 = autocomplete(1, 15); + auto ac2 = autocomplete('1'); CHECK(ac2.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "local_function") { check(R"( - local function + local function @1 )"); - auto ac = autocomplete(1, 23); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( - local function s + local function @1s@2 )"); - ac = autocomplete(1, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); - ac = autocomplete(1, 24); + ac = autocomplete('2'); CHECK(ac.entryMap.empty()); check(R"( - local function () + local function @1()@2 )"); - ac = autocomplete(1, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); - ac = autocomplete(1, 25); + ac = autocomplete('2'); CHECK(ac.entryMap.count("end")); check(R"( - local function something + local function something@1 )"); - ac = autocomplete(1, 32); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( local tbl = {} - function tbl.something() end + function tbl.something@1() end )"); - ac = autocomplete(2, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "local_function_params") { check(R"( - local function abc(def) + local function @1a@2bc(@3d@4ef)@5 @6 )"); - CHECK(autocomplete(1, 23).entryMap.empty()); - CHECK(autocomplete(1, 24).entryMap.empty()); - CHECK(autocomplete(1, 27).entryMap.empty()); - CHECK(autocomplete(1, 28).entryMap.empty()); - CHECK(!autocomplete(1, 31).entryMap.empty()); + CHECK(autocomplete('1').entryMap.empty()); + CHECK(autocomplete('2').entryMap.empty()); + CHECK(autocomplete('3').entryMap.empty()); + CHECK(autocomplete('4').entryMap.empty()); + CHECK(!autocomplete('5').entryMap.empty()); - CHECK(!autocomplete(1, 32).entryMap.empty()); + CHECK(!autocomplete('6').entryMap.empty()); check(R"( local function abc(def) - end +@1 end )"); for (unsigned int i = 23; i < 31; ++i) @@ -943,17 +1037,19 @@ TEST_CASE_FIXTURE(ACFixture, "local_function_params") } CHECK(!autocomplete(1, 32).entryMap.empty()); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("abc"), 1); CHECK_EQ(ac2.entryMap.count("def"), 1); + CHECK_EQ(ac2.context, AutocompleteContext::Statement); check(R"( - local function abc(def, ghi) + local function abc(def, ghi@1) end )"); - auto ac3 = autocomplete(1, 35); + auto ac3 = autocomplete('1'); CHECK(ac3.entryMap.empty()); + CHECK_EQ(ac3.context, AutocompleteContext::Unknown); } TEST_CASE_FIXTURE(ACFixture, "global_function_params") @@ -981,48 +1077,50 @@ TEST_CASE_FIXTURE(ACFixture, "global_function_params") check(R"( function abc(def) - +@1 end )"); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("abc"), 1); CHECK_EQ(ac2.entryMap.count("def"), 1); + CHECK_EQ(ac2.context, AutocompleteContext::Statement); check(R"( - function abc(def, ghi) + function abc(def, ghi@1) end )"); - auto ac3 = autocomplete(1, 29); + auto ac3 = autocomplete('1'); CHECK(ac3.entryMap.empty()); + CHECK_EQ(ac3.context, AutocompleteContext::Unknown); } TEST_CASE_FIXTURE(ACFixture, "arguments_to_global_lambda") { check(R"( - abc = function(def, ghi) + abc = function(def, ghi@1) end )"); - auto ac = autocomplete(1, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "function_expr_params") { check(R"( - abc = function(def) + abc = function(def) @1 )"); for (unsigned int i = 20; i < 27; ++i) { CHECK(autocomplete(1, i).entryMap.empty()); } - CHECK(!autocomplete(1, 28).entryMap.empty()); + CHECK(!autocomplete('1').entryMap.empty()); check(R"( - abc = function(def) + abc = function(def) @1 end )"); @@ -1030,25 +1128,26 @@ TEST_CASE_FIXTURE(ACFixture, "function_expr_params") { CHECK(autocomplete(1, i).entryMap.empty()); } - CHECK(!autocomplete(1, 28).entryMap.empty()); + CHECK(!autocomplete('1').entryMap.empty()); check(R"( abc = function(def) - +@1 end )"); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("def"), 1); + CHECK_EQ(ac2.context, AutocompleteContext::Statement); } TEST_CASE_FIXTURE(ACFixture, "local_initializer") { check(R"( - local a = t + local a = t@1 )"); - auto ac = autocomplete(1, 19); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); CHECK_EQ(ac.entryMap.count("true"), 1); } @@ -1056,20 +1155,20 @@ TEST_CASE_FIXTURE(ACFixture, "local_initializer") TEST_CASE_FIXTURE(ACFixture, "local_initializer_2") { check(R"( - local a= + local a=@1 )"); - auto ac = autocomplete(1, 16); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); } TEST_CASE_FIXTURE(ACFixture, "get_member_completions") { check(R"( - local a = 12.3 + local a = 12.@13 )"); - auto ac = autocomplete(1, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } @@ -1083,24 +1182,25 @@ TEST_CASE_FIXTURE(ACFixture, "sometimes_the_metatable_is_an_error") return setmetatable({x=6}, X) -- oops! end local t = T.new() - t. + t. @1 )"); - autocomplete(8, 12); + autocomplete('1'); // Don't crash! } TEST_CASE_FIXTURE(ACFixture, "local_types_builtin") { check(R"( -local a: n +local a: n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); + CHECK_EQ(ac.context, AutocompleteContext::Type); } TEST_CASE_FIXTURE(ACFixture, "private_types") @@ -1108,23 +1208,23 @@ TEST_CASE_FIXTURE(ACFixture, "private_types") check(R"( do type num = number - local a: nu - local b: num + local a: n@1u + local b: nu@2m end -local a: nu +local a: nu@3 )"); - auto ac = autocomplete(3, 14); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(4, 15); + ac = autocomplete('2'); CHECK(ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(6, 11); + ac = autocomplete('3'); CHECK(!ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); @@ -1136,11 +1236,11 @@ TEST_CASE_FIXTURE(ACFixture, "type_scoping_easy") type Table = { a: number, b: number } do type Table = { x: string, y: string } - local a: T + local a: T@1 end )"); - auto ac = autocomplete(4, 14); + auto ac = autocomplete('1'); REQUIRE(ac.entryMap.count("Table")); REQUIRE(ac.entryMap["Table"].type); @@ -1169,6 +1269,7 @@ local a: aa auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 11}, nullCallback); CHECK(ac.entryMap.count("aaa")); + CHECK_EQ(ac.context, AutocompleteContext::Type); } TEST_CASE_FIXTURE(ACFixture, "module_type_members") @@ -1193,78 +1294,82 @@ local a: aaa. CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("A")); CHECK(ac.entryMap.count("B")); + CHECK_EQ(ac.context, AutocompleteContext::Type); } TEST_CASE_FIXTURE(ACFixture, "argument_types") { check(R"( -local function f(a: n +local function f(a: n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); + CHECK_EQ(ac.context, AutocompleteContext::Type); } TEST_CASE_FIXTURE(ACFixture, "return_types") { check(R"( -local function f(a: number): n +local function f(a: number): n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 30); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); + CHECK_EQ(ac.context, AutocompleteContext::Type); } TEST_CASE_FIXTURE(ACFixture, "as_types") { check(R"( local a: any = 5 -local b: number = (a :: n +local b: number = (a :: n@1 )"); - auto ac = autocomplete(2, 25); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); + CHECK_EQ(ac.context, AutocompleteContext::Type); } TEST_CASE_FIXTURE(ACFixture, "function_type_types") { check(R"( -local a: (n -local b: (number, (n -local c: (number, (number) -> n -local d: (number, (number) -> (number, n -local e: (n: n +local a: (n@1 +local b: (number, (n@2 +local c: (number, (number) -> n@3 +local d: (number, (number) -> (number, n@4 +local e: (n: n@5 )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(2, 20); + ac = autocomplete('2'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(3, 31); + ac = autocomplete('3'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(4, 40); + ac = autocomplete('4'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(5, 14); + ac = autocomplete('5'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1272,17 +1377,15 @@ local e: (n: n TEST_CASE_FIXTURE(ACFixture, "generic_types") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - check(R"( -function f(a: T +function f(a: T@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 25); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("Tee")); + CHECK_EQ(ac.context, AutocompleteContext::Type); } TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_in_argument") @@ -1293,10 +1396,10 @@ local function target(a: number, b: string) return a + #b end local one = 4 local two = "hello" -return target(o +return target(o@1 )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1307,10 +1410,10 @@ local function target(a: number, b: string) return a + #b end local one = 4 local two = "hello" -return target(one, t +return target(one, t@1 )"); - ac = autocomplete(5, 20); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1321,10 +1424,10 @@ return target(one, t local function target(a: number, b: string) return a + #b end local a = { one = 4, two = "hello" } -return target(a. +return target(a.@1 )"); - ac = autocomplete(4, 16); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1334,10 +1437,10 @@ return target(a. local function target(a: number, b: string) return a + #b end local a = { one = 4, two = "hello" } -return target(a.one, a. +return target(a.one, a.@1 )"); - ac = autocomplete(4, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1348,10 +1451,10 @@ return target(a.one, a. local function target(a: string?) return #b end local a = { one = 4, two = "hello" } -return target(a. +return target(a.@1 )"); - ac = autocomplete(4, 16); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1363,26 +1466,28 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_in_table") check(R"( type Foo = { a: number, b: string } local a = { one = 4, two = "hello" } -local b: Foo = { a = a. +local b: Foo = { a = a.@1 )"); - auto ac = autocomplete(3, 23); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::None); + CHECK_EQ(ac.context, AutocompleteContext::Property); check(R"( type Foo = { a: number, b: string } local a = { one = 4, two = "hello" } -local b: Foo = { b = a. +local b: Foo = { b = a.@1 )"); - ac = autocomplete(3, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::None); + CHECK_EQ(ac.context, AutocompleteContext::Property); } TEST_CASE_FIXTURE(ACFixture, "type_correct_function_return_types") @@ -1390,12 +1495,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_function_return_types") check(R"( local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end -return target(b +return target(b@1 )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1406,10 +1511,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end local function bar2(a: string) return a .. 'x' end -return target(bar1, b +return target(bar1, b@1 )"); - ac = autocomplete(5, 21); + ac = autocomplete('1'); CHECK(ac.entryMap.count("bar2")); CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1418,12 +1523,12 @@ return target(bar1, b check(R"( local function target(a: number, b: string) return a + #b end local function bar1(a: number): (...number) return -a, a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end -return target(b +return target(b@1 )"); - ac = autocomplete(5, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1433,69 +1538,69 @@ return target(b TEST_CASE_FIXTURE(ACFixture, "type_correct_local_type_suggestion") { check(R"( -local b: s = "str" +local b: s@1 = "str" )"); - auto ac = autocomplete(1, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f() return "str" end -local b: s = f() +local b: s@1 = f() )"); - ac = autocomplete(2, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: s, c: n = "str", 2 +local b: s@1, c: n@2 = "str", 2 )"); - ac = autocomplete(1, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(1, 16); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f() return 1, "str", 3 end -local a: b, b: n, c: s, d: n = false, f() +local a: b@1, b: n@2, c: s@3, d: n@4 = false, f() )"); - ac = autocomplete(2, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("boolean")); CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 16); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 22); + ac = autocomplete('3'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 28); + ac = autocomplete('4'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f(): ...number return 1, 2, 3 end -local a: boolean, b: n = false, f() +local a: boolean, b: n@1 = false, f() )"); - ac = autocomplete(2, 22); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1504,46 +1609,46 @@ local a: boolean, b: n = false, f() TEST_CASE_FIXTURE(ACFixture, "type_correct_function_type_suggestion") { check(R"( -local b: (n) -> number = function(a: number, b: string) return a + #b end +local b: (n@1) -> number = function(a: number, b: string) return a + #b end )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, s = function(a: number, b: string) return a + #b end +local b: (number, s@1 = function(a: number, b: string) return a + #b end )"); - ac = autocomplete(1, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, string) -> b = function(a: number, b: string): boolean return a + #b == 0 end +local b: (number, string) -> b@1 = function(a: number, b: string): boolean return a + #b == 0 end )"); - ac = autocomplete(1, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("boolean")); CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, ...s) = function(a: number, ...: string) return a end +local b: (number, ...s@1) = function(a: number, ...: string) return a end )"); - ac = autocomplete(1, 22); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number) -> ...s = function(a: number): ...string return "a", "b", "c" end +local b: (number) -> ...s@1 = function(a: number): ...string return "a", "b", "c" end )"); - ac = autocomplete(1, 25); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1552,24 +1657,24 @@ local b: (number) -> ...s = function(a: number): ...string return "a", "b", "c" TEST_CASE_FIXTURE(ACFixture, "type_correct_full_type_suggestion") { check(R"( -local b: = "str" +local b:@1 @2= "str" )"); - auto ac = autocomplete(1, 8); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(1, 9); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: = function(a: number) return -a end +local b: @1= function(a: number) return -a end )"); - ac = autocomplete(1, 9); + ac = autocomplete('1'); CHECK(ac.entryMap.count("(number) -> number")); CHECK(ac.entryMap["(number) -> number"].typeCorrect == TypeCorrectKind::Correct); @@ -1580,12 +1685,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion") check(R"( local function target(a: number, b: string) return a + #b end -local function d(a: n, b) - return target(a, b) +local function d(a: n@1, b) + return target(a, b) end )"); - auto ac = autocomplete(3, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1593,12 +1698,12 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a, b: s) - return target(a, b) +local function d(a, b: s@1) + return target(a, b) end )"); - ac = autocomplete(3, 24); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1606,17 +1711,17 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a: , b) - return target(a, b) +local function d(a:@1 @2, b) + return target(a, b) end )"); - ac = autocomplete(3, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 20); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1624,17 +1729,17 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a, b: ): number - return target(a, b) +local function d(a, b: @1)@2: number + return target(a, b) end )"); - ac = autocomplete(3, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 24); + ac = autocomplete('2'); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::None); } @@ -1644,10 +1749,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion") check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: +local x = target(function(a: @1 )"); - auto ac = autocomplete(3, 29); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1655,10 +1760,10 @@ local x = target(function(a: check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: n +local x = target(function(a: n@1 )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1666,17 +1771,17 @@ local x = target(function(a: n check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: n, b: ) - return a + #b +local x = target(function(a: n@1, b: @2) + return a + #b end) )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 35); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1684,12 +1789,12 @@ end) check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(a: n) - return a +local x = target(function(a: n@1) + return a end )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1700,12 +1805,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(...:n) - return a +local x = target(function(...:n@1) + return a end )"); - auto ac = autocomplete(3, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1713,12 +1818,12 @@ end check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(a:number, b:number, ...:) - return a + b +local x = target(function(a:number, b:number, ...:@1) + return a + b end )"); - ac = autocomplete(3, 50); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1729,12 +1834,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion") check(R"( local function target(callback: () -> number) return callback() end -local x = target(function(): n - return 1 +local x = target(function(): n@1 + return 1 end )"); - auto ac = autocomplete(3, 30); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1742,12 +1847,12 @@ end check(R"( local function target(callback: () -> (number, number)) return callback() end -local x = target(function(): (number, n - return 1, 2 +local x = target(function(): (number, n@1 + return 1, 2 end )"); - ac = autocomplete(3, 39); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1758,12 +1863,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion" check(R"( local function target(callback: () -> ...number) return callback() end -local x = target(function(): ...n - return 1, 2, 3 +local x = target(function(): ...n@1 + return 1, 2, 3 end )"); - auto ac = autocomplete(3, 33); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1771,12 +1876,12 @@ end check(R"( local function target(callback: () -> ...number) return callback() end -local x = target(function(): (number, number, ...n - return 1, 2, 3 +local x = target(function(): (number, number, ...n@1 + return 1, 2, 3 end )"); - ac = autocomplete(3, 50); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1787,10 +1892,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion_opt check(R"( local function target(callback: nil | (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: +local x = target(function(a: @1 )"); - auto ac = autocomplete(3, 29); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1803,21 +1908,21 @@ local t = {} t.x = 5 function t:target(callback: (a: number, b: string) -> number) return callback(self.x, "hello") end -local x = t:target(function(a: , b: ) end) -local y = t.target(t, function(a: number, b: ) end) +local x = t:target(function(a: @1, b:@2 ) end) +local y = t.target(t, function(a: number, b: @3) end) )"); - auto ac = autocomplete(5, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(5, 35); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(6, 45); + ac = autocomplete('3'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1857,7 +1962,7 @@ ex.b(function(x: CHECK(!ac.entryMap.count("(done) -> number")); } -TEST_CASE_FIXTURE(ACFixture, "suggest_external_module_type") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "suggest_external_module_type") { fileResolver.source["Module/A"] = R"( export type done = { x: number, y: number } @@ -1899,26 +2004,25 @@ TEST_CASE_FIXTURE(ACFixture, "do_not_suggest_synthetic_table_name") { check(R"( local foo = { a = 1, b = 2 } -local bar: = foo +local bar: @1= foo )"); - auto ac = autocomplete(2, 11); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.count("foo")); } -// CLI-45692: Remove UnfrozenFixture here -TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_function_no_parenthesis") +TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") { check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end -return target(b +return target(b@1 )"); - auto ac = autocomplete(frontend, "MainModule", Position{5, 15}, nullCallback); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::Correct); @@ -1926,20 +2030,47 @@ return target(b CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::None); } +TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses") +{ + check(R"( +local function bar(a: number) return -a end +local abc = b@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("bar")); + CHECK(ac.entryMap["bar"].parens == ParenthesesRecommendation::CursorInside); +} + +TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses") +{ + check(R"( +local function foo() return 1 end +local function bar(a: number) return -a end +local abc = bar(@1) + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("foo")); + CHECK(ac.entryMap["foo"].parens == ParenthesesRecommendation::CursorAfter); +} + TEST_CASE_FIXTURE(ACFixture, "type_correct_sealed_table") { check(R"( local function f(a: { x: number, y: number }) return a.x + a.y end -local fp: = f +local fp: @1= f )"); - auto ac = autocomplete(2, 10); + auto ac = autocomplete('1'); + REQUIRE_EQ("({| x: number, y: number |}) -> number", toString(requireType("f"))); CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// CLI-45692: Remove UnfrozenFixture here -TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_keywords") +TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords") { check(R"( local function a(x: boolean) end @@ -1951,33 +2082,33 @@ local function e(x: ((number) -> string) & ((boolean) -> number)) end local tru = {} local ni = false -local ac = a(t) -local bc = b(n) -local cc = c(f) -local dc = d(f) -local ec = e(f) +local ac = a(t@1) +local bc = b(n@2) +local cc = c(f@3) +local dc = d(f@4) +local ec = e(f@5) )"); - auto ac = autocomplete(frontend, "MainModule", Position{10, 14}, nullCallback); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("tru")); CHECK(ac.entryMap["tru"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{11, 14}, nullCallback); + ac = autocomplete('2'); CHECK(ac.entryMap.count("ni")); CHECK(ac.entryMap["ni"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["nil"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{12, 14}, nullCallback); + ac = autocomplete('3'); CHECK(ac.entryMap.count("false")); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{13, 14}, nullCallback); + ac = autocomplete('4'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{14, 14}, nullCallback); + ac = autocomplete('5'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); } @@ -1988,10 +2119,10 @@ local target: ((number) -> string) & ((string) -> number)) local one = 4 local two = "hello" -return target(o) +return target(o@1) )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2002,10 +2133,10 @@ local target: ((number) -> string) & ((number) -> number)) local one = 4 local two = "hello" -return target(o) +return target(o@1) )"); - ac = autocomplete(5, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2016,10 +2147,10 @@ local target: ((number, number) -> string) & ((string) -> number)) local one = 4 local two = "hello" -return target(1, o) +return target(1, o@1) )"); - ac = autocomplete(5, 18); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2032,10 +2163,10 @@ TEST_CASE_FIXTURE(ACFixture, "optional_members") local a = { x = 2, y = 3 } type A = typeof(a) local b: A? = a -return b. +return b.@1 )"); - auto ac = autocomplete(4, 9); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); @@ -2045,10 +2176,10 @@ return b. local a = { x = 2, y = 3 } type A = typeof(a) local b: nil | A = a -return b. +return b.@1 )"); - ac = autocomplete(4, 9); + ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); @@ -2056,10 +2187,10 @@ return b. check(R"( local b: nil | nil -return b. +return b.@1 )"); - ac = autocomplete(2, 9); + ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -2067,26 +2198,26 @@ return b. TEST_CASE_FIXTURE(ACFixture, "no_function_name_suggestions") { check(R"( -function na +function na@1 )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( -local function +local function @1 )"); - ac = autocomplete(1, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( -local function na +local function na@1 )"); - ac = autocomplete(1, 17); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } @@ -2095,20 +2226,20 @@ TEST_CASE_FIXTURE(ACFixture, "skip_current_local") { check(R"( local other = 1 -local name = na +local name = na@1 )"); - auto ac = autocomplete(2, 15); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.count("name")); CHECK(ac.entryMap.count("other")); check(R"( local other = 1 -local name, test = na +local name, test = na@1 )"); - ac = autocomplete(2, 21); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("name")); CHECK(!ac.entryMap.count("test")); @@ -2119,26 +2250,26 @@ TEST_CASE_FIXTURE(ACFixture, "keyword_members") { check(R"( local a = { done = 1, forever = 2 } -local b = a.do -local c = a.for -local d = a. +local b = a.do@1 +local c = a.for@2 +local d = a.@3 do end )"); - auto ac = autocomplete(2, 14); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); CHECK(ac.entryMap.count("forever")); - ac = autocomplete(3, 15); + ac = autocomplete('2'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); CHECK(ac.entryMap.count("forever")); - ac = autocomplete(4, 12); + ac = autocomplete('3'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2150,10 +2281,10 @@ TEST_CASE_FIXTURE(ACFixture, "keyword_methods") check(R"( local a = {} function a:done() end -local b = a:do +local b = a:do@1 )"); - auto ac = autocomplete(3, 14); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2183,45 +2314,18 @@ local a: aaa.do CHECK(ac.entryMap.count("other")); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") + +TEST_CASE_FIXTURE(ACFixture, "comments") { - std::string_view source = R"( - local a = table. -- Line 1 - -- | Column 23 - )"; + fileResolver.source["Comments"] = "--!str"; - auto ac = autocompleteSource(frontend, source, Position{1, 24}, nullCallback).result; - - CHECK_EQ(16, ac.entryMap.size()); - CHECK(ac.entryMap.count("find")); - CHECK(ac.entryMap.count("pack")); - CHECK(!ac.entryMap.count("math")); -} - -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_require") -{ - ScopedFastFlag luauResolveModuleNameWithoutACurrentModule("LuauResolveModuleNameWithoutACurrentModule", true); - - std::string_view source = R"( - local a = require(w -- Line 1 - -- | Column 27 - )"; - - // CLI-43699 require shouldn't crash inside autocompleteSource - auto ac = autocompleteSource(frontend, source, Position{1, 27}, nullCallback).result; -} - -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_comments") -{ - std::string_view source = "--!str"; - - auto ac = autocompleteSource(frontend, source, Position{0, 6}, nullCallback).result; + auto ac = Luau::autocomplete(frontend, "Comments", Position{0, 6}, nullCallback); CHECK_EQ(0, ac.entryMap.size()); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteProp_index_function_metamethod_is_variadic") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteProp_index_function_metamethod_is_variadic") { - std::string_view source = R"( + fileResolver.source["Module/A"] = R"( type Foo = {x: number} local t = {} setmetatable(t, { @@ -2234,7 +2338,7 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteProp_index_function_metamethod_is_vari -- | Column 20 )"; - auto ac = autocompleteSource(frontend, source, Position{9, 20}, nullCallback).result; + auto ac = Luau::autocomplete(frontend, "Module/A", Position{9, 20}, nullCallback); REQUIRE_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); } @@ -2247,29 +2351,29 @@ local elsewhere = false local doover = false local endurance = true -if 1 then -else +if 1 then@1 +else@2 end -while false do +while false do@3 end -repeat +repeat@4 until )"); - auto ac = autocomplete(6, 9); + auto ac = autocomplete('1'); CHECK(ac.entryMap.size() == 1); CHECK(ac.entryMap.count("then")); - ac = autocomplete(7, 4); + ac = autocomplete('2'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); - ac = autocomplete(10, 14); + ac = autocomplete('3'); CHECK(ac.entryMap.count("do")); - ac = autocomplete(13, 6); + ac = autocomplete('4'); CHECK(ac.entryMap.count("do")); // FIXME: ideally we want to handle start and end of all statements as well @@ -2277,18 +2381,16 @@ until TEST_CASE_FIXTURE(ACFixture, "if_then_else_elseif_completions") { - ScopedFastFlag sff{"ElseElseIfCompletionImprovements", true}; - check(R"( local elsewhere = false if true then return 1 -el +el@1 end )"); - auto ac = autocomplete(5, 2); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); CHECK(ac.entryMap.count("elsewhere") == 0); @@ -2300,11 +2402,11 @@ if true then return 1 else return 2 -el +el@1 end )"); - ac = autocomplete(7, 2); + ac = autocomplete('1'); CHECK(ac.entryMap.count("else") == 0); CHECK(ac.entryMap.count("elseif") == 0); CHECK(ac.entryMap.count("elsewhere")); @@ -2316,163 +2418,175 @@ if true then print("1") elif true then print("2") -el +el@1 end )"); - ac = autocomplete(7, 2); + ac = autocomplete('1'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); CHECK(ac.entryMap.count("elsewhere")); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_not_the_var_we_are_defining") +TEST_CASE_FIXTURE(ACFixture, "not_the_var_we_are_defining") { - std::string_view source = "abc,de"; + fileResolver.source["Module/A"] = "abc,de"; - auto ac = autocompleteSource(frontend, source, Position{0, 6}, nullCallback).result; + auto ac = Luau::autocomplete(frontend, "Module/A", Position{0, 6}, nullCallback); CHECK(!ac.entryMap.count("de")); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_recursive_function") +TEST_CASE_FIXTURE(ACFixture, "recursive_function_global") { - { - std::string_view global = R"(function abc() + fileResolver.source["global"] = R"(function abc() end )"; - auto ac = autocompleteSource(frontend, global, Position{1, 0}, nullCallback).result; - CHECK(ac.entryMap.count("abc")); - } + auto ac = Luau::autocomplete(frontend, "global", Position{1, 0}, nullCallback); + CHECK(ac.entryMap.count("abc")); +} - { - std::string_view local = R"(local function abc() + + +TEST_CASE_FIXTURE(ACFixture, "recursive_function_local") +{ + fileResolver.source["local"] = R"(local function abc() end )"; - auto ac = autocompleteSource(frontend, local, Position{1, 0}, nullCallback).result; - CHECK(ac.entryMap.count("abc")); - } + auto ac = Luau::autocomplete(frontend, "local", Position{1, 0}, nullCallback); + CHECK(ac.entryMap.count("abc")); } TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys") { check(R"( type Test = { first: number, second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - auto ac = autocomplete(2, 19); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); + CHECK_EQ(ac.context, AutocompleteContext::Property); // Intersection check(R"( type Test = { first: number } & { second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); + CHECK_EQ(ac.context, AutocompleteContext::Property); // Union check(R"( type Test = { first: number, second: number } | { second: number, third: number } -local t: Test = { s } +local t: Test = { s@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("second")); CHECK(!ac.entryMap.count("first")); CHECK(!ac.entryMap.count("third")); + CHECK_EQ(ac.context, AutocompleteContext::Property); // No parenthesis suggestion check(R"( type Test = { first: (number) -> number, second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap["first"].parens == ParenthesesRecommendation::None); + CHECK_EQ(ac.context, AutocompleteContext::Property); // When key is changed check(R"( type Test = { first: number, second: number } -local t: Test = { f = 2 } +local t: Test = { f@1 = 2 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); + CHECK_EQ(ac.context, AutocompleteContext::Property); // Alternative key syntax check(R"( type Test = { first: number, second: number } -local t: Test = { ["f"] } +local t: Test = { ["f@1"] } )"); - ac = autocomplete(2, 21); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); + CHECK_EQ(ac.context, AutocompleteContext::Property); // Not an alternative key syntax check(R"( type Test = { first: number, second: number } -local t: Test = { "f" } +local t: Test = { "f@1" } )"); - ac = autocomplete(2, 20); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("first")); CHECK(!ac.entryMap.count("second")); + CHECK_EQ(ac.context, AutocompleteContext::String); // Skip keys that are already defined check(R"( type Test = { first: number, second: number } -local t: Test = { first = 2, s } +local t: Test = { first = 2, s@1 } )"); - ac = autocomplete(2, 30); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); + CHECK_EQ(ac.context, AutocompleteContext::Property); // Don't skip active key check(R"( type Test = { first: number, second: number } -local t: Test = { first } +local t: Test = { first@1 } )"); - ac = autocomplete(2, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); + CHECK_EQ(ac.context, AutocompleteContext::Property); // Inference after first key check(R"( local t = { { first = 5, second = 10 }, - { f } + { f@1 } } )"); - ac = autocomplete(3, 7); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); + CHECK_EQ(ac.context, AutocompleteContext::Property); check(R"( local t = { [2] = { first = 5, second = 10 }, - [5] = { f } + [5] = { f@1 } } )"); - ac = autocomplete(3, 13); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); + CHECK_EQ(ac.context, AutocompleteContext::Property); } -TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") +TEST_CASE_FIXTURE(ACFixture, "autocomplete_documentation_symbols") { loadDefinition(R"( declare y: { @@ -2480,13 +2594,11 @@ TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") } )"); - fileResolver.source["Module/A"] = R"( - local a = y. - )"; + check(R"( + local a = y.@1 + )"); - frontend.check("Module/A"); - - auto ac = autocomplete(frontend, "Module/A", Position{1, 21}, nullCallback); + auto ac = autocomplete('1'); REQUIRE(ac.entryMap.count("x")); CHECK_EQ(ac.entryMap["x"].documentationSymbol, "@test/global/y.x"); @@ -2494,83 +2606,729 @@ TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - { - check(R"( + check(R"( local temp = false local even = true; local a = true -a = if t1@emp then t -a = if temp t2@ -a = if temp then e3@ -a = if temp then even e4@ -a = if temp then even elseif t5@ -a = if temp then even elseif true t6@ -a = if temp then even elseif true then t7@ -a = if temp then even elseif true then temp e8@ -a = if temp then even elseif true then temp else e9@ +a = if t@1emp then t +a = if temp t@2 +a = if temp then e@3 +a = if temp then even e@4 +a = if temp then even elseif t@5 +a = if temp then even elseif true t@6 +a = if temp then even elseif true then t@7 +a = if temp then even elseif true then temp e@8 +a = if temp then even elseif true then temp else e@9 + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + CHECK_EQ(ac.context, AutocompleteContext::Expression); + + ac = autocomplete('2'); + CHECK(ac.entryMap.count("temp") == 0); + CHECK(ac.entryMap.count("true") == 0); + CHECK(ac.entryMap.count("then")); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + CHECK_EQ(ac.context, AutocompleteContext::Keyword); + + ac = autocomplete('3'); + CHECK(ac.entryMap.count("even")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + CHECK_EQ(ac.context, AutocompleteContext::Expression); + + ac = autocomplete('4'); + CHECK(ac.entryMap.count("even") == 0); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else")); + CHECK(ac.entryMap.count("elseif")); + CHECK_EQ(ac.context, AutocompleteContext::Keyword); + + ac = autocomplete('5'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + CHECK_EQ(ac.context, AutocompleteContext::Expression); + + ac = autocomplete('6'); + CHECK(ac.entryMap.count("temp") == 0); + CHECK(ac.entryMap.count("true") == 0); + CHECK(ac.entryMap.count("then")); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + CHECK_EQ(ac.context, AutocompleteContext::Keyword); + + ac = autocomplete('7'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + CHECK_EQ(ac.context, AutocompleteContext::Expression); + + ac = autocomplete('8'); + CHECK(ac.entryMap.count("even") == 0); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else")); + CHECK(ac.entryMap.count("elseif")); + CHECK_EQ(ac.context, AutocompleteContext::Keyword); + + ac = autocomplete('9'); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_else_regression") +{ + check(R"( +local abcdef = 0; +local temp = false +local even = true; +local a +a = if temp then even else@1 +a = if temp then even else @2 +a = if temp then even else abc@3 + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("else") == 0); + ac = autocomplete('2'); + CHECK(ac.entryMap.count("else") == 0); + ac = autocomplete('3'); + CHECK(ac.entryMap.count("abcdef")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_constant") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + check(R"(f(`@1`))"); + auto ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); + + check(R"(f(`@1 {"a"}`))"); + ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); + + check(R"(f(`{"a"} @1`))"); + ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); + + check(R"(f(`{"a"} @1 {"b"}`))"); + ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_expression") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + check(R"(f(`expression = {@1}`))"); + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_expression_with_comments") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + check(R"(f(`expression = {--[[ bla bla bla ]]@1`))"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); + + check(R"(f(`expression = {@1 --[[ bla bla bla ]]`))"); + ac = autocomplete('1'); + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_as_singleton") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + check(R"( + --!strict + local function f(a: "cat" | "dog") end + + f(`@1`) + f(`uhhh{'try'}@2`) + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("cat")); + CHECK_EQ(ac.context, AutocompleteContext::String); + + ac = autocomplete('2'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") +{ + check(R"( +type A = () -> T... +local a: A<(number, s@1> + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); + CHECK_EQ(ac.context, AutocompleteContext::Type); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_first_function_arg_expected_type") +{ + check(R"( +local function foo1() return 1 end +local function foo2() return "1" end + +local function bar0() return "got" .. a end +local function bar1(a: number) return "got " .. a end +local function bar2(a: number, b: string) return "got " .. a .. b end + +local t = {} +function t:bar1(a: number) return "got " .. a end + +local r1 = bar0(@1) +local r2 = bar1(@2) +local r3 = bar2(@3) +local r4 = t:bar1(@4) + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::None); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('2'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('3'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('4'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_parameters") +{ + check(R"( +type A = () -> T + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); + CHECK_EQ(ac.context, AutocompleteContext::Type); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_pack_parameters") +{ + check(R"( +type A = () -> T + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); + CHECK_EQ(ac.context, AutocompleteContext::Type); +} + +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_oop_implicit_self") +{ + check(R"( +--!strict +local Class = {} +Class.__index = Class +type Class = typeof(setmetatable({} :: { x: number }, Class)) +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end +function Class.getx(self: Class) + return self.x +end +function test() + local c = Class.new(42) + local n = c:@1 + print(n) +end + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("getx")); +} + +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") +{ + check(R"( + --!strict + local foo: "hello" | "bye" = "hello" + foo:@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("format")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") +{ + check(R"( + type tag = "cat" | "dog" + local function f(a: tag) end + f("@1") + f(@2) + local x: tag = "@3" + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + CHECK_EQ(ac.context, AutocompleteContext::String); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("\"cat\"")); + CHECK(ac.entryMap.count("\"dog\"")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); + + ac = autocomplete('3'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + CHECK_EQ(ac.context, AutocompleteContext::String); + + check(R"( + type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number} + local x: tagged = {tag="@4"} + )"); + + ac = autocomplete('4'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + CHECK_EQ(ac.context, AutocompleteContext::String); +} + +TEST_CASE_FIXTURE(ACFixture, "string_singleton_as_table_key") +{ + ScopedFastFlag sff{"LuauCompleteTableKeysBetter", true}; + + check(R"( + type Direction = "up" | "down" + + local a: {[Direction]: boolean} = {[@1] = true} + local b: {[Direction]: boolean} = {["@2"] = true} + local c: {[Direction]: boolean} = {u@3 = true} + local d: {[Direction]: boolean} = {[u@4] = true} + + local e: {[Direction]: boolean} = {[@5]} + local f: {[Direction]: boolean} = {["@6"]} + local g: {[Direction]: boolean} = {u@7} + local h: {[Direction]: boolean} = {[u@8]} + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("\"up\"")); + CHECK(ac.entryMap.count("\"down\"")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("up")); + CHECK(ac.entryMap.count("down")); + + ac = autocomplete('3'); + + CHECK(ac.entryMap.count("up")); + CHECK(ac.entryMap.count("down")); + + ac = autocomplete('4'); + + CHECK(!ac.entryMap.count("up")); + CHECK(!ac.entryMap.count("down")); + + CHECK(ac.entryMap.count("\"up\"")); + CHECK(ac.entryMap.count("\"down\"")); + + ac = autocomplete('5'); + + CHECK(ac.entryMap.count("\"up\"")); + CHECK(ac.entryMap.count("\"down\"")); + + ac = autocomplete('6'); + + CHECK(ac.entryMap.count("up")); + CHECK(ac.entryMap.count("down")); + + ac = autocomplete('7'); + + CHECK(ac.entryMap.count("up")); + CHECK(ac.entryMap.count("down")); + + ac = autocomplete('8'); + + CHECK(!ac.entryMap.count("up")); + CHECK(!ac.entryMap.count("down")); + + CHECK(ac.entryMap.count("\"up\"")); + CHECK(ac.entryMap.count("\"down\"")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") +{ + check(R"( + type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number} + local x: tagged = {tag="cat", fieldx=2} + if x.tag == "@1" or "@2" ~= x.tag then end + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + // CLI-48823: assignment to x.tag should also autocomplete, but union l-values are not supported yet +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_boolean_singleton") +{ + check(R"( +local function f(x: true) end +f(@1) + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("true")); + CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); + REQUIRE(ac.entryMap.count("false")); + CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") +{ + check(R"( + type tag = "strange\t\"cat\"" | 'nice\t"dog"' + local function f(x: tag) end + f(@1) + f("@2") + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("\"strange\\t\\\"cat\\\"\"")); + CHECK(ac.entryMap.count("\"nice\\t\\\"dog\\\"\"")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("strange\\t\\\"cat\\\"")); + CHECK(ac.entryMap.count("nice\\t\\\"dog\\\"")); +} + +TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2") +{ + check(R"( +local bar: ((number) -> number) & (number, number) -> number) +local abc = b@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("bar")); + CHECK(ac.entryMap["bar"].parens == ParenthesesRecommendation::CursorInside); +} + +TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_on_class") +{ + loadDefinition(R"( +declare class Foo + function one(self): number + two: () -> number +end + )"); + + { + check(R"( +local t: Foo +t:@1 )"); auto ac = autocomplete('1'); - CHECK(ac.entryMap.count("temp")); - CHECK(ac.entryMap.count("true")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); - ac = autocomplete('2'); - CHECK(ac.entryMap.count("temp") == 0); - CHECK(ac.entryMap.count("true") == 0); - CHECK(ac.entryMap.count("then")); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); + REQUIRE(ac.entryMap.count("one")); + REQUIRE(ac.entryMap.count("two")); + CHECK(!ac.entryMap["one"].wrongIndexType); + CHECK(ac.entryMap["two"].wrongIndexType); + } - ac = autocomplete('3'); - CHECK(ac.entryMap.count("even")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); + { + check(R"( +local t: Foo +t.@1 + )"); - ac = autocomplete('4'); - CHECK(ac.entryMap.count("even") == 0); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else")); - CHECK(ac.entryMap.count("elseif")); + auto ac = autocomplete('1'); - ac = autocomplete('5'); - CHECK(ac.entryMap.count("temp")); - CHECK(ac.entryMap.count("true")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); - - ac = autocomplete('6'); - CHECK(ac.entryMap.count("temp") == 0); - CHECK(ac.entryMap.count("true") == 0); - CHECK(ac.entryMap.count("then")); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); - - ac = autocomplete('7'); - CHECK(ac.entryMap.count("temp")); - CHECK(ac.entryMap.count("true")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); - - ac = autocomplete('8'); - CHECK(ac.entryMap.count("even") == 0); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else")); - CHECK(ac.entryMap.count("elseif")); - - ac = autocomplete('9'); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); + REQUIRE(ac.entryMap.count("one")); + REQUIRE(ac.entryMap.count("two")); + CHECK(ac.entryMap["one"].wrongIndexType); + CHECK(!ac.entryMap["two"].wrongIndexType); } } +TEST_CASE_FIXTURE(ACFixture, "do_compatible_self_calls") +{ + check(R"( +local t = {} +function t:m() end +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + CHECK(!ac.entryMap["m"].wrongIndexType); +} + +TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") +{ + check(R"( +local t = {} +function t.m() end +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + CHECK(ac.entryMap["m"].wrongIndexType); +} + +TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") +{ + check(R"( +local f: (() -> number) & ((number) -> number) = function(x: number?) return 2 end +local t = {} +t.f = f +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("f")); + CHECK(ac.entryMap["f"].wrongIndexType); +} + +TEST_CASE_FIXTURE(ACFixture, "do_wrong_compatible_self_calls") +{ + check(R"( +local t = {} +function t.m(x: typeof(t)) end +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + // We can make changes to mark this as a wrong way to call even though it's compatible + CHECK(!ac.entryMap["m"].wrongIndexType); +} + +TEST_CASE_FIXTURE(ACFixture, "no_wrong_compatible_self_calls_with_generics") +{ + check(R"( +local t = {} +function t.m(a: T) end +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + // While this call is compatible with the type, this requires instantiation of a generic type which we don't perform + CHECK(ac.entryMap["m"].wrongIndexType); +} + +TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") +{ + check(R"( +local s = "hello" +s:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("byte")); + CHECK(ac.entryMap["byte"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == true); + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == false); +} + +TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") +{ + check(R"( +local s = "hello" +s.@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == true); +} + +TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_non_self_calls_are_fine") +{ + check(R"( +string.@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("byte")); + CHECK(ac.entryMap["byte"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == false); + + check(R"( +table.@1 + )"); + + ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("remove")); + CHECK(ac.entryMap["remove"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("getn")); + CHECK(ac.entryMap["getn"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("insert")); + CHECK(ac.entryMap["insert"].wrongIndexType == false); +} + +TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_self_calls_are_invalid") +{ + check(R"( +string:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("byte")); + CHECK(ac.entryMap["byte"].wrongIndexType == true); + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == true); + + // We want the next test to evaluate to 'true', but we have to allow function defined with 'self' to be callable with ':' + // We may change the definition of the string metatable to not use 'self' types in the future (like byte/char/pack/unpack) + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == false); +} + +TEST_CASE_FIXTURE(ACFixture, "source_module_preservation_and_invalidation") +{ + check(R"( +local a = { x = 2, y = 4 } +a.@1 + )"); + + frontend.clear(); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.check("MainModule", {}); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.markDirty("MainModule", nullptr); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.check("MainModule", {}); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); +} + +TEST_CASE_FIXTURE(ACFixture, "globals_are_order_independent") +{ + check(R"( + local myLocal = 4 + function abc0() + local myInnerLocal = 1 +@1 + end + + function abc1() + local myInnerLocal = 1 + end + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("myLocal")); + CHECK(ac.entryMap.count("myInnerLocal")); + CHECK(ac.entryMap.count("abc0")); + CHECK(ac.entryMap.count("abc1")); +} + TEST_SUITE_END(); diff --git a/tests/BuiltinDefinitions.test.cpp b/tests/BuiltinDefinitions.test.cpp index dbe80f2c..496df4b4 100644 --- a/tests/BuiltinDefinitions.test.cpp +++ b/tests/BuiltinDefinitions.test.cpp @@ -10,8 +10,10 @@ using namespace Luau; TEST_SUITE_BEGIN("BuiltinDefinitionsTest"); -TEST_CASE_FIXTURE(Fixture, "lib_documentation_symbols") +TEST_CASE_FIXTURE(BuiltinsFixture, "lib_documentation_symbols") { + CHECK(!typeChecker.globalScope->bindings.empty()); + for (const auto& [name, binding] : typeChecker.globalScope->bindings) { std::string nameString(name.c_str()); diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp new file mode 100644 index 00000000..18939e24 --- /dev/null +++ b/tests/ClassFixture.cpp @@ -0,0 +1,113 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "ClassFixture.h" + +#include "Luau/BuiltinDefinitions.h" + +using std::nullopt; + +namespace Luau +{ + +ClassFixture::ClassFixture() +{ + TypeArena& arena = typeChecker.globalTypes; + TypeId numberType = typeChecker.numberType; + + unfreeze(arena); + + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(baseClassInstanceType)->props = { + {"BaseMethod", {makeFunction(arena, baseClassInstanceType, {numberType}, {})}}, + {"BaseField", {numberType}}, + }; + + TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(baseClassType)->props = { + {"StaticMethod", {makeFunction(arena, nullopt, {}, {numberType})}}, + {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, + {"New", {makeFunction(arena, nullopt, {}, {baseClassInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + addGlobalBinding(frontend, "BaseClass", baseClassType, "@test"); + + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); + + getMutable(childClassInstanceType)->props = { + {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, + }; + + TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); + getMutable(childClassType)->props = { + {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + addGlobalBinding(frontend, "ChildClass", childClassType, "@test"); + + TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); + + getMutable(grandChildInstanceType)->props = { + {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, + }; + + TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); + getMutable(grandChildType)->props = { + {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; + addGlobalBinding(frontend, "GrandChild", childClassType, "@test"); + + TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); + + getMutable(anotherChildInstanceType)->props = { + {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, + }; + + TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); + getMutable(anotherChildType)->props = { + {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; + addGlobalBinding(frontend, "AnotherChild", childClassType, "@test"); + + TypeId unrelatedClassInstanceType = arena.addType(ClassTypeVar{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); + + TypeId unrelatedClassType = arena.addType(ClassTypeVar{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(unrelatedClassType)->props = { + {"New", {makeFunction(arena, nullopt, {}, {unrelatedClassInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["UnrelatedClass"] = TypeFun{{}, unrelatedClassInstanceType}; + addGlobalBinding(frontend, "UnrelatedClass", unrelatedClassType, "@test"); + + TypeId vector2MetaType = arena.addType(TableTypeVar{}); + + TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"}); + getMutable(vector2InstanceType)->props = { + {"X", {numberType}}, + {"Y", {numberType}}, + }; + + TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"}); + getMutable(vector2Type)->props = { + {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, + }; + getMutable(vector2MetaType)->props = { + {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; + addGlobalBinding(frontend, "Vector2", vector2Type, "@test"); + + TypeId callableClassMetaType = arena.addType(TableTypeVar{}); + TypeId callableClassType = arena.addType(ClassTypeVar{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); + getMutable(callableClassMetaType)->props = { + {"__call", {makeFunction(arena, nullopt, {callableClassType, typeChecker.stringType}, {typeChecker.numberType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; + + for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) + persist(tf.type); + + freeze(arena); +} + +} // namespace Luau diff --git a/tests/ClassFixture.h b/tests/ClassFixture.h new file mode 100644 index 00000000..66aec764 --- /dev/null +++ b/tests/ClassFixture.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +namespace Luau +{ + +struct ClassFixture : BuiltinsFixture +{ + ClassFixture(); +}; + +} // namespace Luau diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp new file mode 100644 index 00000000..65b485a7 --- /dev/null +++ b/tests/CodeAllocator.test.cpp @@ -0,0 +1,480 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/AssemblyBuilderA64.h" +#include "Luau/CodeAllocator.h" +#include "Luau/CodeBlockUnwind.h" +#include "Luau/UnwindBuilder.h" +#include "Luau/UnwindBuilderDwarf2.h" +#include "Luau/UnwindBuilderWin.h" + +#include "doctest.h" + +#include +#include + +#include + +using namespace Luau::CodeGen; + +TEST_SUITE_BEGIN("CodeAllocation"); + +TEST_CASE("CodeAllocation") +{ + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + uint8_t* nativeData = nullptr; + size_t sizeNativeData = 0; + uint8_t* nativeEntry = nullptr; + + std::vector code; + code.resize(128); + + REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + CHECK(nativeData != nullptr); + CHECK(sizeNativeData == 128); + CHECK(nativeEntry != nullptr); + CHECK(nativeEntry == nativeData); + + std::vector data; + data.resize(8); + + REQUIRE(allocator.allocate(data.data(), data.size(), code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + CHECK(nativeData != nullptr); + CHECK(sizeNativeData == kCodeAlignment + 128); + CHECK(nativeEntry != nullptr); + CHECK(nativeEntry == nativeData + kCodeAlignment); +} + +TEST_CASE("CodeAllocationFailure") +{ + size_t blockSize = 3000; + size_t maxTotalSize = 7000; + CodeAllocator allocator(blockSize, maxTotalSize); + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + + std::vector code; + code.resize(4000); + + // allocation has to fit in a block + REQUIRE(!allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + + // each allocation exhausts a block, so third allocation fails + code.resize(2000); + REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(!allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); +} + +TEST_CASE("CodeAllocationWithUnwindCallbacks") +{ + struct Info + { + std::vector unwind; + uint8_t* block = nullptr; + bool destroyCalled = false; + }; + Info info; + info.unwind.resize(8); + + { + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + uint8_t* nativeData = nullptr; + size_t sizeNativeData = 0; + uint8_t* nativeEntry = nullptr; + + std::vector code; + code.resize(128); + + std::vector data; + data.resize(8); + + allocator.context = &info; + allocator.createBlockUnwindInfo = [](void* context, uint8_t* block, size_t blockSize, size_t& beginOffset) -> void* { + Info& info = *(Info*)context; + + CHECK(info.unwind.size() == 8); + memcpy(block, info.unwind.data(), info.unwind.size()); + beginOffset = 8; + + info.block = block; + + return new int(7); + }; + allocator.destroyBlockUnwindInfo = [](void* context, void* unwindData) { + Info& info = *(Info*)context; + + info.destroyCalled = true; + + CHECK(*(int*)unwindData == 7); + delete (int*)unwindData; + }; + + REQUIRE(allocator.allocate(data.data(), data.size(), code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + CHECK(nativeData != nullptr); + CHECK(sizeNativeData == kCodeAlignment + 128); + CHECK(nativeEntry != nullptr); + CHECK(nativeEntry == nativeData + kCodeAlignment); + CHECK(nativeData == info.block + kCodeAlignment); + } + + CHECK(info.destroyCalled); +} + +#if !defined(LUAU_BIG_ENDIAN) +TEST_CASE("WindowsUnwindCodesX64") +{ + UnwindBuilderWin unwind; + + unwind.start(); + unwind.spill(16, rdx); + unwind.spill(8, rcx); + unwind.save(rdi); + unwind.save(rsi); + unwind.save(rbx); + unwind.save(rbp); + unwind.save(r12); + unwind.save(r13); + unwind.save(r14); + unwind.save(r15); + unwind.allocStack(72); + unwind.setupFrameReg(rbp, 48); + unwind.finish(); + + std::vector data; + data.resize(unwind.getSize()); + unwind.finalize(data.data(), nullptr, 0); + + std::vector expected{0x01, 0x23, 0x0a, 0x35, 0x23, 0x33, 0x1e, 0x82, 0x1a, 0xf0, 0x18, 0xe0, 0x16, 0xd0, 0x14, 0xc0, 0x12, 0x50, 0x10, + 0x30, 0x0e, 0x60, 0x0c, 0x70}; + + REQUIRE(data.size() == expected.size()); + CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); +} +#endif + +TEST_CASE("Dwarf2UnwindCodesX64") +{ + UnwindBuilderDwarf2 unwind; + + unwind.start(); + unwind.save(rdi); + unwind.save(rsi); + unwind.save(rbx); + unwind.save(rbp); + unwind.save(r12); + unwind.save(r13); + unwind.save(r14); + unwind.save(r15); + unwind.allocStack(72); + unwind.setupFrameReg(rbp, 48); + unwind.finish(); + + std::vector data; + data.resize(unwind.getSize()); + unwind.finalize(data.data(), nullptr, 0); + + std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x05, 0x10, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x85, 0x02, 0x02, 0x02, 0x0e, 0x18, 0x84, 0x03, 0x02, 0x02, 0x0e, 0x20, 0x83, + 0x04, 0x02, 0x02, 0x0e, 0x28, 0x86, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, 0x0e, 0x40, + 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x02, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00}; + + REQUIRE(data.size() == expected.size()); + CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); +} + +#if defined(__x86_64__) || defined(_M_X64) + +#if defined(_WIN32) +// Windows x64 ABI +constexpr RegisterX64 rArg1 = rcx; +constexpr RegisterX64 rArg2 = rdx; +constexpr RegisterX64 rArg3 = r8; +#else +// System V AMD64 ABI +constexpr RegisterX64 rArg1 = rdi; +constexpr RegisterX64 rArg2 = rsi; +constexpr RegisterX64 rArg3 = rdx; +#endif + +constexpr RegisterX64 rNonVol1 = r12; +constexpr RegisterX64 rNonVol2 = rbx; + +TEST_CASE("GeneratedCodeExecutionX64") +{ + AssemblyBuilderX64 build(/* logText= */ false); + + build.mov(rax, rArg1); + build.add(rax, rArg2); + build.imul(rax, rax, 7); + build.ret(); + + build.finalize(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, int64_t); + FunctionType* f = (FunctionType*)nativeEntry; + int64_t result = f(10, 20); + CHECK(result == 210); +} + +void throwing(int64_t arg) +{ + CHECK(arg == 25); + + throw std::runtime_error("testing"); +} + +TEST_CASE("GeneratedCodeExecutionWithThrowX64") +{ + AssemblyBuilderX64 build(/* logText= */ false); + +#if defined(_WIN32) + std::unique_ptr unwind = std::make_unique(); +#else + std::unique_ptr unwind = std::make_unique(); +#endif + + unwind->start(); + + // Prologue + build.push(rNonVol1); + unwind->save(rNonVol1); + build.push(rNonVol2); + unwind->save(rNonVol2); + build.push(rbp); + unwind->save(rbp); + + int stackSize = 32; + int localsSize = 16; + + build.sub(rsp, stackSize + localsSize); + unwind->allocStack(stackSize + localsSize); + + build.lea(rbp, addr[rsp + stackSize]); + unwind->setupFrameReg(rbp, stackSize); + + unwind->finish(); + + // Body + build.mov(rNonVol1, rArg1); + build.mov(rNonVol2, rArg2); + + build.add(rNonVol1, 15); + build.mov(rArg1, rNonVol1); + build.call(rNonVol2); + + // Epilogue + build.lea(rsp, addr[rbp + localsSize]); + build.pop(rbp); + build.pop(rNonVol2); + build.pop(rNonVol1); + build.ret(); + + build.finalize(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, void (*)(int64_t)); + FunctionType* f = (FunctionType*)nativeEntry; + + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } +} + +TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") +{ + AssemblyBuilderX64 build(/* logText= */ false); + +#if defined(_WIN32) + std::unique_ptr unwind = std::make_unique(); +#else + std::unique_ptr unwind = std::make_unique(); +#endif + + unwind->start(); + + // Prologue (some of these registers don't have to be saved, but we want to have a big prologue) + build.push(r10); + unwind->save(r10); + build.push(r11); + unwind->save(r11); + build.push(r12); + unwind->save(r12); + build.push(r13); + unwind->save(r13); + build.push(r14); + unwind->save(r14); + build.push(r15); + unwind->save(r15); + build.push(rbp); + unwind->save(rbp); + + int stackSize = 64; + int localsSize = 16; + + build.sub(rsp, stackSize + localsSize); + unwind->allocStack(stackSize + localsSize); + + build.lea(rbp, addr[rsp + stackSize]); + unwind->setupFrameReg(rbp, stackSize); + + unwind->finish(); + + size_t prologueSize = build.setLabel().location; + + // Body + build.mov(rax, rArg1); + build.mov(rArg1, 25); + build.jmp(rax); + + Label returnOffset = build.setLabel(); + + // Epilogue + build.lea(rsp, addr[rbp + localsSize]); + build.pop(rbp); + build.pop(r15); + build.pop(r14); + build.pop(r13); + build.pop(r12); + build.pop(r11); + build.pop(r10); + build.ret(); + + build.finalize(); + + size_t blockSize = 4096; // Force allocate to create a new block each time + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData1; + size_t sizeNativeData1; + uint8_t* nativeEntry1; + REQUIRE( + allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData1, sizeNativeData1, nativeEntry1)); + REQUIRE(nativeEntry1); + + // Now we set the offset at the begining so that functions in new blocks will not overlay the locations + // specified by the unwind information of the entry function + unwind->setBeginOffset(prologueSize); + + using FunctionType = int64_t(void*, void (*)(int64_t), void*); + FunctionType* f = (FunctionType*)nativeEntry1; + + uint8_t* nativeExit = nativeEntry1 + returnOffset.location; + + AssemblyBuilderX64 build2(/* logText= */ false); + + build2.mov(r12, rArg3); + build2.call(rArg2); + build2.jmp(r12); + + build2.finalize(); + + uint8_t* nativeData2; + size_t sizeNativeData2; + uint8_t* nativeEntry2; + REQUIRE(allocator.allocate( + build2.data.data(), build2.data.size(), build2.code.data(), build2.code.size(), nativeData2, sizeNativeData2, nativeEntry2)); + REQUIRE(nativeEntry2); + + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(nativeEntry2, throwing, nativeExit); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } + + REQUIRE(nativeEntry2); +} + +#endif + +#if defined(__aarch64__) + +TEST_CASE("GeneratedCodeExecutionA64") +{ + AssemblyBuilderA64 build(/* logText= */ false); + + Label skip; + build.cbz(x1, skip); + build.ldrsw(x1, x1); + build.cbnz(x1, skip); + build.mov(x1, 0); // doesn't execute due to cbnz above + build.setLabel(skip); + + uint8_t one = 1; + build.adr(x2, &one, 1); + build.ldrb(w2, x2); + build.sub(x1, x1, x2); + + build.add(x1, x1, 2); + build.add(x0, x0, x1, /* LSL */ 1); + build.ret(); + + build.finalize(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), reinterpret_cast(build.code.data()), build.code.size() * 4, nativeData, + sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, int*); + FunctionType* f = (FunctionType*)nativeEntry; + int input = 10; + int64_t result = f(20, &input); + CHECK(result == 42); +} + +#endif + +TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 54a31a68..d2cf0ae8 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -10,17 +10,20 @@ #include #include -LUAU_FASTFLAG(LuauPreloadClosures) -LUAU_FASTFLAG(LuauPreloadClosuresFenv) -LUAU_FASTFLAG(LuauPreloadClosuresUpval) +namespace Luau +{ +std::string rep(const std::string& s, size_t n); +} using namespace Luau; -static std::string compileFunction(const char* source, uint32_t id) +static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1) { Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); - Luau::compileOrThrow(bcb, source); + Luau::CompileOptions options; + options.optimizationLevel = optimizationLevel; + Luau::compileOrThrow(bcb, source, options); return bcb.dumpFunction(id); } @@ -59,6 +62,29 @@ LOADN R0 5 LOADK R1 K0 RETURN R0 2 )"); + + CHECK_EQ("\n" + bcb.dumpEverything(), R"( +Function 0 (??): +LOADN R0 5 +LOADK R1 K0 +RETURN R0 2 + +)"); +} + +TEST_CASE("CompileError") +{ + std::string source = "local " + rep("a,", 300) + "a = ..."; + + // fails to parse + std::string bc1 = Luau::compile(source + " !#*$!#$^&!*#&$^*"); + + // parses, but fails to compile (too many locals) + std::string bc2 = Luau::compile(source); + + // 0 acts as a special marker for error bytecode + CHECK_EQ(bc1[0], 0); + CHECK_EQ(bc2[0], 0); } TEST_CASE("LocalsDirectReference") @@ -75,20 +101,10 @@ TEST_CASE("BasicFunction") bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); Luau::compileOrThrow(bcb, "local function foo(a, b) return b end"); - if (FFlag::LuauPreloadClosures) - { - CHECK_EQ("\n" + bcb.dumpFunction(1), R"( + CHECK_EQ("\n" + bcb.dumpFunction(1), R"( DUPCLOSURE R0 K0 RETURN R0 0 )"); - } - else - { - CHECK_EQ("\n" + bcb.dumpFunction(1), R"( -NEWCLOSURE R0 P0 -RETURN R0 0 -)"); - } CHECK_EQ("\n" + bcb.dumpFunction(0), R"( RETURN R1 1 @@ -169,11 +185,11 @@ TEST_CASE("ImportCall") { CHECK_EQ("\n" + compileFunction0("return math.max(1, 2)"), R"( LOADN R1 1 -FASTCALL2K 18 R1 K0 +4 +FASTCALL2K 18 R1 K0 L0 LOADK R2 K0 GETIMPORT R0 3 CALL R0 2 -1 -RETURN R0 -1 +L0: RETURN R0 -1 )"); } @@ -253,16 +269,16 @@ RETURN R0 1 TEST_CASE("RepeatLocals") { CHECK_EQ("\n" + compileFunction0("repeat local a a = 5 until a - 4 < 0 or a - 4 >= 0"), R"( -LOADNIL R0 +L0: LOADNIL R0 LOADN R0 5 SUBK R1 R0 K0 LOADN R2 0 -JUMPIFLT R1 R2 +6 +JUMPIFLT R1 R2 L1 SUBK R1 R0 K0 LOADN R2 0 -JUMPIFLE R2 R1 +2 -JUMPBACK -11 -RETURN R0 0 +JUMPIFLE R2 R1 L1 +JUMPBACK L0 +L1: RETURN R0 0 )"); } @@ -273,12 +289,12 @@ TEST_CASE("ForBytecode") LOADN R2 1 LOADN R0 5 LOADN R1 1 -FORNPREP R0 +5 -GETIMPORT R3 1 +FORNPREP R0 L1 +L0: GETIMPORT R3 1 MOVE R4 R2 CALL R3 1 0 -FORNLOOP R0 -5 -RETURN R0 0 +FORNLOOP R0 L0 +L1: RETURN R0 0 )"); // when you assign the variable internally, we freak out and copy the variable so that you aren't changing the loop behavior @@ -286,14 +302,14 @@ RETURN R0 0 LOADN R2 1 LOADN R0 5 LOADN R1 1 -FORNPREP R0 +7 -MOVE R3 R2 +FORNPREP R0 L1 +L0: MOVE R3 R2 LOADN R3 7 GETIMPORT R4 1 MOVE R5 R3 CALL R4 1 0 -FORNLOOP R0 -7 -RETURN R0 0 +FORNLOOP R0 L0 +L1: RETURN R0 0 )"); // basic for-in loop, generic version @@ -302,11 +318,11 @@ GETIMPORT R0 2 LOADK R1 K3 LOADK R2 K4 CALL R0 2 3 -JUMP +4 -GETIMPORT R5 6 +FORGPREP R0 L1 +L0: GETIMPORT R5 6 MOVE R6 R3 CALL R5 1 0 -FORGLOOP R0 -5 1 +L1: FORGLOOP R0 L0 1 RETURN R0 0 )"); @@ -315,12 +331,12 @@ RETURN R0 0 GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 -FORGPREP_INEXT R0 +5 -GETIMPORT R5 3 +FORGPREP_INEXT R0 L1 +L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -FORGLOOP_INEXT R0 -6 +L1: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -329,12 +345,12 @@ RETURN R0 0 GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 -FORGPREP_NEXT R0 +5 -GETIMPORT R5 3 +FORGPREP_NEXT R0 L1 +L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -FORGLOOP_NEXT R0 -6 +L1: FORGLOOP R0 L0 2 RETURN R0 0 )"); @@ -342,12 +358,12 @@ RETURN R0 0 GETIMPORT R0 1 NEWTABLE R1 0 0 LOADNIL R2 -FORGPREP_NEXT R0 +5 -GETIMPORT R5 3 +FORGPREP_NEXT R0 L1 +L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -FORGLOOP_NEXT R0 -6 +L1: FORGLOOP R0 L0 2 RETURN R0 0 )"); } @@ -359,8 +375,8 @@ TEST_CASE("ForBytecodeBuiltin") GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 -FORGPREP_INEXT R0 +0 -FORGLOOP_INEXT R0 -1 +FORGPREP_INEXT R0 L0 +L0: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -370,8 +386,8 @@ GETIMPORT R0 1 MOVE R1 R0 NEWTABLE R2 0 0 CALL R1 1 3 -FORGPREP_INEXT R1 +0 -FORGLOOP_INEXT R1 -1 +FORGPREP_INEXT R1 L0 +L0: FORGLOOP R1 L0 2 [inext] RETURN R0 0 )"); @@ -380,8 +396,8 @@ RETURN R0 0 GETUPVAL R0 0 NEWTABLE R1 0 0 CALL R0 1 3 -FORGPREP_INEXT R0 +0 -FORGLOOP_INEXT R0 -1 +FORGPREP_INEXT R0 L0 +L0: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -392,8 +408,8 @@ GETIMPORT R0 3 MOVE R1 R0 NEWTABLE R2 0 0 CALL R1 1 3 -JUMP +0 -FORGLOOP R1 -1 2 +FORGPREP R1 L0 +L0: FORGLOOP R1 L0 2 RETURN R0 0 )"); @@ -404,8 +420,8 @@ SETGLOBAL R0 K2 GETGLOBAL R0 K2 NEWTABLE R1 0 0 CALL R0 1 3 -JUMP +0 -FORGLOOP R0 -1 2 +FORGPREP R0 L0 +L0: FORGLOOP R0 L0 2 RETURN R0 0 )"); @@ -414,8 +430,8 @@ RETURN R0 0 GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 -JUMP +0 -FORGLOOP R0 -1 2 +FORGPREP R0 L0 +L0: FORGLOOP R0 L0 2 RETURN R0 0 )"); } @@ -617,9 +633,40 @@ RETURN R0 1 )"); } -TEST_CASE("EmptyTableHashSizePredictionOptimization") +TEST_CASE("TableLiteralsIndexConstant") { - const char* hashSizeSource = R"( + // validate that we use SETTTABLEKS for constant variable keys + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = "key", "value" + return {[a] = 42, [b] = 0} +)"), + R"( +NEWTABLE R0 2 0 +LOADN R1 42 +SETTABLEKS R1 R0 K0 +LOADN R1 0 +SETTABLEKS R1 R0 K1 +RETURN R0 1 +)"); + + // validate that we use SETTABLEN for constant variable keys *and* that we predict array size + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = 1, 2 + return {[a] = 42, [b] = 0} +)"), + R"( +NEWTABLE R0 0 2 +LOADN R1 42 +SETTABLEN R1 R0 1 +LOADN R1 0 +SETTABLEN R1 R0 2 +RETURN R0 1 +)"); +} + +TEST_CASE("TableSizePredictionBasic") +{ + CHECK_EQ("\n" + compileFunction0(R"( local t = {} t.a = 1 t.b = 1 @@ -630,36 +677,8 @@ t.f = 1 t.g = 1 t.h = 1 t.i = 1 -)"; - - const char* hashSizeSource2 = R"( -local t = {} -t.x = 1 -t.x = 2 -t.x = 3 -t.x = 4 -t.x = 5 -t.x = 6 -t.x = 7 -t.x = 8 -t.x = 9 -)"; - - const char* arraySizeSource = R"( -local t = {} -t[1] = 1 -t[2] = 1 -t[3] = 1 -t[4] = 1 -t[5] = 1 -t[6] = 1 -t[7] = 1 -t[8] = 1 -t[9] = 1 -t[10] = 1 -)"; - - CHECK_EQ("\n" + compileFunction0(hashSizeSource), R"( +)"), + R"( NEWTABLE R0 16 0 LOADN R1 1 SETTABLEKS R1 R0 K0 @@ -682,7 +701,19 @@ SETTABLEKS R1 R0 K8 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction0(hashSizeSource2), R"( + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +t.x = 1 +t.x = 2 +t.x = 3 +t.x = 4 +t.x = 5 +t.x = 6 +t.x = 7 +t.x = 8 +t.x = 9 +)"), + R"( NEWTABLE R0 1 0 LOADN R1 1 SETTABLEKS R1 R0 K0 @@ -705,7 +736,20 @@ SETTABLEKS R1 R0 K0 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction0(arraySizeSource), R"( + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +t[1] = 1 +t[2] = 1 +t[3] = 1 +t[4] = 1 +t[5] = 1 +t[6] = 1 +t[7] = 1 +t[8] = 1 +t[9] = 1 +t[10] = 1 +)"), + R"( NEWTABLE R0 0 10 LOADN R1 1 SETTABLEN R1 R0 1 @@ -731,6 +775,27 @@ RETURN R0 0 )"); } +TEST_CASE("TableSizePredictionObject") +{ + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +t.field = 1 +function t:getfield() + return self.field +end +return t +)", + 1), + R"( +NEWTABLE R0 2 0 +LOADN R1 1 +SETTABLEKS R1 R0 K0 +DUPCLOSURE R1 K1 +SETTABLEKS R1 R0 K2 +RETURN R0 1 +)"); +} + TEST_CASE("TableSizePredictionSetMetatable") { CHECK_EQ("\n" + compileFunction0(R"( @@ -740,18 +805,41 @@ t.field2 = 2 return t )"), R"( -GETIMPORT R0 1 NEWTABLE R1 2 0 -LOADNIL R2 +FASTCALL2K 61 R1 K0 L0 +LOADK R2 K0 +GETIMPORT R0 2 CALL R0 2 1 -LOADN R1 1 -SETTABLEKS R1 R0 K2 -LOADN R1 2 +L0: LOADN R1 1 SETTABLEKS R1 R0 K3 +LOADN R1 2 +SETTABLEKS R1 R0 K4 RETURN R0 1 )"); } +TEST_CASE("TableSizePredictionLoop") +{ + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +for i=1,4 do + t[i] = 0 +end +return t +)"), + R"( +NEWTABLE R0 0 4 +LOADN R3 1 +LOADN R1 4 +LOADN R2 1 +FORNPREP R1 L1 +L0: LOADN R4 0 +SETTABLE R4 R0 R3 +FORNLOOP R1 L0 +L1: RETURN R0 1 +)"); +} + TEST_CASE("ReflectionEnums") { CHECK_EQ("\n" + compileFunction0("return Enum.EasingStyle.Linear"), R"( @@ -768,11 +856,11 @@ TEST_CASE("CaptureSelf") local MaterialsListClass = {} function MaterialsListClass:_MakeToolTip(guiElement, text) - local function updateTooltipPosition() - self._tweakingTooltipFrame = 5 - end + local function updateTooltipPosition() + self._tweakingTooltipFrame = 5 + end - updateTooltipPosition() + updateTooltipPosition() end return MaterialsListClass @@ -798,18 +886,18 @@ TEST_CASE("ConditionalBasic") { CHECK_EQ("\n" + compileFunction0("local a = ... if a then return 5 end"), R"( GETVARARGS R0 1 -JUMPIFNOT R0 +2 +JUMPIFNOT R0 L0 LOADN R1 5 RETURN R1 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = ... if not a then return 5 end"), R"( GETVARARGS R0 1 -JUMPIF R0 +2 +JUMPIF R0 L0 LOADN R1 5 RETURN R1 1 -RETURN R0 0 +L0: RETURN R0 0 )"); } @@ -817,50 +905,50 @@ TEST_CASE("ConditionalCompare") { CHECK_EQ("\n" + compileFunction0("local a, b = ... if a < b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTLT R0 R1 +3 +JUMPIFNOTLT R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if a <= b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTLE R0 R1 +3 +JUMPIFNOTLE R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if a > b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTLT R1 R0 +3 +JUMPIFNOTLT R1 R0 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if a >= b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTLE R1 R0 +3 +JUMPIFNOTLE R1 R0 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if a == b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTEQ R0 R1 +3 +JUMPIFNOTEQ R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if a ~= b then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFEQ R0 R1 +3 +JUMPIFEQ R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); } @@ -868,18 +956,18 @@ TEST_CASE("ConditionalNot") { CHECK_EQ("\n" + compileFunction0("local a, b = ... if not (not (a < b)) then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFNOTLT R0 R1 +3 +JUMPIFNOTLT R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b = ... if not (not (not (a < b))) then return 5 end"), R"( GETVARARGS R0 2 -JUMPIFLT R0 R1 +3 +JUMPIFLT R0 R1 L0 LOADN R2 5 RETURN R2 1 -RETURN R0 0 +L0: RETURN R0 0 )"); } @@ -887,51 +975,51 @@ TEST_CASE("ConditionalAndOr") { CHECK_EQ("\n" + compileFunction0("local a, b, c = ... if a < b and b < c then return 5 end"), R"( GETVARARGS R0 3 -JUMPIFNOTLT R0 R1 +5 -JUMPIFNOTLT R1 R2 +3 +JUMPIFNOTLT R0 R1 L0 +JUMPIFNOTLT R1 R2 L0 LOADN R3 5 RETURN R3 1 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a, b, c = ... if a < b or b < c then return 5 end"), R"( GETVARARGS R0 3 -JUMPIFLT R0 R1 +3 -JUMPIFNOTLT R1 R2 +3 -LOADN R3 5 +JUMPIFLT R0 R1 L0 +JUMPIFNOTLT R1 R2 L1 +L0: LOADN R3 5 RETURN R3 1 -RETURN R0 0 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a,b,c,d = ... if (a or b) and not (c and d) then return 5 end"), R"( GETVARARGS R0 4 -JUMPIF R0 +1 -JUMPIFNOT R1 +4 -JUMPIFNOT R2 +1 -JUMPIF R3 +2 -LOADN R4 5 +JUMPIF R0 L0 +JUMPIFNOT R1 L2 +L0: JUMPIFNOT R2 L1 +JUMPIF R3 L2 +L1: LOADN R4 5 RETURN R4 1 -RETURN R0 0 +L2: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a,b,c = ... if a or not b or c then return 5 end"), R"( GETVARARGS R0 3 -JUMPIF R0 +2 -JUMPIFNOT R1 +1 -JUMPIFNOT R2 +2 -LOADN R3 5 +JUMPIF R0 L0 +JUMPIFNOT R1 L0 +JUMPIFNOT R2 L1 +L0: LOADN R3 5 RETURN R3 1 -RETURN R0 0 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a,b,c = ... if a and not b and c then return 5 end"), R"( GETVARARGS R0 3 -JUMPIFNOT R0 +4 -JUMPIF R1 +3 -JUMPIFNOT R2 +2 +JUMPIFNOT R0 L0 +JUMPIF R1 L0 +JUMPIFNOT R2 L0 LOADN R3 5 RETURN R3 1 -RETURN R0 0 +L0: RETURN R0 0 )"); } @@ -956,9 +1044,9 @@ LOADN R0 1 LOADN R1 2 SETGLOBAL R1 K0 MOVE R1 R0 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 GETGLOBAL R1 K0 -MOVE R0 R1 +L0: MOVE R0 R1 RETURN R0 1 )"); @@ -981,9 +1069,9 @@ LOADN R0 1 LOADN R1 2 SETGLOBAL R1 K0 MOVE R1 R0 -JUMPIF R1 +2 +JUMPIF R1 L0 GETGLOBAL R1 K0 -MOVE R0 R1 +L0: MOVE R0 R1 RETURN R0 1 )"); @@ -995,9 +1083,9 @@ MOVE R0 R0 LOADN R1 2 SETGLOBAL R1 K0 MOVE R1 R0 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 GETGLOBAL R1 K0 -RETURN R1 1 +L0: RETURN R1 1 )"); CHECK_EQ("\n" + compileFunction0("local a = 1 a = a b = 2 local c = a or b return c"), R"( @@ -1006,9 +1094,42 @@ MOVE R0 R0 LOADN R1 2 SETGLOBAL R1 K0 MOVE R1 R0 -JUMPIF R1 +2 +JUMPIF R1 L0 GETGLOBAL R1 K0 -RETURN R1 1 +L0: RETURN R1 1 +)"); +} + +TEST_CASE("AndOrFoldLeft") +{ + // constant folding and/or expression is possible even if just the left hand is constant + CHECK_EQ("\n" + compileFunction0("local a = false if a and b then b() end"), R"( +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = true if a or b then b() end"), R"( +GETIMPORT R0 1 +CALL R0 0 0 +RETURN R0 0 +)"); + + // however, if right hand side is constant we can't constant fold the entire expression + // (note that we don't need to evaluate the right hand side, but we do need a branch) + CHECK_EQ("\n" + compileFunction0("local a = false if b and a then b() end"), R"( +GETIMPORT R0 1 +JUMPIFNOT R0 L0 +RETURN R0 0 +GETIMPORT R0 1 +CALL R0 0 0 +L0: RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = true if b or a then b() end"), R"( +GETIMPORT R0 1 +JUMPIF R0 L0 +L0: GETIMPORT R0 1 +CALL R0 0 0 +RETURN R0 0 )"); } @@ -1027,27 +1148,24 @@ GETIMPORT R3 1 SUB R1 R2 R3 GETIMPORT R3 4 ADDK R2 R3 K2 -JUMPIFNOTLT R1 R2 +4 +JUMPIFNOTLT R1 R2 L0 GETIMPORT R0 8 -JUMPIF R0 +15 -GETIMPORT R1 10 +JUMPIF R0 L2 +L0: GETIMPORT R1 10 LOADN R2 0 -JUMPIFNOTLT R2 R1 +9 +JUMPIFNOTLT R2 R1 L1 GETIMPORT R1 10 LOADN R2 1 -JUMPIFNOTLT R1 R2 +4 +JUMPIFNOTLT R1 R2 L1 GETIMPORT R0 8 -JUMPIF R0 +2 -GETIMPORT R0 12 -RETURN R0 1 +JUMPIF R0 L2 +L1: GETIMPORT R0 12 +L2: RETURN R0 1 )"); } TEST_CASE("IfElseExpression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - // codegen for a true constant condition CHECK_EQ("\n" + compileFunction0("return if true then 10 else 20"), R"( LOADN R0 10 @@ -1058,6 +1176,18 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0("return if false then 10 else 20"), R"( LOADN R0 20 RETURN R0 1 +)"); + + // codegen for a true constant condition with non-constant expressions + CHECK_EQ("\n" + compileFunction0("return if true then {} else error()"), R"( +NEWTABLE R0 0 0 +RETURN R0 1 +)"); + + // codegen for a false constant condition with non-constant expressions + CHECK_EQ("\n" + compileFunction0("return if false then error() else {}"), R"( +NEWTABLE R0 0 0 +RETURN R0 1 )"); // codegen for a false (in this case 'nil') constant condition @@ -1076,50 +1206,119 @@ RETURN R0 1 // codegen for a non-constant condition CHECK_EQ("\n" + compileFunction0("return if condition then 10 else 20"), R"( GETIMPORT R1 1 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 LOADN R0 10 RETURN R0 1 -LOADN R0 20 +L0: LOADN R0 20 RETURN R0 1 )"); // codegen for a non-constant condition using an assignment CHECK_EQ("\n" + compileFunction0("result = if condition then 10 else 20"), R"( GETIMPORT R1 1 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 LOADN R0 10 -JUMP +1 -LOADN R0 20 -SETGLOBAL R0 K2 +JUMP L1 +L0: LOADN R0 20 +L1: SETGLOBAL R0 K2 RETURN R0 0 )"); // codegen for a non-constant condition using an assignment to a local variable CHECK_EQ("\n" + compileFunction0("local result = if condition then 10 else 20"), R"( GETIMPORT R1 1 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 LOADN R0 10 RETURN R0 0 -LOADN R0 20 +L0: LOADN R0 20 RETURN R0 0 )"); // codegen for an if-else expression with multiple elseif's CHECK_EQ("\n" + compileFunction0("result = if condition1 then 10 elseif condition2 then 20 elseif condition3 then 30 else 40"), R"( GETIMPORT R1 1 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 LOADN R0 10 -JUMP +11 -GETIMPORT R1 3 -JUMPIFNOT R1 +2 +JUMP L3 +L0: GETIMPORT R1 3 +JUMPIFNOT R1 L1 LOADN R0 20 -JUMP +6 -GETIMPORT R1 5 -JUMPIFNOT R1 +2 +JUMP L3 +L1: GETIMPORT R1 5 +JUMPIFNOT R1 L2 LOADN R0 30 -JUMP +1 -LOADN R0 40 -SETGLOBAL R0 K6 +JUMP L3 +L2: LOADN R0 40 +L3: SETGLOBAL R0 K6 +RETURN R0 0 +)"); +} + +TEST_CASE("UnaryBasic") +{ + CHECK_EQ("\n" + compileFunction0("local a = ... return not a"), R"( +GETVARARGS R0 1 +NOT R1 R0 +RETURN R1 1 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = ... return -a"), R"( +GETVARARGS R0 1 +MINUS R1 R0 +RETURN R1 1 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = ... return #a"), R"( +GETVARARGS R0 1 +LENGTH R1 R0 +RETURN R1 1 +)"); +} + +TEST_CASE("InterpStringWithNoExpressions") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CHECK_EQ(compileFunction0(R"(return "hello")"), compileFunction0("return `hello`")); +} + +TEST_CASE("InterpStringZeroCost") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CHECK_EQ("\n" + compileFunction0(R"(local _ = `hello, {"world"}!`)"), + R"( +LOADK R1 K0 +LOADK R3 K1 +NAMECALL R1 R1 K2 +CALL R1 2 1 +MOVE R0 R1 +RETURN R0 0 +)"); +} + +TEST_CASE("InterpStringRegisterCleanup") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CHECK_EQ("\n" + compileFunction0(R"( + local a, b, c = nil, "um", "uh oh" + a = `foo{"bar"}` + print(a) + )"), + + R"( +LOADNIL R0 +LOADK R1 K0 +LOADK R2 K1 +LOADK R3 K2 +LOADK R5 K3 +NAMECALL R3 R3 K4 +CALL R3 2 1 +MOVE R0 R3 +GETIMPORT R3 6 +MOVE R4 R0 +CALL R3 1 0 RETURN R0 0 )"); } @@ -1168,6 +1367,17 @@ RETURN R0 1 )"); } +TEST_CASE("ConstantFoldStringLen") +{ + CHECK_EQ("\n" + compileFunction0("return #'string', #'', #'a', #('b')"), R"( +LOADN R0 6 +LOADN R1 0 +LOADN R2 1 +LOADN R3 1 +RETURN R0 4 +)"); +} + TEST_CASE("ConstantFoldCompare") { // ordered comparisons @@ -1346,16 +1556,16 @@ RETURN R0 1 // constant fold parts in chains of and/or statements CHECK_EQ("\n" + compileFunction0("return a and true and b"), R"( GETIMPORT R0 1 -JUMPIFNOT R0 +2 +JUMPIFNOT R0 L0 GETIMPORT R0 3 -RETURN R0 1 +L0: RETURN R0 1 )"); CHECK_EQ("\n" + compileFunction0("return a or false or b"), R"( GETIMPORT R0 1 -JUMPIF R0 +2 +JUMPIF R0 L0 GETIMPORT R0 3 -RETURN R0 1 +L0: RETURN R0 1 )"); } @@ -1363,38 +1573,38 @@ TEST_CASE("ConstantFoldConditionalAndOr") { CHECK_EQ("\n" + compileFunction0("local a = ... if false or a then print(1) end"), R"( GETVARARGS R0 1 -JUMPIFNOT R0 +4 +JUMPIFNOT R0 L0 GETIMPORT R1 1 LOADN R2 1 CALL R1 1 0 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = ... if not (false or a) then print(1) end"), R"( GETVARARGS R0 1 -JUMPIF R0 +4 +JUMPIF R0 L0 GETIMPORT R1 1 LOADN R2 1 CALL R1 1 0 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = ... if true and a then print(1) end"), R"( GETVARARGS R0 1 -JUMPIFNOT R0 +4 +JUMPIFNOT R0 L0 GETIMPORT R1 1 LOADN R2 1 CALL R1 1 0 -RETURN R0 0 +L0: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = ... if not (true and a) then print(1) end"), R"( GETVARARGS R0 1 -JUMPIF R0 +4 +JUMPIF R0 L0 GETIMPORT R1 1 LOADN R2 1 CALL R1 1 0 -RETURN R0 0 +L0: RETURN R0 0 )"); } @@ -1428,10 +1638,10 @@ RETURN R0 0 // while CHECK_EQ("\n" + compileFunction0("while true do print(1) end"), R"( -GETIMPORT R0 1 +L0: GETIMPORT R0 1 LOADN R1 1 CALL R0 1 0 -JUMPBACK -5 +JUMPBACK L0 RETURN R0 0 )"); @@ -1448,21 +1658,21 @@ RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("repeat print(1) until false"), R"( -GETIMPORT R0 1 +L0: GETIMPORT R0 1 LOADN R1 1 CALL R0 1 0 -JUMPBACK -5 +JUMPBACK L0 RETURN R0 0 )"); // there's an odd case in repeat..until compilation where we evaluate the expression that is always false for side-effects of the left hand side CHECK_EQ("\n" + compileFunction0("repeat print(1) until five and false"), R"( -GETIMPORT R0 1 +L0: GETIMPORT R0 1 LOADN R1 1 CALL R0 1 0 GETIMPORT R0 3 -JUMPIFNOT R0 +0 -JUMPBACK -8 +JUMPIFNOT R0 L1 +L1: JUMPBACK L0 RETURN R0 0 )"); } @@ -1471,24 +1681,24 @@ TEST_CASE("LoopBreak") { // default codegen: compile breaks as unconditional jumps CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break else end end"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFNOTLT R0 R1 +3 +JUMPIFNOTLT R0 R1 L1 RETURN R0 0 -JUMP +0 -JUMPBACK -9 +JUMP L1 +L1: JUMPBACK L0 RETURN R0 0 )"); // optimization: if then body is a break statement, flip the branches CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break end end"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFLT R0 R1 +2 -JUMPBACK -7 -RETURN R0 0 +JUMPIFLT R0 R1 L1 +JUMPBACK L0 +L1: RETURN R0 0 )"); } @@ -1496,28 +1706,28 @@ TEST_CASE("LoopContinue") { // default codegen: compile continue as unconditional jumps CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue else end break until false error()"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFNOTLT R0 R1 +5 -JUMP +2 -JUMP +2 -JUMP +1 -JUMPBACK -10 -GETIMPORT R0 5 +JUMPIFNOTLT R0 R1 L2 +JUMP L1 +JUMP L2 +JUMP L2 +L1: JUMPBACK L0 +L2: GETIMPORT R0 5 CALL R0 0 0 RETURN R0 0 )"); // optimization: if then body is a continue statement, flip the branches CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue end break until false error()"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFLT R0 R1 +2 -JUMP +1 -JUMPBACK -8 -GETIMPORT R0 5 +JUMPIFLT R0 R1 L1 +JUMP L2 +L1: JUMPBACK L0 +L2: GETIMPORT R0 5 CALL R0 0 0 RETURN R0 0 )"); @@ -1527,15 +1737,15 @@ TEST_CASE("LoopContinueUntil") { // it's valid to use locals defined inside the loop in until expression if they're defined before continue CHECK_EQ("\n" + compileFunction0("repeat local r = math.random() if r > 0.5 then continue end r = r + 0.3 until r < 0.5"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFLT R1 R0 +2 +JUMPIFLT R1 R0 L1 ADDK R0 R0 K4 -LOADK R1 K3 -JUMPIFLT R0 R1 +2 -JUMPBACK -11 -RETURN R0 0 +L1: LOADK R1 K3 +JUMPIFLT R0 R1 L2 +JUMPBACK L0 +L2: RETURN R0 0 )"); // it's however invalid to use locals if they are defined after continue @@ -1566,16 +1776,16 @@ until rr < 0.5 CHECK_EQ("\n" + compileFunction0( "repeat local r = math.random() repeat if r > 0.5 then continue end r = r - 0.1 until true r = r + 0.3 until r < 0.5"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFLT R1 R0 +2 +JUMPIFLT R1 R0 L1 SUBK R0 R0 K4 -ADDK R0 R0 K5 +L1: ADDK R0 R0 K5 LOADK R1 K3 -JUMPIFLT R0 R1 +2 -JUMPBACK -12 -RETURN R0 0 +JUMPIFLT R0 R1 L2 +JUMPBACK L0 +L2: RETURN R0 0 )"); // and it's also okay to use a local defined in the until expression as long as it's inside a function! @@ -1583,20 +1793,20 @@ RETURN R0 0 "\n" + compileFunction( "repeat local r = math.random() if r > 0.5 then continue end r = r + 0.3 until (function() local a = r return a < 0.5 end)()", 1), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFNOTLT R1 R0 +3 +JUMPIFNOTLT R1 R0 L1 CLOSEUPVALS R0 -JUMP +1 -ADDK R0 R0 K4 -NEWCLOSURE R1 P0 +JUMP L2 +L1: ADDK R0 R0 K4 +L2: NEWCLOSURE R1 P0 CAPTURE REF R0 CALL R1 0 1 -JUMPIF R1 +2 -CLOSEUPVALS R0 -JUMPBACK -15 +JUMPIF R1 L3 CLOSEUPVALS R0 +JUMPBACK L0 +L3: CLOSEUPVALS R0 RETURN R0 0 )"); @@ -1627,17 +1837,17 @@ until (function() return rr end)() < 0.5 CHECK_EQ("\n" + compileFunction0("local stop = false stop = true function test() repeat local r = math.random() if r > 0.5 then " "continue end r = r + 0.3 until stop or r < 0.5 end"), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFLT R1 R0 +2 +JUMPIFLT R1 R0 L1 ADDK R0 R0 K4 -GETUPVAL R1 0 -JUMPIF R1 +4 +L1: GETUPVAL R1 0 +JUMPIF R1 L2 LOADK R1 K3 -JUMPIFLT R0 R1 +2 -JUMPBACK -13 -RETURN R0 0 +JUMPIFLT R0 R1 L2 +JUMPBACK L0 +L2: RETURN R0 0 )"); // including upvalue references from a function expression @@ -1645,21 +1855,21 @@ RETURN R0 0 "end r = r + 0.3 until (function() return stop or r < 0.5 end)() end", 1), R"( -GETIMPORT R0 2 +L0: GETIMPORT R0 2 CALL R0 0 1 LOADK R1 K3 -JUMPIFNOTLT R1 R0 +3 +JUMPIFNOTLT R1 R0 L1 CLOSEUPVALS R0 -JUMP +1 -ADDK R0 R0 K4 -NEWCLOSURE R1 P0 +JUMP L2 +L1: ADDK R0 R0 K4 +L2: NEWCLOSURE R1 P0 CAPTURE UPVAL U0 CAPTURE REF R0 CALL R1 0 1 -JUMPIF R1 +2 -CLOSEUPVALS R0 -JUMPBACK -16 +JUMPIF R1 L3 CLOSEUPVALS R0 +JUMPBACK L0 +L3: CLOSEUPVALS R0 RETURN R0 0 )"); } @@ -1700,11 +1910,11 @@ ORK R2 R1 K0 SUB R0 R0 R2 LOADN R4 1 LOADN R8 0 -JUMPIFNOTLT R0 R8 +3 +JUMPIFNOTLT R0 R8 L0 MINUS R7 R0 -JUMPIF R7 +1 -MOVE R7 R0 -MULK R6 R7 K1 +JUMPIF R7 L1 +L0: MOVE R7 R0 +L1: MULK R6 R7 K1 LOADN R8 1 SUB R7 R8 R2 DIV R5 R6 R7 @@ -1724,14 +1934,14 @@ LOADB R2 0 LOADK R4 K0 MULK R5 R1 K1 SUB R3 R4 R5 -JUMPIFNOTLT R3 R0 +8 +JUMPIFNOTLT R3 R0 L1 LOADK R4 K0 MULK R5 R1 K1 ADD R3 R4 R5 -JUMPIFLT R0 R3 +2 +JUMPIFLT R0 R3 L0 LOADB R2 0 +1 -LOADB R2 1 -RETURN R2 1 +L0: LOADB R2 1 +L1: RETURN R2 1 )"); // sometimes we need to compute a boolean; this uses LOADB with an offset for the last op, note that first op is compiled better @@ -1746,14 +1956,14 @@ LOADB R2 1 LOADK R4 K0 MULK R5 R1 K1 SUB R3 R4 R5 -JUMPIFLT R0 R3 +8 +JUMPIFLT R0 R3 L1 LOADK R4 K0 MULK R5 R1 K1 ADD R3 R4 R5 -JUMPIFLT R3 R0 +2 +JUMPIFLT R3 R0 L0 LOADB R2 0 +1 -LOADB R2 1 -RETURN R2 1 +L0: LOADB R2 1 +L1: RETURN R2 1 )"); // trivial ternary if with constants @@ -1764,10 +1974,10 @@ end )", 0), R"( -JUMPIFNOT R0 +2 +JUMPIFNOT R0 L0 LOADN R1 1 RETURN R1 1 -LOADN R1 0 +L0: LOADN R1 0 RETURN R1 1 )"); @@ -1780,14 +1990,14 @@ end 0), R"( LOADN R2 0 -JUMPIFNOTLT R0 R2 +3 +JUMPIFNOTLT R0 R2 L0 LOADN R1 0 RETURN R1 1 -LOADN R2 1 -JUMPIFNOTLT R2 R0 +3 +L0: LOADN R2 1 +JUMPIFNOTLT R2 R0 L1 LOADN R1 1 RETURN R1 1 -MOVE R1 R0 +L1: MOVE R1 R0 RETURN R1 1 )"); } @@ -1797,24 +2007,24 @@ TEST_CASE("JumpFold") // jump-to-return folding to return CHECK_EQ("\n" + compileFunction0("return a and 1 or 0"), R"( GETIMPORT R1 1 -JUMPIFNOT R1 +2 +JUMPIFNOT R1 L0 LOADN R0 1 RETURN R0 1 -LOADN R0 0 +L0: LOADN R0 0 RETURN R0 1 )"); // conditional jump in the inner if() folding to jump out of the expression (JUMPIFNOT+5 skips over all jumps, JUMP+1 skips over JUMP+0) CHECK_EQ("\n" + compileFunction0("if a then if b then b() else end else end d()"), R"( GETIMPORT R0 1 -JUMPIFNOT R0 +8 +JUMPIFNOT R0 L0 GETIMPORT R0 3 -JUMPIFNOT R0 +5 +JUMPIFNOT R0 L0 GETIMPORT R0 3 CALL R0 0 0 -JUMP +1 -JUMP +0 -GETIMPORT R0 5 +JUMP L0 +JUMP L0 +L0: GETIMPORT R0 5 CALL R0 0 0 RETURN R0 0 )"); @@ -1822,14 +2032,14 @@ RETURN R0 0 // same as example before but the unconditional jumps are folded with RETURN CHECK_EQ("\n" + compileFunction0("if a then if b then b() else end else end"), R"( GETIMPORT R0 1 -JUMPIFNOT R0 +8 +JUMPIFNOT R0 L0 GETIMPORT R0 3 -JUMPIFNOT R0 +5 +JUMPIFNOT R0 L0 GETIMPORT R0 3 CALL R0 0 0 RETURN R0 0 RETURN R0 0 -RETURN R0 0 +L0: RETURN R0 0 )"); // in this example, we do *not* have a JUMP after RETURN in the if branch @@ -1849,7 +2059,7 @@ end R"( ORK R6 R3 K0 ORK R7 R4 K1 -JUMPIF R5 +19 +JUMPIF R5 L0 GETIMPORT R10 5 DIV R13 R0 R7 MULK R14 R6 K6 @@ -1866,7 +2076,7 @@ CALL R10 3 1 MULK R9 R10 K2 ADDK R8 R9 K2 RETURN R8 1 -GETIMPORT R8 5 +L0: GETIMPORT R8 5 DIV R11 R0 R7 MULK R12 R6 K6 ADD R10 R11 R12 @@ -1883,15 +2093,6 @@ RETURN R8 -1 )"); } -static std::string rep(const std::string& s, size_t n) -{ - std::string r; - r.reserve(s.length() * n); - for (size_t i = 0; i < n; ++i) - r += s; - return r; -} - TEST_CASE("RecursionParse") { // The test forcibly pushes the stack limit during compilation; in NoOpt, the stack consumption is much larger so we need to reduce the limit to @@ -1953,6 +2154,66 @@ TEST_CASE("RecursionParse") { CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); } + + try + { + Luau::compileOrThrow(bcb, rep("a(", 1500) + "42" + rep(")", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "return " + rep("{", 1500) + "42" + rep("}", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, rep("while true do ", 1500) + "print()" + rep(" end", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, rep("for i=1,1 do ", 1500) + "print()" + rep(" end", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, rep("function a() ", 1500) + "print()" + rep(" end", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "return " + rep("function() return ", 1500) + "42" + rep(" end", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); + } } TEST_CASE("ArrayIndexLiteral") @@ -1984,16 +2245,16 @@ RETURN R0 0 TEST_CASE("NestedFunctionCalls") { CHECK_EQ("\n" + compileFunction0("function clamp(t,a,b) return math.min(math.max(t,a),b) end"), R"( -FASTCALL2 18 R0 R1 +5 +FASTCALL2 18 R0 R1 L0 MOVE R5 R0 MOVE R6 R1 GETIMPORT R4 2 CALL R4 2 1 -FASTCALL2 19 R4 R2 +4 +L0: FASTCALL2 19 R4 R2 L1 MOVE R5 R2 GETIMPORT R3 4 CALL R3 2 -1 -RETURN R3 -1 +L1: RETURN R3 -1 )"); } @@ -2001,14 +2262,14 @@ TEST_CASE("UpvaluesLoopsBytecode") { CHECK_EQ("\n" + compileFunction(R"( function test() - for i=1,10 do + for i=1,10 do i = i - foo(function() return i end) - if bar then - break - end - end - return 0 + foo(function() return i end) + if bar then + break + end + end + return 0 end )", 1), @@ -2016,33 +2277,33 @@ end LOADN R2 1 LOADN R0 10 LOADN R1 1 -FORNPREP R0 +14 -MOVE R3 R2 +FORNPREP R0 L2 +L0: MOVE R3 R2 MOVE R3 R3 GETIMPORT R4 1 NEWCLOSURE R5 P0 CAPTURE REF R3 CALL R4 1 0 GETIMPORT R4 3 -JUMPIFNOT R4 +2 +JUMPIFNOT R4 L1 CLOSEUPVALS R3 -JUMP +2 -CLOSEUPVALS R3 -FORNLOOP R0 -14 -LOADN R0 0 +JUMP L2 +L1: CLOSEUPVALS R3 +FORNLOOP R0 L0 +L2: LOADN R0 0 RETURN R0 1 )"); CHECK_EQ("\n" + compileFunction(R"( function test() - for i in ipairs(data) do + for i in ipairs(data) do i = i - foo(function() return i end) - if bar then - break - end - end - return 0 + foo(function() return i end) + if bar then + break + end + end + return 0 end )", 1), @@ -2050,42 +2311,42 @@ end GETIMPORT R0 1 GETIMPORT R1 3 CALL R0 1 3 -FORGPREP_INEXT R0 +12 -MOVE R3 R3 +FORGPREP_INEXT R0 L2 +L0: MOVE R3 R3 GETIMPORT R5 5 NEWCLOSURE R6 P0 CAPTURE REF R3 CALL R5 1 0 GETIMPORT R5 7 -JUMPIFNOT R5 +2 +JUMPIFNOT R5 L1 CLOSEUPVALS R3 -JUMP +2 -CLOSEUPVALS R3 -FORGLOOP_INEXT R0 -13 -LOADN R0 0 +JUMP L3 +L1: CLOSEUPVALS R3 +L2: FORGLOOP R0 L0 1 [inext] +L3: LOADN R0 0 RETURN R0 1 )"); CHECK_EQ("\n" + compileFunction(R"( function test() - local i = 0 - while i < 5 do - local j + local i = 0 + while i < 5 do + local j j = i - foo(function() return j end) - i = i + 1 - if bar then - break - end - end - return 0 + foo(function() return j end) + i = i + 1 + if bar then + break + end + end + return 0 end )", 1), R"( LOADN R0 0 -LOADN R1 5 -JUMPIFNOTLT R0 R1 +16 +L0: LOADN R1 5 +JUMPIFNOTLT R0 R1 L2 LOADNIL R1 MOVE R1 R0 GETIMPORT R2 1 @@ -2094,34 +2355,34 @@ CAPTURE REF R1 CALL R2 1 0 ADDK R0 R0 K2 GETIMPORT R2 4 -JUMPIFNOT R2 +2 +JUMPIFNOT R2 L1 CLOSEUPVALS R1 -JUMP +2 -CLOSEUPVALS R1 -JUMPBACK -18 -LOADN R1 0 +JUMP L2 +L1: CLOSEUPVALS R1 +JUMPBACK L0 +L2: LOADN R1 0 RETURN R1 1 )"); CHECK_EQ("\n" + compileFunction(R"( function test() - local i = 0 - repeat - local j + local i = 0 + repeat + local j j = i - foo(function() return j end) - i = i + 1 - if bar then - break - end - until i < 5 - return 0 + foo(function() return j end) + i = i + 1 + if bar then + break + end + until i < 5 + return 0 end )", 1), R"( LOADN R0 0 -LOADNIL R1 +L0: LOADNIL R1 MOVE R1 R0 GETIMPORT R2 1 NEWCLOSURE R3 P0 @@ -2129,15 +2390,15 @@ CAPTURE REF R1 CALL R2 1 0 ADDK R0 R0 K2 GETIMPORT R2 4 -JUMPIFNOT R2 +2 +JUMPIFNOT R2 L1 CLOSEUPVALS R1 -JUMP +6 -LOADN R2 5 -JUMPIFLT R0 R2 +3 +JUMP L3 +L1: LOADN R2 5 +JUMPIFLT R0 R2 L2 CLOSEUPVALS R1 -JUMPBACK -18 -CLOSEUPVALS R1 -LOADN R1 0 +JUMPBACK L0 +L2: CLOSEUPVALS R1 +L3: LOADN R1 0 RETURN R1 1 )"); } @@ -2197,11 +2458,11 @@ return result 14: GETIMPORT R2 11 14: MOVE R3 R0 14: CALL R2 1 3 -14: FORGPREP_NEXT R2 +3 -15: MOVE R7 R1 +14: FORGPREP_NEXT R2 L1 +15: L0: MOVE R7 R1 15: MOVE R8 R5 15: CONCAT R1 R7 R8 -14: FORGLOOP_NEXT R2 -4 +14: L1: FORGLOOP R2 L0 1 17: RETURN R1 1 )"); } @@ -2228,11 +2489,11 @@ end 5: LOADN R0 1 7: LOADN R1 2 9: LOADN R2 3 -9: JUMP +4 -11: GETIMPORT R5 1 +9: FORGPREP R0 L1 +11: L0: GETIMPORT R5 1 11: MOVE R6 R3 11: CALL R5 1 0 -2: FORGLOOP R0 -5 1 +2: L1: FORGLOOP R0 L0 1 13: RETURN R0 0 )"); } @@ -2254,14 +2515,14 @@ end CHECK_EQ("\n" + bcb.dumpFunction(0), R"( 2: LOADN R0 0 -4: ADDK R0 R0 K0 +4: L0: ADDK R0 R0 K0 5: LOADN R1 1 -5: JUMPIFNOTLT R1 R0 +6 +5: JUMPIFNOTLT R1 R0 L1 6: GETIMPORT R1 2 6: LOADK R2 K3 6: CALL R1 1 0 10: RETURN R0 0 -3: JUMPBACK -10 +3: L1: JUMPBACK L0 10: RETURN R0 0 )"); } @@ -2272,7 +2533,7 @@ TEST_CASE("DebugLineInfoRepeatUntil") local f = 0 repeat f += 1 - if f == 1 then + if f == 1 then print(f) else f = 0 @@ -2282,16 +2543,16 @@ until f == 0 0), R"( 2: LOADN R0 0 -4: ADDK R0 R0 K0 -5: JUMPIFNOTEQK R0 K0 +6 +4: L0: ADDK R0 R0 K0 +5: JUMPXEQKN R0 K0 L1 NOT 6: GETIMPORT R1 2 6: MOVE R2 R0 6: CALL R1 1 0 -6: JUMP +1 -8: LOADN R0 0 -10: JUMPIFEQK R0 K3 +2 -10: JUMPBACK -12 -11: RETURN R0 0 +6: JUMP L2 +8: L1: LOADN R0 0 +10: L2: JUMPXEQKN R0 K3 L3 +10: JUMPBACK L0 +11: L3: RETURN R0 0 )"); } @@ -2304,10 +2565,10 @@ local Value1, Value2, Value3 = ... local Table = {} Table.SubTable["Key"] = { - Key1 = Value1, - Key2 = Value2, - Key3 = Value3, - Key4 = true, + Key1 = Value1, + Key2 = Value2, + Key3 = Value3, + Key4 = true, } )"); @@ -2350,6 +2611,87 @@ Foo:Bar( )"); } +TEST_CASE("DebugLineInfoCallChain") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local Foo = ... + +Foo +:Bar(1) +:Baz(2) +.Qux(3) +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: GETVARARGS R0 1 +5: LOADN R4 1 +5: NAMECALL R2 R0 K0 +5: CALL R2 2 1 +6: LOADN R4 2 +6: NAMECALL R2 R2 K1 +6: CALL R2 2 1 +7: GETTABLEKS R1 R2 K2 +7: LOADN R2 3 +7: CALL R1 1 0 +8: RETURN R0 0 +)"); +} + +TEST_CASE("DebugLineInfoFastCall") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local Foo, Bar = ... + +return + math.max( + Foo, + Bar) +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: GETVARARGS R0 2 +5: FASTCALL2 18 R0 R1 L0 +5: MOVE R3 R0 +5: MOVE R4 R1 +5: GETIMPORT R2 2 +5: CALL R2 2 -1 +5: L0: RETURN R2 -1 +)"); +} + +TEST_CASE("DebugLineInfoAssignment") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( + local a = { b = { c = { d = 3 } } } + +a +["b"] +["c"] +["d"] = 4 +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: DUPTABLE R0 1 +2: DUPTABLE R1 3 +2: DUPTABLE R2 5 +2: LOADN R3 3 +2: SETTABLEKS R3 R2 K4 +2: SETTABLEKS R2 R1 K2 +2: SETTABLEKS R1 R0 K0 +5: GETTABLEKS R2 R0 K0 +6: GETTABLEKS R1 R2 K2 +7: LOADN R2 4 +7: SETTABLEKS R2 R1 K4 +8: RETURN R0 0 +)"); +} + TEST_CASE("DebugSource") { const char* source = R"( @@ -2413,13 +2755,13 @@ LOADK R1 K9 GETIMPORT R2 11 MOVE R3 R0 CALL R2 1 3 -FORGPREP_NEXT R2 +3 +FORGPREP_NEXT R2 L1 15: result = result .. k -MOVE R7 R1 +L0: MOVE R7 R1 MOVE R8 R5 CONCAT R1 R7 R8 14: for k in pairs(kSelectedBiomes) do -FORGLOOP_NEXT R2 -4 +L1: FORGLOOP R2 L0 1 17: return result RETURN R1 1 )"); @@ -2464,29 +2806,29 @@ end local 0: reg 5, start pc 5 line 5, end pc 8 line 5 local 1: reg 6, start pc 14 line 8, end pc 18 line 8 local 2: reg 7, start pc 14 line 8, end pc 18 line 8 -local 3: reg 3, start pc 21 line 12, end pc 24 line 12 -local 4: reg 3, start pc 26 line 16, end pc 30 line 16 -local 5: reg 0, start pc 0 line 3, end pc 34 line 21 -local 6: reg 1, start pc 0 line 3, end pc 34 line 21 -local 7: reg 2, start pc 1 line 4, end pc 34 line 21 -local 8: reg 3, start pc 34 line 21, end pc 34 line 21 +local 3: reg 3, start pc 22 line 12, end pc 25 line 12 +local 4: reg 3, start pc 27 line 16, end pc 31 line 16 +local 5: reg 0, start pc 0 line 3, end pc 35 line 21 +local 6: reg 1, start pc 0 line 3, end pc 35 line 21 +local 7: reg 2, start pc 1 line 4, end pc 35 line 21 +local 8: reg 3, start pc 35 line 21, end pc 35 line 21 3: LOADN R2 1 4: LOADN R5 1 4: LOADN R3 3 4: LOADN R4 1 -4: FORNPREP R3 +5 -5: GETIMPORT R6 1 +4: FORNPREP R3 L1 +5: L0: GETIMPORT R6 1 5: MOVE R7 R5 5: CALL R6 1 0 -4: FORNLOOP R3 -5 -7: GETIMPORT R3 3 +4: FORNLOOP R3 L0 +7: L1: GETIMPORT R3 3 7: CALL R3 0 3 -7: FORGPREP_NEXT R3 +5 -8: GETIMPORT R8 1 +7: FORGPREP_NEXT R3 L3 +8: L2: GETIMPORT R8 1 8: MOVE R9 R6 8: MOVE R10 R7 8: CALL R8 2 0 -7: FORGLOOP_NEXT R3 -6 +7: L3: FORGLOOP R3 L2 2 11: LOADN R3 2 12: GETIMPORT R4 1 12: LOADN R5 2 @@ -2502,48 +2844,108 @@ local 8: reg 3, start pc 34 line 21, end pc 34 line 21 )"); } +TEST_CASE("DebugRemarks") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Remarks); + + uint32_t fid = bcb.beginFunction(0); + + bcb.addDebugRemark("test remark #%d", 1); + bcb.emitABC(LOP_LOADNIL, 0, 0, 0); + bcb.addDebugRemark("test remark #%d", 2); + bcb.addDebugRemark("test remark #%d", 3); + bcb.emitABC(LOP_RETURN, 0, 1, 0); + + bcb.endFunction(1, 0); + + bcb.setMainFunction(fid); + bcb.finalize(); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +REMARK test remark #1 +LOADNIL R0 +REMARK test remark #2 +REMARK test remark #3 +RETURN R0 0 +)"); +} + +TEST_CASE("SourceRemarks") +{ + const char* source = R"( +local a, b = ... + +local function foo(x) + return(math.abs(x)) +end + +return foo(a) + foo(assert(b)) +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); + bcb.setDumpSource(source); + + Luau::CompileOptions options; + options.optimizationLevel = 2; + + Luau::compileOrThrow(bcb, source, options); + + std::string remarks = bcb.dumpSourceRemarks(); + + CHECK_EQ(remarks, R"( +local a, b = ... + +local function foo(x) + -- remark: builtin math.abs/1 + return(math.abs(x)) +end + +-- remark: builtin assert/1 +-- remark: inlining succeeded (cost 2, profit 2.50x, depth 0) +return foo(a) + foo(assert(b)) +)"); +} + TEST_CASE("AssignmentConflict") { // assignments are left to right CHECK_EQ("\n" + compileFunction0("local a, b a, b = 1, 2"), R"( LOADNIL R0 LOADNIL R1 -LOADN R2 1 -LOADN R3 2 -MOVE R0 R2 -MOVE R1 R3 +LOADN R0 1 +LOADN R1 2 RETURN R0 0 )"); - // if assignment of a local invalidates a direct register reference in later assignments, the value is evacuated to a temp register + // if assignment of a local invalidates a direct register reference in later assignments, the value is assigned to a temp register first CHECK_EQ("\n" + compileFunction0("local a a, a[1] = 1, 2"), R"( LOADNIL R0 -MOVE R1 R0 -LOADN R2 1 -LOADN R3 2 -MOVE R0 R2 -SETTABLEN R3 R1 1 +LOADN R1 1 +LOADN R2 2 +SETTABLEN R2 R0 1 +MOVE R0 R1 RETURN R0 0 )"); // note that this doesn't happen if the local assignment happens last naturally CHECK_EQ("\n" + compileFunction0("local a a[1], a = 1, 2"), R"( LOADNIL R0 -LOADN R1 1 -LOADN R2 2 -SETTABLEN R1 R0 1 -MOVE R0 R2 +LOADN R2 1 +LOADN R1 2 +SETTABLEN R2 R0 1 +MOVE R0 R1 RETURN R0 0 )"); // this will happen if assigned register is used in any table expression, including as an object... CHECK_EQ("\n" + compileFunction0("local a a, a.foo = 1, 2"), R"( LOADNIL R0 -MOVE R1 R0 -LOADN R2 1 -LOADN R3 2 -MOVE R0 R2 -SETTABLEKS R3 R1 K0 +LOADN R1 1 +LOADN R2 2 +SETTABLEKS R2 R0 K0 +MOVE R0 R1 RETURN R0 0 )"); @@ -2551,22 +2953,20 @@ RETURN R0 0 CHECK_EQ("\n" + compileFunction0("local a a, foo[a] = 1, 2"), R"( LOADNIL R0 GETIMPORT R1 1 -MOVE R2 R0 -LOADN R3 1 -LOADN R4 2 -MOVE R0 R3 -SETTABLE R4 R1 R2 +LOADN R2 1 +LOADN R3 2 +SETTABLE R3 R1 R0 +MOVE R0 R2 RETURN R0 0 )"); // ... or both ... CHECK_EQ("\n" + compileFunction0("local a a, a[a] = 1, 2"), R"( LOADNIL R0 -MOVE R1 R0 -LOADN R2 1 -LOADN R3 2 -MOVE R0 R2 -SETTABLE R3 R1 R1 +LOADN R1 1 +LOADN R2 2 +SETTABLE R2 R0 R0 +MOVE R0 R1 RETURN R0 0 )"); @@ -2574,14 +2974,12 @@ RETURN R0 0 CHECK_EQ("\n" + compileFunction0("local a, b a, b, a[b] = 1, 2, 3"), R"( LOADNIL R0 LOADNIL R1 -MOVE R2 R0 -MOVE R3 R1 -LOADN R4 1 -LOADN R5 2 -LOADN R6 3 -MOVE R0 R4 -MOVE R1 R5 -SETTABLE R6 R2 R3 +LOADN R2 1 +LOADN R3 2 +LOADN R4 3 +SETTABLE R4 R0 R1 +MOVE R0 R2 +MOVE R1 R3 RETURN R0 0 )"); @@ -2591,10 +2989,9 @@ RETURN R0 0 LOADNIL R0 GETIMPORT R1 1 ADDK R2 R0 K2 -LOADN R3 1 -LOADN R4 2 -MOVE R0 R3 -SETTABLE R4 R1 R2 +LOADN R0 1 +LOADN R3 2 +SETTABLE R3 R1 R2 RETURN R0 0 )"); } @@ -2604,29 +3001,29 @@ TEST_CASE("FastcallBytecode") // direct global call CHECK_EQ("\n" + compileFunction0("return math.abs(-5)"), R"( LOADN R1 -5 -FASTCALL1 2 R1 +2 +FASTCALL1 2 R1 L0 GETIMPORT R0 2 CALL R0 1 -1 -RETURN R0 -1 +L0: RETURN R0 -1 )"); // call through a local variable CHECK_EQ("\n" + compileFunction0("local abs = math.abs return abs(-5)"), R"( GETIMPORT R0 2 LOADN R2 -5 -FASTCALL1 2 R2 +1 +FASTCALL1 2 R2 L0 MOVE R1 R0 CALL R1 1 -1 -RETURN R1 -1 +L0: RETURN R1 -1 )"); // call through an upvalue CHECK_EQ("\n" + compileFunction0("local abs = math.abs function foo() return abs(-5) end return foo()"), R"( LOADN R1 -5 -FASTCALL1 2 R1 +1 +FASTCALL1 2 R1 L0 GETUPVAL R0 0 CALL R0 1 -1 -RETURN R0 -1 +L0: RETURN R0 -1 )"); // mutating the global in the script breaks the optimization @@ -2663,6 +3060,75 @@ RETURN R1 -1 )"); } +TEST_CASE("FastcallSelect") +{ + // select(_, ...) compiles to a builtin call + CHECK_EQ("\n" + compileFunction0("return (select('#', ...))"), R"( +LOADK R1 K0 +FASTCALL1 57 R1 L0 +GETIMPORT R0 2 +GETVARARGS R2 -1 +CALL R0 -1 1 +L0: RETURN R0 1 +)"); + + // more complex example: select inside a for loop bound + select from a iterator + CHECK_EQ("\n" + compileFunction0(R"( +local sum = 0 +for i=1, select('#', ...) do + sum += select(i, ...) +end +return sum +)"), + R"( +LOADN R0 0 +LOADN R3 1 +LOADK R5 K0 +FASTCALL1 57 R5 L0 +GETIMPORT R4 2 +GETVARARGS R6 -1 +CALL R4 -1 1 +L0: MOVE R1 R4 +LOADN R2 1 +FORNPREP R1 L3 +L1: FASTCALL1 57 R3 L2 +GETIMPORT R4 2 +MOVE R5 R3 +GETVARARGS R6 -1 +CALL R4 -1 1 +L2: ADD R0 R0 R4 +FORNLOOP R1 L1 +L3: RETURN R0 1 +)"); + + // currently we assume a single value return to avoid dealing with stack resizing + CHECK_EQ("\n" + compileFunction0("return select('#', ...)"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +GETVARARGS R2 -1 +CALL R0 -1 -1 +RETURN R0 -1 +)"); + + // note that select with a non-variadic second argument doesn't get optimized + CHECK_EQ("\n" + compileFunction0("return select('#')"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +CALL R0 1 -1 +RETURN R0 -1 +)"); + + // note that select with a non-variadic second argument doesn't get optimized + CHECK_EQ("\n" + compileFunction0("return select('#', foo())"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +GETIMPORT R2 4 +CALL R2 0 -1 +CALL R0 -1 -1 +RETURN R0 -1 +)"); +} + TEST_CASE("LotsOfParameters") { const char* source = R"( @@ -2785,47 +3251,35 @@ CAPTURE UPVAL U1 RETURN R0 1 )"); - if (FFlag::LuauPreloadClosuresUpval) - { - // recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( + // recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( DUPCLOSURE R0 K0 CAPTURE VAL R0 RETURN R0 0 )"); - // multi-level recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( + // multi-level recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( DUPCLOSURE R0 K0 CAPTURE UPVAL U0 RETURN R0 1 )"); - // multi-level recursive capture where function isn't top-level - // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler - CHECK_EQ("\n" + compileFunction(R"( + // multi-level recursive capture where function isn't top-level + // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler + CHECK_EQ("\n" + compileFunction(R"( local function foo() local function bar() return function() return bar() end end end )", - 1), - R"( + 1), + R"( NEWCLOSURE R0 P0 CAPTURE UPVAL U0 RETURN R0 1 )"); - } - else - { - // recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( -NEWCLOSURE R0 P0 -CAPTURE VAL R0 -RETURN R0 0 -)"); - } } TEST_CASE("OutOfLocals") @@ -2947,15 +3401,24 @@ TEST_CASE("FastCallImportFallback") std::vector insns = Luau::split(code, '\n'); - CHECK_EQ(insns[insns.size() - 9], "LOADN R1 1024"); - CHECK_EQ(insns[insns.size() - 8], "LOADK R2 K1023"); - CHECK_EQ(insns[insns.size() - 7], "SETTABLE R2 R0 R1"); - CHECK_EQ(insns[insns.size() - 6], "LOADN R2 -1"); - CHECK_EQ(insns[insns.size() - 5], "FASTCALL1 2 R2 +4"); - CHECK_EQ(insns[insns.size() - 4], "GETGLOBAL R3 K1024"); // note: it's important that this doesn't overwrite R2 - CHECK_EQ(insns[insns.size() - 3], "GETTABLEKS R1 R3 K1025"); - CHECK_EQ(insns[insns.size() - 2], "CALL R1 1 -1"); - CHECK_EQ(insns[insns.size() - 1], "RETURN R1 -1"); + std::string fragment; + for (size_t i = 9; i > 1; --i) + { + fragment += std::string(insns[insns.size() - i]); + fragment += "\n"; + } + + // note: it's important that GETGLOBAL below doesn't overwrite R2 + CHECK_EQ("\n" + fragment, R"( +LOADN R1 1024 +LOADK R2 K1023 +SETTABLE R2 R0 R1 +LOADN R2 -1 +FASTCALL1 2 R2 L0 +GETGLOBAL R3 K1024 +GETTABLEKS R1 R3 K1025 +CALL R1 1 -1 +)"); } TEST_CASE("CompoundAssignment") @@ -3009,7 +3472,7 @@ RETURN R0 0 // table variants (indexed by string, number, variable) CHECK_EQ("\n" + compileFunction0("local a = {} a.foo += 5"), R"( -NEWTABLE R0 1 0 +NEWTABLE R0 0 0 GETTABLEKS R1 R0 K0 ADDK R1 R1 K1 SETTABLEKS R1 R0 K0 @@ -3017,7 +3480,7 @@ RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = {} a[1] += 5"), R"( -NEWTABLE R0 0 1 +NEWTABLE R0 0 0 GETTABLEN R1 R0 1 ADDK R1 R1 K0 SETTABLEN R1 R0 1 @@ -3112,36 +3575,47 @@ TEST_CASE("JumpTrampoline") insns.push_back(insn); // FORNPREP and early JUMPs (break) need to go through a trampoline - CHECK_EQ(insns[0], "LOADN R0 0"); - CHECK_EQ(insns[1], "LOADN R3 1"); - CHECK_EQ(insns[2], "LOADN R1 3"); - CHECK_EQ(insns[3], "LOADN R2 1"); - CHECK_EQ(insns[4], "JUMP +1"); - CHECK_EQ(insns[5], "JUMPX +54542"); - CHECK_EQ(insns[6], "FORNPREP R1 -2"); - CHECK_EQ(insns[7], "ADD R0 R0 R3"); - CHECK_EQ(insns[8], "LOADK R4 K0"); - CHECK_EQ(insns[9], "JUMP +1"); - CHECK_EQ(insns[10], "JUMPX +54537"); - CHECK_EQ(insns[11], "JUMPIFLT R4 R0 -2"); - CHECK_EQ(insns[12], "ADD R0 R0 R3"); - CHECK_EQ(insns[13], "LOADK R4 K0"); - CHECK_EQ(insns[14], "JUMP +1"); - CHECK_EQ(insns[15], "JUMPX +54531"); - CHECK_EQ(insns[16], "JUMPIFLT R4 R0 -2"); + std::string head; + for (size_t i = 0; i < 16; ++i) + head += insns[i] + "\n"; + + CHECK_EQ("\n" + head, R"( +LOADN R0 0 +LOADN R3 1 +LOADN R1 3 +LOADN R2 1 +JUMP L1 +L0: JUMPX L14543 +L1: FORNPREP R1 L0 +L2: ADD R0 R0 R3 +LOADK R4 K0 +JUMP L4 +L3: JUMPX L14543 +L4: JUMPIFLT R4 R0 L3 +ADD R0 R0 R3 +LOADK R4 K0 +JUMP L6 +L5: JUMPX L14543 +)"); // FORNLOOP has to go through a trampoline since the jump is back to the beginning of the function // however, late JUMPs (break) don't need a trampoline since the loop end is really close by - CHECK_EQ(insns[44539], "ADD R0 R0 R3"); - CHECK_EQ(insns[44540], "LOADK R4 K0"); - CHECK_EQ(insns[44541], "JUMPIFLT R4 R0 +8"); - CHECK_EQ(insns[44542], "ADD R0 R0 R3"); - CHECK_EQ(insns[44543], "LOADK R4 K0"); - CHECK_EQ(insns[44544], "JUMPIFLT R4 R0 +4"); - CHECK_EQ(insns[44545], "JUMP +1"); - CHECK_EQ(insns[44546], "JUMPX -54540"); - CHECK_EQ(insns[44547], "FORNLOOP R1 -2"); - CHECK_EQ(insns[44548], "RETURN R0 1"); + std::string tail; + for (size_t i = 44539; i < insns.size(); ++i) + tail += insns[i] + "\n"; + + CHECK_EQ("\n" + tail, R"( +ADD R0 R0 R3 +LOADK R4 K0 +JUMPIFLT R4 R0 L14543 +ADD R0 R0 R3 +LOADK R4 K0 +JUMPIFLT R4 R0 L14543 +JUMP L14542 +L14541: JUMPX L2 +L14542: FORNLOOP R1 L14541 +L14543: RETURN R0 1 +)"); } TEST_CASE("CompileBytecode") @@ -3216,10 +3690,10 @@ local b = obj == 1 )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 +2 +JUMPXEQKN R0 K0 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0(R"( @@ -3228,10 +3702,10 @@ local b = 1 == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 +2 +JUMPXEQKN R0 K0 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0(R"( @@ -3240,10 +3714,10 @@ local b = "Hello, Sailor!" == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 +2 +JUMPXEQKS R0 K0 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0(R"( @@ -3252,10 +3726,10 @@ local b = nil == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 +2 +JUMPXEQKNIL R0 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0(R"( @@ -3264,10 +3738,10 @@ local b = true == obj )"), R"( GETVARARGS R0 1 -JUMPIFEQK R0 K0 +2 +JUMPXEQKB R0 1 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0(R"( @@ -3276,10 +3750,10 @@ local b = nil ~= obj )"), R"( GETVARARGS R0 1 -JUMPIFNOTEQK R0 K0 +2 +JUMPXEQKNIL R0 L0 NOT LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); // table literals should not generate IFEQK variants @@ -3290,10 +3764,10 @@ local b = obj == {} R"( GETVARARGS R0 1 NEWTABLE R2 0 0 -JUMPIFEQ R0 R2 +2 +JUMPIFEQ R0 R2 L0 LOADB R1 0 +1 -LOADB R1 1 -RETURN R0 0 +L0: LOADB R1 1 +L1: RETURN R0 0 )"); } @@ -3355,13 +3829,13 @@ end R"( 2: COVERAGE 2: GETIMPORT R0 1 -2: JUMPIFNOT R0 +6 +2: JUMPIFNOT R0 L0 3: COVERAGE 3: GETIMPORT R0 3 3: LOADN R1 1 3: CALL R0 1 0 7: RETURN R0 0 -5: COVERAGE +5: L0: COVERAGE 5: GETIMPORT R0 3 5: LOADN R1 2 5: CALL R0 1 0 @@ -3383,13 +3857,13 @@ end R"( 2: COVERAGE 2: GETIMPORT R0 1 -2: JUMPIFNOT R0 +6 +2: JUMPIFNOT R0 L0 4: COVERAGE 4: GETIMPORT R0 3 4: LOADN R1 1 4: CALL R0 1 0 9: RETURN R0 0 -7: COVERAGE +7: L0: COVERAGE 7: GETIMPORT R0 3 7: LOADN R1 2 7: CALL R0 1 0 @@ -3430,8 +3904,6 @@ local t = { TEST_CASE("ConstantClosure") { - ScopedFastFlag sff("LuauPreloadClosures", true); - // closures without upvalues are created when bytecode is loaded CHECK_EQ("\n" + compileFunction(R"( return function() end @@ -3467,15 +3939,13 @@ CAPTURE VAL R0 RETURN R1 1 )"); - if (FFlag::LuauPreloadClosuresFenv) - { - // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion - CHECK_EQ("\n" + compileFunction(R"( + // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion + CHECK_EQ("\n" + compileFunction(R"( setfenv(1, {}) return function() print("hi") end )", - 1), - R"( + 1), + R"( GETIMPORT R0 1 LOADN R1 1 NEWTABLE R2 0 0 @@ -3484,24 +3954,20 @@ NEWCLOSURE R0 P0 RETURN R0 1 )"); - // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature - CHECK_EQ("\n" + compileFunction(R"( + // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature + CHECK_EQ("\n" + compileFunction(R"( if false then setfenv(1, {}) end return function() print("hi") end )", - 1), - R"( + 1), + R"( NEWCLOSURE R0 P0 RETURN R0 1 )"); - } } TEST_CASE("SharedClosure") { - ScopedFastFlag sff1("LuauPreloadClosures", true); - ScopedFastFlag sff2("LuauPreloadClosuresUpval", true); - // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( local val = ... @@ -3630,33 +4096,2698 @@ end LOADN R2 1 LOADN R0 10 LOADN R1 1 -FORNPREP R0 +6 -GETIMPORT R3 1 +FORNPREP R0 L1 +L0: GETIMPORT R3 1 NEWCLOSURE R4 P0 CAPTURE VAL R2 CALL R3 1 0 -FORNLOOP R0 -6 -GETIMPORT R0 3 +FORNLOOP R0 L0 +L1: GETIMPORT R0 3 GETVARARGS R1 -1 CALL R0 -1 3 -FORGPREP_NEXT R0 +5 -GETIMPORT R5 1 +FORGPREP_NEXT R0 L3 +L2: GETIMPORT R5 1 NEWCLOSURE R6 P1 CAPTURE VAL R3 CALL R5 1 0 -FORGLOOP_NEXT R0 -6 +L3: FORGLOOP R0 L2 2 LOADN R2 1 LOADN R0 10 LOADN R1 1 -FORNPREP R0 +7 -MOVE R3 R2 -GETIMPORT R4 1 -NEWCLOSURE R5 P2 -CAPTURE VAL R3 -CALL R4 1 0 -FORNLOOP R0 -7 +FORNPREP R0 L5 +L4: GETIMPORT R3 1 +NEWCLOSURE R4 P2 +CAPTURE VAL R2 +CALL R3 1 0 +FORNLOOP R0 L4 +L5: RETURN R0 0 +)"); +} + +TEST_CASE("MutableGlobals") +{ + const char* source = R"( +print() +Game.print() +Workspace.print() +_G.print() +game.print() +plugin.print() +script.print() +shared.print() +workspace.print() +)"; + + // Check Roblox globals are no longer here + CHECK_EQ("\n" + compileFunction0(source), R"( +GETIMPORT R0 1 +CALL R0 0 0 +GETIMPORT R0 3 +CALL R0 0 0 +GETIMPORT R0 5 +CALL R0 0 0 +GETIMPORT R1 7 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R0 9 +CALL R0 0 0 +GETIMPORT R0 11 +CALL R0 0 0 +GETIMPORT R0 13 +CALL R0 0 0 +GETIMPORT R0 15 +CALL R0 0 0 +GETIMPORT R0 17 +CALL R0 0 0 +RETURN R0 0 +)"); + + // Check we can add them back + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + const char* mutableGlobals[] = {"Game", "Workspace", "game", "plugin", "script", "shared", "workspace", NULL}; + options.mutableGlobals = &mutableGlobals[0]; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +GETIMPORT R0 1 +CALL R0 0 0 +GETIMPORT R1 3 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 5 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 7 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 9 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 11 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 13 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 15 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 17 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 RETURN R0 0 )"); } +TEST_CASE("ConstantsNoFolding") +{ + const char* source = "return nil, true, 42, 'hello'"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.optimizationLevel = 0; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +LOADNIL R0 +LOADB R1 1 +LOADK R2 K0 +LOADK R3 K1 +RETURN R0 4 +)"); +} + +TEST_CASE("VectorFastCall") +{ + const char* source = "return Vector3.new(1, 2, 3)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.vectorLib = "Vector3"; + options.vectorCtor = "new"; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +LOADN R1 1 +LOADN R2 2 +LOADN R3 3 +FASTCALL 54 L0 +GETIMPORT R0 2 +CALL R0 3 -1 +L0: RETURN R0 -1 +)"); +} + +TEST_CASE("TypeAssertion") +{ + // validate that type assertions work with the compiler and that the code inside type assertion isn't evaluated + CHECK_EQ("\n" + compileFunction0(R"( +print(foo() :: typeof(error("compile time"))) +)"), + R"( +GETIMPORT R0 1 +GETIMPORT R1 3 +CALL R1 0 1 +CALL R0 1 0 +RETURN R0 0 +)"); + + // note that above, foo() is treated as single-arg function; removing type assertion changes the bytecode + CHECK_EQ("\n" + compileFunction0(R"( +print(foo()) +)"), + R"( +GETIMPORT R0 1 +GETIMPORT R1 3 +CALL R1 0 -1 +CALL R0 -1 0 +RETURN R0 0 +)"); +} + +TEST_CASE("Arithmetics") +{ + // basic arithmetics codegen with non-constants + CHECK_EQ("\n" + compileFunction0(R"( +local a, b = ... +return a + b, a - b, a / b, a * b, a % b, a ^ b +)"), + R"( +GETVARARGS R0 2 +ADD R2 R0 R1 +SUB R3 R0 R1 +DIV R4 R0 R1 +MUL R5 R0 R1 +MOD R6 R0 R1 +POW R7 R0 R1 +RETURN R2 6 +)"); + + // basic arithmetics codegen with constants on the right side + // note that we don't simplify these expressions as we don't know the type of a + CHECK_EQ("\n" + compileFunction0(R"( +local a = ... +return a + 1, a - 1, a / 1, a * 1, a % 1, a ^ 1 +)"), + R"( +GETVARARGS R0 1 +ADDK R1 R0 K0 +SUBK R2 R0 K0 +DIVK R3 R0 K0 +MULK R4 R0 K0 +MODK R5 R0 K0 +POWK R6 R0 K0 +RETURN R1 6 +)"); +} + +TEST_CASE("LoopUnrollBasic") +{ + // forward loops + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,2 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 2 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 2 +SETTABLEN R1 R0 2 +RETURN R0 1 +)"); + + // backward loops + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=2,1,-1 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 2 +SETTABLEN R1 R0 2 +LOADN R1 1 +SETTABLEN R1 R0 1 +RETURN R0 1 +)"); + + // loops with step that doesn't divide to-from + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,4,2 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 3 +SETTABLEN R1 R0 3 +RETURN R0 1 +)"); + + // empty loops + CHECK_EQ("\n" + compileFunction(R"( +for i=2,1 do +end +)", + 0, 2), + R"( +RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollNested") +{ + // we can unroll nested loops just fine + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=0,1 do + for j=0,1 do + t[i*2+(j+1)] = 0 + end +end +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 0 +SETTABLEN R1 R0 1 +LOADN R1 0 +SETTABLEN R1 R0 2 +LOADN R1 0 +SETTABLEN R1 R0 3 +LOADN R1 0 +SETTABLEN R1 R0 4 +RETURN R0 0 +)"); + + // if the inner loop is too expensive, we won't unroll the outer loop though, but we'll still unroll the inner loop! + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=0,3 do + for j=0,3 do + t[i*4+(j+1)] = 0 + end +end +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R3 0 +LOADN R1 3 +LOADN R2 1 +FORNPREP R1 L1 +L0: MULK R5 R3 K1 +ADDK R4 R5 K0 +LOADN R5 0 +SETTABLE R5 R0 R4 +MULK R5 R3 K1 +ADDK R4 R5 K2 +LOADN R5 0 +SETTABLE R5 R0 R4 +MULK R5 R3 K1 +ADDK R4 R5 K3 +LOADN R5 0 +SETTABLE R5 R0 R4 +MULK R5 R3 K1 +ADDK R4 R5 K1 +LOADN R5 0 +SETTABLE R5 R0 R4 +FORNLOOP R1 L0 +L1: RETURN R0 0 +)"); + + // note, we sometimes can even unroll a loop with varying internal iterations + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=0,1 do + for j=0,i do + t[i*2+(j+1)] = 0 + end +end +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 0 +SETTABLEN R1 R0 1 +LOADN R1 0 +SETTABLEN R1 R0 3 +LOADN R1 0 +SETTABLEN R1 R0 4 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollUnsupported") +{ + // can't unroll loops with non-constant bounds + CHECK_EQ("\n" + compileFunction(R"( +for i=x,y,z do +end +)", + 0, 2), + R"( +GETIMPORT R2 1 +GETIMPORT R0 3 +GETIMPORT R1 5 +FORNPREP R0 L1 +L0: FORNLOOP R0 L0 +L1: RETURN R0 0 +)"); + + // can't unroll loops with bounds where we can't compute trip count + CHECK_EQ("\n" + compileFunction(R"( +for i=1,1,0 do +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 1 +LOADN R1 0 +FORNPREP R0 L1 +L0: FORNLOOP R0 L0 +L1: RETURN R0 0 +)"); + + // can't unroll loops with bounds that might be imprecise (non-integer) + CHECK_EQ("\n" + compileFunction(R"( +for i=1,2,0.1 do +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 2 +LOADK R1 K0 +FORNPREP R0 L1 +L0: FORNLOOP R0 L0 +L1: RETURN R0 0 +)"); + + // can't unroll loops if the bounds are too large, as it might overflow trip count math + CHECK_EQ("\n" + compileFunction(R"( +for i=4294967295,4294967296 do +end +)", + 0, 2), + R"( +LOADK R2 K0 +LOADK R0 K1 +LOADN R1 1 +FORNPREP R0 L1 +L0: FORNLOOP R0 L0 +L1: RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollControlFlow") +{ + ScopedFastInt sfis[] = { + {"LuauCompileLoopUnrollThreshold", 50}, + {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, + }; + + // break jumps to the end + CHECK_EQ("\n" + compileFunction(R"( +for i=1,3 do + if math.random() < 0.5 then + break + end +end +)", + 0, 2), + R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L0 +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L0 +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L0 +L0: RETURN R0 0 +)"); + + // continue jumps to the next iteration + CHECK_EQ("\n" + compileFunction(R"( +for i=1,3 do + if math.random() < 0.5 then + continue + end + print(i) +end +)", + 0, 2), + R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L0 +GETIMPORT R0 5 +LOADN R1 1 +CALL R0 1 0 +L0: GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L1 +GETIMPORT R0 5 +LOADN R1 2 +CALL R0 1 0 +L1: GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 L2 +GETIMPORT R0 5 +LOADN R1 3 +CALL R0 1 0 +L2: RETURN R0 0 +)"); + + // continue needs to properly close upvalues + CHECK_EQ("\n" + compileFunction(R"( +for i=1,1 do + local j = global(i) + print(function() return j end) + if math.random() < 0.5 then + continue + end + j += 1 +end +)", + 1, 2), + R"( +GETIMPORT R0 1 +LOADN R1 1 +CALL R0 1 1 +GETIMPORT R1 3 +NEWCLOSURE R2 P0 +CAPTURE REF R0 +CALL R1 1 0 +GETIMPORT R1 6 +CALL R1 0 1 +LOADK R2 K7 +JUMPIFNOTLT R1 R2 L0 +CLOSEUPVALS R0 +RETURN R0 0 +L0: ADDK R0 R0 K8 +CLOSEUPVALS R0 +RETURN R0 0 +)"); + + // this weird contraption just disappears + CHECK_EQ("\n" + compileFunction(R"( +for i=1,1 do + for j=1,1 do + if i == 1 then + continue + else + break + end + end +end +)", + 0, 2), + R"( +RETURN R0 0 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollNestedClosure") +{ + // if the body has functions that refer to loop variables, we unroll the loop and use MOVE+CAPTURE for upvalues + CHECK_EQ("\n" + compileFunction(R"( +for i=1,2 do + local x = function() return i end +end +)", + 1, 2), + R"( +LOADN R1 1 +NEWCLOSURE R0 P0 +CAPTURE VAL R1 +LOADN R1 2 +NEWCLOSURE R0 P0 +CAPTURE VAL R1 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollCost") +{ + ScopedFastInt sfis[] = { + {"LuauCompileLoopUnrollThreshold", 25}, + {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, + }; + + // loops with short body + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,10 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 10 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 2 +SETTABLEN R1 R0 2 +LOADN R1 3 +SETTABLEN R1 R0 3 +LOADN R1 4 +SETTABLEN R1 R0 4 +LOADN R1 5 +SETTABLEN R1 R0 5 +LOADN R1 6 +SETTABLEN R1 R0 6 +LOADN R1 7 +SETTABLEN R1 R0 7 +LOADN R1 8 +SETTABLEN R1 R0 8 +LOADN R1 9 +SETTABLEN R1 R0 9 +LOADN R1 10 +SETTABLEN R1 R0 10 +RETURN R0 1 +)"); + + // loops with body that's too long + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,100 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R3 1 +LOADN R1 100 +LOADN R2 1 +FORNPREP R1 L1 +L0: SETTABLE R3 R0 R3 +FORNLOOP R1 L0 +L1: RETURN R0 1 +)"); + + // loops with body that's long but has a high boost factor due to constant folding + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,25 do + t[i] = i * i * i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 8 +SETTABLEN R1 R0 2 +LOADN R1 27 +SETTABLEN R1 R0 3 +LOADN R1 64 +SETTABLEN R1 R0 4 +LOADN R1 125 +SETTABLEN R1 R0 5 +LOADN R1 216 +SETTABLEN R1 R0 6 +LOADN R1 343 +SETTABLEN R1 R0 7 +LOADN R1 512 +SETTABLEN R1 R0 8 +LOADN R1 729 +SETTABLEN R1 R0 9 +LOADN R1 1000 +SETTABLEN R1 R0 10 +LOADN R1 1331 +SETTABLEN R1 R0 11 +LOADN R1 1728 +SETTABLEN R1 R0 12 +LOADN R1 2197 +SETTABLEN R1 R0 13 +LOADN R1 2744 +SETTABLEN R1 R0 14 +LOADN R1 3375 +SETTABLEN R1 R0 15 +LOADN R1 4096 +SETTABLEN R1 R0 16 +LOADN R1 4913 +SETTABLEN R1 R0 17 +LOADN R1 5832 +SETTABLEN R1 R0 18 +LOADN R1 6859 +SETTABLEN R1 R0 19 +LOADN R1 8000 +SETTABLEN R1 R0 20 +LOADN R1 9261 +SETTABLEN R1 R0 21 +LOADN R1 10648 +SETTABLEN R1 R0 22 +LOADN R1 12167 +SETTABLEN R1 R0 23 +LOADN R1 13824 +SETTABLEN R1 R0 24 +LOADN R1 15625 +SETTABLEN R1 R0 25 +RETURN R0 1 +)"); + + // loops with body that's long and doesn't have a high boost factor + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,10 do + t[i] = math.abs(math.sin(i)) +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 10 +LOADN R3 1 +LOADN R1 10 +LOADN R2 1 +FORNPREP R1 L3 +L0: FASTCALL1 24 R3 L1 +MOVE R6 R3 +GETIMPORT R5 2 +CALL R5 1 -1 +L1: FASTCALL 2 L2 +GETIMPORT R4 4 +CALL R4 -1 1 +L2: SETTABLE R4 R0 R3 +FORNLOOP R1 L0 +L3: RETURN R0 1 +)"); +} + +TEST_CASE("LoopUnrollMutable") +{ + // can't unroll loops that mutate iteration variable + CHECK_EQ("\n" + compileFunction(R"( +for i=1,3 do + i = 3 + print(i) -- should print 3 three times in a row +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 3 +LOADN R1 1 +FORNPREP R0 L1 +L0: MOVE R3 R2 +LOADN R3 3 +GETIMPORT R4 1 +MOVE R5 R3 +CALL R4 1 0 +FORNLOOP R0 L0 +L1: RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollCostBuiltins") +{ + ScopedFastInt sfis[] = { + {"LuauCompileLoopUnrollThreshold", 25}, + {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, + }; + + // this loop uses builtins and is close to the cost budget so it's important that we model builtins as cheaper than regular calls + CHECK_EQ("\n" + compileFunction(R"( +function cipher(block, nonce) + for i = 0,3 do + block[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff) + end +end +)", + 0, 2), + R"( +FASTCALL2K 39 R1 K0 L0 +MOVE R4 R1 +LOADK R5 K0 +GETIMPORT R3 3 +CALL R3 2 1 +L0: FASTCALL2K 29 R3 K4 L1 +LOADK R4 K4 +GETIMPORT R2 6 +CALL R2 2 1 +L1: SETTABLEN R2 R0 1 +FASTCALL2K 39 R1 K7 L2 +MOVE R4 R1 +LOADK R5 K7 +GETIMPORT R3 3 +CALL R3 2 1 +L2: FASTCALL2K 29 R3 K4 L3 +LOADK R4 K4 +GETIMPORT R2 6 +CALL R2 2 1 +L3: SETTABLEN R2 R0 2 +FASTCALL2K 39 R1 K8 L4 +MOVE R4 R1 +LOADK R5 K8 +GETIMPORT R3 3 +CALL R3 2 1 +L4: FASTCALL2K 29 R3 K4 L5 +LOADK R4 K4 +GETIMPORT R2 6 +CALL R2 2 1 +L5: SETTABLEN R2 R0 3 +FASTCALL2K 39 R1 K9 L6 +MOVE R4 R1 +LOADK R5 K9 +GETIMPORT R3 3 +CALL R3 2 1 +L6: FASTCALL2K 29 R3 K4 L7 +LOADK R4 K4 +GETIMPORT R2 6 +CALL R2 2 1 +L7: SETTABLEN R2 R0 4 +RETURN R0 0 +)"); + + // note that if we break compiler's ability to reason about bit32 builtin the loop is no longer unrolled as it's too expensive + CHECK_EQ("\n" + compileFunction(R"( +bit32 = {} + +function cipher(block, nonce) + for i = 0,3 do + block[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff) + end +end +)", + 0, 2), + R"( +LOADN R4 0 +LOADN R2 3 +LOADN R3 1 +FORNPREP R2 L1 +L0: ADDK R5 R4 K0 +GETGLOBAL R7 K1 +GETTABLEKS R6 R7 K2 +GETGLOBAL R8 K1 +GETTABLEKS R7 R8 K3 +MOVE R8 R1 +MULK R9 R4 K4 +CALL R7 2 1 +LOADN R8 255 +CALL R6 2 1 +SETTABLE R6 R0 R5 +FORNLOOP R2 L0 +L1: RETURN R0 0 +)"); + + // additionally, if we pass too many constants the builtin stops being cheap because of argument setup + CHECK_EQ("\n" + compileFunction(R"( +function cipher(block, nonce) + for i = 0,3 do + block[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff, 0xff, 0xff, 0xff, 0xff) + end +end +)", + 0, 2), + R"( +LOADN R4 0 +LOADN R2 3 +LOADN R3 1 +FORNPREP R2 L3 +L0: ADDK R5 R4 K0 +MULK R9 R4 K1 +FASTCALL2 39 R1 R9 L1 +MOVE R8 R1 +GETIMPORT R7 4 +CALL R7 2 1 +L1: LOADN R8 255 +LOADN R9 255 +LOADN R10 255 +LOADN R11 255 +LOADN R12 255 +FASTCALL 29 L2 +GETIMPORT R6 6 +CALL R6 6 1 +L2: SETTABLE R6 R0 R5 +FORNLOOP R2 L0 +L3: RETURN R0 0 +)"); +} + +TEST_CASE("InlineBasic") +{ + // inline function that returns a constant + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return 42 +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // inline function that returns the argument + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // inline function that returns one of the two arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b, c) + if a then + return b + else + return c + end +end + +local x = foo(true, math.random(), 5) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETIMPORT R2 3 +CALL R2 0 1 +MOVE R1 R2 +RETURN R1 1 +RETURN R1 1 +)"); + + // inline function that returns one of the two arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b, c) + if a then + return b + else + return c + end +end + +local x = foo(true, 5, math.random()) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETIMPORT R2 3 +CALL R2 0 1 +LOADN R1 5 +RETURN R1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineBasicProhibited") +{ + // we can't inline variadic functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(...) + return 42 +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // we can't inline any functions in modules with getfenv/setfenv + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return 42 +end + +local x = foo() +getfenv() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +GETIMPORT R2 2 +CALL R2 0 0 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineNestedLoops") +{ + // functions with basic loops get inlined + CHECK_EQ("\n" + compileFunction(R"( +local function foo(t) + for i=1,3 do + t[i] = i + end + return t +end + +local x = foo({}) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +NEWTABLE R2 0 0 +LOADN R3 1 +SETTABLEN R3 R2 1 +LOADN R3 2 +SETTABLEN R3 R2 2 +LOADN R3 3 +SETTABLEN R3 R2 3 +MOVE R1 R2 +RETURN R1 1 +)"); + + // we can even unroll the loops based on inline argument + CHECK_EQ("\n" + compileFunction(R"( +local function foo(t, n) + for i=1, n do + t[i] = i + end + return t +end + +local x = foo({}, 3) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +NEWTABLE R2 0 0 +LOADN R3 1 +SETTABLEN R3 R2 1 +LOADN R3 2 +SETTABLEN R3 R2 2 +LOADN R3 3 +SETTABLEN R3 R2 3 +MOVE R1 R2 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineNestedClosures") +{ + // we can inline functions that contain/return functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(x) + return function(y) return x + y end +end + +local x = foo(1)(2) +return x +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 1 +NEWCLOSURE R1 P1 +CAPTURE VAL R2 +LOADN R2 2 +CALL R1 1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineMutate") +{ + // if the argument is mutated, it gets a register even if the value is constant + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 5 + return a +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 42 +ORK R2 R2 K1 +MOVE R1 R2 +RETURN R1 1 +)"); + + // if the argument is a local, it can be used directly + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = ... +local y = foo(x) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R2 R1 +RETURN R2 1 +)"); + + // ... but if it's mutated, we move it in case it is mutated through a capture during the inlined function + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = ... +x = nil +local y = foo(x) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +LOADNIL R1 +MOVE R3 R1 +MOVE R2 R3 +RETURN R2 1 +)"); + + // we also don't inline functions if they have been assigned to + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +foo = foo + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R0 R0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineUpval") +{ + // if the argument is an upvalue, we naturally need to copy it to a local + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local b = ... + +function bar() + local x = foo(b) + return x +end +)", + 1, 2), + R"( +GETUPVAL R1 0 +MOVE R0 R1 +RETURN R0 1 +)"); + + // if the function uses an upvalue it's more complicated, because the lexical upvalue may become a local + CHECK_EQ("\n" + compileFunction(R"( +local b = ... + +local function foo(a) + return a + b +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +GETVARARGS R0 1 +DUPCLOSURE R1 K0 +CAPTURE VAL R0 +LOADN R3 42 +ADD R2 R3 R0 +RETURN R2 1 +)"); + + // sometimes the lexical upvalue is deep enough that it's still an upvalue though + CHECK_EQ("\n" + compileFunction(R"( +local b = ... + +function bar() + local function foo(a) + return a + b + end + + local x = foo(42) + return x +end +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +CAPTURE UPVAL U0 +LOADN R2 42 +GETUPVAL R3 0 +ADD R1 R2 R3 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineCapture") +{ + // if the argument is captured by a nested closure, normally we can rely on capture by value + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local x = ... +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +NEWCLOSURE R2 P1 +CAPTURE VAL R1 +RETURN R2 1 +)"); + + // if the argument is a constant, we move it to a register so that capture by value can happen + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local y = foo(42) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 42 +NEWCLOSURE R1 P1 +CAPTURE VAL R2 +RETURN R1 1 +)"); + + // if the argument is an externally mutated variable, we copy it to an argument and capture it by value + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local x x = 42 +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADN R1 42 +MOVE R3 R1 +NEWCLOSURE R2 P1 +CAPTURE VAL R3 +RETURN R2 1 +)"); + + // finally, if the argument is mutated internally, we must capture it by reference and close the upvalue + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 42 + return function() return a end +end + +local y = foo() +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R2 +ORK R2 R2 K1 +NEWCLOSURE R1 P1 +CAPTURE REF R2 +CLOSEUPVALS R2 +RETURN R1 1 +)"); + + // note that capture might need to be performed during the fallthrough block + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 42 + print(function() return a end) +end + +local x = ... +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R3 R1 +ORK R3 R3 K1 +GETIMPORT R4 3 +NEWCLOSURE R5 P1 +CAPTURE REF R3 +CALL R4 1 0 +LOADNIL R2 +CLOSEUPVALS R3 +RETURN R2 1 +)"); + + // note that mutation and capture might be inside internal control flow + // TODO: this has an oddly redundant CLOSEUPVALS after JUMP; it's not due to inlining, and is an artifact of how StatBlock/StatReturn interact + // fixing this would reduce the number of redundant CLOSEUPVALS a bit but it only affects bytecode size as these instructions aren't executed + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + if not a then + local b b = 42 + return function() return b end + end +end + +local x = ... +local y = foo(x) +return y, x +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +JUMPIF R1 L0 +LOADNIL R3 +LOADN R3 42 +NEWCLOSURE R2 P1 +CAPTURE REF R3 +CLOSEUPVALS R3 +JUMP L1 +CLOSEUPVALS R3 +L0: LOADNIL R2 +L1: MOVE R3 R2 +MOVE R4 R1 +RETURN R3 2 +)"); +} + +TEST_CASE("InlineFallthrough") +{ + // if the function doesn't return, we still fill the results with nil + CHECK_EQ("\n" + compileFunction(R"( +local function foo() +end + +local a, b = foo() +return a, b +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADNIL R2 +RETURN R1 2 +)"); + + // this happens even if the function returns conditionally + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + if a then return 42 end +end + +local a, b = foo(false) +return a, b +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADNIL R2 +RETURN R1 2 +)"); + + // note though that we can't inline a function like this in multret context + // this is because we don't have a SETTOP instruction + CHECK_EQ("\n" + compileFunction(R"( +local function foo() +end + +return foo() +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 -1 +RETURN R1 -1 +)"); +} + +TEST_CASE("InlineArgMismatch") +{ + // when inlining a function, we must respect all the usual rules + + // caller might not have enough arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +RETURN R1 1 +)"); + + // caller might be using multret for arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x = foo(math.modf(1.5)) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADK R3 K1 +FASTCALL1 20 R3 L0 +GETIMPORT R2 4 +CALL R2 1 2 +L0: ADD R1 R2 R3 +RETURN R1 1 +)"); + + // caller might be using varargs for arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x = foo(...) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R2 2 +ADD R1 R2 R3 +RETURN R1 1 +)"); + + // caller might have too many arguments, but we still need to compute them for side effects + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo(42, print()) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETIMPORT R2 2 +CALL R2 0 1 +LOADN R1 42 +RETURN R1 1 +)"); + + // caller might not have enough arguments, and the arg might be mutated so it needs a register + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = 42 + return a +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R2 +LOADN R2 42 +MOVE R1 R2 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineMultiple") +{ + // we call this with a different set of variable/constant args + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x, y = ... +local a = foo(x, 1) +local b = foo(1, x) +local c = foo(1, 2) +local d = foo(x, y) +return a, b, c, d +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 2 +ADDK R3 R1 K1 +LOADN R5 1 +ADD R4 R5 R1 +LOADN R5 3 +ADD R6 R1 R2 +RETURN R3 4 +)"); +} + +TEST_CASE("InlineChain") +{ + // inline a chain of functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local function bar(x) + return foo(x, 1) * foo(x, -1) +end + +local function baz() + return (bar(42)) +end + +return (baz()) +)", + 3, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +DUPCLOSURE R2 K2 +LOADN R4 43 +LOADN R5 41 +MUL R3 R4 R5 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineThresholds") +{ + ScopedFastInt sfis[] = { + {"LuauCompileInlineThreshold", 25}, + {"LuauCompileInlineThresholdMaxBoost", 300}, + {"LuauCompileInlineDepth", 2}, + }; + + // this function has enormous register pressure (50 regs) so we choose not to inline it + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return {{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}} +end + +return (foo()) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // this function has less register pressure but a large cost + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return {},{},{},{},{} +end + +return (foo()) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // this chain of function is of length 3 but our limit in this test is 2, so we call foo twice + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local function bar(x) + return foo(x, 1) * foo(x, -1) +end + +local function baz() + return (bar(42)) +end + +return (baz()) +)", + 3, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +DUPCLOSURE R2 K2 +MOVE R4 R0 +LOADN R5 42 +LOADN R6 1 +CALL R4 2 1 +MOVE R5 R0 +LOADN R6 42 +LOADN R7 -1 +CALL R5 2 1 +MUL R3 R4 R5 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineIIFE") +{ + // IIFE with arguments + CHECK_EQ("\n" + compileFunction(R"( +function choose(a, b, c) + return ((function(a, b, c) if a then return b else return c end end)(a, b, c)) +end +)", + 1, 2), + R"( +JUMPIFNOT R0 L0 +MOVE R3 R1 +RETURN R3 1 +L0: MOVE R3 R2 +RETURN R3 1 +RETURN R3 1 +)"); + + // IIFE with upvalues + CHECK_EQ("\n" + compileFunction(R"( +function choose(a, b, c) + return ((function() if a then return b else return c end end)()) +end +)", + 1, 2), + R"( +JUMPIFNOT R0 L0 +MOVE R3 R1 +RETURN R3 1 +L0: MOVE R3 R2 +RETURN R3 1 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineRecurseArguments") +{ + // we can't inline a function if it's used to compute its own arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) +end +foo(foo(foo,foo(foo,foo))[foo]) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R2 R0 +MOVE R3 R0 +MOVE R4 R0 +MOVE R5 R0 +MOVE R6 R0 +CALL R4 2 -1 +CALL R2 -1 1 +GETTABLE R1 R2 R0 +RETURN R0 0 +)"); +} + +TEST_CASE("InlineFastCallK") +{ + CHECK_EQ("\n" + compileFunction(R"( +local function set(l0) + rawset({}, l0) +end + +set(false) +set({}) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +NEWTABLE R2 0 0 +FASTCALL2K 49 R2 K1 L0 +LOADK R3 K1 +GETIMPORT R1 3 +CALL R1 2 0 +L0: NEWTABLE R1 0 0 +NEWTABLE R3 0 0 +FASTCALL2 49 R3 R1 L1 +MOVE R4 R1 +GETIMPORT R2 3 +CALL R2 2 0 +L1: RETURN R0 0 +)"); +} + +TEST_CASE("InlineExprIndexK") +{ + CHECK_EQ("\n" + compileFunction(R"( +local _ = function(l0) +local _ = nil +while _(_)[_] do +end +end +local _ = _(0)[""] +if _ then +do +for l0=0,8 do +end +end +elseif _ then +_ = nil +do +for l0=0,8 do +return true +end +end +end +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +L0: LOADNIL R4 +LOADNIL R5 +CALL R4 1 1 +LOADNIL R5 +GETTABLE R3 R4 R5 +JUMPIFNOT R3 L1 +JUMPBACK L0 +L1: LOADNIL R2 +GETTABLEKS R1 R2 K1 +JUMPIFNOT R1 L2 +RETURN R0 0 +L2: JUMPIFNOT R1 L3 +LOADNIL R1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +L3: RETURN R0 0 +)"); +} + +TEST_CASE("InlineHiddenMutation") +{ + // when the argument is assigned inside the function, we can't reuse the local + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = 42 + return a +end + +local x = ... +local y = foo(x :: number) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R3 R1 +LOADN R3 42 +MOVE R2 R3 +RETURN R2 1 +)"); + + // and neither can we do that when it's assigned outside the function + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + mutator() + return a +end + +local x = ... +mutator = function() x = 42 end + +local y = foo(x :: number) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +NEWCLOSURE R2 P1 +CAPTURE REF R1 +SETGLOBAL R2 K1 +MOVE R3 R1 +GETGLOBAL R4 K1 +CALL R4 0 0 +MOVE R2 R3 +CLOSEUPVALS R1 +RETURN R2 1 +)"); +} + +TEST_CASE("InlineMultret") +{ + // inlining a function in multret context is prohibited since we can't adjust L->top outside of CALL/GETVARARGS + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a() +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // however, if we can deduce statically that a function always returns a single value, the inlining will work + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // this analysis will also propagate through other functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local function bar(a) + return foo(a) +end + +return bar(42) +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +LOADN R2 42 +RETURN R2 1 +)"); + + // we currently don't do this analysis fully for recursive functions since they can't be inlined anyway + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return foo(a) +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +CAPTURE VAL R0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // and unfortunately we can't do this analysis for builtins or method calls due to getfenv + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return math.abs(a) +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); +} + +TEST_CASE("ReturnConsecutive") +{ + // we can return a single local directly + CHECK_EQ("\n" + compileFunction0(R"( +local x = ... +return x +)"), + R"( +GETVARARGS R0 1 +RETURN R0 1 +)"); + + // or multiple, when they are allocated in consecutive registers + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return x, y +)"), + R"( +GETVARARGS R0 2 +RETURN R0 2 +)"); + + // but not if it's an expression + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return x, y + 1 +)"), + R"( +GETVARARGS R0 2 +MOVE R2 R0 +ADDK R3 R1 K0 +RETURN R2 2 +)"); + + // or a local with wrong register number + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return y, x +)"), + R"( +GETVARARGS R0 2 +MOVE R2 R1 +MOVE R3 R0 +RETURN R2 2 +)"); + + // also double check the optimization doesn't trip on no-argument return (these are rare) + CHECK_EQ("\n" + compileFunction0(R"( +return +)"), + R"( +RETURN R0 0 +)"); + + // this optimization also works in presence of group / type casts + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return (x), y :: number +)"), + R"( +GETVARARGS R0 2 +RETURN R0 2 +)"); +} + +TEST_CASE("OptimizationLevel") +{ + // at optimization level 1, no inlining is performed + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +return foo(42) +)", + 1, 1), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // you can override the level from 1 to 2 to force it + CHECK_EQ("\n" + compileFunction(R"( +--!optimize 2 +local function foo(a) + return a +end + +return foo(42) +)", + 1, 1), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // you can also override it externally + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // ... after which you can downgrade it back via hot comment + CHECK_EQ("\n" + compileFunction(R"( +--!optimize 1 +local function foo(a) + return a +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); +} + +TEST_CASE("BuiltinFolding") +{ + CHECK_EQ("\n" + compileFunction(R"( +return + math.abs(-42), + math.acos(1), + math.asin(0), + math.atan2(0, 1), + math.atan(0), + math.ceil(1.5), + math.cosh(0), + math.cos(0), + math.deg(3.14159265358979323846), + math.exp(0), + math.floor(-1.5), + math.fmod(7, 3), + math.ldexp(0.5, 3), + math.log10(100), + math.log(1), + math.log(4, 2), + math.log(27, 3), + math.max(1, 2, 3), + math.min(1, 2, 3), + math.pow(3, 3), + math.floor(math.rad(180)), + math.sinh(0), + math.sin(0), + math.sqrt(9), + math.tanh(0), + math.tan(0), + bit32.arshift(-10, 1), + bit32.arshift(10, 1), + bit32.band(1, 3), + bit32.bnot(-2), + bit32.bor(1, 2), + bit32.bxor(3, 7), + bit32.btest(1, 3), + bit32.extract(100, 1, 3), + bit32.lrotate(100, -1), + bit32.lshift(100, 1), + bit32.replace(100, 5, 1, 3), + bit32.rrotate(100, -1), + bit32.rshift(100, 1), + type(100), + string.byte("a"), + string.byte("abc", 2), + string.len("abc"), + typeof(true), + math.clamp(-1, 0, 1), + math.sign(77), + math.round(7.6), + bit32.extract(-1, 31), + bit32.replace(100, 1, 0), + math.log(100, 10), + typeof(nil), + (type("fin")) +)", + 0, 2), + R"( +LOADN R0 42 +LOADN R1 0 +LOADN R2 0 +LOADN R3 0 +LOADN R4 0 +LOADN R5 2 +LOADN R6 1 +LOADN R7 1 +LOADN R8 180 +LOADN R9 1 +LOADN R10 -2 +LOADN R11 1 +LOADN R12 4 +LOADN R13 2 +LOADN R14 0 +LOADN R15 2 +LOADN R16 3 +LOADN R17 3 +LOADN R18 1 +LOADN R19 27 +LOADN R20 3 +LOADN R21 0 +LOADN R22 0 +LOADN R23 3 +LOADN R24 0 +LOADN R25 0 +LOADK R26 K0 +LOADN R27 5 +LOADN R28 1 +LOADN R29 1 +LOADN R30 3 +LOADN R31 4 +LOADB R32 1 +LOADN R33 2 +LOADN R34 50 +LOADN R35 200 +LOADN R36 106 +LOADN R37 200 +LOADN R38 50 +LOADK R39 K1 +LOADN R40 97 +LOADN R41 98 +LOADN R42 3 +LOADK R43 K2 +LOADN R44 0 +LOADN R45 1 +LOADN R46 8 +LOADN R47 1 +LOADN R48 101 +LOADN R49 2 +LOADK R50 K3 +LOADK R51 K4 +RETURN R0 52 +)"); +} + +TEST_CASE("BuiltinFoldingProhibited") +{ + CHECK_EQ("\n" + compileFunction(R"( +return + math.abs(), + math.max(1, true), + string.byte("abc", 42), + bit32.rshift(10, 42), + bit32.extract(1, 2, "3"), + bit32.bor(1, true), + bit32.band(1, true), + bit32.bxor(1, true), + bit32.btest(1, true), + math.min(1, true) +)", + 0, 2), + R"( +FASTCALL 2 L0 +GETIMPORT R0 2 +CALL R0 0 1 +L0: LOADN R2 1 +FASTCALL2K 18 R2 K3 L1 +LOADK R3 K3 +GETIMPORT R1 5 +CALL R1 2 1 +L1: LOADK R3 K6 +FASTCALL2K 41 R3 K7 L2 +LOADK R4 K7 +GETIMPORT R2 10 +CALL R2 2 1 +L2: LOADN R4 10 +FASTCALL2K 39 R4 K7 L3 +LOADK R5 K7 +GETIMPORT R3 13 +CALL R3 2 1 +L3: LOADN R5 1 +LOADN R6 2 +LOADK R7 K14 +FASTCALL 34 L4 +GETIMPORT R4 16 +CALL R4 3 1 +L4: LOADN R6 1 +FASTCALL2K 31 R6 K3 L5 +LOADK R7 K3 +GETIMPORT R5 18 +CALL R5 2 1 +L5: LOADN R7 1 +FASTCALL2K 29 R7 K3 L6 +LOADK R8 K3 +GETIMPORT R6 20 +CALL R6 2 1 +L6: LOADN R8 1 +FASTCALL2K 32 R8 K3 L7 +LOADK R9 K3 +GETIMPORT R7 22 +CALL R7 2 1 +L7: LOADN R9 1 +FASTCALL2K 33 R9 K3 L8 +LOADK R10 K3 +GETIMPORT R8 24 +CALL R8 2 1 +L8: LOADN R10 1 +FASTCALL2K 19 R10 K3 L9 +LOADK R11 K3 +GETIMPORT R9 26 +CALL R9 2 -1 +L9: RETURN R0 -1 +)"); +} + +TEST_CASE("BuiltinFoldingProhibitedCoverage") +{ + const char* builtins[] = { + "math.abs", + "math.acos", + "math.asin", + "math.atan2", + "math.atan", + "math.ceil", + "math.cosh", + "math.cos", + "math.deg", + "math.exp", + "math.floor", + "math.fmod", + "math.ldexp", + "math.log10", + "math.log", + "math.max", + "math.min", + "math.pow", + "math.rad", + "math.sinh", + "math.sin", + "math.sqrt", + "math.tanh", + "math.tan", + "bit32.arshift", + "bit32.band", + "bit32.bnot", + "bit32.bor", + "bit32.bxor", + "bit32.btest", + "bit32.extract", + "bit32.lrotate", + "bit32.lshift", + "bit32.replace", + "bit32.rrotate", + "bit32.rshift", + "type", + "string.byte", + "string.len", + "typeof", + "math.clamp", + "math.sign", + "math.round", + }; + + for (const char* func : builtins) + { + std::string source = "return "; + source += func; + source += "()"; + + std::string bc = compileFunction(source.c_str(), 0, 2); + + CHECK(bc.find("FASTCALL") != std::string::npos); + } +} + +TEST_CASE("BuiltinFoldingMultret") +{ + CHECK_EQ("\n" + compileFunction(R"( +local NoLanes: Lanes = --[[ ]] 0b0000000000000000000000000000000 +local OffscreenLane: Lane = --[[ ]] 0b1000000000000000000000000000000 + +local function getLanesToRetrySynchronouslyOnError(root: FiberRoot): Lanes + local everythingButOffscreen = bit32.band(root.pendingLanes, bit32.bnot(OffscreenLane)) + if everythingButOffscreen ~= NoLanes then + return everythingButOffscreen + end + if bit32.band(everythingButOffscreen, OffscreenLane) ~= 0 then + return OffscreenLane + end + return NoLanes +end +)", + 0, 2), + R"( +GETTABLEKS R2 R0 K0 +FASTCALL2K 29 R2 K1 L0 +LOADK R3 K1 +GETIMPORT R1 4 +CALL R1 2 1 +L0: JUMPXEQKN R1 K5 L1 +RETURN R1 1 +L1: FASTCALL2K 29 R1 K6 L2 +MOVE R3 R1 +LOADK R4 K6 +GETIMPORT R2 4 +CALL R2 2 1 +L2: JUMPXEQKN R2 K5 L3 +LOADK R2 K6 +RETURN R2 1 +L3: LOADN R2 0 +RETURN R2 1 +)"); + + // Note: similarly, here we should have folded the return value but haven't because it's the last call in the sequence + CHECK_EQ("\n" + compileFunction(R"( +return math.abs(-42) +)", + 0, 2), + R"( +LOADN R0 42 +RETURN R0 1 +)"); +} + +TEST_CASE("LocalReassign") +{ + // locals can be re-assigned and the register gets reused + CHECK_EQ("\n" + compileFunction0(R"( +local function test(a, b) + local c = a + return c + b +end +)"), + R"( +ADD R2 R0 R1 +RETURN R2 1 +)"); + + // this works if the expression is using type casts or grouping + CHECK_EQ("\n" + compileFunction0(R"( +local function test(a, b) + local c = (a :: number) + return c + b +end +)"), + R"( +ADD R2 R0 R1 +RETURN R2 1 +)"); + + // the optimization requires that neither local is mutated + CHECK_EQ("\n" + compileFunction0(R"( +local function test(a, b) + local c = a + c += 0 + local d = b + b += 0 + return c + d +end +)"), + R"( +MOVE R2 R0 +ADDK R2 R2 K0 +MOVE R3 R1 +ADDK R1 R1 K0 +ADD R4 R2 R3 +RETURN R4 1 +)"); + + // sanity check for two values + CHECK_EQ("\n" + compileFunction0(R"( +local function test(a, b) + local c = a + local d = b + return c + d +end +)"), + R"( +ADD R2 R0 R1 +RETURN R2 1 +)"); + + // note: we currently only support this for single assignments + CHECK_EQ("\n" + compileFunction0(R"( +local function test(a, b) + local c, d = a, b + return c + d +end +)"), + R"( +MOVE R2 R0 +MOVE R3 R1 +ADD R4 R2 R3 +RETURN R4 1 +)"); + + // of course, captures capture the original register as well (by value since it's immutable) + CHECK_EQ("\n" + compileFunction(R"( +local function test(a, b) + local c = a + local d = b + return function() return c + d end +end +)", + 1), + R"( +NEWCLOSURE R2 P0 +CAPTURE VAL R0 +CAPTURE VAL R1 +RETURN R2 1 +)"); +} + +TEST_CASE("MultipleAssignments") +{ + // order of assignments is left to right + CHECK_EQ("\n" + compileFunction0(R"( + local a, b + a, b = f(1), f(2) + )"), + R"( +LOADNIL R0 +LOADNIL R1 +GETIMPORT R2 1 +LOADN R3 1 +CALL R2 1 1 +MOVE R0 R2 +GETIMPORT R2 1 +LOADN R3 2 +CALL R2 1 1 +MOVE R1 R2 +RETURN R0 0 +)"); + + // this includes table assignments + CHECK_EQ("\n" + compileFunction0(R"( + local t + t[1], t[2] = 3, 4 + )"), + R"( +LOADNIL R0 +LOADNIL R1 +LOADN R2 3 +LOADN R3 4 +SETTABLEN R2 R0 1 +SETTABLEN R3 R1 2 +RETURN R0 0 +)"); + + // semantically, we evaluate the right hand side first; this allows us to e.g swap elements in a table easily + CHECK_EQ("\n" + compileFunction0(R"( + local t = ... + t[1], t[2] = t[2], t[1] + )"), + R"( +GETVARARGS R0 1 +GETTABLEN R1 R0 2 +GETTABLEN R2 R0 1 +SETTABLEN R1 R0 1 +SETTABLEN R2 R0 2 +RETURN R0 0 +)"); + + // however, we need to optimize local assignments; to do this well, we need to handle assignment conflicts + // let's first go through a few cases where there are no conflicts: + + // when multiple assignments have no conflicts (all local vars are read after being assigned), codegen is the same as a series of single + // assignments + CHECK_EQ("\n" + compileFunction0(R"( + local xm1, x, xp1, xi = ... + + xm1,x,xp1,xi = x,xp1,xp1+1,xi-1 + )"), + R"( +GETVARARGS R0 4 +MOVE R0 R1 +MOVE R1 R2 +ADDK R2 R2 K0 +SUBK R3 R3 K0 +RETURN R0 0 +)"); + + // similar example to above from a more complex case + CHECK_EQ("\n" + compileFunction0(R"( + local a, b, c, d, e, f, g, h, t1, t2 = ... + + h, g, f, e, d, c, b, a = g, f, e, d + t1, c, b, a, t1 + t2 + )"), + R"( +GETVARARGS R0 10 +MOVE R7 R6 +MOVE R6 R5 +MOVE R5 R4 +ADD R4 R3 R8 +MOVE R3 R2 +MOVE R2 R1 +MOVE R1 R0 +ADD R0 R8 R9 +RETURN R0 0 +)"); + + // when locals have a conflict, we assign temporaries instead of locals, and at the end copy the values back + // the basic example of this is a swap/rotate + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = ... + a, b = b, a + )"), + R"( +GETVARARGS R0 2 +MOVE R2 R1 +MOVE R1 R0 +MOVE R0 R2 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0(R"( + local a, b, c = ... + a, b, c = c, a, b + )"), + R"( +GETVARARGS R0 3 +MOVE R3 R2 +MOVE R4 R0 +MOVE R2 R1 +MOVE R0 R3 +MOVE R1 R4 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0(R"( + local a, b, c = ... + a, b, c = b, c, a + )"), + R"( +GETVARARGS R0 3 +MOVE R3 R1 +MOVE R1 R2 +MOVE R2 R0 +MOVE R0 R3 +RETURN R0 0 +)"); + + // multiple assignments with multcall handling - foo() evalutes to temporary registers and they are copied out to target + CHECK_EQ("\n" + compileFunction0(R"( + local a, b, c, d = ... + a, b, c, d = 1, foo() + )"), + R"( +GETVARARGS R0 4 +LOADN R0 1 +GETIMPORT R4 1 +CALL R4 0 3 +MOVE R1 R4 +MOVE R2 R5 +MOVE R3 R6 +RETURN R0 0 +)"); + + // note that during this we still need to handle local reassignment, eg when table assignments are performed + CHECK_EQ("\n" + compileFunction0(R"( + local a, b, c, d = ... + a, b[a], c[d], d = 1, foo() + )"), + R"( +GETVARARGS R0 4 +LOADN R4 1 +GETIMPORT R6 1 +CALL R6 0 3 +SETTABLE R6 R1 R0 +SETTABLE R7 R2 R3 +MOVE R0 R4 +MOVE R3 R8 +RETURN R0 0 +)"); + + // multiple assignments with multcall handling - foo evaluates to a single argument so all remaining locals are assigned to nil + // note that here we don't assign the locals directly, as this case is very rare so we use the similar code path as above + CHECK_EQ("\n" + compileFunction0(R"( + local a, b, c, d = ... + a, b, c, d = 1, foo + )"), + R"( +GETVARARGS R0 4 +LOADN R0 1 +GETIMPORT R4 1 +LOADNIL R5 +LOADNIL R6 +MOVE R1 R4 +MOVE R2 R5 +MOVE R3 R6 +RETURN R0 0 +)"); + + // note that we also try to use locals as a source of assignment directly when assigning fields; this works using old local value when possible + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = ... + a[1], a[2] = b, b + 1 + )"), + R"( +GETVARARGS R0 2 +ADDK R2 R1 K0 +SETTABLEN R1 R0 1 +SETTABLEN R2 R0 2 +RETURN R0 0 +)"); + + // ... of course if the local is reassigned, we defer the assignment until later + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = ... + b, a[1] = 42, b + )"), + R"( +GETVARARGS R0 2 +LOADN R2 42 +SETTABLEN R1 R0 1 +MOVE R1 R2 +RETURN R0 0 +)"); + + // when there are more expressions when values, we evalute them for side effects, but they also participate in conflict handling + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = ... + a, b = 1, 2, a + b + )"), + R"( +GETVARARGS R0 2 +LOADN R2 1 +LOADN R3 2 +ADD R4 R0 R1 +MOVE R0 R2 +MOVE R1 R3 +RETURN R0 0 +)"); + + ScopedFastFlag luauMultiAssignmentConflictFix{"LuauMultiAssignmentConflictFix", true}; + + // because we perform assignments to complex l-values after assignments to locals, we make sure register conflicts are tracked accordingly + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = ... + a[1], b = b, b + 1 + )"), + R"( +GETVARARGS R0 2 +ADDK R2 R1 K0 +SETTABLEN R1 R0 1 +MOVE R1 R2 +RETURN R0 0 +)"); +} + +TEST_CASE("BuiltinExtractK") +{ + // below, K0 refers to a packed f+w constant for bit32.extractk builtin + // K1 and K2 refer to 1 and 3 and are only used during fallback path + CHECK_EQ("\n" + compileFunction0(R"( +local v = ... + +return bit32.extract(v, 1, 3) +)"), + R"( +GETVARARGS R0 1 +FASTCALL2K 59 R0 K0 L0 +MOVE R2 R0 +LOADK R3 K1 +LOADK R4 K2 +GETIMPORT R1 5 +CALL R1 3 -1 +L0: RETURN R1 -1 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 5a697a49..6be838dd 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -1,21 +1,35 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Compiler.h" +#include "lua.h" +#include "lualib.h" +#include "luacode.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/ModuleResolver.h" #include "Luau/TypeInfer.h" #include "Luau/StringUtils.h" #include "Luau/BytecodeBuilder.h" +#include "Luau/CodeGen.h" #include "doctest.h" #include "ScopedFlags.h" -#include "lua.h" -#include "lualib.h" - #include +#include #include +extern bool verbose; +extern bool codegen; +extern int optimizationLevel; + +static lua_CompileOptions defaultOptions() +{ + lua_CompileOptions copts = {}; + copts.optimizationLevel = optimizationLevel; + copts.debugLevel = 1; + + return copts; +} + static int lua_collectgarbage(lua_State* L) { static const char* const opts[] = {"stop", "restart", "collect", "count", "isrunning", "step", "setgoal", "setstepmul", "setstepsize", nullptr}; @@ -49,13 +63,17 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); - std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + size_t bytecodeSize = 0; + char* bytecode = luau_compile(s, l, nullptr, &bytecodeSize); + int result = luau_load(L, chunkname, bytecode, bytecodeSize, 0); + free(bytecode); + + if (result == 0) return 1; lua_pushnil(L); - lua_insert(L, -2); /* put before error message */ - return 2; /* return nil plus error message */ + lua_insert(L, -2); // put before error message + return 2; // return nil plus error message } static int lua_vector(lua_State* L) @@ -64,44 +82,42 @@ static int lua_vector(lua_State* L) double y = luaL_checknumber(L, 2); double z = luaL_checknumber(L, 3); +#if LUA_VECTOR_SIZE == 4 + double w = luaL_optnumber(L, 4, 0.0); + lua_pushvector(L, float(x), float(y), float(z), float(w)); +#else lua_pushvector(L, float(x), float(y), float(z)); +#endif return 1; } static int lua_vector_dot(lua_State* L) { - const float* a = lua_tovector(L, 1); - const float* b = lua_tovector(L, 2); + const float* a = luaL_checkvector(L, 1); + const float* b = luaL_checkvector(L, 2); - if (a && b) - { - lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); - return 1; - } - - throw std::runtime_error("invalid arguments to vector:Dot"); + lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); + return 1; } static int lua_vector_index(lua_State* L) { + const float* v = luaL_checkvector(L, 1); const char* name = luaL_checkstring(L, 2); - if (const float* v = lua_tovector(L, 1)) + if (strcmp(name, "Magnitude") == 0) { - if (strcmp(name, "Magnitude") == 0) - { - lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); - return 1; - } - - if (strcmp(name, "Dot") == 0) - { - lua_pushcfunction(L, lua_vector_dot, "Dot"); - return 1; - } + lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); + return 1; } - throw std::runtime_error(Luau::format("%s is not a valid member of vector", name)); + if (strcmp(name, "Dot") == 0) + { + lua_pushcfunction(L, lua_vector_dot, "Dot"); + return 1; + } + + luaL_error(L, "%s is not a valid member of vector", name); } static int lua_vector_namecall(lua_State* L) @@ -112,7 +128,7 @@ static int lua_vector_namecall(lua_State* L) return lua_vector_dot(L); } - throw std::runtime_error(Luau::format("%s is not a valid method of vector", luaL_checkstring(L, 1))); + luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); } int lua_silence(lua_State* L) @@ -122,8 +138,8 @@ int lua_silence(lua_State* L) using StateRef = std::unique_ptr; -static StateRef runConformance( - const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, lua_State* initialLuaState = nullptr) +static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, + lua_State* initialLuaState = nullptr, lua_CompileOptions* options = nullptr, bool skipCodegen = false) { std::string path = __FILE__; path.erase(path.find_last_of("\\/")); @@ -142,18 +158,27 @@ static StateRef runConformance( StateRef globalState(initialLuaState, lua_close); lua_State* L = globalState.get(); + if (codegen && !skipCodegen && Luau::CodeGen::isSupported()) + Luau::CodeGen::create(L); + luaL_openlibs(L); // Register a few global functions for conformance tests - static const luaL_Reg funcs[] = { + std::vector funcs = { {"collectgarbage", lua_collectgarbage}, {"loadstring", lua_loadstring}, - {"print", lua_silence}, // Disable print() by default; comment this out to enable debug prints in tests - {nullptr, nullptr}, }; + if (!verbose) + { + funcs.push_back({"print", lua_silence}); + } + + // "null" terminate the list of functions to register + funcs.push_back({nullptr, nullptr}); + lua_pushvalue(L, LUA_GLOBALSINDEX); - luaL_register(L, nullptr, funcs); + luaL_register(L, nullptr, funcs.data()); lua_pop(L, 1); // In some configurations we have a larger C stack consumption which trips some conformance tests @@ -179,21 +204,18 @@ static StateRef runConformance( std::string chunkname = "=" + std::string(name); - Luau::CompileOptions copts; - copts.debugLevel = 2; // for debugger tests - copts.vectorCtor = "vector"; // for vector tests + // note: luau_compile supports nullptr options, but we need to customize our defaults to improve test coverage + lua_CompileOptions opts = options ? *options : defaultOptions(); - std::string bytecode = Luau::compile(source, copts); - int status = 0; + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), &opts, &bytecodeSize); + int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); + free(bytecode); - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) - { - status = lua_resume(L, nullptr, 0); - } - else - { - status = LUA_ERRSYNTAX; - } + if (result == 0 && codegen && !skipCodegen && Luau::CodeGen::isSupported()) + Luau::CodeGen::compile(L, -1); + + int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX; while (yield && (status == LUA_YIELD || status == LUA_BREAK)) { @@ -238,11 +260,19 @@ TEST_CASE("Math") runConformance("math.lua"); } -TEST_CASE("Table") +TEST_CASE("Tables") { - ScopedFastFlag sff("LuauTableFreeze", true); - - runConformance("nextvar.lua"); + runConformance("tables.lua", [](lua_State* L) { + lua_pushcfunction( + L, + [](lua_State* L) { + unsigned v = luaL_checkunsigned(L, 1); + lua_pushlightuserdata(L, reinterpret_cast(uintptr_t(v))); + return 1; + }, + "makelud"); + lua_setglobal(L, "makelud"); + }); } TEST_CASE("PatternMatch") @@ -270,6 +300,13 @@ TEST_CASE("Strings") runConformance("strings.lua"); } +TEST_CASE("StringInterp") +{ + ScopedFastFlag sffInterpStrings{"LuauInterpolatedStringBaseSupport", true}; + + runConformance("stringinterp.lua"); +} + TEST_CASE("VarArg") { runConformance("vararg.lua"); @@ -335,24 +372,30 @@ TEST_CASE("Coroutine") runConformance("coroutine.lua"); } +static int cxxthrow(lua_State* L) +{ +#if LUA_USE_LONGJMP + luaL_error(L, "oops"); +#else + throw std::runtime_error("oops"); +#endif +} + TEST_CASE("PCall") { runConformance("pcall.lua", [](lua_State* L) { - lua_pushcfunction(L, [](lua_State* L) -> int { -#if LUA_USE_LONGJMP - luaL_error(L, "oops"); -#else - throw std::runtime_error("oops"); -#endif - }); + lua_pushcfunction(L, cxxthrow, "cxxthrow"); lua_setglobal(L, "cxxthrow"); - lua_pushcfunction(L, [](lua_State* L) -> int { - lua_State* co = lua_tothread(L, 1); - lua_xmove(L, co, 1); - lua_resumeerror(co, L); - return 0; - }); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + lua_State* co = lua_tothread(L, 1); + lua_xmove(L, co, 1); + lua_resumeerror(co, L); + return 0; + }, + "resumeerror"); lua_setglobal(L, "resumeerror"); }); } @@ -364,25 +407,35 @@ TEST_CASE("Pack") TEST_CASE("Vector") { - runConformance("vector.lua", [](lua_State* L) { - lua_pushcfunction(L, lua_vector); - lua_setglobal(L, "vector"); + lua_CompileOptions copts = defaultOptions(); + copts.vectorCtor = "vector"; - lua_pushvector(L, 0.0f, 0.0f, 0.0f); - luaL_newmetatable(L, "vector"); + runConformance( + "vector.lua", + [](lua_State* L) { + lua_pushcfunction(L, lua_vector, "vector"); + lua_setglobal(L, "vector"); - lua_pushstring(L, "__index"); - lua_pushcfunction(L, lua_vector_index); - lua_settable(L, -3); +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); +#else + lua_pushvector(L, 0.0f, 0.0f, 0.0f); +#endif + luaL_newmetatable(L, "vector"); - lua_pushstring(L, "__namecall"); - lua_pushcfunction(L, lua_vector_namecall); - lua_settable(L, -3); + lua_pushstring(L, "__index"); + lua_pushcfunction(L, lua_vector_index, nullptr); + lua_settable(L, -3); - lua_setreadonly(L, -1, true); - lua_setmetatable(L, -2); - lua_pop(L, 1); - }); + lua_pushstring(L, "__namecall"); + lua_pushcfunction(L, lua_vector_namecall, nullptr); + lua_settable(L, -3); + + lua_setreadonly(L, -1, true); + lua_setmetatable(L, -2); + lua_pop(L, 1); + }, + nullptr, nullptr, &copts); } static void populateRTTI(lua_State* L, Luau::TypeId type) @@ -451,9 +504,10 @@ TEST_CASE("Types") runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; - Luau::TypeChecker env(&moduleResolver, &iceHandler); + Luau::SingletonTypes singletonTypes; + Luau::TypeChecker env(&moduleResolver, Luau::NotNull{&singletonTypes}, &iceHandler); - Luau::registerBuiltinTypes(env); + Luau::registerBuiltinGlobals(env); Luau::freeze(env.globalTypes); lua_newtable(L); @@ -482,19 +536,44 @@ TEST_CASE("Debugger") { static int breakhits = 0; static lua_State* interruptedthread = nullptr; + static bool singlestep = false; + static int stephits = 0; + + SUBCASE("") + { + singlestep = false; + } + SUBCASE("SingleStep") + { + singlestep = true; + } breakhits = 0; interruptedthread = nullptr; + stephits = 0; + + lua_CompileOptions copts = defaultOptions(); + copts.debugLevel = 2; runConformance( "debugger.lua", [](lua_State* L) { lua_Callbacks* cb = lua_callbacks(L); + lua_singlestep(L, singlestep); + + // this will only be called in single-step mode + cb->debugstep = [](lua_State* L, lua_Debug* ar) { + stephits++; + }; + // for breakpoints to work we should make sure debugbreak is installed cb->debugbreak = [](lua_State* L, lua_Debug* ar) { breakhits++; + // make sure we can trace the stack for every breakpoint we hit + lua_debugtrace(L); + // for every breakpoint, we break on the first invocation and continue on second // this allows us to easily step off breakpoints // (real implementaiton may require singlestepping) @@ -511,15 +590,19 @@ TEST_CASE("Debugger") }; // add breakpoint() function - lua_pushcfunction(L, [](lua_State* L) -> int { - int line = luaL_checkinteger(L, 1); + lua_pushcclosurek( + L, + [](lua_State* L) -> int { + int line = luaL_checkinteger(L, 1); + bool enabled = luaL_optboolean(L, 2, true); - lua_Debug ar = {}; - lua_getinfo(L, 1, "f", &ar); + lua_Debug ar = {}; + lua_getinfo(L, lua_stackdepth(L) - 1, "f", &ar); - lua_breakpoint(L, -1, line, true); - return 0; - }); + lua_breakpoint(L, -1, line, enabled); + return 0; + }, + "breakpoint", 0, nullptr); lua_setglobal(L, "breakpoint"); }, [](lua_State* L) { @@ -535,6 +618,11 @@ TEST_CASE("Debugger") CHECK(lua_tointeger(L, -1) == 50); lua_pop(L, 1); + int v = lua_getargument(L, 0, 2); + REQUIRE(v); + CHECK(lua_tointeger(L, -1) == 42); + lua_pop(L, 1); + // test lua_getlocal const char* l = lua_getlocal(L, 0, 1); REQUIRE(l); @@ -594,9 +682,13 @@ TEST_CASE("Debugger") lua_resume(interruptedthread, nullptr, 0); interruptedthread = nullptr; } - }); + }, + nullptr, &copts, /* skipCodegen */ true); // Native code doesn't support debugging yet - CHECK(breakhits == 10); // 2 hits per breakpoint + CHECK(breakhits == 12); // 2 hits per breakpoint + + if (singlestep) + CHECK(stephits > 100); // note; this will depend on number of instructions which can vary, so we just make sure the callback gets hit often } TEST_CASE("SameHash") @@ -617,31 +709,6 @@ TEST_CASE("SameHash") CHECK(luaS_hash(buf + 1, 120) == luaS_hash(buf + 2, 120)); } -TEST_CASE("InlineDtor") -{ - static int dtorhits = 0; - - dtorhits = 0; - - { - StateRef globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - void* u1 = lua_newuserdatadtor(L, 4, [](void* data) { - dtorhits += *(int*)data; - }); - - void* u2 = lua_newuserdatadtor(L, 1, [](void* data) { - dtorhits += *(char*)data; - }); - - *(int*)u1 = 39; - *(char*)u2 = 3; - } - - CHECK(dtorhits == 42); -} - TEST_CASE("Reference") { static int dtorhits = 0; @@ -681,26 +748,248 @@ TEST_CASE("Reference") CHECK(dtorhits == 2); } -TEST_CASE("ApiFunctionCalls") +TEST_CASE("NewUserdataOverflow") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_pushcfunction( + L, + [](lua_State* L1) { + // The following userdata request might cause an overflow. + lua_newuserdatadtor(L1, SIZE_MAX, [](void* d) {}); + // The overflow might segfault in the following call. + lua_getmetatable(L1, -1); + return 0; + }, + nullptr); + + CHECK(lua_pcall(L, 0, 0, 0) == LUA_ERRRUN); + CHECK(strcmp(lua_tostring(L, -1), "memory allocation error: block too big") == 0); +} + +TEST_CASE("ApiTables") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_newtable(L); + lua_pushnumber(L, 123.0); + lua_setfield(L, -2, "key"); + lua_pushnumber(L, 456.0); + lua_rawsetfield(L, -2, "key2"); + lua_pushstring(L, "test"); + lua_rawseti(L, -2, 5); + + // lua_gettable + lua_pushstring(L, "key"); + CHECK(lua_gettable(L, -2) == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // lua_getfield + CHECK(lua_getfield(L, -1, "key") == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // lua_rawgetfield + CHECK(lua_rawgetfield(L, -1, "key2") == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 456.0); + lua_pop(L, 1); + + // lua_rawget + lua_pushstring(L, "key"); + CHECK(lua_rawget(L, -2) == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // lua_rawgeti + CHECK(lua_rawgeti(L, -1, 5) == LUA_TSTRING); + CHECK(strcmp(lua_tostring(L, -1), "test") == 0); + lua_pop(L, 1); + + // lua_cleartable + lua_cleartable(L, -1); + lua_pushnil(L); + CHECK(lua_next(L, -2) == 0); + + lua_pop(L, 1); +} + +TEST_CASE("ApiIter") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_newtable(L); + lua_pushnumber(L, 123.0); + lua_setfield(L, -2, "key"); + lua_pushnumber(L, 456.0); + lua_rawsetfield(L, -2, "key2"); + lua_pushstring(L, "test"); + lua_rawseti(L, -2, 1); + + // Lua-compatible iteration interface: lua_next + double sum1 = 0; + lua_pushnil(L); + while (lua_next(L, -2)) + { + sum1 += lua_tonumber(L, -2); // key + sum1 += lua_tonumber(L, -1); // value + lua_pop(L, 1); // pop value, key is used by lua_next + } + CHECK(sum1 == 580); + + // Luau iteration interface: lua_rawiter (faster and preferable to lua_next) + double sum2 = 0; + for (int index = 0; index = lua_rawiter(L, -1, index), index >= 0;) + { + sum2 += lua_tonumber(L, -2); // key + sum2 += lua_tonumber(L, -1); // value + lua_pop(L, 2); // pop both key and value + } + CHECK(sum2 == 580); + + // pop table + lua_pop(L, 1); +} + +TEST_CASE("ApiCalls") { StateRef globalState = runConformance("apicalls.lua"); lua_State* L = globalState.get(); - lua_getfield(L, LUA_GLOBALSINDEX, "add"); - lua_pushnumber(L, 40); - lua_pushnumber(L, 2); - lua_call(L, 2, 1); - CHECK(lua_isnumber(L, -1)); - CHECK(lua_tonumber(L, -1) == 42); - lua_pop(L, 1); + // lua_call + { + lua_getfield(L, LUA_GLOBALSINDEX, "add"); + lua_pushnumber(L, 40); + lua_pushnumber(L, 2); + lua_call(L, 2, 1); + CHECK(lua_isnumber(L, -1)); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + } - lua_getfield(L, LUA_GLOBALSINDEX, "add"); - lua_pushnumber(L, 40); - lua_pushnumber(L, 2); - lua_pcall(L, 2, 1, 0); - CHECK(lua_isnumber(L, -1)); - CHECK(lua_tonumber(L, -1) == 42); - lua_pop(L, 1); + // lua_pcall + { + lua_getfield(L, LUA_GLOBALSINDEX, "add"); + lua_pushnumber(L, 40); + lua_pushnumber(L, 2); + lua_pcall(L, 2, 1, 0); + CHECK(lua_isnumber(L, -1)); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + } + + // lua_equal with a sleeping thread wake up + { + lua_State* L2 = lua_newthread(L); + + lua_getfield(L2, LUA_GLOBALSINDEX, "create_with_tm"); + lua_pushnumber(L2, 42); + lua_pcall(L2, 1, 1, 0); + + lua_getfield(L2, LUA_GLOBALSINDEX, "create_with_tm"); + lua_pushnumber(L2, 42); + lua_pcall(L2, 1, 1, 0); + + // Reset GC + lua_gc(L2, LUA_GCCOLLECT, 0); + + // Try to mark 'L2' as sleeping + // Can't control GC precisely, even in tests + lua_gc(L2, LUA_GCSTEP, 8); + + CHECK(lua_equal(L2, -1, -2) == 1); + lua_pop(L2, 2); + } + + // lua_clonefunction + fenv + { + lua_getfield(L, LUA_GLOBALSINDEX, "getpi"); + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3.1415926); + lua_pop(L, 1); + + lua_getfield(L, LUA_GLOBALSINDEX, "getpi"); + + // clone & override env + lua_clonefunction(L, -1); + lua_newtable(L); + lua_pushnumber(L, 42); + lua_setfield(L, -2, "pi"); + lua_setfenv(L, -2); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + + // this one calls original function again + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3.1415926); + lua_pop(L, 1); + } + + // lua_clonefunction + upvalues + { + lua_getfield(L, LUA_GLOBALSINDEX, "incuv"); + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 1); + lua_pop(L, 1); + + lua_getfield(L, LUA_GLOBALSINDEX, "incuv"); + // two clones + lua_clonefunction(L, -1); + lua_clonefunction(L, -2); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 2); + lua_pop(L, 1); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3); + lua_pop(L, 1); + + // this one calls original function again + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 4); + lua_pop(L, 1); + } +} + +TEST_CASE("ApiAtoms") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_callbacks(L)->useratom = [](const char* s, size_t l) -> int16_t { + if (strcmp(s, "string") == 0) + return 0; + if (strcmp(s, "important") == 0) + return 1; + + return -1; + }; + + lua_pushstring(L, "string"); + lua_pushstring(L, "import"); + lua_pushstring(L, "ant"); + lua_concat(L, 2); + lua_pushstring(L, "unimportant"); + + int a1, a2, a3; + const char* s1 = lua_tostringatom(L, -3, &a1); + const char* s2 = lua_tostringatom(L, -2, &a2); + const char* s3 = lua_tostringatom(L, -1, &a3); + + CHECK(strcmp(s1, "string") == 0); + CHECK(a1 == 0); + + CHECK(strcmp(s2, "important") == 0); + CHECK(a2 == 1); + + CHECK(strcmp(s3, "unimportant") == 0); + CHECK(a3 == -1); } static bool endsWith(const std::string& str, const std::string& suffix) @@ -714,8 +1003,6 @@ static bool endsWith(const std::string& str, const std::string& suffix) #if !LUA_USE_LONGJMP TEST_CASE("ExceptionObject") { - ScopedFastFlag sff("LuauExceptionMessageFix", true); - struct ExceptionResult { bool exceptionGenerated; @@ -738,11 +1025,11 @@ TEST_CASE("ExceptionObject") return ExceptionResult{false, ""}; }; - auto reallocFunc = [](lua_State* L, void* /*ud*/, void* ptr, size_t /*osize*/, size_t nsize) -> void* { + auto reallocFunc = [](void* /*ud*/, void* ptr, size_t /*osize*/, size_t nsize) -> void* { if (nsize == 0) { free(ptr); - return NULL; + return nullptr; } else if (nsize > 512 * 1024) { @@ -796,9 +1083,518 @@ TEST_CASE("ExceptionObject") TEST_CASE("IfElseExpression") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - runConformance("ifelseexpr.lua"); } +TEST_CASE("TagMethodError") +{ + runConformance("tmerror.lua", [](lua_State* L) { + auto* cb = lua_callbacks(L); + + cb->debugprotectederror = [](lua_State* L) { + CHECK(lua_isyieldable(L)); + }; + }); +} + +TEST_CASE("Coverage") +{ + lua_CompileOptions copts = defaultOptions(); + copts.optimizationLevel = 1; // disable inlining to get fixed expected hit results + copts.coverageLevel = 2; + + runConformance( + "coverage.lua", + [](lua_State* L) { + lua_pushcfunction( + L, + [](lua_State* L) -> int { + luaL_argexpected(L, lua_isLfunction(L, 1), 1, "function"); + + lua_newtable(L); + lua_getcoverage(L, 1, L, [](void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) { + lua_State* L = static_cast(context); + + lua_newtable(L); + + lua_pushstring(L, function); + lua_setfield(L, -2, "name"); + + lua_pushinteger(L, linedefined); + lua_setfield(L, -2, "linedefined"); + + lua_pushinteger(L, depth); + lua_setfield(L, -2, "depth"); + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + { + lua_pushinteger(L, hits[i]); + lua_rawseti(L, -2, int(i)); + } + + lua_rawseti(L, -2, lua_objlen(L, -2) + 1); + }); + + return 1; + }, + "getcoverage"); + lua_setglobal(L, "getcoverage"); + }, + nullptr, nullptr, &copts); +} + +TEST_CASE("StringConversion") +{ + runConformance("strconv.lua"); +} + +TEST_CASE("GCDump") +{ + // internal function, declared in lgc.h - not exposed via lua.h + extern void luaC_dump(lua_State * L, void* file, const char* (*categoryName)(lua_State * L, uint8_t memcat)); + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // push various objects on stack to cover different paths + lua_createtable(L, 1, 2); + lua_pushstring(L, "value"); + lua_setfield(L, -2, "key"); + + lua_pushinteger(L, 42); + lua_rawseti(L, -2, 1000); + + lua_pushinteger(L, 42); + lua_rawseti(L, -2, 1); + + lua_pushvalue(L, -1); + lua_setmetatable(L, -2); + + lua_newuserdata(L, 42); + lua_pushvalue(L, -2); + lua_setmetatable(L, -2); + + lua_pushinteger(L, 1); + lua_pushcclosure(L, lua_silence, "test", 1); + + lua_State* CL = lua_newthread(L); + + lua_pushstring(CL, "local x x = {} local function f() x[1] = math.abs(42) end function foo() coroutine.yield() end foo() return f"); + lua_loadstring(CL); + lua_resume(CL, nullptr, 0); + +#ifdef _WIN32 + const char* path = "NUL"; +#else + const char* path = "/dev/null"; +#endif + + FILE* f = fopen(path, "w"); + REQUIRE(f); + + luaC_dump(L, f, nullptr); + + fclose(f); +} + +TEST_CASE("Interrupt") +{ + lua_CompileOptions copts = defaultOptions(); + copts.optimizationLevel = 1; // disable loop unrolling to get fixed expected hit results + + static const int expectedhits[] = { + 2, + 9, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 6, + 18, + 13, + 13, + 13, + 13, + 16, + 20, + }; + static int index; + + index = 0; + + runConformance( + "interrupt.lua", + [](lua_State* L) { + auto* cb = lua_callbacks(L); + + // note: for simplicity here we setup the interrupt callback once + // however, this carries a noticeable performance cost. in a real application, + // it's advised to set interrupt callback on a timer from a different thread, + // and set it back to nullptr once the interrupt triggered. + cb->interrupt = [](lua_State* L, int gc) { + if (gc >= 0) + return; + + CHECK(index < int(std::size(expectedhits))); + + lua_Debug ar = {}; + lua_getinfo(L, 0, "l", &ar); + + CHECK(ar.currentline == expectedhits[index]); + + index++; + + // check that we can yield inside an interrupt + if (index == 5) + lua_yield(L, 0); + }; + }, + [](lua_State* L) { + CHECK(index == 5); // a single yield point + }, + nullptr, &copts); + + CHECK(index == int(std::size(expectedhits))); +} + +TEST_CASE("UserdataApi") +{ + static int dtorhits = 0; + + dtorhits = 0; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // setup dtor for tag 42 (created later) + lua_setuserdatadtor(L, 42, [](lua_State* l, void* data) { + dtorhits += *(int*)data; + }); + + // light user data + int lud; + lua_pushlightuserdata(L, &lud); + + CHECK(lua_tolightuserdata(L, -1) == &lud); + CHECK(lua_touserdata(L, -1) == &lud); + CHECK(lua_topointer(L, -1) == &lud); + + // regular user data + int* ud1 = (int*)lua_newuserdata(L, 4); + *ud1 = 42; + + CHECK(lua_tolightuserdata(L, -1) == nullptr); + CHECK(lua_touserdata(L, -1) == ud1); + CHECK(lua_topointer(L, -1) == ud1); + + // tagged user data + int* ud2 = (int*)lua_newuserdatatagged(L, 4, 42); + *ud2 = -4; + + CHECK(lua_touserdatatagged(L, -1, 42) == ud2); + CHECK(lua_touserdatatagged(L, -1, 41) == nullptr); + CHECK(lua_userdatatag(L, -1) == 42); + + lua_setuserdatatag(L, -1, 43); + CHECK(lua_userdatatag(L, -1) == 43); + lua_setuserdatatag(L, -1, 42); + + // user data with inline dtor + void* ud3 = lua_newuserdatadtor(L, 4, [](void* data) { + dtorhits += *(int*)data; + }); + + void* ud4 = lua_newuserdatadtor(L, 1, [](void* data) { + dtorhits += *(char*)data; + }); + + *(int*)ud3 = 43; + *(char*)ud4 = 3; + + // user data with named metatable + luaL_newmetatable(L, "udata1"); + luaL_newmetatable(L, "udata2"); + + void* ud5 = lua_newuserdata(L, 0); + lua_getfield(L, LUA_REGISTRYINDEX, "udata1"); + lua_setmetatable(L, -2); + + void* ud6 = lua_newuserdata(L, 0); + lua_getfield(L, LUA_REGISTRYINDEX, "udata2"); + lua_setmetatable(L, -2); + + CHECK(luaL_checkudata(L, -2, "udata1") == ud5); + CHECK(luaL_checkudata(L, -1, "udata2") == ud6); + + globalState.reset(); + + CHECK(dtorhits == 42); +} + +TEST_CASE("Iter") +{ + runConformance("iter.lua"); +} + +const int kInt64Tag = 1; +static int gInt64MT = -1; + +static int64_t getInt64(lua_State* L, int idx) +{ + if (void* p = lua_touserdatatagged(L, idx, kInt64Tag)) + return *static_cast(p); + + if (lua_isnumber(L, idx)) + return lua_tointeger(L, idx); + + luaL_typeerror(L, 1, "int64"); +} + +static void pushInt64(lua_State* L, int64_t value) +{ + void* p = lua_newuserdatatagged(L, sizeof(int64_t), kInt64Tag); + + lua_getref(L, gInt64MT); + lua_setmetatable(L, -2); + + *static_cast(p) = value; +} + +TEST_CASE("Userdata") +{ + runConformance("userdata.lua", [](lua_State* L) { + // create metatable with all the metamethods + lua_newtable(L); + gInt64MT = lua_ref(L, -1); + + // __index + lua_pushcfunction( + L, + [](lua_State* L) { + void* p = lua_touserdatatagged(L, 1, kInt64Tag); + if (!p) + luaL_typeerror(L, 1, "int64"); + + const char* name = luaL_checkstring(L, 2); + + if (strcmp(name, "value") == 0) + { + lua_pushnumber(L, double(*static_cast(p))); + return 1; + } + + luaL_error(L, "unknown field %s", name); + }, + nullptr); + lua_setfield(L, -2, "__index"); + + // __newindex + lua_pushcfunction( + L, + [](lua_State* L) { + void* p = lua_touserdatatagged(L, 1, kInt64Tag); + if (!p) + luaL_typeerror(L, 1, "int64"); + + const char* name = luaL_checkstring(L, 2); + + if (strcmp(name, "value") == 0) + { + double value = luaL_checknumber(L, 3); + *static_cast(p) = int64_t(value); + return 0; + } + + luaL_error(L, "unknown field %s", name); + }, + nullptr); + lua_setfield(L, -2, "__newindex"); + + // __eq + lua_pushcfunction( + L, + [](lua_State* L) { + lua_pushboolean(L, getInt64(L, 1) == getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__eq"); + + // __lt + lua_pushcfunction( + L, + [](lua_State* L) { + lua_pushboolean(L, getInt64(L, 1) < getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__lt"); + + // __le + lua_pushcfunction( + L, + [](lua_State* L) { + lua_pushboolean(L, getInt64(L, 1) <= getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__le"); + + // __add + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, getInt64(L, 1) + getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__add"); + + // __sub + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, getInt64(L, 1) - getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__sub"); + + // __mul + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, getInt64(L, 1) * getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__mul"); + + // __div + lua_pushcfunction( + L, + [](lua_State* L) { + // ideally we'd guard against 0 but it's a test so eh + pushInt64(L, getInt64(L, 1) / getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__div"); + + // __mod + lua_pushcfunction( + L, + [](lua_State* L) { + // ideally we'd guard against 0 and INT64_MIN but it's a test so eh + pushInt64(L, getInt64(L, 1) % getInt64(L, 2)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__mod"); + + // __pow + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, int64_t(pow(double(getInt64(L, 1)), double(getInt64(L, 2))))); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__pow"); + + // __unm + lua_pushcfunction( + L, + [](lua_State* L) { + pushInt64(L, -getInt64(L, 1)); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__unm"); + + // __tostring + lua_pushcfunction( + L, + [](lua_State* L) { + int64_t value = getInt64(L, 1); + std::string str = std::to_string(value); + lua_pushlstring(L, str.c_str(), str.length()); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__tostring"); + + // ctor + lua_pushcfunction( + L, + [](lua_State* L) { + double v = luaL_checknumber(L, 1); + pushInt64(L, int64_t(v)); + return 1; + }, + "int64"); + lua_setglobal(L, "int64"); + }); +} + +TEST_CASE("SafeEnv") +{ + runConformance("safeenv.lua"); +} + +TEST_CASE("HugeFunction") +{ + std::string source; + + // add non-executed block that requires JUMPKX and generates a lot of constants that take available short (15-bit) constant space + source += "if ... then\n"; + source += "local _ = {\n"; + + for (int i = 0; i < 40000; ++i) + { + source += "0."; + source += std::to_string(i); + source += ","; + } + + source += "}\n"; + source += "end\n"; + + // use failed fast-calls with imports and constants to exercise all of the more complex fallback sequences + source += "return bit32.lshift('84', -1)"; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + if (codegen && Luau::CodeGen::isSupported()) + Luau::CodeGen::create(L); + + luaL_openlibs(L); + luaL_sandbox(L); + luaL_sandboxthread(L); + + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), nullptr, &bytecodeSize); + int result = luau_load(L, "=HugeFunction", bytecode, bytecodeSize, 0); + free(bytecode); + + REQUIRE(result == 0); + + if (codegen && Luau::CodeGen::isSupported()) + Luau::CodeGen::compile(L, -1); + + int status = lua_resume(L, nullptr, 0); + REQUIRE(status == 0); + + CHECK(lua_tonumber(L, -1) == 42); +} + TEST_SUITE_END(); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp new file mode 100644 index 00000000..30e1b2e6 --- /dev/null +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -0,0 +1,34 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ConstraintGraphBuilderFixture.h" + +namespace Luau +{ + +ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() + : Fixture() + , mainModule(new Module) + , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} +{ + BlockedTypeVar::nextIndex = 0; + BlockedTypePack::nextIndex = 0; +} + +void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) +{ + AstStatBlock* root = parse(code); + dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); + cgb = std::make_unique("MainModule", mainModule, &arena, NotNull(&moduleResolver), singletonTypes, NotNull(&ice), + frontend.getGlobalScope(), &logger, NotNull{dfg.get()}); + cgb->visit(root); + rootScope = cgb->rootScope; + constraints = Luau::borrowConstraints(cgb->constraints); +} + +void ConstraintGraphBuilderFixture::solve(const std::string& code) +{ + generateConstraints(code); + ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger}; + cs.run(); +} + +} // namespace Luau diff --git a/tests/ConstraintGraphBuilderFixture.h b/tests/ConstraintGraphBuilderFixture.h new file mode 100644 index 00000000..9785a838 --- /dev/null +++ b/tests/ConstraintGraphBuilderFixture.h @@ -0,0 +1,38 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ConstraintSolver.h" +#include "Luau/DcrLogger.h" +#include "Luau/TypeArena.h" +#include "Luau/Module.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +namespace Luau +{ + +struct ConstraintGraphBuilderFixture : Fixture +{ + TypeArena arena; + ModulePtr mainModule; + DcrLogger logger; + UnifierSharedState sharedState{&ice}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + + std::unique_ptr dfg; + std::unique_ptr cgb; + Scope* rootScope = nullptr; + + std::vector> constraints; + + ScopedFastFlag forceTheFlag; + + ConstraintGraphBuilderFixture(); + + void generateConstraints(const std::string& code); + void solve(const std::string& code); +}; + +} // namespace Luau diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp new file mode 100644 index 00000000..eaa0d41a --- /dev/null +++ b/tests/ConstraintSolver.test.cpp @@ -0,0 +1,63 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "ConstraintGraphBuilderFixture.h" +#include "Fixture.h" +#include "doctest.h" + +using namespace Luau; + +static TypeId requireBinding(Scope* scope, const char* name) +{ + auto b = linearSearchForBinding(scope, name); + LUAU_ASSERT(b.has_value()); + return *b; +} + +TEST_SUITE_BEGIN("ConstraintSolver"); + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") +{ + solve(R"( + local a = 55 + local b = a + )"); + + TypeId bType = requireBinding(rootScope, "b"); + + CHECK("number" == toString(bType)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") +{ + solve(R"( + local function id(a) + return a + end + )"); + + TypeId idType = requireBinding(rootScope, "id"); + + CHECK("(a) -> a" == toString(idType)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") +{ + solve(R"( + local function a(c) + local function d(e) + return c + end + + return d + end + + local b = a(5) + )"); + + TypeId idType = requireBinding(rootScope, "b"); + + ToStringOptions opts; + CHECK("(a) -> number" == toString(idType, opts)); +} + +TEST_SUITE_END(); diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp new file mode 100644 index 00000000..d82d5d83 --- /dev/null +++ b/tests/CostModel.test.cpp @@ -0,0 +1,245 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" + +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +namespace Luau +{ +namespace Compile +{ + +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount, const DenseHashMap& builtins); +int computeCost(uint64_t model, const bool* varsConst, size_t varCount); + +} // namespace Compile +} // namespace Luau + +TEST_SUITE_BEGIN("CostModel"); + +static uint64_t modelFunction(const char* source) +{ + Allocator allocator; + AstNameTable names(allocator); + + ParseResult result = Parser::parse(source, strlen(source), names, allocator); + REQUIRE(result.root != nullptr); + + AstStatFunction* func = result.root->body.data[0]->as(); + REQUIRE(func); + + return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size, {nullptr}); +} + +TEST_CASE("Expression") +{ + uint64_t model = modelFunction(R"( +function test(a, b, c) + return a + (b + 1) * (b + 1) - c +end +)"); + + const bool args1[] = {false, false, false}; + const bool args2[] = {false, true, false}; + + CHECK_EQ(5, Luau::Compile::computeCost(model, args1, 3)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 3)); +} + +TEST_CASE("PropagateVariable") +{ + uint64_t model = modelFunction(R"( +function test(a) + local b = a * a * a + return b * b +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(0, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("LoopAssign") +{ + uint64_t model = modelFunction(R"( +function test(a) + for i=1,3 do + a[i] = i + end +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + // loop baseline cost is 5 + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("MutableVariable") +{ + uint64_t model = modelFunction(R"( +function test(a, b) + local x = a * a + x += b + return x * x +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("ImportCall") +{ + uint64_t model = modelFunction(R"( +function test(a) + return Instance.new(a) +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("FastCall") +{ + uint64_t model = modelFunction(R"( +function test(a) + return math.abs(a + 1) +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + // note: we currently don't treat fast calls differently from cost model perspective + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(5, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("ControlFlow") +{ + uint64_t model = modelFunction(R"( +function test(a) + while a < 0 do + a += 1 + end + for i=10,1,-1 do + a += 1 + end + for i in pairs({}) do + a += 1 + if a % 2 == 0 then continue end + end + repeat + a += 1 + if a % 2 == 0 then break end + until a > 10 + return a +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(82, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(79, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("Conditional") +{ + uint64_t model = modelFunction(R"( +function test(a) + return if a < 0 then -a else a +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(4, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("VarArgs") +{ + uint64_t model = modelFunction(R"( +function test(...) + return select('#', ...) :: number +end +)"); + + CHECK_EQ(8, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("TablesFunctions") +{ + uint64_t model = modelFunction(R"( +function test() + return { 42, op = function() end } +end +)"); + + CHECK_EQ(22, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("CostOverflow") +{ + uint64_t model = modelFunction(R"( +function test() + return {{{{{{{{{{{{{{{}}}}}}}}}}}}}}} +end +)"); + + CHECK_EQ(127, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("TableAssign") +{ + uint64_t model = modelFunction(R"( +function test(a) + for i=1,#a do + a[i] = i + end +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(7, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("InterpString") +{ + ScopedFastFlag sff("LuauInterpolatedStringBaseSupport", true); + + uint64_t model = modelFunction(R"( +function test(a) + return `hello, {a}!` +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_SUITE_END(); diff --git a/tests/DataFlowGraph.test.cpp b/tests/DataFlowGraph.test.cpp new file mode 100644 index 00000000..d8230700 --- /dev/null +++ b/tests/DataFlowGraph.test.cpp @@ -0,0 +1,104 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/DataFlowGraph.h" +#include "Luau/Error.h" +#include "Luau/Parser.h" + +#include "AstQueryDsl.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +class DataFlowGraphFixture +{ + // Only needed to fix the operator== reflexivity of an empty Symbol. + ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; + + InternalErrorReporter handle; + + Allocator allocator; + AstNameTable names{allocator}; + AstStatBlock* module; + + std::optional graph; + +public: + void dfg(const std::string& code) + { + ParseResult parseResult = Parser::parse(code.c_str(), code.size(), names, allocator); + if (!parseResult.errors.empty()) + throw ParseErrors(std::move(parseResult.errors)); + module = parseResult.root; + graph = DataFlowGraphBuilder::build(module, NotNull{&handle}); + } + + template + std::optional getDef(const std::vector& nths = {nth(N)}) + { + T* node = query(module, nths); + REQUIRE(node); + return graph->getDef(node); + } + + template + DefId requireDef(const std::vector& nths = {nth(N)}) + { + auto loc = getDef(nths); + REQUIRE(loc); + return NotNull{*loc}; + } +}; + +TEST_SUITE_BEGIN("DataFlowGraphBuilder"); + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_locals_in_local_stat") +{ + dfg(R"( + local x = 5 + local y = x + )"); + + REQUIRE(getDef()); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "define_parameters_in_functions") +{ + dfg(R"( + local function f(x) + local y = x + end + )"); + + REQUIRE(getDef()); +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "find_aliases") +{ + dfg(R"( + local x = 5 + local y = x + local z = y + )"); + + DefId x = requireDef(); + DefId y = requireDef(); + REQUIRE(x != y); // TODO: they should be equal but it's not just locals that can alias, so we'll support this later. +} + +TEST_CASE_FIXTURE(DataFlowGraphFixture, "independent_locals") +{ + dfg(R"( + local x = 5 + local y = 5 + + local a = x + local b = y + )"); + + DefId x = requireDef(); + DefId y = requireDef(); + REQUIRE(x != y); +} + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 26bc77f7..eb77ce52 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -2,49 +2,112 @@ #include "Fixture.h" #include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Constraint.h" +#include "Luau/ModuleResolver.h" +#include "Luau/NotNull.h" +#include "Luau/Parser.h" #include "Luau/TypeVar.h" #include "Luau/TypeAttach.h" #include "Luau/Transpiler.h" -#include "Luau/BuiltinDefinitions.h" - #include "doctest.h" #include #include #include +#include static const char* mainModuleName = "MainModule"; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(LuauReportShadowedTypeAlias) + +extern std::optional randomSeed; // tests/main.cpp + namespace Luau { -std::optional TestFileResolver::fromAstFragment(AstExpr* expr) const +std::optional TestFileResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) { - auto g = expr->as(); - if (!g) - return std::nullopt; - - std::string_view value = g->name.value; - if (value == "game" || value == "Game" || value == "workspace" || value == "Workspace" || value == "script" || value == "Script") - return ModuleName(value); + if (auto name = pathExprToModuleName(currentModuleName, pathExpr)) + return {{*name, false}}; return std::nullopt; } -ModuleName TestFileResolver::concat(const ModuleName& lhs, std::string_view rhs) const +const ModulePtr TestFileResolver::getModule(const ModuleName& moduleName) const { - return lhs + "/" + ModuleName(rhs); + LUAU_ASSERT(false); + return nullptr; } -std::optional TestFileResolver::getParentModuleName(const ModuleName& name) const +bool TestFileResolver::moduleExists(const ModuleName& moduleName) const { - std::string_view view = name; - const size_t lastSeparatorIndex = view.find_last_of('/'); + auto it = source.find(moduleName); + return (it != source.end()); +} - if (lastSeparatorIndex != std::string_view::npos) +std::optional TestFileResolver::readSource(const ModuleName& name) +{ + auto it = source.find(name); + if (it == source.end()) + return std::nullopt; + + SourceCode::Type sourceType = SourceCode::Module; + + auto it2 = sourceTypes.find(name); + if (it2 != sourceTypes.end()) + sourceType = it2->second; + + return SourceCode{it->second, sourceType}; +} + +std::optional TestFileResolver::resolveModule(const ModuleInfo* context, AstExpr* expr) +{ + if (AstExprGlobal* g = expr->as()) { - return ModuleName(view.substr(0, lastSeparatorIndex)); + if (g->name == "game") + return ModuleInfo{"game"}; + if (g->name == "workspace") + return ModuleInfo{"workspace"}; + if (g->name == "script") + return context ? std::optional(*context) : std::nullopt; + } + else if (AstExprIndexName* i = expr->as(); i && context) + { + if (i->index == "Parent") + { + std::string_view view = context->name; + size_t lastSeparatorIndex = view.find_last_of('/'); + + if (lastSeparatorIndex == std::string_view::npos) + return std::nullopt; + + return ModuleInfo{ModuleName(view.substr(0, lastSeparatorIndex)), context->optional}; + } + else + { + return ModuleInfo{context->name + '/' + i->index.value, context->optional}; + } + } + else if (AstExprIndexExpr* i = expr->as(); i && context) + { + if (AstExprConstantString* index = i->index->as()) + { + return ModuleInfo{context->name + '/' + std::string(index->value.data, index->value.size), context->optional}; + } + } + else if (AstExprCall* call = expr->as(); call && call->self && call->args.size >= 1 && context) + { + if (AstExprConstantString* index = call->args.data[0]->as()) + { + AstName func = call->func->as()->index; + + if (func == "GetService" && context->name == "game") + return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)}; + } } return std::nullopt; @@ -64,18 +127,30 @@ std::optional TestFileResolver::getEnvironmentForModule(const Modul return std::nullopt; } -Fixture::Fixture(bool freeze) +const Config& TestConfigResolver::getConfig(const ModuleName& name) const +{ + auto it = configFiles.find(name); + if (it != configFiles.end()) + return it->second; + + return defaultConfig; +} + +Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) - , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true}) + , frontend(&fileResolver, &configResolver, + {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* randomConstraintResolutionSeed */ randomSeed}) , typeChecker(frontend.typeChecker) + , singletonTypes(frontend.singletonTypes) { configResolver.defaultConfig.mode = Mode::Strict; configResolver.defaultConfig.enabledLint.warningMask = ~0ull; configResolver.defaultConfig.parseOptions.captureComments = true; - registerBuiltinTypes(frontend.typeChecker); - registerTestTypes(); + registerBuiltinTypes(frontend); + Luau::freeze(frontend.typeChecker.globalTypes); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); Luau::setPrintLine([](auto s) {}); } @@ -85,11 +160,6 @@ Fixture::~Fixture() Luau::resetPrintLine(); } -UnfrozenFixture::UnfrozenFixture() - : Fixture(false) -{ -} - AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& parseOptions) { sourceModule.reset(new SourceModule); @@ -99,7 +169,7 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars sourceModule->name = fromString(mainModuleName); sourceModule->root = result.root; sourceModule->mode = parseMode(result.hotcomments); - sourceModule->ignoreLints = LintWarning::parseMask(result.hotcomments); + sourceModule->hotcomments = std::move(result.hotcomments); if (!result.errors.empty()) { @@ -118,21 +188,9 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars } CheckResult Fixture::check(Mode mode, std::string source) -{ - configResolver.defaultConfig.mode = mode; - fileResolver.source[mainModuleName] = std::move(source); - - CheckResult result = frontend.check(fromString(mainModuleName)); - - configResolver.defaultConfig.mode = Mode::Strict; - - return result; -} - -CheckResult Fixture::check(const std::string& source) { ModuleName mm = fromString(mainModuleName); - configResolver.defaultConfig.mode = Mode::Strict; + configResolver.defaultConfig.mode = mode; fileResolver.source[mm] = std::move(source); frontend.markDirty(mm); @@ -141,9 +199,15 @@ CheckResult Fixture::check(const std::string& source) return result; } +CheckResult Fixture::check(const std::string& source) +{ + return check(Mode::Strict, source); +} + LintResult Fixture::lint(const std::string& source, const std::optional& lintOptions) { ParseOptions parseOptions; + parseOptions.captureComments = true; configResolver.defaultConfig.mode = Mode::Nonstrict; parse(source, parseOptions); @@ -178,7 +242,7 @@ ParseResult Fixture::tryParse(const std::string& source, const ParseOptions& par return result; } -ParseResult Fixture::matchParseError(const std::string& source, const std::string& message) +ParseResult Fixture::matchParseError(const std::string& source, const std::string& message, std::optional location) { ParseOptions options; options.allowDeclarationSyntax = true; @@ -186,9 +250,15 @@ ParseResult Fixture::matchParseError(const std::string& source, const std::strin sourceModule.reset(new SourceModule); ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options); - REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); + CHECK_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); - CHECK_EQ(result.errors.front().getMessage(), message); + if (!result.errors.empty()) + { + CHECK_EQ(result.errors.front().getMessage(), message); + + if (location) + CHECK_EQ(result.errors.front().getLocation(), *location); + } return result; } @@ -201,11 +271,14 @@ ParseResult Fixture::matchParseErrorPrefix(const std::string& source, const std: sourceModule.reset(new SourceModule); ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options); - REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); + CHECK_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); - const std::string& message = result.errors.front().getMessage(); - CHECK_GE(message.length(), prefix.length()); - CHECK_EQ(prefix, message.substr(0, prefix.size())); + if (!result.errors.empty()) + { + const std::string& message = result.errors.front().getMessage(); + CHECK_GE(message.length(), prefix.length()); + CHECK_EQ(prefix, message.substr(0, prefix.size())); + } return result; } @@ -217,7 +290,7 @@ ModulePtr Fixture::getMainModule() SourceModule* Fixture::getMainSourceModule() { - return frontend.getSourceModule(fromString("MainModule")); + return frontend.getSourceModule(fromString(mainModuleName)); } std::optional Fixture::getPrimitiveType(TypeId ty) @@ -239,13 +312,16 @@ std::optional Fixture::getType(const std::string& name) ModulePtr module = getMainModule(); REQUIRE(module); - return lookupName(module->getModuleScope(), name); + if (FFlag::DebugLuauDeferredConstraintResolution) + return linearSearchForBinding(module->getModuleScope().get(), name.c_str()); + else + return lookupName(module->getModuleScope(), name); } TypeId Fixture::requireType(const std::string& name) { std::optional ty = getType(name); - REQUIRE(bool(ty)); + REQUIRE_MESSAGE(bool(ty), "Unable to requireType \"" << name << "\""); return follow(*ty); } @@ -275,6 +351,13 @@ std::optional Fixture::findTypeAtPosition(Position position) return Luau::findTypeAtPosition(*module, *sourceModule, position); } +std::optional Fixture::findExpectedTypeAtPosition(Position position) +{ + ModulePtr module = getMainModule(); + SourceModule* sourceModule = getMainSourceModule(); + return Luau::findExpectedTypeAtPosition(*module, *sourceModule, position); +} + TypeId Fixture::requireTypeAtPosition(Position position) { auto ty = findTypeAtPosition(position); @@ -323,7 +406,7 @@ void Fixture::dumpErrors(std::ostream& os, const std::vector& errors) if (error.location.begin.line >= lines.size()) { os << "\tSource not available?" << std::endl; - return; + continue; } std::string_view theLine = lines[error.location.begin.line]; @@ -337,24 +420,32 @@ void Fixture::dumpErrors(std::ostream& os, const std::vector& errors) void Fixture::registerTestTypes() { - addGlobalBinding(typeChecker, "game", typeChecker.anyType, "@luau"); - addGlobalBinding(typeChecker, "workspace", typeChecker.anyType, "@luau"); - addGlobalBinding(typeChecker, "script", typeChecker.anyType, "@luau"); + addGlobalBinding(frontend, "game", typeChecker.anyType, "@luau"); + addGlobalBinding(frontend, "workspace", typeChecker.anyType, "@luau"); + addGlobalBinding(frontend, "script", typeChecker.anyType, "@luau"); } void Fixture::dumpErrors(const CheckResult& cr) { - dumpErrors(std::cout, cr.errors); + std::string error = getErrors(cr); + if (!error.empty()) + MESSAGE(error); } void Fixture::dumpErrors(const ModulePtr& module) { - dumpErrors(std::cout, module->errors); + std::stringstream ss; + dumpErrors(ss, module->errors); + if (!ss.str().empty()) + MESSAGE(ss.str()); } void Fixture::dumpErrors(const Module& module) { - dumpErrors(std::cout, module.errors); + std::stringstream ss; + dumpErrors(ss, module.errors); + if (!ss.str().empty()) + MESSAGE(ss.str()); } std::string Fixture::getErrors(const CheckResult& cr) @@ -382,13 +473,30 @@ void Fixture::validateErrors(const std::vector& errors) LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) { unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, source, "@test"); + LoadDefinitionFileResult result = frontend.loadDefinitionFile(source, "@test"); freeze(typeChecker.globalTypes); + if (result.module) + dumpErrors(result.module); REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); return result; } +BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) + : Fixture(freeze, prepareAutocomplete) +{ + Luau::unfreeze(frontend.typeChecker.globalTypes); + Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); + + registerBuiltinGlobals(frontend); + if (prepareAutocomplete) + registerBuiltinGlobals(frontend.typeCheckerForAutocomplete); + registerTestTypes(); + + Luau::freeze(frontend.typeChecker.globalTypes); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); +} + ModuleName fromString(std::string_view name) { return ModuleName(name); @@ -428,4 +536,37 @@ std::optional lookupName(ScopePtr scope, const std::string& name) return std::nullopt; } +std::optional linearSearchForBinding(Scope* scope, const char* name) +{ + while (scope) + { + for (const auto& [n, ty] : scope->bindings) + { + if (n.astName() == name) + return ty.typeId; + } + + scope = scope->parent.get(); + } + + return std::nullopt; +} + +void registerHiddenTypes(Fixture& fixture, TypeArena& arena) +{ + TypeId t = arena.addType(GenericTypeVar{"T"}); + GenericTypeDefinition genericT{t}; + + ScopePtr moduleScope = fixture.frontend.getGlobalScope(); + moduleScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, arena.addType(NegationTypeVar{t})}; + moduleScope->exportedTypeBindings["fun"] = TypeFun{{}, fixture.singletonTypes->functionType}; +} + +void dump(const std::vector& constraints) +{ + ToStringOptions opts; + for (const auto& c : constraints) + printf("%s\n", toString(c, opts).c_str()); +} + } // namespace Luau diff --git a/tests/Fixture.h b/tests/Fixture.h index c6294b01..5d838b16 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -8,65 +8,35 @@ #include "Luau/Linter.h" #include "Luau/Location.h" #include "Luau/ModuleResolver.h" -#include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" -#include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "IostreamOptional.h" #include "ScopedFlags.h" -#include #include #include - #include namespace Luau { +struct TypeChecker; + struct TestFileResolver : FileResolver , ModuleResolver { - std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override - { - if (auto name = pathExprToModuleName(currentModuleName, pathExpr)) - return {{*name, false}}; + std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; - return std::nullopt; - } + const ModulePtr getModule(const ModuleName& moduleName) const override; - const ModulePtr getModule(const ModuleName& moduleName) const override - { - LUAU_ASSERT(false); - return nullptr; - } + bool moduleExists(const ModuleName& moduleName) const override; - bool moduleExists(const ModuleName& moduleName) const override - { - auto it = source.find(moduleName); - return (it != source.end()); - } + std::optional readSource(const ModuleName& name) override; - std::optional readSource(const ModuleName& name) override - { - auto it = source.find(name); - if (it == source.end()) - return std::nullopt; - - SourceCode::Type sourceType = SourceCode::Module; - - auto it2 = sourceTypes.find(name); - if (it2 != sourceTypes.end()) - sourceType = it2->second; - - return SourceCode{it->second, sourceType}; - } - - std::optional fromAstFragment(AstExpr* expr) const override; - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override; - std::optional getParentModuleName(const ModuleName& name) const override; + std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; std::string getHumanReadableModuleName(const ModuleName& name) const override; @@ -82,19 +52,12 @@ struct TestConfigResolver : ConfigResolver Config defaultConfig; std::unordered_map configFiles; - const Config& getConfig(const ModuleName& name) const override - { - auto it = configFiles.find(name); - if (it != configFiles.end()) - return it->second; - - return defaultConfig; - } + const Config& getConfig(const ModuleName& name) const override; }; struct Fixture { - explicit Fixture(bool freeze = true); + explicit Fixture(bool freeze = true, bool prepareAutocomplete = false); ~Fixture(); // Throws Luau::ParseErrors if the parse fails. @@ -108,7 +71,7 @@ struct Fixture /// Parse with all language extensions enabled ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {}); - ParseResult matchParseError(const std::string& source, const std::string& message); + ParseResult matchParseError(const std::string& source, const std::string& message, std::optional location = std::nullopt); // Verify a parse error occurs and the parse error message has the specified prefix ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); @@ -124,17 +87,22 @@ struct Fixture std::optional findTypeAtPosition(Position position); TypeId requireTypeAtPosition(Position position); + std::optional findExpectedTypeAtPosition(Position position); std::optional lookupType(const std::string& name); std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; + ScopedFastFlag sff_UnknownNever{"LuauUnknownAndNeverType", true}; TestFileResolver fileResolver; TestConfigResolver configResolver; + NullModuleResolver moduleResolver; std::unique_ptr sourceModule; Frontend frontend; + InternalErrorReporter ice; TypeChecker& typeChecker; + NotNull singletonTypes; std::string decorateWithTypes(const std::string& code); @@ -153,13 +121,9 @@ struct Fixture LoadDefinitionFileResult loadDefinition(const std::string& source); }; -// Disables arena freezing for a given test case. -// Do not use this in new tests. If you are running into access violations, you -// are violating Luau's memory model - the fix is not to use UnfrozenFixture. -// Related: CLI-45692 -struct UnfrozenFixture : Fixture +struct BuiltinsFixture : Fixture { - UnfrozenFixture(); + BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); }; ModuleName fromString(std::string_view name); @@ -181,9 +145,14 @@ bool isInArena(TypeId t, const TypeArena& arena); void dumpErrors(const ModulePtr& module); void dumpErrors(const Module& module); void dump(const std::string& name, TypeId ty); +void dump(const std::vector& constraints); std::optional lookupName(ScopePtr scope, const std::string& name); // Warning: This function runs in O(n**2) +std::optional linearSearchForBinding(Scope* scope, const char* name); + +void registerHiddenTypes(Fixture& fixture, TypeArena& arena); + } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 3f33a5d1..6f92b655 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -2,7 +2,6 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/Frontend.h" -#include "Luau/Parser.h" #include "Luau/RequireTracer.h" #include "Fixture.h" @@ -46,32 +45,44 @@ NaiveModuleResolver naiveModuleResolver; struct NaiveFileResolver : NullFileResolver { - std::optional fromAstFragment(AstExpr* expr) const override + std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override { - AstExprGlobal* g = expr->as(); - if (g && g->name == "Modules") - return "Modules"; + if (AstExprGlobal* g = expr->as()) + { + if (g->name == "Modules") + return ModuleInfo{"Modules"}; - if (g && g->name == "game") - return "game"; + if (g->name == "game") + return ModuleInfo{"game"}; + } + else if (AstExprIndexName* i = expr->as()) + { + if (context) + return ModuleInfo{context->name + '/' + i->index.value, context->optional}; + } + else if (AstExprCall* call = expr->as(); call && call->self && call->args.size >= 1 && context) + { + if (AstExprConstantString* index = call->args.data[0]->as()) + { + AstName func = call->func->as()->index; + + if (func == "GetService" && context->name == "game") + return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)}; + } + } return std::nullopt; } - - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override - { - return lhs + "/" + ModuleName(rhs); - } }; } // namespace -struct FrontendFixture : Fixture +struct FrontendFixture : BuiltinsFixture { FrontendFixture() { - addGlobalBinding(typeChecker, "game", frontend.typeChecker.anyType, "@test"); - addGlobalBinding(typeChecker, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "game", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); } }; @@ -86,8 +97,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "find_a_require") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(1, res.requires.size()); - CHECK_EQ(res.requires[0].first, "Modules/Foo/Bar"); + CHECK_EQ(1, res.requireList.size()); + CHECK_EQ(res.requireList[0].first, "Modules/Foo/Bar"); } // It could be argued that this should not work. @@ -102,7 +113,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "find_a_require_inside_a_function") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(1, res.requires.size()); + CHECK_EQ(1, res.requireList.size()); } TEST_CASE_FIXTURE(FrontendFixture, "real_source") @@ -127,7 +138,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "real_source") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(8, res.requires.size()); + CHECK_EQ(8, res.requireList.size()); } TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") @@ -373,6 +384,66 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_error_paths") CHECK_EQ(ce2->cycle[1], "game/Gui/Modules/A"); } +TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface") +{ + fileResolver.source["game/A"] = R"( + return {hello = 2} + )"; + + CheckResult result = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/A"] = R"( + local me = require(game.A) + return {hello = 2} + )"; + frontend.markDirty("game/A"); + + result = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(result); + + auto ty = requireType("game/A", "me"); + CHECK_EQ(toString(ty), "any"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer") +{ + fileResolver.source["game/A"] = R"( + return {mod_a = 2} + )"; + + CheckResult result = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/B"] = R"( + local me = require(game.A) + return {mod_b = 4} + )"; + + result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/A"] = R"( + local me = require(game.B) + return {mod_a_prime = 3} + )"; + + frontend.markDirty("game/A"); + frontend.markDirty("game/B"); + + result = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(result); + + TypeId tyA = requireType("game/A", "me"); + CHECK_EQ(toString(tyA), "any"); + + result = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(result); + + TypeId tyB = requireType("game/B", "me"); + CHECK_EQ(toString(tyB), "any"); +} + TEST_CASE_FIXTURE(FrontendFixture, "dont_reparse_clean_file_when_linting") { fileResolver.source["Modules/A"] = R"( @@ -446,6 +517,29 @@ TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_is_dirty") CHECK_EQ("{| b_value: string |}", toString(*bExports)); } +TEST_CASE_FIXTURE(FrontendFixture, "mark_non_immediate_reverse_deps_as_dirty") +{ + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = R"( + return require(game:GetService('Gui').Modules.A) + )"; + fileResolver.source["game/Gui/Modules/C"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {c_value = B.hello} + )"; + + frontend.check("game/Gui/Modules/C"); + + std::vector markedDirty; + frontend.markDirty("game/Gui/Modules/A", &markedDirty); + + REQUIRE(markedDirty.size() == 3); + CHECK(std::find(markedDirty.begin(), markedDirty.end(), "game/Gui/Modules/A") != markedDirty.end()); + CHECK(std::find(markedDirty.begin(), markedDirty.end(), "game/Gui/Modules/B") != markedDirty.end()); + CHECK(std::find(markedDirty.begin(), markedDirty.end(), "game/Gui/Modules/C") != markedDirty.end()); +} + #if 0 // Does not work yet. :( TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_has_a_parse_error") @@ -528,7 +622,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "ignore_require_to_nonexistent_file") { fileResolver.source["Modules/A"] = R"( local Modules = script - local B = require(Modules.B :: any) + local B = require(Modules.B) :: any )"; CheckResult result = frontend.check("Modules/A"); @@ -685,26 +779,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "test_lint_uses_correct_config") CHECK_EQ(0, result4.warnings.size()); } -TEST_CASE_FIXTURE(FrontendFixture, "lintFragment") -{ - LintOptions lintOptions; - lintOptions.enableWarning(LintWarning::Code_ForRange); - - auto [_sourceModule, result] = frontend.lintFragment(R"( - local t = {} - - for i=#t,1 do - end - - for i=#t,1,-1 do - end - )", - lintOptions); - - CHECK_EQ(1, result.warnings.size()); - CHECK_EQ(0, result.errors.size()); -} - TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") { Frontend fe{&fileResolver, &configResolver, {false}}; @@ -720,6 +794,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") CHECK_EQ(0, module->internalTypes.typeVars.size()); CHECK_EQ(0, module->internalTypes.typePacks.size()); CHECK_EQ(0, module->astTypes.size()); + CHECK_EQ(0, module->astResolvedTypes.size()); + CHECK_EQ(0, module->astResolvedTypePacks.size()); } TEST_CASE_FIXTURE(FrontendFixture, "it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded") @@ -885,8 +961,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "clearStats") TEST_CASE_FIXTURE(FrontendFixture, "typecheck_twice_for_ast_types") { - ScopedFastFlag sffs("LuauTypeCheckTwice", true); - fileResolver.source["Module/A"] = R"( local a = 1 )"; @@ -915,7 +989,7 @@ return a; --!nonstrict local a = require(script.Parent.A) local b = {} -function a:b() end -- this should error, but doesn't +function a:b() end -- this should error, since A doesn't define a:b() return b )"; @@ -930,8 +1004,7 @@ a:b() -- this should error, since A doesn't define a:b() LUAU_REQUIRE_NO_ERRORS(resultA); CheckResult resultB = frontend.check("Module/B"); - // TODO (CLI-45592): this should error, since we shouldn't be adding properties to objects from other modules - LUAU_REQUIRE_NO_ERRORS(resultB); + LUAU_REQUIRE_ERRORS(resultB); CheckResult resultC = frontend.check("Module/C"); LUAU_REQUIRE_ERRORS(resultC); @@ -943,7 +1016,6 @@ TEST_CASE("no_use_after_free_with_type_fun_instantiation") { // This flag forces this test to crash if there's a UAF in this code. ScopedFastFlag sff_DebugLuauFreezeArena("DebugLuauFreezeArena", true); - ScopedFastFlag sff_LuauCloneCorrectlyBeforeMutatingTableType("LuauCloneCorrectlyBeforeMutatingTableType", true); FrontendFixture fix; @@ -962,4 +1034,75 @@ return false; fix.frontend.check("Module/B"); } +TEST_CASE("check_without_builtin_next") +{ + TestFileResolver fileResolver; + TestConfigResolver configResolver; + Frontend frontend(&fileResolver, &configResolver); + + fileResolver.source["Module/A"] = "for k,v in 2 do end"; + fileResolver.source["Module/B"] = "return next"; + + // We don't care about the result. That we haven't crashed is enough. + frontend.check("Module/A"); + frontend.check("Module/B"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_cyclic_type") +{ + fileResolver.source["Module/A"] = R"( + type F = (set: G) -> () + + export type G = { + forEach: (a: F) -> (), + } + + function X(a: F): () + end + + return X + )"; + + fileResolver.source["Module/B"] = R"( + --!strict + local A = require(script.Parent.A) + + export type G = A.G + + return { + A = A, + } + )"; + + CheckResult result = frontend.check("Module/B"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_type_alias") +{ + fileResolver.source["Module/A"] = R"( + type KeyOfTestEvents = "test-file-start" | "test-file-success" | "test-file-failure" | "test-case-result" + type MyAny = any + + export type TestFileEvent = ( + eventName: T, + args: any --[[ ROBLOX TODO: Unhandled node for type: TSIndexedAccessType ]] --[[ TestEvents[T] ]] + ) -> MyAny + + return {} + )"; + + fileResolver.source["Module/B"] = R"( + --!strict + local A = require(script.Parent.A) + + export type TestFileEvent = A.TestFileEvent + )"; + + CheckResult result = frontend.check("Module/B"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h index 9f874899..e0756bad 100644 --- a/tests/IostreamOptional.h +++ b/tests/IostreamOptional.h @@ -2,6 +2,10 @@ #pragma once #include +#include + +namespace std +{ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) { @@ -9,10 +13,12 @@ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) } template -std::ostream& operator<<(std::ostream& lhs, const std::optional& t) +auto operator<<(std::ostream& lhs, const std::optional& t) -> decltype(lhs << *t) // SFINAE to only instantiate << for supported types { if (t) return lhs << *t; else return lhs << "none"; } + +} // namespace std diff --git a/tests/JsonEmitter.test.cpp b/tests/JsonEmitter.test.cpp new file mode 100644 index 00000000..ff9a5955 --- /dev/null +++ b/tests/JsonEmitter.test.cpp @@ -0,0 +1,195 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/JsonEmitter.h" + +#include "doctest.h" + +using namespace Luau::Json; + +TEST_SUITE_BEGIN("JsonEmitter"); + +TEST_CASE("write_array") +{ + JsonEmitter emitter; + ArrayEmitter a = emitter.writeArray(); + a.writeValue(123); + a.writeValue("foo"); + a.finish(); + + std::string result = emitter.str(); + CHECK(result == "[123,\"foo\"]"); +} + +TEST_CASE("write_object") +{ + JsonEmitter emitter; + ObjectEmitter o = emitter.writeObject(); + o.writePair("foo", "bar"); + o.writePair("bar", "baz"); + o.finish(); + + std::string result = emitter.str(); + CHECK(result == "{\"foo\":\"bar\",\"bar\":\"baz\"}"); +} + +TEST_CASE("write_bool") +{ + JsonEmitter emitter; + write(emitter, false); + CHECK(emitter.str() == "false"); + + emitter = JsonEmitter{}; + write(emitter, true); + CHECK(emitter.str() == "true"); +} + +TEST_CASE("write_null") +{ + JsonEmitter emitter; + write(emitter, nullptr); + CHECK(emitter.str() == "null"); +} + +TEST_CASE("write_string") +{ + JsonEmitter emitter; + write(emitter, R"(foo,bar,baz, +"this should be escaped")"); + CHECK(emitter.str() == "\"foo,bar,baz,\\n\\\"this should be escaped\\\"\""); +} + +TEST_CASE("write_comma") +{ + JsonEmitter emitter; + emitter.writeComma(); + write(emitter, true); + emitter.writeComma(); + write(emitter, false); + CHECK(emitter.str() == "true,false"); +} + +TEST_CASE("push_and_pop_comma") +{ + JsonEmitter emitter; + emitter.writeComma(); + write(emitter, true); + emitter.writeComma(); + emitter.writeRaw('['); + bool comma = emitter.pushComma(); + emitter.writeComma(); + write(emitter, true); + emitter.writeComma(); + write(emitter, false); + emitter.writeRaw(']'); + emitter.popComma(comma); + emitter.writeComma(); + write(emitter, false); + + CHECK(emitter.str() == "true,[true,false],false"); +} + +TEST_CASE("write_optional") +{ + JsonEmitter emitter; + emitter.writeComma(); + write(emitter, std::optional{true}); + emitter.writeComma(); + write(emitter, std::nullopt); + + CHECK(emitter.str() == "true,null"); +} + +TEST_CASE("write_vector") +{ + std::vector values{1, 2, 3, 4}; + JsonEmitter emitter; + write(emitter, values); + CHECK(emitter.str() == "[1,2,3,4]"); +} + +TEST_CASE("prevent_multiple_object_finish") +{ + JsonEmitter emitter; + ObjectEmitter o = emitter.writeObject(); + o.writePair("a", "b"); + o.finish(); + o.finish(); + + CHECK(emitter.str() == "{\"a\":\"b\"}"); +} + +TEST_CASE("prevent_multiple_array_finish") +{ + JsonEmitter emitter; + ArrayEmitter a = emitter.writeArray(); + a.writeValue(1); + a.finish(); + a.finish(); + + CHECK(emitter.str() == "[1]"); +} + +TEST_CASE("cannot_write_pair_after_finished") +{ + JsonEmitter emitter; + ObjectEmitter o = emitter.writeObject(); + o.finish(); + o.writePair("a", "b"); + + CHECK(emitter.str() == "{}"); +} + +TEST_CASE("cannot_write_value_after_finished") +{ + JsonEmitter emitter; + ArrayEmitter a = emitter.writeArray(); + a.finish(); + a.writeValue(1); + + CHECK(emitter.str() == "[]"); +} + +TEST_CASE("finish_when_destructing_object") +{ + JsonEmitter emitter; + emitter.writeObject(); + + CHECK(emitter.str() == "{}"); +} + +TEST_CASE("finish_when_destructing_array") +{ + JsonEmitter emitter; + emitter.writeArray(); + + CHECK(emitter.str() == "[]"); +} + +namespace Luau::Json +{ + +struct Special +{ + int foo; + int bar; +}; + +void write(JsonEmitter& emitter, const Special& value) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("foo", value.foo); + o.writePair("bar", value.bar); +} + +} // namespace Luau::Json + +TEST_CASE("afford_extensibility") +{ + std::vector vec{Special{1, 2}, Special{3, 4}}; + JsonEmitter e; + write(e, vec); + + std::string result = e.str(); + CHECK(result == R"([{"foo":1,"bar":2},{"foo":3,"bar":4}])"); +} + +TEST_SUITE_END(); diff --git a/tests/JsonEncoder.test.cpp b/tests/JsonEncoder.test.cpp deleted file mode 100644 index 4a717275..00000000 --- a/tests/JsonEncoder.test.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Ast.h" -#include "Luau/JsonEncoder.h" - -#include "doctest.h" - -#include - -using namespace Luau; - -TEST_SUITE_BEGIN("JsonEncoderTests"); - -TEST_CASE("encode_constants") -{ - AstExprConstantNil nil{Location()}; - AstExprConstantBool b{Location(), true}; - AstExprConstantNumber n{Location(), 8.2}; - - CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil)); - CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b)); - CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":8.2})", toJson(&n)); -} - -TEST_CASE("basic_escaping") -{ - std::string s = "hello \"world\""; - AstArray theString{s.data(), s.size()}; - AstExprConstantString str{Location(), theString}; - - std::string expected = R"({"type":"AstExprConstantString","location":"0,0 - 0,0","value":"hello \"world\""})"; - CHECK_EQ(expected, toJson(&str)); -} - -TEST_CASE("encode_AstStatBlock") -{ - AstLocal astlocal{AstName{"a_local"}, Location(), nullptr, 0, 0, nullptr}; - AstLocal* astlocalarray[] = {&astlocal}; - - AstArray vars{astlocalarray, 1}; - AstArray values{nullptr, 0}; - AstStatLocal local{Location(), vars, values, std::nullopt}; - AstStat* statArray[] = {&local}; - - AstArray bodyArray{statArray, 1}; - - AstStatBlock block{Location(), bodyArray}; - - CHECK_EQ( - (R"({"type":"AstStatBlock","location":"0,0 - 0,0","body":[{"type":"AstStatLocal","location":"0,0 - 0,0","vars":["a_local"],"values":[]}]})"), - toJson(&block)); -} - -TEST_SUITE_END(); diff --git a/tests/LValue.test.cpp b/tests/LValue.test.cpp new file mode 100644 index 00000000..0bb91ece --- /dev/null +++ b/tests/LValue.test.cpp @@ -0,0 +1,197 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeInfer.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r) +{ + Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId { + // TODO: normalize here also. + std::unordered_set s; + + if (auto utv = get(follow(a))) + s.insert(begin(utv), end(utv)); + else + s.insert(a); + + if (auto utv = get(follow(b))) + s.insert(begin(utv), end(utv)); + else + s.insert(b); + + std::vector options(s.begin(), s.end()); + return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)}); + }); +} + +static LValue mkSymbol(const std::string& s) +{ + return Symbol{AstName{s.data()}}; +} + +struct LValueFixture +{ + SingletonTypes singletonTypes; +}; + +TEST_SUITE_BEGIN("LValue"); + +TEST_CASE_FIXTURE(LValueFixture, "Luau_merge_hashmap_order") +{ + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + + RefinementMap m{{ + {mkSymbol(b), singletonTypes.stringType}, + {mkSymbol(c), singletonTypes.numberType}, + }}; + + RefinementMap other{{ + {mkSymbol(a), singletonTypes.stringType}, + {mkSymbol(b), singletonTypes.stringType}, + {mkSymbol(c), singletonTypes.booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(3, m.size()); + REQUIRE(m.count(mkSymbol(a))); + REQUIRE(m.count(mkSymbol(b))); + REQUIRE(m.count(mkSymbol(c))); + + CHECK_EQ("string", toString(m[mkSymbol(a)])); + CHECK_EQ("string", toString(m[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m[mkSymbol(c)])); +} + +TEST_CASE_FIXTURE(LValueFixture, "Luau_merge_hashmap_order2") +{ + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + + RefinementMap m{{ + {mkSymbol(a), singletonTypes.stringType}, + {mkSymbol(b), singletonTypes.stringType}, + {mkSymbol(c), singletonTypes.numberType}, + }}; + + RefinementMap other{{ + {mkSymbol(b), singletonTypes.stringType}, + {mkSymbol(c), singletonTypes.booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(3, m.size()); + REQUIRE(m.count(mkSymbol(a))); + REQUIRE(m.count(mkSymbol(b))); + REQUIRE(m.count(mkSymbol(c))); + + CHECK_EQ("string", toString(m[mkSymbol(a)])); + CHECK_EQ("string", toString(m[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m[mkSymbol(c)])); +} + +TEST_CASE_FIXTURE(LValueFixture, "one_map_has_overlap_at_end_whereas_other_has_it_in_start") +{ + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + std::string d = "d"; + std::string e = "e"; + + RefinementMap m{{ + {mkSymbol(a), singletonTypes.stringType}, + {mkSymbol(b), singletonTypes.numberType}, + {mkSymbol(c), singletonTypes.booleanType}, + }}; + + RefinementMap other{{ + {mkSymbol(c), singletonTypes.stringType}, + {mkSymbol(d), singletonTypes.numberType}, + {mkSymbol(e), singletonTypes.booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(5, m.size()); + REQUIRE(m.count(mkSymbol(a))); + REQUIRE(m.count(mkSymbol(b))); + REQUIRE(m.count(mkSymbol(c))); + REQUIRE(m.count(mkSymbol(d))); + REQUIRE(m.count(mkSymbol(e))); + + CHECK_EQ("string", toString(m[mkSymbol(a)])); + CHECK_EQ("number", toString(m[mkSymbol(b)])); + CHECK_EQ("boolean | string", toString(m[mkSymbol(c)])); + CHECK_EQ("number", toString(m[mkSymbol(d)])); + CHECK_EQ("boolean", toString(m[mkSymbol(e)])); +} + +TEST_CASE_FIXTURE(LValueFixture, "hashing_lvalue_global_prop_access") +{ + std::string t1 = "t"; + std::string x1 = "x"; + + LValue t_x1{Field{std::make_shared(Symbol{AstName{t1.data()}}), x1}}; + + std::string t2 = "t"; + std::string x2 = "x"; + + LValue t_x2{Field{std::make_shared(Symbol{AstName{t2.data()}}), x2}}; + + CHECK_EQ(t_x1, t_x1); + CHECK_EQ(t_x1, t_x2); + CHECK_EQ(t_x2, t_x2); + + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x1)); + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); + CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); + + RefinementMap m; + m[t_x1] = singletonTypes.stringType; + m[t_x2] = singletonTypes.numberType; + + CHECK_EQ(1, m.size()); +} + +TEST_CASE_FIXTURE(LValueFixture, "hashing_lvalue_local_prop_access") +{ + std::string t1 = "t"; + std::string x1 = "x"; + + AstLocal localt1{AstName{t1.data()}, Location(), nullptr, 0, 0, nullptr}; + LValue t_x1{Field{std::make_shared(Symbol{&localt1}), x1}}; + + std::string t2 = "t"; + std::string x2 = "x"; + + AstLocal localt2{AstName{t2.data()}, Location(), &localt1, 0, 0, nullptr}; + LValue t_x2{Field{std::make_shared(Symbol{&localt2}), x2}}; + + CHECK_EQ(t_x1, t_x1); + CHECK_NE(t_x1, t_x2); + CHECK_EQ(t_x2, t_x2); + + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x1)); + CHECK_NE(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); + CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); + + RefinementMap m; + m[t_x1] = singletonTypes.stringType; + m[t_x2] = singletonTypes.numberType; + + CHECK_EQ(2, m.size()); +} + +TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp new file mode 100644 index 00000000..890d100b --- /dev/null +++ b/tests/Lexer.test.cpp @@ -0,0 +1,227 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Lexer.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("LexerTests"); + +TEST_CASE("broken_string_works") +{ + const std::string testInput = "[["; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::Type::BrokenString); + CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, 2))); +} + +TEST_CASE("broken_comment") +{ + const std::string testInput = "--[[ "; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::Type::BrokenComment); + CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, 6))); +} + +TEST_CASE("broken_comment_kept") +{ + const std::string testInput = "--[[ "; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + lexer.setSkipComments(true); + CHECK_EQ(lexer.next().type, Lexeme::Type::BrokenComment); +} + +TEST_CASE("comment_skipped") +{ + const std::string testInput = "-- "; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + lexer.setSkipComments(true); + CHECK_EQ(lexer.next().type, Lexeme::Type::Eof); +} + +TEST_CASE("multilineCommentWithLexemeInAndAfter") +{ + const std::string testInput = "--[[ function \n" + "]] end"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme comment = lexer.next(); + Lexeme end = lexer.next(); + + CHECK_EQ(comment.type, Lexeme::Type::BlockComment); + CHECK_EQ(comment.location, Luau::Location(Luau::Position(0, 0), Luau::Position(1, 2))); + CHECK_EQ(end.type, Lexeme::Type::ReservedEnd); + CHECK_EQ(end.location, Luau::Location(Luau::Position(1, 3), Luau::Position(1, 6))); +} + +TEST_CASE("testBrokenEscapeTolerant") +{ + const std::string testInput = "'\\3729472897292378'"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme item = lexer.next(); + + CHECK_EQ(item.type, Lexeme::QuotedString); + CHECK_EQ(item.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, int(testInput.size())))); +} + +TEST_CASE("testBigDelimiters") +{ + const std::string testInput = "--[===[\n" + "\n" + "\n" + "\n" + "]===]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme item = lexer.next(); + + CHECK_EQ(item.type, Lexeme::Type::BlockComment); + CHECK_EQ(item.location, Luau::Location(Luau::Position(0, 0), Luau::Position(4, 5))); +} + +TEST_CASE("lookahead") +{ + const std::string testInput = "foo --[[ comment ]] bar : nil end"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + lexer.setSkipComments(true); + lexer.next(); // must call next() before reading data from lexer at least once + + CHECK_EQ(lexer.current().type, Lexeme::Name); + CHECK_EQ(lexer.current().name, std::string("foo")); + CHECK_EQ(lexer.lookahead().type, Lexeme::Name); + CHECK_EQ(lexer.lookahead().name, std::string("bar")); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::Name); + CHECK_EQ(lexer.current().name, std::string("bar")); + CHECK_EQ(lexer.lookahead().type, ':'); + + lexer.next(); + + CHECK_EQ(lexer.current().type, ':'); + CHECK_EQ(lexer.lookahead().type, Lexeme::ReservedNil); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::ReservedNil); + CHECK_EQ(lexer.lookahead().type, Lexeme::ReservedEnd); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::ReservedEnd); + CHECK_EQ(lexer.lookahead().type, Lexeme::Eof); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::Eof); + CHECK_EQ(lexer.lookahead().type, Lexeme::Eof); +} + +TEST_CASE("string_interpolation_basic") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + const std::string testInput = R"(`foo {"bar"}`)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme interpBegin = lexer.next(); + CHECK_EQ(interpBegin.type, Lexeme::InterpStringBegin); + + Lexeme quote = lexer.next(); + CHECK_EQ(quote.type, Lexeme::QuotedString); + + Lexeme interpEnd = lexer.next(); + CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); +} + +TEST_CASE("string_interpolation_double_brace") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + const std::string testInput = R"(`foo{{bad}}bar`)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + auto brokenInterpBegin = lexer.next(); + CHECK_EQ(brokenInterpBegin.type, Lexeme::BrokenInterpDoubleBrace); + CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.length), std::string("foo")); + + CHECK_EQ(lexer.next().type, Lexeme::Name); + + auto interpEnd = lexer.next(); + CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); + CHECK_EQ(std::string(interpEnd.data, interpEnd.length), std::string("}bar")); +} + +TEST_CASE("string_interpolation_double_but_unmatched_brace") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + const std::string testInput = R"(`{{oops}`, 1)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + CHECK_EQ(lexer.next().type, Lexeme::BrokenInterpDoubleBrace); + CHECK_EQ(lexer.next().type, Lexeme::Name); + CHECK_EQ(lexer.next().type, Lexeme::InterpStringEnd); + CHECK_EQ(lexer.next().type, ','); + CHECK_EQ(lexer.next().type, Lexeme::Number); +} + +TEST_CASE("string_interpolation_unmatched_brace") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + const std::string testInput = R"({ + `hello {"world"} + } -- this might be incorrectly parsed as a string)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + CHECK_EQ(lexer.next().type, '{'); + CHECK_EQ(lexer.next().type, Lexeme::InterpStringBegin); + CHECK_EQ(lexer.next().type, Lexeme::QuotedString); + CHECK_EQ(lexer.next().type, Lexeme::BrokenString); + CHECK_EQ(lexer.next().type, '}'); +} + +TEST_CASE("string_interpolation_with_unicode_escape") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + const std::string testInput = R"(`\u{1F41B}`)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + CHECK_EQ(lexer.next().type, Lexeme::InterpStringSimple); + CHECK_EQ(lexer.next().type, Lexeme::Eof); +} + +TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index c8eff399..78ff85c3 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -21,28 +21,40 @@ end return math.max(fib(5), 1) )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnknownGlobal") { LintResult result = lint("--!nocheck\nreturn foo"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Unknown global 'foo'"); } TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobal") { // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(typeChecker, "Wait", Binding{typeChecker.anyType, {}, true, "wait", "@test/global/Wait"}); + addGlobalBinding(frontend, "Wait", Binding{typeChecker.anyType, {}, true, "wait", "@test/global/Wait"}); LintResult result = lintTyped("Wait(5)"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'Wait' is deprecated, use 'wait' instead"); } +TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobalNoReplacement") +{ + // Normally this would be defined externally, so hack it in for testing + const char* deprecationReplacementString = ""; + addGlobalBinding(frontend, "Version", Binding{typeChecker.anyType, {}, true, deprecationReplacementString}); + + LintResult result = lintTyped("Version()"); + + REQUIRE(1 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Global 'Version' is deprecated"); +} + TEST_CASE_FIXTURE(Fixture, "PlaceholderRead") { LintResult result = lint(R"( @@ -50,7 +62,18 @@ local _ = 5 return _ )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Placeholder value '_' is read here; consider using a named variable"); +} + +TEST_CASE_FIXTURE(Fixture, "PlaceholderReadGlobal") +{ + LintResult result = lint(R"( +_ = 5 +print(_) +)"); + + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Placeholder value '_' is read here; consider using a named variable"); } @@ -61,10 +84,10 @@ local _ = 5 _ = 6 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } -TEST_CASE_FIXTURE(Fixture, "BuiltinGlobalWrite") +TEST_CASE_FIXTURE(BuiltinsFixture, "BuiltinGlobalWrite") { LintResult result = lint(R"( math = {} @@ -75,7 +98,7 @@ end assert(5) )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Built-in global 'math' is overwritten here; consider using a local or changing the name"); CHECK_EQ(result.warnings[1].text, "Built-in global 'assert' is overwritten here; consider using a local or changing the name"); } @@ -86,7 +109,7 @@ TEST_CASE_FIXTURE(Fixture, "MultilineBlock") if true then print(1) print(2) print(3) end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "A new statement is on the same line; add semi-colon on previous statement to silence"); } @@ -96,7 +119,7 @@ TEST_CASE_FIXTURE(Fixture, "MultilineBlockSemicolonsWhitelisted") print(1); print(2); print(3) )"); - CHECK(result.warnings.empty()); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "MultilineBlockMissedSemicolon") @@ -105,7 +128,7 @@ TEST_CASE_FIXTURE(Fixture, "MultilineBlockMissedSemicolon") print(1); print(2) print(3) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "A new statement is on the same line; add semi-colon on previous statement to silence"); } @@ -117,7 +140,7 @@ local _x do end )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "ConfusingIndentation") @@ -127,7 +150,7 @@ print(math.max(1, 2)) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Statement spans multiple lines; use indentation to silence"); } @@ -142,10 +165,117 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'foo' is only used in the enclosing function 'bar'; consider changing it to local"); } +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMultiFx") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + foo = 6 + return foo +end + +function baz() + foo = 6 + return foo +end + +return bar() + baz() +)"); + + REQUIRE(1 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Global 'foo' is never read before being written. Consider changing it to local"); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMultiFxWithRead") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + foo = 6 + return foo +end + +function baz() + foo = 6 + return foo +end + +function read() + print(foo) +end + +return bar() + baz() + read() +)"); + + REQUIRE(0 == result.warnings.size()); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalWithConditional") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + if true then foo = 6 end + return foo +end + +function baz() + foo = 6 + return foo +end + +return bar() + baz() +)"); + + REQUIRE(0 == result.warnings.size()); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocal3WithConditionalRead") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + foo = 6 + return foo +end + +function baz() + foo = 6 + return foo +end + +function read() + if false then print(foo) end +end + +return bar() + baz() + read() +)"); + + REQUIRE(0 == result.warnings.size()); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalInnerRead") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function foo() + local f = function() return bar end + f() + bar = 42 +end + +function baz() bar = 0 end + +return foo() + baz() +)"); + + REQUIRE(0 == result.warnings.size()); +} + TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMulti") { LintResult result = lint(R"( @@ -172,7 +302,7 @@ fnA() -- prints "true", "nil" fnB() -- prints "false", "nil" )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'moreInternalLogic' is only used in the enclosing function defined at line 2; consider changing it to local"); } @@ -187,11 +317,11 @@ local arg = 5 print(arg) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'arg' shadows previous declaration at line 2"); } -TEST_CASE_FIXTURE(Fixture, "LocalShadowGlobal") +TEST_CASE_FIXTURE(BuiltinsFixture, "LocalShadowGlobal") { LintResult result = lint(R"( local math = math @@ -205,7 +335,7 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'global' shadows a global variable used at line 3"); } @@ -220,7 +350,7 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'a' shadows previous declaration at line 2"); } @@ -240,7 +370,7 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'arg' is never used; prefix with '_' to silence"); CHECK_EQ(result.warnings[1].text, "Variable 'blarg' is never used; prefix with '_' to silence"); } @@ -248,14 +378,14 @@ return bar() TEST_CASE_FIXTURE(Fixture, "ImportUnused") { // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(typeChecker, "game", typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "game", typeChecker.anyType, "@test"); LintResult result = lint(R"( local Roact = require(game.Packages.Roact) local _Roact = require(game.Packages.Roact) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Import 'Roact' is never used; prefix with '_' to silence"); } @@ -280,7 +410,7 @@ end return foo() )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Function 'bar' is never used; prefix with '_' to silence"); CHECK_EQ(result.warnings[1].text, "Function 'qux' is never used; prefix with '_' to silence"); } @@ -295,7 +425,7 @@ end print("hi!") )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 5); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always returns)"); } @@ -311,7 +441,7 @@ end print("hi!") )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always breaks)"); } @@ -327,7 +457,7 @@ end print("hi!") )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always continues)"); } @@ -363,7 +493,7 @@ end return { foo1, foo2, foo3 } )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 7); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always returns)"); } @@ -383,7 +513,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnreachableCodeAssertFalseReturnSilent") @@ -400,7 +530,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnreachableCodeErrorReturnNonSilentBranchy") @@ -418,7 +548,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 7); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always errors)"); } @@ -439,7 +569,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 8); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always errors)"); } @@ -457,7 +587,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnreachableCodeLoopRepeat") @@ -473,16 +603,12 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), - 0); // this is technically a bug, since the repeat body always returns; fixing this bug is a bit more involved than I'd like + // this is technically a bug, since the repeat body always returns; fixing this bug is a bit more involved than I'd like + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnknownType") { - ScopedFastFlag sff{"LuauLinterUnknownTypeVectorAware", true}; - - SourceModule sm; - unfreeze(typeChecker.globalTypes); TableTypeVar::Props instanceProps{ {"ClassName", {typeChecker.anyType}}, @@ -492,81 +618,26 @@ TEST_CASE_FIXTURE(Fixture, "UnknownType") TypeId instanceType = typeChecker.globalTypes.addType(instanceTable); TypeFun instanceTypeFun{{}, instanceType}; - ClassTypeVar::Props enumItemProps{ - {"EnumType", {typeChecker.anyType}}, - }; - - ClassTypeVar enumItemClass{"EnumItem", enumItemProps, std::nullopt, std::nullopt, {}, {}}; - TypeId enumItemType = typeChecker.globalTypes.addType(enumItemClass); - TypeFun enumItemTypeFun{{}, enumItemType}; - - ClassTypeVar normalIdClass{"NormalId", {}, enumItemType, std::nullopt, {}, {}}; - TypeId normalIdType = typeChecker.globalTypes.addType(normalIdClass); - TypeFun normalIdTypeFun{{}, normalIdType}; - - // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(typeChecker, "game", typeChecker.anyType, "@test"); - addGlobalBinding(typeChecker, "typeof", typeChecker.anyType, "@test"); typeChecker.globalScope->exportedTypeBindings["Part"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["Workspace"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["RunService"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["Instance"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["ColorSequence"] = TypeFun{{}, typeChecker.anyType}; - typeChecker.globalScope->exportedTypeBindings["EnumItem"] = enumItemTypeFun; - typeChecker.globalScope->importedTypeBindings["Enum"] = {{"NormalId", normalIdTypeFun}}; - freeze(typeChecker.globalTypes); LintResult result = lint(R"( -local _e01 = game:GetService("Foo") -local _e02 = game:GetService("NormalId") -local _e03 = game:FindService("table") -local _e04 = type(game) == "Part" -local _e05 = type(game) == "NormalId" -local _e06 = typeof(game) == "Bar" -local _e07 = typeof(game) == "Part" -local _e08 = typeof(game) == "vector" -local _e09 = typeof(game) == "NormalId" -local _e10 = game:IsA("ColorSequence") -local _e11 = game:IsA("Enum.NormalId") -local _e12 = game:FindFirstChildWhichIsA("function") +local game = ... +local _e01 = type(game) == "Part" +local _e02 = typeof(game) == "Bar" +local _e03 = typeof(game) == "vector" -local _o01 = game:GetService("Workspace") -local _o02 = game:FindService("RunService") -local _o03 = type(game) == "number" -local _o04 = type(game) == "vector" -local _o05 = typeof(game) == "string" -local _o06 = typeof(game) == "Instance" -local _o07 = typeof(game) == "EnumItem" -local _o08 = game:IsA("Part") -local _o09 = game:IsA("NormalId") -local _o10 = game:FindFirstChildWhichIsA("Part") +local _o01 = type(game) == "number" +local _o02 = type(game) == "vector" +local _o03 = typeof(game) == "Part" )"); - REQUIRE_EQ(result.warnings.size(), 12); - CHECK_EQ(result.warnings[0].location.begin.line, 1); - CHECK_EQ(result.warnings[0].text, "Unknown type 'Foo'"); - CHECK_EQ(result.warnings[1].location.begin.line, 2); - CHECK_EQ(result.warnings[1].text, "Unknown type 'NormalId' (expected class type)"); - CHECK_EQ(result.warnings[2].location.begin.line, 3); - CHECK_EQ(result.warnings[2].text, "Unknown type 'table' (expected class type)"); - CHECK_EQ(result.warnings[3].location.begin.line, 4); - CHECK_EQ(result.warnings[3].text, "Unknown type 'Part' (expected primitive type)"); - CHECK_EQ(result.warnings[4].location.begin.line, 5); - CHECK_EQ(result.warnings[4].text, "Unknown type 'NormalId' (expected primitive type)"); - CHECK_EQ(result.warnings[5].location.begin.line, 6); - CHECK_EQ(result.warnings[5].text, "Unknown type 'Bar'"); - CHECK_EQ(result.warnings[6].location.begin.line, 7); - CHECK_EQ(result.warnings[6].text, "Unknown type 'Part' (expected primitive or userdata type)"); - CHECK_EQ(result.warnings[7].location.begin.line, 8); - CHECK_EQ(result.warnings[7].text, "Unknown type 'vector' (expected primitive or userdata type)"); - CHECK_EQ(result.warnings[8].location.begin.line, 9); - CHECK_EQ(result.warnings[8].text, "Unknown type 'NormalId' (expected primitive or userdata type)"); - CHECK_EQ(result.warnings[9].location.begin.line, 10); - CHECK_EQ(result.warnings[9].text, "Unknown type 'ColorSequence' (expected class or enum type)"); - CHECK_EQ(result.warnings[10].location.begin.line, 11); - CHECK_EQ(result.warnings[10].text, "Unknown type 'Enum.NormalId'"); - CHECK_EQ(result.warnings[11].location.begin.line, 12); - CHECK_EQ(result.warnings[11].text, "Unknown type 'function' (expected class type)"); + REQUIRE(3 == result.warnings.size()); + CHECK_EQ(result.warnings[0].location.begin.line, 2); + CHECK_EQ(result.warnings[0].text, "Unknown type 'Part' (expected primitive type)"); + CHECK_EQ(result.warnings[1].location.begin.line, 3); + CHECK_EQ(result.warnings[1].text, "Unknown type 'Bar'"); + CHECK_EQ(result.warnings[2].location.begin.line, 4); + CHECK_EQ(result.warnings[2].text, "Unknown type 'vector' (expected primitive or userdata type)"); } TEST_CASE_FIXTURE(Fixture, "ForRangeTable") @@ -581,7 +652,7 @@ for i=#t,1,-1 do end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[0].text, "For loop should iterate backwards; did you forget to specify -1 as step?"); } @@ -596,7 +667,7 @@ for i=8,1,-1 do end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 1); CHECK_EQ(result.warnings[0].text, "For loop should iterate backwards; did you forget to specify -1 as step?"); } @@ -611,7 +682,7 @@ for i=1.3,7.5,1 do end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 1); CHECK_EQ(result.warnings[0].text, "For loop ends at 7.3 instead of 7.5; did you forget to specify step?"); } @@ -629,7 +700,7 @@ for i=#t,0 do end )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 1); CHECK_EQ(result.warnings[0].text, "For loop starts at 0, but arrays start at 1"); CHECK_EQ(result.warnings[1].location.begin.line, 7); @@ -657,7 +728,7 @@ local _a,_b,_c = pcall(), nil end )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 5); CHECK_EQ(result.warnings[0].text, "Assigning 2 values to 3 variables initializes extra variables with nil; add 'nil' to value list to silence"); CHECK_EQ(result.warnings[1].location.begin.line, 11); @@ -722,7 +793,7 @@ end return f1,f2,f3,f4,f5,f6,f7 )"); - CHECK_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 4); CHECK_EQ(result.warnings[0].text, "Function 'f1' can implicitly return no values even though there's an explicit return at line 4; add explicit return to silence"); @@ -778,7 +849,7 @@ end return f1,f2,f3,f4 )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 25); CHECK_EQ(result.warnings[0].text, "Function 'f3' can implicitly return no values even though there's an explicit return at line 21; add explicit return to silence"); @@ -791,17 +862,17 @@ TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings") { LintResult result = lint(R"(--!strict type InputData = { - id: number, - inputType: EnumItem, - inputState: EnumItem, - updated: number, - position: Vector3, - keyCode: EnumItem, - name: string + id: number, + inputType: EnumItem, + inputState: EnumItem, + updated: number, + position: Vector3, + keyCode: EnumItem, + name: string } )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "BreakFromInfiniteLoopMakesStatementReachable") @@ -820,7 +891,7 @@ until true return 1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "IgnoreLintAll") @@ -830,7 +901,7 @@ TEST_CASE_FIXTURE(Fixture, "IgnoreLintAll") return foo )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "IgnoreLintSpecific") @@ -841,7 +912,7 @@ local x = 1 return foo )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'x' is never used; prefix with '_' to silence"); } @@ -857,10 +928,10 @@ string.format("%Y") local _ = ("%"):format() -- correct format strings, just to uh make sure -string.format("hello %d %f", 4, 5) +string.format("hello %+10d %.02f %%", 4, 5) )"); - CHECK_EQ(result.warnings.size(), 4); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid format string: unfinished format specifier"); CHECK_EQ(result.warnings[1].text, "Invalid format string: invalid format specifier: must be a string format specifier or %"); CHECK_EQ(result.warnings[2].text, "Invalid format string: invalid format specifier: must be a string format specifier or %"); @@ -900,7 +971,7 @@ string.packsize("c99999999999999999999") string.packsize("=!1bbbI3c42") )"); - CHECK_EQ(result.warnings.size(), 11); + REQUIRE(11 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); CHECK_EQ(result.warnings[1].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); CHECK_EQ(result.warnings[2].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); @@ -944,7 +1015,7 @@ local _ = s:match("%q") string.match(s, "[A-Z]+(%d)%1") )"); - CHECK_EQ(result.warnings.size(), 14); + REQUIRE(14 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); CHECK_EQ(result.warnings[2].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); @@ -976,7 +1047,7 @@ string.match(s, "((a)%1)") string.match(s, "((a)%3)") )~"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid capture reference, must refer to a closed capture"); CHECK_EQ(result.warnings[0].location.begin.line, 7); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid capture reference, must refer to a valid capture"); @@ -1014,7 +1085,7 @@ string.match(s, "[]|'[]") string.match(s, "[^]|'[]") )~"); - CHECK_EQ(result.warnings.size(), 7); + REQUIRE(7 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: expected ] at the end of the string to close a set"); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: expected ] at the end of the string to close a set"); CHECK_EQ(result.warnings[2].text, "Invalid match pattern: character range can't include character sets"); @@ -1045,7 +1116,7 @@ string.find("foo"); ("foo"):find() )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); CHECK_EQ(result.warnings[0].location.begin.line, 4); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); @@ -1068,7 +1139,7 @@ string.gsub(s, '[A-Z]+(%d)', "%0%1") string.gsub(s, 'foo', "%0") )"); - CHECK_EQ(result.warnings.size(), 4); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match replacement: unfinished replacement"); CHECK_EQ(result.warnings[1].text, "Invalid match replacement: unexpected replacement character; must be a digit or %"); CHECK_EQ(result.warnings[2].text, "Invalid match replacement: invalid capture index, must refer to pattern capture"); @@ -1082,16 +1153,18 @@ TEST_CASE_FIXTURE(Fixture, "FormatStringDate") os.date("%") os.date("%L") os.date("%?") +os.date("\0") -- correct formats os.date("it's %c now") os.date("!*t") )"); - CHECK_EQ(result.warnings.size(), 3); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid date format: unfinished replacement"); CHECK_EQ(result.warnings[1].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); CHECK_EQ(result.warnings[2].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); + CHECK_EQ(result.warnings[3].text, "Invalid date format: date format can not contain null characters"); } TEST_CASE_FIXTURE(Fixture, "FormatStringTyped") @@ -1106,7 +1179,7 @@ s:match("[]") nons:match("[]") )~"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: expected ] at the end of the string to close a set"); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: expected ] at the end of the string to close a set"); @@ -1156,7 +1229,7 @@ _ = { } )"); - CHECK_EQ(result.warnings.size(), 6); + REQUIRE(6 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Table field 'first' is a duplicate; previously defined at line 3"); CHECK_EQ(result.warnings[1].text, "Table field 'first' is a duplicate; previously defined at line 9"); CHECK_EQ(result.warnings[2].text, "Table index 1 is a duplicate; previously defined as a list entry"); @@ -1173,7 +1246,7 @@ TEST_CASE_FIXTURE(Fixture, "ImportOnlyUsedInTypeAnnotation") local x: Foo.Y = 1 )"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'x' is never used; prefix with '_' to silence"); } @@ -1184,7 +1257,7 @@ TEST_CASE_FIXTURE(Fixture, "DisableUnknownGlobalWithTypeChecking") unknownGlobal() )"); - REQUIRE_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "no_spurious_warning_after_a_function_type_alias") @@ -1196,7 +1269,7 @@ TEST_CASE_FIXTURE(Fixture, "no_spurious_warning_after_a_function_type_alias") return exports )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") @@ -1219,7 +1292,7 @@ TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") LintResult result = frontend.lint("A"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "DeadLocalsUsed") @@ -1245,7 +1318,7 @@ do end )"); - CHECK_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'x' defined at line 4 is never initialized or assigned; initialize with 'nil' to silence"); CHECK_EQ(result.warnings[1].text, "Assigning 2 values to 3 variables initializes extra variables with nil; add 'nil' to value list to silence"); CHECK_EQ(result.warnings[2].text, "Variable 'c' defined at line 12 is never initialized or assigned; initialize with 'nil' to silence"); @@ -1258,7 +1331,7 @@ local foo function foo() end )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "DuplicateGlobalFunction") @@ -1333,7 +1406,7 @@ TEST_CASE_FIXTURE(Fixture, "DontTriggerTheWarningIfTheFunctionsAreInDifferentSco return c )"); - CHECK(result.warnings.empty()); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") @@ -1369,13 +1442,14 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") local h: Hooty.Pt )"); - CHECK_EQ(result.warnings.size(), 12); + REQUIRE(12 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") { unfreeze(typeChecker.globalTypes); - TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}}); + TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); + persist(instanceType); typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; getMutable(instanceType)->props = { @@ -1383,22 +1457,32 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") {"DataCost", {typeChecker.numberType, /* deprecated= */ true}}, {"Wait", {typeChecker.anyType, /* deprecated= */ true}}, }; + + TypeId colorType = typeChecker.globalTypes.addType(TableTypeVar{{}, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}); + + getMutable(colorType)->props = {{"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; + + addGlobalBinding(frontend, "Color3", Binding{colorType, {}}); + freeze(typeChecker.globalTypes); LintResult result = lintTyped(R"( return function (i: Instance) i:Wait(1.0) print(i.Name) + print(Color3.toHSV()) + print(Color3.doesntexist, i.doesntexist) -- type error, but this verifies we correctly handle non-existent members return i.DataCost end )"); - REQUIRE_EQ(result.warnings.size(), 2); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Member 'Instance.Wait' is deprecated"); - CHECK_EQ(result.warnings[1].text, "Member 'Instance.DataCost' is deprecated"); + CHECK_EQ(result.warnings[1].text, "Member 'toHSV' is deprecated, use 'Color3:ToHSV' instead"); + CHECK_EQ(result.warnings[2].text, "Member 'Instance.DataCost' is deprecated"); } -TEST_CASE_FIXTURE(Fixture, "TableOperations") +TEST_CASE_FIXTURE(BuiltinsFixture, "TableOperations") { LintResult result = lintTyped(R"( local t = {} @@ -1417,9 +1501,15 @@ table.remove(t, 0) table.remove(t, #t-1) table.insert(t, string.find("hello", "h")) + +table.move(t, 0, #t, 1, tt) +table.move(t, 1, #t, 0, tt) + +table.create(42, {}) +table.create(42, {} :: {}) )"); - REQUIRE_EQ(result.warnings.size(), 6); + REQUIRE(10 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the " "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency"); @@ -1429,6 +1519,12 @@ table.insert(t, string.find("hello", "h")) "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[5].text, "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument"); + CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + CHECK_EQ( + result.warnings[8].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + CHECK_EQ( + result.warnings[9].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditions") @@ -1454,9 +1550,11 @@ _ = (true and true) or true _ = (true and false) and (42 and false) _ = true and true or false -- no warning since this is is a common pattern used as a ternary replacement + +_ = if true then 1 elseif true then 2 else 3 )"); - REQUIRE_EQ(result.warnings.size(), 7); + REQUIRE(8 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); CHECK_EQ(result.warnings[0].location.begin.line + 1, 4); CHECK_EQ(result.warnings[1].text, "Condition has already been checked on column 5"); @@ -1466,6 +1564,23 @@ _ = true and true or false -- no warning since this is is a common pattern used CHECK_EQ(result.warnings[5].text, "Condition has already been checked on column 6"); CHECK_EQ(result.warnings[6].text, "Condition has already been checked on column 15"); CHECK_EQ(result.warnings[6].location.begin.line + 1, 19); + CHECK_EQ(result.warnings[7].text, "Condition has already been checked on column 8"); +} + +TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsExpr") +{ + LintResult result = lint(R"( +local correct, opaque = ... + +if correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls")}) then +elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls")}) then +elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls", false)}) then +end +)"); + + REQUIRE(1 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 4"); + CHECK_EQ(result.warnings[0].location.begin.line + 1, 5); } TEST_CASE_FIXTURE(Fixture, "DuplicateLocal") @@ -1484,11 +1599,167 @@ end return foo, moo, a1, a2 )"); - REQUIRE_EQ(result.warnings.size(), 4); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Function parameter 'a1' already defined on column 14"); CHECK_EQ(result.warnings[1].text, "Variable 'a1' is never used; prefix with '_' to silence"); CHECK_EQ(result.warnings[2].text, "Variable 'a1' already defined on column 7"); CHECK_EQ(result.warnings[3].text, "Function parameter 'self' already defined implicitly"); } +TEST_CASE_FIXTURE(Fixture, "MisleadingAndOr") +{ + LintResult result = lint(R"( +_ = math.random() < 0.5 and true or 42 +_ = math.random() < 0.5 and false or 42 -- misleading +_ = math.random() < 0.5 and nil or 42 -- misleading +_ = math.random() < 0.5 and 0 or 42 +_ = (math.random() < 0.5 and false) or 42 -- currently ignored +)"); + + REQUIRE(2 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "The and-or expression always evaluates to the second alternative because the first alternative is false; " + "consider using if-then-else expression instead"); + CHECK_EQ(result.warnings[1].text, "The and-or expression always evaluates to the second alternative because the first alternative is nil; " + "consider using if-then-else expression instead"); +} + +TEST_CASE_FIXTURE(Fixture, "WrongComment") +{ + LintResult result = lint(R"( +--!strict +--!struct +--!nolintGlobal +--!nolint Global +--!nolint KnownGlobal +--!nolint UnknownGlobal +--! no more lint +--!strict here +do end +--!nolint +)"); + + REQUIRE(6 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Unknown comment directive 'struct'; did you mean 'strict'?"); + CHECK_EQ(result.warnings[1].text, "Unknown comment directive 'nolintGlobal'"); + CHECK_EQ(result.warnings[2].text, "nolint directive refers to unknown lint rule 'Global'"); + CHECK_EQ(result.warnings[3].text, "nolint directive refers to unknown lint rule 'KnownGlobal'; did you mean 'UnknownGlobal'?"); + CHECK_EQ(result.warnings[4].text, "Comment directive with the type checking mode has extra symbols at the end of the line"); + CHECK_EQ(result.warnings[5].text, "Comment directive is ignored because it is placed after the first non-comment token"); +} + +TEST_CASE_FIXTURE(Fixture, "WrongCommentMuteSelf") +{ + LintResult result = lint(R"( +--!nolint +--!struct +)"); + + REQUIRE(0 == result.warnings.size()); // --!nolint disables WrongComment lint :) +} + +TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsIfStatAndExpr") +{ + LintResult result = lint(R"( +if if 1 then 2 else 3 then +elseif if 1 then 2 else 3 then +elseif if 0 then 5 else 4 then +end +)"); + + REQUIRE(1 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); +} + +TEST_CASE_FIXTURE(Fixture, "WrongCommentOptimize") +{ + LintResult result = lint(R"( +--!optimize +--!optimize me +--!optimize 100500 +--!optimize 2 +)"); + + REQUIRE(3 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "optimize directive requires an optimization level"); + CHECK_EQ(result.warnings[1].text, "optimize directive uses unknown optimization level 'me', 0..2 expected"); + CHECK_EQ(result.warnings[2].text, "optimize directive uses unknown optimization level '100500', 0..2 expected"); + + result = lint("--!optimize "); + REQUIRE(1 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "optimize directive requires an optimization level"); +} + +TEST_CASE_FIXTURE(Fixture, "TestStringInterpolation") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + LintResult result = lint(R"( + --!nocheck + local _ = `unknown {foo}` + )"); + + REQUIRE(1 == result.warnings.size()); +} + +TEST_CASE_FIXTURE(Fixture, "IntegerParsing") +{ + LintResult result = lint(R"( +local _ = 0b10000000000000000000000000000000000000000000000000000000000000000 +local _ = 0x10000000000000000 +)"); + + REQUIRE(2 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "Binary number literal exceeded available precision and has been truncated to 2^64"); + CHECK_EQ(result.warnings[1].text, "Hexadecimal number literal exceeded available precision and has been truncated to 2^64"); +} + +// TODO: remove with FFlagLuauErrorDoubleHexPrefix +TEST_CASE_FIXTURE(Fixture, "IntegerParsingDoublePrefix") +{ + ScopedFastFlag luauErrorDoubleHexPrefix{"LuauErrorDoubleHexPrefix", false}; // Lint will be available until we start rejecting code + + LintResult result = lint(R"( +local _ = 0x0x123 +local _ = 0x0xffffffffffffffffffffffffffffffffff +)"); + + REQUIRE(2 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, + "Hexadecimal number literal has a double prefix, which will fail to parse in the future; remove the extra 0x to fix"); + CHECK_EQ(result.warnings[1].text, + "Hexadecimal number literal has a double prefix, which will fail to parse in the future; remove the extra 0x to fix"); +} + +TEST_CASE_FIXTURE(Fixture, "ComparisonPrecedence") +{ + LintResult result = lint(R"( +local a, b = ... + +local _ = not a == b +local _ = not a ~= b +local _ = not a <= b +local _ = a <= b == 0 +local _ = a <= b <= 0 + +local _ = not a == not b -- weird but ok + +-- silence tests for all of the above +local _ = not (a == b) +local _ = (not a) == b +local _ = not (a ~= b) +local _ = (not a) ~= b +local _ = not (a <= b) +local _ = (not a) <= b +local _ = (a <= b) == 0 +local _ = a <= (b == 0) +)"); + + REQUIRE(5 == result.warnings.size()); + CHECK_EQ(result.warnings[0].text, "not X == Y is equivalent to (not X) == Y; consider using X ~= Y, or add parentheses to silence"); + CHECK_EQ(result.warnings[1].text, "not X ~= Y is equivalent to (not X) ~= Y; consider using X == Y, or add parentheses to silence"); + CHECK_EQ(result.warnings[2].text, "not X <= Y is equivalent to (not X) <= Y; add parentheses to silence"); + CHECK_EQ(result.warnings[3].text, "X <= Y == Z is equivalent to (X <= Y) == Z; add parentheses to silence"); + CHECK_EQ(result.warnings[4].text, "X <= Y <= Z is equivalent to (X <= Y) <= Z; did you mean X <= Y and Y <= Z?"); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 1b146ed2..33d9c75a 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -1,5 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Clone.h" #include "Luau/Module.h" +#include "Luau/Scope.h" +#include "Luau/RecursionCounter.h" #include "Fixture.h" @@ -7,6 +10,8 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + TEST_SUITE_BEGIN("ModuleTests"); TEST_CASE_FIXTURE(Fixture, "is_within_comment") @@ -40,27 +45,23 @@ TEST_CASE_FIXTURE(Fixture, "is_within_comment") TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; + CloneState cloneState; // numberType is persistent. We leave it as-is. - TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks); + TypeId newNumber = clone(typeChecker.numberType, dest, cloneState); CHECK_EQ(newNumber, typeChecker.numberType); } TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; + CloneState cloneState; // Create a new number type that isn't persistent unfreeze(typeChecker.globalTypes); TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveTypeVar{PrimitiveTypeVar::Number}); freeze(typeChecker.globalTypes); - TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks); + TypeId newNumber = clone(oldNumber, dest, cloneState); CHECK_NE(newNumber, oldNumber); CHECK_EQ(*oldNumber, *newNumber); @@ -86,11 +87,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") TypeId counterType = requireType("Cyclic"); - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - TypeArena dest; - TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks); + CloneState cloneState; + TypeId counterCopy = clone(counterType, dest, cloneState); TableTypeVar* ttv = getMutable(counterCopy); REQUIRE(ttv != nullptr); @@ -103,7 +102,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") const FunctionTypeVar* ftv = get(methodType); REQUIRE(ftv != nullptr); - std::optional methodReturnType = first(ftv->retType); + std::optional methodReturnType = first(ftv->retTypes); REQUIRE(methodReturnType); CHECK_EQ(methodReturnType, counterCopy); @@ -111,7 +110,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") CHECK_EQ(2, dest.typeVars.size()); // One table and one function } -TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") { CheckResult result = check(R"( return {sign=math.sign} @@ -132,20 +131,21 @@ TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") REQUIRE(signType != nullptr); CHECK(!isInArena(signType, module->interfaceTypes)); - CHECK(isInArena(signType, typeChecker.globalTypes)); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK(isInArena(signType, frontend.globalTypes)); + else + CHECK(isInArena(signType, typeChecker.globalTypes)); } TEST_CASE_FIXTURE(Fixture, "deepClone_union") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; + CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldUnion = typeChecker.globalTypes.addType(UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks); + TypeId newUnion = clone(oldUnion, dest, cloneState); CHECK_NE(newUnion, oldUnion); CHECK_EQ("number | string", toString(newUnion)); @@ -155,14 +155,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; + CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks); + TypeId newIntersection = clone(oldIntersection, dest, cloneState); CHECK_NE(newIntersection, oldIntersection); CHECK_EQ("number & string", toString(newIntersection)); @@ -175,20 +173,18 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") { {"__add", {typeChecker.anyType}}, }, - std::nullopt, std::nullopt, {}, {}}}; + std::nullopt, std::nullopt, {}, {}, "Test"}}; TypeVar exampleClass{ClassTypeVar{"ExampleClass", { {"PropOne", {typeChecker.numberType}}, {"PropTwo", {typeChecker.stringType}}, }, - std::nullopt, &exampleMetaClass, {}, {}}}; + std::nullopt, &exampleMetaClass, {}, {}, "Test"}}; TypeArena dest; + CloneState cloneState; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - - TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks); + TypeId cloned = clone(&exampleClass, dest, cloneState); const ClassTypeVar* ctv = get(cloned); REQUIRE(ctv != nullptr); @@ -200,49 +196,42 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") CHECK_EQ("ExampleClassMeta", metatable->name); } -TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") +TEST_CASE_FIXTURE(Fixture, "clone_free_types") { TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; + CloneState cloneState; - bool encounteredFreeType = false; - TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTy)); - CHECK(encounteredFreeType); + TypeId clonedTy = clone(&freeTy, dest, cloneState); + CHECK(get(clonedTy)); - encounteredFreeType = false; - TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTp)); - CHECK(encounteredFreeType); + cloneState = {}; + TypePackId clonedTp = clone(&freeTp, dest, cloneState); + CHECK(get(clonedTp)); } -TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") +TEST_CASE_FIXTURE(Fixture, "clone_free_tables") { TypeVar tableTy{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tableTy); ttv->state = TableState::Free; TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; + CloneState cloneState; - bool encounteredFreeType = false; - TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); + TypeId cloned = clone(&tableTy, dest, cloneState); const TableTypeVar* clonedTtv = get(cloned); - CHECK_EQ(clonedTtv->state, TableState::Sealed); - CHECK(encounteredFreeType); + CHECK_EQ(clonedTtv->state, TableState::Free); } -TEST_CASE_FIXTURE(Fixture, "clone_self_property") +TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") { fileResolver.source["Module/A"] = R"( --!nonstrict local a = {} - function a:foo(x) + function a:foo(x: number) return -x; end return a; @@ -258,10 +247,130 @@ TEST_CASE_FIXTURE(Fixture, "clone_self_property") )"; result = frontend.check("Module/B"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("This function must be called with self. Did you mean to use a colon instead of a dot?", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") +{ +#if defined(_DEBUG) || defined(_NOOPT) + int limit = 250; +#else + int limit = 400; +#endif + ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; + + TypeArena src; + + TypeId table = src.addType(TableTypeVar{}); + TypeId nested = table; + + for (int i = 0; i < limit + 100; i++) + { + TableTypeVar* ttv = getMutable(nested); + + ttv->props["a"].type = src.addType(TableTypeVar{}); + nested = ttv->props["a"].type; + } + + TypeArena dest; + CloneState cloneState; + + CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); +} + +TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") +{ + fileResolver.source["Module/A"] = R"( +export type A = B +type B = A + )"; + + FrontendOptions opts; + opts.retainFullTypeGraphs = false; + CheckResult result = frontend.check("Module/A", opts); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), "This function was declared to accept self, but you did not pass enough arguments. Use a colon instead of a " - "dot or pass 1 extra nil to suppress this warning"); + auto mod = frontend.moduleResolver.getModule("Module/A"); + auto it = mod->getModuleScope()->exportedTypeBindings.find("A"); + REQUIRE(it != mod->getModuleScope()->exportedTypeBindings.end()); + CHECK(toString(it->second.type) == "any"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_reexports") +{ + ScopedFastFlag flags[] = { + {"LuauClonePublicInterfaceLess", true}, + {"LuauSubstitutionReentrant", true}, + {"LuauClassTypeVarsInSubstitution", true}, + {"LuauSubstitutionFixMissingFields", true}, + }; + + fileResolver.source["Module/A"] = R"( +export type A = {p : number} +return {} + )"; + + fileResolver.source["Module/B"] = R"( +local a = require(script.Parent.A) +export type B = {q : a.A} +return {} + )"; + + CheckResult result = frontend.check("Module/B"); + LUAU_REQUIRE_NO_ERRORS(result); + + ModulePtr modA = frontend.moduleResolver.getModule("Module/A"); + ModulePtr modB = frontend.moduleResolver.getModule("Module/B"); + REQUIRE(modA); + REQUIRE(modB); + auto modAiter = modA->getModuleScope()->exportedTypeBindings.find("A"); + auto modBiter = modB->getModuleScope()->exportedTypeBindings.find("B"); + REQUIRE(modAiter != modA->getModuleScope()->exportedTypeBindings.end()); + REQUIRE(modBiter != modB->getModuleScope()->exportedTypeBindings.end()); + TypeId typeA = modAiter->second.type; + TypeId typeB = modBiter->second.type; + TableTypeVar* tableB = getMutable(typeB); + REQUIRE(tableB); + CHECK(typeA == tableB->props["q"].type); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") +{ + ScopedFastFlag flags[] = { + {"LuauClonePublicInterfaceLess", true}, + {"LuauSubstitutionReentrant", true}, + {"LuauClassTypeVarsInSubstitution", true}, + {"LuauSubstitutionFixMissingFields", true}, + }; + + fileResolver.source["Module/A"] = R"( +local exports = {a={p=5}} +return exports + )"; + + fileResolver.source["Module/B"] = R"( +local a = require(script.Parent.A) +local exports = {b=a.a} +return exports + )"; + + CheckResult result = frontend.check("Module/B"); + LUAU_REQUIRE_NO_ERRORS(result); + + ModulePtr modA = frontend.moduleResolver.getModule("Module/A"); + ModulePtr modB = frontend.moduleResolver.getModule("Module/B"); + REQUIRE(modA); + REQUIRE(modB); + std::optional typeA = first(modA->getModuleScope()->returnType); + std::optional typeB = first(modB->getModuleScope()->returnType); + REQUIRE(typeA); + REQUIRE(typeB); + TableTypeVar* tableA = getMutable(*typeA); + TableTypeVar* tableB = getMutable(*typeB); + CHECK(tableA->props["a"].type == tableB->props["b"].type); } TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index f3c76d55..89aab5ee 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -31,7 +31,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") REQUIRE_EQ("any", toString(args[0])); REQUIRE_EQ("any", toString(args[1])); - auto rets = flatten(ftv->retType).first; + auto rets = flatten(ftv->retTypes).first; REQUIRE_EQ(0, rets.size()); } @@ -136,14 +136,9 @@ TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") T:staticmethod() )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(std::any_of(result.errors.begin(), result.errors.end(), [](const TypeError& e) { - return get(e); - })); - CHECK(std::any_of(result.errors.begin(), result.errors.end(), [](const TypeError& e) { - return get(e); - })); + CHECK_EQ("This function does not take self. Did you mean to use a dot instead of a colon?", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") @@ -151,16 +146,13 @@ TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") CheckResult result = check(R"( --!nonstrict local T = {} - function T:method() end - T.method() + function T:method(x: number) end + T.method(5) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - auto e = get(result.errors[0]); - REQUIRE(e != nullptr); - - REQUIRE_EQ(1, e->requiredExtraNils); + CHECK_EQ("This function must be called with self. Did you mean to use a colon instead of a dot?", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "table_props_are_any") @@ -177,6 +169,7 @@ TEST_CASE_FIXTURE(Fixture, "table_props_are_any") REQUIRE(ttv != nullptr); + REQUIRE(ttv->props.count("foo")); TypeId fooProp = ttv->props["foo"].type; REQUIRE(fooProp != nullptr); @@ -204,7 +197,7 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") CHECK_MESSAGE(get(ttv->props["three"].type), "Should be a function: " << *ttv->props["three"].type); } -TEST_CASE_FIXTURE(Fixture, "for_in_iterator_variables_are_any") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_iterator_variables_are_any") { CheckResult result = check(R"( --!nonstrict @@ -223,7 +216,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_iterator_variables_are_any") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_dot_insert_and_recursive_calls") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_insert_and_recursive_calls") { CheckResult result = check(R"( --!nonstrict @@ -279,4 +272,38 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") REQUIRE_EQ("any", toString(getMainModule()->getModuleScope()->returnType)); } +TEST_CASE_FIXTURE(Fixture, "returning_insufficient_return_values") +{ + CheckResult result = check(R"( + --!nonstrict + + function foo(): (boolean, string?) + if true then + return true, "hello" + else + return false + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "returning_too_many_values") +{ + CheckResult result = check(R"( + --!nonstrict + + function foo(): boolean + if true then + return true, "hello" + else + return false + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp new file mode 100644 index 00000000..e6bf00a1 --- /dev/null +++ b/tests/Normalize.test.cpp @@ -0,0 +1,606 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "Luau/Common.h" +#include "Luau/TypeVar.h" +#include "doctest.h" + +#include "Luau/Normalize.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +namespace +{ +struct IsSubtypeFixture : Fixture +{ + bool isSubtype(TypeId a, TypeId b) + { + return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); + } +}; +} // namespace + +void createSomeClasses(Frontend& frontend) +{ + auto& arena = frontend.globalTypes; + + unfreeze(arena); + + TypeId parentType = arena.addType(ClassTypeVar{"Parent", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); + + ClassTypeVar* parentClass = getMutable(parentType); + parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; + + parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; + + addGlobalBinding(frontend, "Parent", {parentType}); + frontend.getGlobalScope()->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; + + TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); + + ClassTypeVar* childClass = getMutable(childType); + childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; + + addGlobalBinding(frontend, "Child", {childType}); + frontend.getGlobalScope()->exportedTypeBindings["Child"] = TypeFun{{}, childType}; + + TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); + + addGlobalBinding(frontend, "Unrelated", {unrelatedType}); + frontend.getGlobalScope()->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; + + for (const auto& [name, ty] : frontend.getGlobalScope()->exportedTypeBindings) + persist(ty.type); + + freeze(arena); +} + +TEST_SUITE_BEGIN("isSubtype"); + +TEST_CASE_FIXTURE(IsSubtypeFixture, "primitives") +{ + check(R"( + local a = 41 + local b = 32 + + local c = "hello" + local d = "world" + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(isSubtype(d, c)); + CHECK(!isSubtype(d, a)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions") +{ + check(R"( + function a(x: number): number return x end + function b(x: number): number return x end + + function c(x: number?): number return x end + function d(x: number): number? return x end + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(isSubtype(c, a)); + CHECK(!isSubtype(d, a)); + CHECK(isSubtype(a, d)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_and_any") +{ + check(R"( + function a(n: number) return "string" end + function b(q: any) return 5 :: any end + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + // any makes things work even when it makes no sense. + + CHECK(isSubtype(b, a)); + CHECK(isSubtype(a, b)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_functions_with_no_head") +{ + check(R"( + local a: (...number) -> () + local b: (...number?) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); +} + +#if 0 +TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_function_with_head") +{ + check(R"( + local a: (...number) -> () + local b: (number, number) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); +} +#endif + +TEST_CASE_FIXTURE(IsSubtypeFixture, "union") +{ + check(R"( + local a: number | string + local b: number + local c: string + local d: number? + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(!isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(b, d)); + CHECK(!isSubtype(d, b)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_union_prop") +{ + check(R"( + local a: {x: number} + local b: {x: number?} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop") +{ + check(R"( + local a: {x: number} + local b: {x: any} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(isSubtype(b, a)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + check(R"( + local a: number & string + local b: number + local c: string + local d: number & nil + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); + + CHECK(!isSubtype(c, a)); + CHECK(isSubtype(a, c)); + + // These types are both equivalent to never + CHECK(isSubtype(d, a)); + CHECK(isSubtype(a, d)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "union_and_intersection") +{ + check(R"( + local a: number & string + local b: number | nil + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "tables") +{ + check(R"( + local a: {x: number} + local b: {x: any} + local c: {y: number} + local d: {x: number, y: number} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(a, b)); + CHECK(isSubtype(b, a)); + + CHECK(!isSubtype(c, a)); + CHECK(!isSubtype(a, c)); + + CHECK(isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(d, b)); + CHECK(!isSubtype(b, d)); +} + +#if 0 +TEST_CASE_FIXTURE(IsSubtypeFixture, "table_indexers_are_invariant") +{ + check(R"( + local a: {[string]: number} + local b: {[string]: any} + local c: {[string]: number} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(isSubtype(a, c)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "mismatched_indexers") +{ + check(R"( + local a: {x: number} + local b: {[string]: number} + local c: {} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(!isSubtype(c, b)); + CHECK(isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "cyclic_table") +{ + check(R"( + type A = {method: (A) -> ()} + local a: A + + type B = {method: (any) -> ()} + local b: B + + type C = {method: (C) -> ()} + local c: C + + type D = {method: (D) -> (), another: (D) -> ()} + local d: D + + type E = {method: (A) -> (), another: (E) -> ()} + local e: E + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + TypeId e = requireType("e"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(e, a)); + CHECK(!isSubtype(a, e)); +} +#endif + +TEST_CASE_FIXTURE(IsSubtypeFixture, "classes") +{ + createSomeClasses(frontend); + + check(""); // Ensure that we have a main Module. + + TypeId p = typeChecker.globalScope->lookupType("Parent")->type; + TypeId c = typeChecker.globalScope->lookupType("Child")->type; + TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type; + + CHECK(isSubtype(c, p)); + CHECK(!isSubtype(p, c)); + CHECK(!isSubtype(u, p)); + CHECK(!isSubtype(p, u)); +} + +#if 0 +TEST_CASE_FIXTURE(IsSubtypeFixture, "metatable" * doctest::expected_failures{1}) +{ + check(R"( + local T = {} + T.__index = T + function T.new() + return setmetatable({}, T) + end + + function T:method() end + + local a: typeof(T.new) + local b: {method: (any) -> ()} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); +} +#endif + +TEST_SUITE_END(); + +struct NormalizeFixture : Fixture +{ + ScopedFastFlag sff1{"LuauNegatedFunctionTypes", true}; + + TypeArena arena; + InternalErrorReporter iceHandler; + UnifierSharedState unifierState{&iceHandler}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&unifierState}}; + + NormalizeFixture() + { + registerHiddenTypes(*this, arena); + } + + const NormalizedType* toNormalizedType(const std::string& annotation) + { + CheckResult result = check("type _Res = " + annotation); + LUAU_REQUIRE_NO_ERRORS(result); + std::optional ty = lookupType("_Res"); + REQUIRE(ty); + return normalizer.normalize(*ty); + } + + TypeId normal(const std::string& annotation) + { + const NormalizedType* norm = toNormalizedType(annotation); + REQUIRE(norm); + return normalizer.typeFromNormal(*norm); + } +}; + +TEST_SUITE_BEGIN("Normalize"); + +TEST_CASE_FIXTURE(NormalizeFixture, "negate_string") +{ + CHECK("number" == toString(normal(R"( + (number | string) & Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "negate_string_from_cofinite_string_intersection") +{ + CHECK("number" == toString(normal(R"( + (number | (string & Not<"hello"> & Not<"world">)) & Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "no_op_negation_is_dropped") +{ + CHECK("number" == toString(normal(R"( + number & Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_of_negation") +{ + CHECK("string" == toString(normal(R"( + (string & Not<"hello">) | "hello" + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersect_truthy") +{ + CHECK("number | string | true" == toString(normal(R"( + (string | number | boolean | nil) & Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersect_truthy_expressed_as_intersection") +{ + CHECK("number | string | true" == toString(normal(R"( + (string | number | boolean | nil) & Not & Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_of_union") +{ + CHECK(R"("alpha" | "beta" | "gamma")" == toString(normal(R"( + ("alpha" | "beta") | "gamma" + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_of_negations") +{ + CHECK(R"(string & ~"world")" == toString(normal(R"( + (string & Not<"hello"> & Not<"world">) | (string & Not<"goodbye"> & Not<"world">) + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "disjoint_negations_normalize_to_string") +{ + CHECK(R"(string)" == toString(normal(R"( + (string & Not<"hello"> & Not<"world">) | (string & Not<"goodbye">) + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean") +{ + CHECK("true" == toString(normal(R"( + boolean & Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "negate_boolean_2") +{ + CHECK("never" == toString(normal(R"( + true & Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersect_function_and_top_function") +{ + CHECK("() -> ()" == toString(normal(R"( + fun & (() -> ()) + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersect_function_and_top_function_reverse") +{ + CHECK("() -> ()" == toString(normal(R"( + (() -> ()) & fun + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_function_and_top_function") +{ + CHECK("function" == toString(normal(R"( + fun | (() -> ()) + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "negated_function_is_anything_except_a_function") +{ + CHECK("(boolean | number | string | thread)?" == toString(normal(R"( + Not + )"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "specific_functions_cannot_be_negated") +{ + CHECK(nullptr == toNormalizedType("Not<(boolean) -> boolean>")); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean") +{ + // TODO: We don't yet have a way to say number | string | thread | nil | Class | Table | Function + CHECK("(function | number | string | thread)?" == toString(normal(R"( + Not + )"))); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function") +{ + check(R"( + function apply(f, x) + return f(x) + end + + local a = apply(function(x: number) return x + x end, 5) + )"); + + TypeId aType = requireType("a"); + CHECK_MESSAGE(isNumber(follow(aType)), "Expected a number but got ", toString(aType)); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_with_annotation") +{ + check(R"( + function apply(f: (a) -> b, x) + return f(x) + end + )"); + + CHECK_EQ("((a) -> b, a) -> b", toString(requireType("apply"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_table_normalizes_sensibly") +{ + CheckResult result = check(R"( + local Cyclic = {} + function Cyclic.get() + return Cyclic + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = requireType("Cyclic"); + CHECK_EQ("t1 where t1 = { get: () -> t1 }", toString(ty, {true})); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "skip_force_normal_on_external_types") +{ + createSomeClasses(frontend); + + CheckResult result = check(R"( +export type t0 = { a: Child } +export type t1 = { a: typeof(string.byte) } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_combine_on_bound_self") +{ + CheckResult result = check(R"( +export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,})) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp new file mode 100644 index 00000000..b827b81b --- /dev/null +++ b/tests/NotNull.test.cpp @@ -0,0 +1,160 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/NotNull.h" + +#include "doctest.h" + +#include +#include +#include + +using Luau::NotNull; + +static_assert(!std::is_convertible, bool>::value, "NotNull ought not to be convertible into bool"); + +namespace +{ + +struct Test +{ + int x; + float y; + + static int count; + Test() + { + ++count; + } + + ~Test() + { + --count; + } +}; + +int Test::count = 0; + +} // namespace + +int foo(NotNull p) +{ + return *p; +} + +void bar(int* q) {} + +TEST_SUITE_BEGIN("NotNull"); + +TEST_CASE("basic_stuff") +{ + NotNull a = NotNull{new int(55)}; // Does runtime test + NotNull b{new int(55)}; // As above + // NotNull c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull in the general case is not + // good. + + // a = nullptr; // nope + + NotNull d = a; // No runtime test. a is known not to be null. + + int e = *d; + *d = 1; + CHECK(e == 55); + + const NotNull f = d; + *f = 5; // valid: there is a difference between const NotNull and NotNull + // f = a; // nope + + CHECK_EQ(a, d); + CHECK(a != b); + + NotNull g(a); + CHECK(g == a); + + // *g = 123; // nope + + (void)f; + + NotNull t{new Test}; + t->x = 5; + t->y = 3.14f; + + const NotNull u = t; + u->x = 44; + int v = u->x; + CHECK(v == 44); + + bar(a); + + // a++; // nope + // a[41]; // nope + // a + 41; // nope + // a - 41; // nope + + delete a; + delete b; + delete t; + + CHECK_EQ(0, Test::count); +} + +TEST_CASE("hashable") +{ + std::unordered_map, const char*> map; + int a_ = 8; + int b_ = 10; + + NotNull a{&a_}; + NotNull b{&b_}; + + std::string hello = "hello"; + std::string world = "world"; + + map[a] = hello.c_str(); + map[b] = world.c_str(); + + CHECK_EQ(2, map.size()); + CHECK_EQ(hello.c_str(), map[a]); + CHECK_EQ(world.c_str(), map[b]); +} + +TEST_CASE("const") +{ + int p = 0; + int q = 0; + + NotNull n{&p}; + + *n = 123; + + NotNull m = n; // Conversion from NotNull to NotNull is allowed + + CHECK(123 == *m); // readonly access of m is ok + + // *m = 321; // nope. m points at const data. + + // NotNull o = m; // nope. Conversion from NotNull to NotNull is forbidden + + NotNull n2{&q}; + m = n2; // ok. m points to const data, but is not itself const + + const NotNull m2 = n; + // m2 = n2; // nope. m2 is const. + *m2 = 321; // ok. m2 is const, but points to mutable data + + CHECK(321 == *n); +} + +TEST_CASE("const_compatibility") +{ + int* raw = new int(8); + + NotNull a(raw); + NotNull b(raw); + NotNull c = a; + // NotNull d = c; // nope - no conversion from const to non-const + + CHECK_EQ(*c, 8); + + delete raw; +} + +TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index cb03a7bd..18e91e1b 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1,12 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" -#include "Luau/TypeInfer.h" +#include "AstQueryDsl.h" #include "Fixture.h" #include "ScopedFlags.h" #include "doctest.h" +#include + using namespace Luau; namespace @@ -96,138 +98,6 @@ TEST_CASE("initial_double_is_aligned") TEST_SUITE_END(); -TEST_SUITE_BEGIN("LexerTests"); - -TEST_CASE("broken_string_works") -{ - const std::string testInput = "[["; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - Lexeme lexeme = lexer.next(); - CHECK_EQ(lexeme.type, Lexeme::Type::BrokenString); - CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, 2))); -} - -TEST_CASE("broken_comment") -{ - const std::string testInput = "--[[ "; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - Lexeme lexeme = lexer.next(); - CHECK_EQ(lexeme.type, Lexeme::Type::BrokenComment); - CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, 6))); -} - -TEST_CASE("broken_comment_kept") -{ - const std::string testInput = "--[[ "; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - lexer.setSkipComments(true); - CHECK_EQ(lexer.next().type, Lexeme::Type::BrokenComment); -} - -TEST_CASE("comment_skipped") -{ - const std::string testInput = "-- "; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - lexer.setSkipComments(true); - CHECK_EQ(lexer.next().type, Lexeme::Type::Eof); -} - -TEST_CASE("multilineCommentWithLexemeInAndAfter") -{ - const std::string testInput = "--[[ function \n" - "]] end"; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - Lexeme comment = lexer.next(); - Lexeme end = lexer.next(); - - CHECK_EQ(comment.type, Lexeme::Type::BlockComment); - CHECK_EQ(comment.location, Luau::Location(Luau::Position(0, 0), Luau::Position(1, 2))); - CHECK_EQ(end.type, Lexeme::Type::ReservedEnd); - CHECK_EQ(end.location, Luau::Location(Luau::Position(1, 3), Luau::Position(1, 6))); -} - -TEST_CASE("testBrokenEscapeTolerant") -{ - const std::string testInput = "'\\3729472897292378'"; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - Lexeme item = lexer.next(); - - CHECK_EQ(item.type, Lexeme::QuotedString); - CHECK_EQ(item.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, int(testInput.size())))); -} - -TEST_CASE("testBigDelimiters") -{ - const std::string testInput = "--[===[\n" - "\n" - "\n" - "\n" - "]===]"; - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - Lexeme item = lexer.next(); - - CHECK_EQ(item.type, Lexeme::Type::BlockComment); - CHECK_EQ(item.location, Luau::Location(Luau::Position(0, 0), Luau::Position(4, 5))); -} - -TEST_CASE("lookahead") -{ - const std::string testInput = "foo --[[ comment ]] bar : nil end"; - - Luau::Allocator alloc; - AstNameTable table(alloc); - Lexer lexer(testInput.c_str(), testInput.size(), table); - lexer.setSkipComments(true); - lexer.next(); // must call next() before reading data from lexer at least once - - CHECK_EQ(lexer.current().type, Lexeme::Name); - CHECK_EQ(lexer.current().name, std::string("foo")); - CHECK_EQ(lexer.lookahead().type, Lexeme::Name); - CHECK_EQ(lexer.lookahead().name, std::string("bar")); - - lexer.next(); - - CHECK_EQ(lexer.current().type, Lexeme::Name); - CHECK_EQ(lexer.current().name, std::string("bar")); - CHECK_EQ(lexer.lookahead().type, ':'); - - lexer.next(); - - CHECK_EQ(lexer.current().type, ':'); - CHECK_EQ(lexer.lookahead().type, Lexeme::ReservedNil); - - lexer.next(); - - CHECK_EQ(lexer.current().type, Lexeme::ReservedNil); - CHECK_EQ(lexer.lookahead().type, Lexeme::ReservedEnd); - - lexer.next(); - - CHECK_EQ(lexer.current().type, Lexeme::ReservedEnd); - CHECK_EQ(lexer.lookahead().type, Lexeme::Eof); - - lexer.next(); - - CHECK_EQ(lexer.current().type, Lexeme::Eof); - CHECK_EQ(lexer.lookahead().type, Lexeme::Eof); -} - -TEST_SUITE_END(); - TEST_SUITE_BEGIN("ParserTests"); TEST_CASE_FIXTURE(Fixture, "basic_parse") @@ -300,8 +170,9 @@ TEST_CASE_FIXTURE(Fixture, "functions_can_have_return_annotations") AstStatFunction* statFunction = block->body.data[0]->as(); REQUIRE(statFunction != nullptr); - CHECK_EQ(statFunction->func->returnAnnotation.types.size, 1); - CHECK(statFunction->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunction->func->returnAnnotation.has_value()); + CHECK_EQ(statFunction->func->returnAnnotation->types.size, 1); + CHECK(statFunction->func->returnAnnotation->tailType == nullptr); } TEST_CASE_FIXTURE(Fixture, "functions_can_have_a_function_type_annotation") @@ -316,9 +187,9 @@ TEST_CASE_FIXTURE(Fixture, "functions_can_have_a_function_type_annotation") AstStatFunction* statFunc = block->body.data[0]->as(); REQUIRE(statFunc != nullptr); - AstArray& retTypes = statFunc->func->returnAnnotation.types; - REQUIRE(statFunc->func->hasReturnAnnotation); - CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunc->func->returnAnnotation.has_value()); + CHECK(statFunc->func->returnAnnotation->tailType == nullptr); + AstArray& retTypes = statFunc->func->returnAnnotation->types; REQUIRE(retTypes.size == 1); AstTypeFunction* funTy = retTypes.data[0]->as(); @@ -337,9 +208,9 @@ TEST_CASE_FIXTURE(Fixture, "function_return_type_should_disambiguate_from_functi AstStatFunction* statFunc = block->body.data[0]->as(); REQUIRE(statFunc != nullptr); - AstArray& retTypes = statFunc->func->returnAnnotation.types; - REQUIRE(statFunc->func->hasReturnAnnotation); - CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunc->func->returnAnnotation.has_value()); + CHECK(statFunc->func->returnAnnotation->tailType == nullptr); + AstArray& retTypes = statFunc->func->returnAnnotation->types; REQUIRE(retTypes.size == 2); AstTypeReference* ty0 = retTypes.data[0]->as(); @@ -363,9 +234,9 @@ TEST_CASE_FIXTURE(Fixture, "function_return_type_should_parse_as_function_type_a AstStatFunction* statFunc = block->body.data[0]->as(); REQUIRE(statFunc != nullptr); - AstArray& retTypes = statFunc->func->returnAnnotation.types; - REQUIRE(statFunc->func->hasReturnAnnotation); - CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunc->func->returnAnnotation.has_value()); + CHECK(statFunc->func->returnAnnotation->tailType == nullptr); + AstArray& retTypes = statFunc->func->returnAnnotation->types; REQUIRE(retTypes.size == 1); AstTypeFunction* funTy = retTypes.data[0]->as(); @@ -625,10 +496,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_messages") )"), "Cannot have more than one table indexer"); - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauGenericFunctionsParserFix", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CHECK_EQ(getParseError(R"( type T = foo )"), @@ -711,12 +578,25 @@ TEST_CASE_FIXTURE(Fixture, "mode_is_unset_if_no_hot_comment") TEST_CASE_FIXTURE(Fixture, "sense_hot_comment_on_first_line") { - ParseResult result = parseEx(" --!strict "); + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx(" --!strict ", options); std::optional mode = parseMode(result.hotcomments); REQUIRE(bool(mode)); CHECK_EQ(int(*mode), int(Mode::Strict)); } +TEST_CASE_FIXTURE(Fixture, "non_header_hot_comments") +{ + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx("do end --!strict", options); + std::optional mode = parseMode(result.hotcomments); + REQUIRE(!mode); +} + TEST_CASE_FIXTURE(Fixture, "stop_if_line_ends_with_hyphen") { CHECK_THROWS_AS(parse(" -"), std::exception); @@ -724,7 +604,10 @@ TEST_CASE_FIXTURE(Fixture, "stop_if_line_ends_with_hyphen") TEST_CASE_FIXTURE(Fixture, "nonstrict_mode") { - ParseResult result = parseEx("--!nonstrict"); + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx("--!nonstrict", options); CHECK(result.errors.empty()); std::optional mode = parseMode(result.hotcomments); REQUIRE(bool(mode)); @@ -733,7 +616,10 @@ TEST_CASE_FIXTURE(Fixture, "nonstrict_mode") TEST_CASE_FIXTURE(Fixture, "nocheck_mode") { - ParseResult result = parseEx("--!nocheck"); + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx("--!nocheck", options); CHECK(result.errors.empty()); std::optional mode = parseMode(result.hotcomments); REQUIRE(bool(mode)); @@ -771,33 +657,47 @@ TEST_CASE_FIXTURE(Fixture, "parse_numbers_decimal") TEST_CASE_FIXTURE(Fixture, "parse_numbers_hexadecimal") { - AstStat* stat = parse("return 0xab, 0XAB05, 0xff_ff"); + AstStat* stat = parse("return 0xab, 0XAB05, 0xff_ff, 0xffffffffffffffff"); REQUIRE(stat != nullptr); AstStatReturn* str = stat->as()->body.data[0]->as(); - CHECK(str->list.size == 3); + CHECK(str->list.size == 4); CHECK_EQ(str->list.data[0]->as()->value, 0xab); CHECK_EQ(str->list.data[1]->as()->value, 0xAB05); CHECK_EQ(str->list.data[2]->as()->value, 0xFFFF); + CHECK_EQ(str->list.data[3]->as()->value, double(ULLONG_MAX)); } TEST_CASE_FIXTURE(Fixture, "parse_numbers_binary") { - AstStat* stat = parse("return 0b1, 0b0, 0b101010"); + AstStat* stat = parse("return 0b1, 0b0, 0b101010, 0b1111111111111111111111111111111111111111111111111111111111111111"); REQUIRE(stat != nullptr); AstStatReturn* str = stat->as()->body.data[0]->as(); - CHECK(str->list.size == 3); + CHECK(str->list.size == 4); CHECK_EQ(str->list.data[0]->as()->value, 1); CHECK_EQ(str->list.data[1]->as()->value, 0); CHECK_EQ(str->list.data[2]->as()->value, 42); + CHECK_EQ(str->list.data[3]->as()->value, double(ULLONG_MAX)); } TEST_CASE_FIXTURE(Fixture, "parse_numbers_error") { + ScopedFastFlag luauErrorDoubleHexPrefix{"LuauErrorDoubleHexPrefix", true}; + CHECK_EQ(getParseError("return 0b123"), "Malformed number"); CHECK_EQ(getParseError("return 123x"), "Malformed number"); CHECK_EQ(getParseError("return 0xg"), "Malformed number"); + CHECK_EQ(getParseError("return 0x0x123"), "Malformed number"); + CHECK_EQ(getParseError("return 0xffffffffffffffffffffllllllg"), "Malformed number"); + CHECK_EQ(getParseError("return 0x0xffffffffffffffffffffffffffff"), "Malformed number"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_numbers_error_soft") +{ + ScopedFastFlag luauErrorDoubleHexPrefix{"LuauErrorDoubleHexPrefix", false}; + + CHECK_EQ(getParseError("return 0x0x0x0x0x0x0x0"), "Malformed number"); } TEST_CASE_FIXTURE(Fixture, "break_return_not_last_error") @@ -1004,6 +904,145 @@ TEST_CASE_FIXTURE(Fixture, "parse_compound_assignment_error_multiple") } } +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_begin") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + _ = `{{oops}}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Double braces are not permitted within interpolated strings. Did you mean '\\{'?", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_mid") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + _ = `{nice} {{oops}}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Double braces are not permitted within interpolated strings. Did you mean '\\{'?", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + auto columnOfEndBraceError = [this](const char* code) { + try + { + parse(code); + FAIL("Expected ParseErrors to be thrown"); + return UINT_MAX; + } + catch (const ParseErrors& e) + { + CHECK_EQ(e.getErrors().size(), 1); + + auto error = e.getErrors().front(); + CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", error.getMessage()); + return error.getLocation().begin.column; + } + }; + + // This makes sure that the error is coming from the brace itself + CHECK_EQ(columnOfEndBraceError("_ = `{a`"), columnOfEndBraceError("_ = `{abcdefg`")); + CHECK_NE(columnOfEndBraceError("_ = `{a`"), columnOfEndBraceError("_ = `{a`")); +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace_in_table") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + _ = { `{a` } + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ(e.getErrors().size(), 2); + + CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", e.getErrors().front().getMessage()); + CHECK_EQ("Expected '}' (to close '{' at line 2), got ", e.getErrors().back().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_mid_without_end_brace_in_table") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + _ = { `x {"y"} {z` } + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ(e.getErrors().size(), 2); + + CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", e.getErrors().front().getMessage()); + CHECK_EQ("Expected '}' (to close '{' at line 2), got ", e.getErrors().back().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_as_type_fail") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + local a: `what` = `???` + local b: `what {"the"}` = `???` + local c: `what {"the"} heck` = `???` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& parseErrors) + { + CHECK_EQ(parseErrors.getErrors().size(), 3); + + for (ParseError error : parseErrors.getErrors()) + CHECK_EQ(error.getMessage(), "Interpolated string literals cannot be used as types"); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_call_without_parens") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + _ = print `{42}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Expected identifier when parsing expression, got `{", e.getErrors().front().getMessage()); + } +} + TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection") { try @@ -1257,7 +1296,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_type_group") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") { ScopedFastInt sfis{"LuauRecursionLimit", 10}; - ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; matchParseErrorPrefix( "function f() if true then if true then if true then if true then if true then if true then if true then if true then if true " @@ -1268,7 +1306,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements") { ScopedFastInt sfis{"LuauRecursionLimit", 10}; - ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; matchParseErrorPrefix( "function f() if false then elseif false then elseif false then elseif false then elseif false then elseif false then elseif " @@ -1278,7 +1315,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements" TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; ScopedFastInt sfis{"LuauRecursionLimit", 10}; matchParseError("function f() return if true then 1 elseif true then 2 elseif true then 3 elseif true then 4 elseif true then 5 elseif true then " @@ -1288,7 +1324,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1 TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions2") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; ScopedFastInt sfis{"LuauRecursionLimit", 10}; matchParseError( @@ -1504,6 +1539,15 @@ return CHECK_EQ(std::string(str->value.data, str->value.size), "\n"); } +TEST_CASE_FIXTURE(Fixture, "parse_error_broken_comment") +{ + const char* expected = "Expected identifier when parsing expression, got unfinished comment"; + + matchParseError("--[[unfinished work", expected); + matchParseError("--!strict\n--[[unfinished work", expected); + matchParseError("local x = 1 --[[unfinished work", expected); +} + TEST_CASE_FIXTURE(Fixture, "string_literals_escapes_broken") { const char* expected = "String literal contains malformed escape sequence"; @@ -1584,6 +1628,35 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_of_functions_unions_and_intersections") CHECK_EQ((Position{3, 42}), block->body.data[2]->location.end); } +TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") +{ + AstStatBlock* block = parse(R"( + type F = number + --comment + print('hello') + )"); + + REQUIRE_EQ(2, block->body.size); + CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); +} + +TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments_even_with_capture") +{ + // Same should hold when comments are captured + ParseOptions opts; + opts.captureComments = true; + + AstStatBlock* block = parse(R"( + type F = number + --comment + print('hello') + )", + opts); + + REQUIRE_EQ(2, block->body.size); + CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); +} + TEST_CASE_FIXTURE(Fixture, "parse_error_loop_control") { matchParseError("break", "break statement must be inside a loop"); @@ -1624,6 +1697,17 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_confusing_function_call") "statements"); CHECK(result3.errors.size() == 1); + + auto result4 = matchParseError(R"( + local t = {} + function f() return t end + t.x, (f) + ().y = 5, 6 + )", + "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " + "statements"); + + CHECK(result4.errors.size() == 1); } TEST_CASE_FIXTURE(Fixture, "parse_error_varargs") @@ -1651,6 +1735,46 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_type_annotation") matchParseError("local a : 2 = 2", "Expected type, got '2'"); } +TEST_CASE_FIXTURE(Fixture, "parse_error_missing_type_annotation") +{ + { + ParseResult result = tryParse("local x:"); + CHECK(result.errors.size() == 1); + Position begin = result.errors[0].getLocation().begin; + Position end = result.errors[0].getLocation().end; + CHECK(begin.line == end.line); + int width = end.column - begin.column; + CHECK(width == 0); + CHECK(result.errors[0].getMessage() == "Expected type, got "); + } + + { + ParseResult result = tryParse(R"( +local x:=42 + )"); + CHECK(result.errors.size() == 1); + Position begin = result.errors[0].getLocation().begin; + Position end = result.errors[0].getLocation().end; + CHECK(begin.line == end.line); + int width = end.column - begin.column; + CHECK(width == 1); // Length of `=` + CHECK(result.errors[0].getMessage() == "Expected type, got '='"); + } + + { + ParseResult result = tryParse(R"( +function func():end + )"); + CHECK(result.errors.size() == 1); + Position begin = result.errors[0].getLocation().begin; + Position end = result.errors[0].getLocation().end; + CHECK(begin.line == end.line); + int width = end.column - begin.column; + CHECK(width == 3); // Length of `end` + CHECK(result.errors[0].getMessage() == "Expected type, got 'end'"); + } +} + TEST_CASE_FIXTURE(Fixture, "parse_declarations") { AstStatBlock* stat = parseEx(R"( @@ -1824,9 +1948,6 @@ TEST_CASE_FIXTURE(Fixture, "variadic_definition_parsing") TEST_CASE_FIXTURE(Fixture, "generic_pack_parsing") { - // Doesn't need LuauGenericFunctions - ScopedFastFlag sffs{"LuauParseGenericFunctions", true}; - ParseResult result = parseEx(R"( function f(...: a...) end @@ -1861,9 +1982,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_pack_parsing") TEST_CASE_FIXTURE(Fixture, "generic_function_declaration_parsing") { - // Doesn't need LuauGenericFunctions - ScopedFastFlag sffs{"LuauParseGenericFunctions", true}; - ParseResult result = parseEx(R"( declare function f() )"); @@ -1953,12 +2071,117 @@ TEST_CASE_FIXTURE(Fixture, "function_type_named_arguments") matchParseError("type MyFunc = (a: number, b: string, c: number) -> (d: number, e: string, f: number)", "Expected '->' when parsing function type, got "); - { - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - ScopedFastFlag luauGenericFunctionsParserFix{"LuauGenericFunctionsParserFix", true}; + matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); +} - matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); +TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") +{ + matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type") +{ + AstStat* stat = parse(R"( +type A = {} +type B = {} +type C = {} +type D = {} +type E = {} +type F = (T...) -> U... +type G = (U...) -> T... + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") +{ + matchParseError("type Y = {}", "Expected default type after type name", Location{{0, 20}, {0, 21}}); + matchParseError("type Y = {}", "Expected default type pack after type pack name", Location{{0, 29}, {0, 30}}); + matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_errors") +{ + matchParseError("type Y = {a: T..., b: number}", "Unexpected '...' after type name; type pack is not allowed in this context", + Location{{0, 20}, {0, 23}}); + matchParseError("type Y = {a: (number | string)...", "Unexpected '...' after type annotation", Location{{0, 36}, {0, 39}}); +} + +TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +{ + { + AstStat* stat = parse("return if true then 1 else 2"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr = str->list.data[0]->as(); + REQUIRE(ifElseExpr != nullptr); } + + { + AstStat* stat = parse("return if true then 1 elseif true then 2 else 3"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr1 = str->list.data[0]->as(); + REQUIRE(ifElseExpr1 != nullptr); + auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); + REQUIRE(ifElseExpr2 != nullptr); + } + + // Use "else if" as opposed to elseif + { + AstStat* stat = parse("return if true then 1 else if true then 2 else 3"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr1 = str->list.data[0]->as(); + REQUIRE(ifElseExpr1 != nullptr); + auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); + REQUIRE(ifElseExpr2 != nullptr); + } + + // Use an if-else expression as the conditional expression of an if-else expression + { + AstStat* stat = parse("return if if true then false else true then 1 else 2"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr = str->list.data[0]->as(); + REQUIRE(ifElseExpr != nullptr); + auto* nestedIfElseExpr = ifElseExpr->condition->as(); + REQUIRE(nestedIfElseExpr != nullptr); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") +{ + AstStat* stat = parse(R"( +type Packed = () -> T... + +type A = Packed +type B = Packed<...number> +type C = Packed<(number, X...)> + )"); + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") +{ + ScopedFastFlag luauFixNamedFunctionParse{"LuauFixNamedFunctionParse", true}; + + matchParseError("type A = (b: number)", "Expected '->' when parsing function type, got "); + matchParseError("type P = () -> T... type B = P<(x: number, y: string)>", "Expected '->' when parsing function type, got '>'"); + matchParseError("type F = (T...) -> ()", "Expected '->' when parsing function type, got '>'"); } TEST_SUITE_END(); @@ -2302,12 +2525,10 @@ TEST_CASE_FIXTURE(Fixture, "capture_comments") TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - ParseOptions options; options.captureComments = true; - ParseResult result = parseEx(R"( + ParseResult result = tryParse(R"( --[[ )", options); @@ -2318,8 +2539,6 @@ TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") TEST_CASE_FIXTURE(Fixture, "capture_broken_comment") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - ParseOptions options; options.captureComments = true; @@ -2362,8 +2581,6 @@ type Fn = ( CHECK_EQ("Expected '->' when parsing function type, got ')'", e.getErrors().front().getMessage()); } - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - try { parse(R"(type Fn = (any, string | number | ()) -> any)"); @@ -2397,8 +2614,6 @@ TEST_CASE_FIXTURE(Fixture, "AstName_comparison") TEST_CASE_FIXTURE(Fixture, "generic_type_list_recovery") { - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - try { parse(R"( @@ -2462,61 +2677,159 @@ do end CHECK_EQ(1, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +TEST_CASE_FIXTURE(Fixture, "recover_expected_type_pack") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; + ParseResult result = tryParse(R"( +type Y = (T...) -> U... + )"); + CHECK_EQ(1, result.errors.size()); +} - { - AstStat* stat = parse("return if true then 1 else 2"); +TEST_CASE_FIXTURE(Fixture, "recover_unexpected_type_pack") +{ + ParseResult result = tryParse(R"( +type X = { a: T..., b: number } +type Y = { a: T..., b: number } +type Z = { a: string | T..., b: number } + )"); + REQUIRE_EQ(3, result.errors.size()); +} - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr = str->list.data[0]->as(); - REQUIRE(ifElseExpr != nullptr); - } +TEST_CASE_FIXTURE(Fixture, "recover_function_return_type_annotations") +{ + ParseResult result = tryParse(R"( +type Custom = { x: A, y: B, z: C } +type Packed = { x: (A...) -> () } +type F = (number): Custom +type G = Packed<(number): (string, number, boolean)> +local function f(x: number) -> Custom +end + )"); + REQUIRE_EQ(3, result.errors.size()); + CHECK_EQ(result.errors[0].getMessage(), "Return types in function type annotations are written after '->' instead of ':'"); + CHECK_EQ(result.errors[1].getMessage(), "Return types in function type annotations are written after '->' instead of ':'"); + CHECK_EQ(result.errors[2].getMessage(), "Function return type annotations are written after ':' instead of '->'"); +} - { - AstStat* stat = parse("return if true then 1 elseif true then 2 else 3"); +TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation") +{ + ParseResult result = tryParse(R"( + type Foo = function + )"); + REQUIRE_EQ(1, result.errors.size()); + CHECK_EQ("Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> ...any'", + result.errors[0].getMessage()); +} - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr1 = str->list.data[0]->as(); - REQUIRE(ifElseExpr1 != nullptr); - auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); - REQUIRE(ifElseExpr2 != nullptr); - } +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_function_argument_list") +{ + ParseResult result = tryParse(R"( + foo(a, b, c,) + )"); - // Use "else if" as opposed to elseif - { - AstStat* stat = parse("return if true then 1 else if true then 2 else 3"); + REQUIRE(1 == result.errors.size()); - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr1 = str->list.data[0]->as(); - REQUIRE(ifElseExpr1 != nullptr); - auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); - REQUIRE(ifElseExpr2 != nullptr); - } + CHECK(Location({1, 20}, {1, 21}) == result.errors[0].getLocation()); + CHECK("Expected expression after ',' but got ')' instead" == result.errors[0].getMessage()); +} - // Use an if-else expression as the conditional expression of an if-else expression - { - AstStat* stat = parse("return if if true then false else true then 1 else 2"); +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_function_parameter_list") +{ + ParseResult result = tryParse(R"( + export type VisitFn = ( + any, + Array>, -- extra comma here + ) -> any + )"); - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr = str->list.data[0]->as(); - REQUIRE(ifElseExpr != nullptr); - auto* nestedIfElseExpr = ifElseExpr->condition->as(); - REQUIRE(nestedIfElseExpr != nullptr); - } + REQUIRE(1 == result.errors.size()); + + CHECK(Location({4, 8}, {4, 9}) == result.errors[0].getLocation()); + CHECK("Expected type after ',' but got ')' instead" == result.errors[0].getMessage()); +} + +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_generic_parameter_list") +{ + ParseResult result = tryParse(R"( + export type VisitFn = (a: A, b: B) -> () + )"); + + REQUIRE(1 == result.errors.size()); + + CHECK(Location({1, 36}, {1, 37}) == result.errors[0].getLocation()); + CHECK("Expected type after ',' but got '>' instead" == result.errors[0].getMessage()); + + REQUIRE(1 == result.root->body.size); + + AstStatTypeAlias* t = result.root->body.data[0]->as(); + REQUIRE(t != nullptr); + + AstTypeFunction* f = t->type->as(); + REQUIRE(f != nullptr); + + CHECK(2 == f->generics.size); +} + +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_between_table_members") +{ + ParseResult result = tryParse(R"( + local t = { + first = 1 + second = 2, + third = 3, + fouth = 4, + } + )"); + + REQUIRE(1 == result.errors.size()); + + CHECK(Location({3, 12}, {3, 18}) == result.errors[0].getLocation()); + CHECK("Expected ',' after table constructor element" == result.errors[0].getMessage()); + + REQUIRE(1 == result.root->body.size); + + AstExprTable* table = Luau::query(result.root); + REQUIRE(table); + CHECK(table->items.size == 4); +} + +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_after_last_table_member") +{ + ParseResult result = tryParse(R"( + local t = { + first = 1 + + local ok = true + local good = ok == true + )"); + + REQUIRE(1 == result.errors.size()); + + CHECK(Location({4, 8}, {4, 13}) == result.errors[0].getLocation()); + CHECK("Expected '}' (to close '{' at line 2), got 'local'" == result.errors[0].getMessage()); + + REQUIRE(3 == result.root->body.size); + + AstExprTable* table = Luau::query(result.root); + REQUIRE(table); + CHECK(table->items.size == 1); +} + +TEST_CASE_FIXTURE(Fixture, "missing_default_type_pack_argument_after_variadic_type_parameter") +{ + ScopedFastFlag sff{"LuauParserErrorsOnMissingDefaultTypePackArgument", true}; + + ParseResult result = tryParse(R"( + type Foo = nil + )"); + + REQUIRE_EQ(2, result.errors.size()); + + CHECK_EQ(Location{{1, 23}, {1, 25}}, result.errors[0].getLocation()); + CHECK_EQ("Expected type, got '>'", result.errors[0].getMessage()); + + CHECK_EQ(Location{{1, 23}, {1, 24}}, result.errors[1].getLocation()); + CHECK_EQ("Expected type pack after '=', got type", result.errors[1].getMessage()); } TEST_SUITE_END(); diff --git a/tests/Predicate.test.cpp b/tests/Predicate.test.cpp deleted file mode 100644 index bb5a93c5..00000000 --- a/tests/Predicate.test.cpp +++ /dev/null @@ -1,123 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeInfer.h" - -#include "Fixture.h" -#include "ScopedFlags.h" - -#include "doctest.h" - -using namespace Luau; - -static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r) -{ - Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId { - // TODO: normalize here also. - std::unordered_set s; - - if (auto utv = get(follow(a))) - s.insert(begin(utv), end(utv)); - else - s.insert(a); - - if (auto utv = get(follow(b))) - s.insert(begin(utv), end(utv)); - else - s.insert(b); - - std::vector options(s.begin(), s.end()); - return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)}); - }); -} - -TEST_SUITE_BEGIN("Predicate"); - -TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order") -{ - ScopedFastFlag sff{"LuauOrPredicate", true}; - - RefinementMap m{ - {"b", typeChecker.stringType}, - {"c", typeChecker.numberType}, - }; - - RefinementMap other{ - {"a", typeChecker.stringType}, - {"b", typeChecker.stringType}, - {"c", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(3, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("string", toString(m["b"])); - CHECK_EQ("boolean | number", toString(m["c"])); -} - -TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2") -{ - ScopedFastFlag sff{"LuauOrPredicate", true}; - - RefinementMap m{ - {"a", typeChecker.stringType}, - {"b", typeChecker.stringType}, - {"c", typeChecker.numberType}, - }; - - RefinementMap other{ - {"b", typeChecker.stringType}, - {"c", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(3, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("string", toString(m["b"])); - CHECK_EQ("boolean | number", toString(m["c"])); -} - -TEST_CASE_FIXTURE(Fixture, "one_map_has_overlap_at_end_whereas_other_has_it_in_start") -{ - ScopedFastFlag sff{"LuauOrPredicate", true}; - - RefinementMap m{ - {"a", typeChecker.stringType}, - {"b", typeChecker.numberType}, - {"c", typeChecker.booleanType}, - }; - - RefinementMap other{ - {"c", typeChecker.stringType}, - {"d", typeChecker.numberType}, - {"e", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(5, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - REQUIRE(m.count("d")); - REQUIRE(m.count("e")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("number", toString(m["b"])); - CHECK_EQ("boolean | string", toString(m["c"])); - CHECK_EQ("number", toString(m["d"])); - CHECK_EQ("boolean", toString(m["e"])); -} - -TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp new file mode 100644 index 00000000..c22d464e --- /dev/null +++ b/tests/Repl.test.cpp @@ -0,0 +1,423 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" + +#include "Repl.h" + +#include "doctest.h" + +#include +#include +#include +#include +#include + +struct Completion +{ + std::string completion; + std::string display; + + bool operator<(Completion const& other) const + { + return std::tie(completion, display) < std::tie(other.completion, other.display); + } +}; + +using CompletionSet = std::set; + +class ReplFixture +{ +public: + ReplFixture() + : luaState(luaL_newstate(), lua_close) + { + L = luaState.get(); + setupState(L); + luaL_sandboxthread(L); + + std::string result = runCode(L, prettyPrintSource); + } + + // Returns all of the output captured from the pretty printer + std::string getCapturedOutput() + { + lua_getglobal(L, "capturedoutput"); + const char* str = lua_tolstring(L, -1, nullptr); + std::string result(str); + lua_pop(L, 1); + return result; + } + + CompletionSet getCompletionSet(const char* inputPrefix) + { + CompletionSet result; + int top = lua_gettop(L); + getCompletions(L, inputPrefix, [&result](const std::string& completion, const std::string& display) { + result.insert(Completion{completion, display}); + }); + // Ensure that generating completions doesn't change the position of luau's stack top. + CHECK(top == lua_gettop(L)); + + return result; + } + + bool checkCompletion(const CompletionSet& completions, const std::string& prefix, const std::string& expected) + { + std::string expectedDisplay(expected.substr(0, expected.find_first_of('('))); + Completion expectedCompletion{prefix + expected, expectedDisplay}; + return completions.count(expectedCompletion) == 1; + } + + lua_State* L; + +private: + std::unique_ptr luaState; + + // This is a simplistic and incomplete pretty printer. + // It is included here to test that the pretty printer hook is being called. + // More elaborate tests to ensure correct output can be added if we introduce + // a more feature rich pretty printer. + std::string prettyPrintSource = R"( +-- Accumulate pretty printer output in `capturedoutput` +capturedoutput = "" + +function arraytostring(arr) + local strings = {} + table.foreachi(arr, function(k,v) table.insert(strings, pptostring(v)) end ) + return "{" .. table.concat(strings, ", ") .. "}" +end + +function pptostring(x) + if type(x) == "table" then + -- Just assume array-like tables for now. + return arraytostring(x) + elseif type(x) == "string" then + return '"' .. x .. '"' + else + return tostring(x) + end +end + +-- Note: Instead of calling print, the pretty printer just stores the output +-- in `capturedoutput` so we can check for the correct results. +function _PRETTYPRINT(...) + local args = table.pack(...) + local strings = {} + for i=1, args.n do + local item = args[i] + local str = pptostring(item, customoptions) + if i == 1 then + capturedoutput = capturedoutput .. str + else + capturedoutput = capturedoutput .. "\t" .. str + end + end +end +)"; +}; + +TEST_SUITE_BEGIN("ReplPrettyPrint"); + +TEST_CASE_FIXTURE(ReplFixture, "AdditionStatement") +{ + runCode(L, "return 30 + 12"); + CHECK(getCapturedOutput() == "42"); +} + +TEST_CASE_FIXTURE(ReplFixture, "TableLiteral") +{ + runCode(L, "return {1, 2, 3, 4}"); + CHECK(getCapturedOutput() == "{1, 2, 3, 4}"); +} + +TEST_CASE_FIXTURE(ReplFixture, "StringLiteral") +{ + runCode(L, "return 'str'"); + CHECK(getCapturedOutput() == "\"str\""); +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithStringLiterals") +{ + runCode(L, "return {1, 'two', 3, 'four'}"); + CHECK(getCapturedOutput() == "{1, \"two\", 3, \"four\"}"); +} + +TEST_CASE_FIXTURE(ReplFixture, "MultipleArguments") +{ + runCode(L, "return 3, 'three'"); + CHECK(getCapturedOutput() == "3\t\"three\""); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("ReplCodeCompletion"); + +TEST_CASE_FIXTURE(ReplFixture, "CompleteGlobalVariables") +{ + runCode(L, R"( + myvariable1 = 5 + myvariable2 = 5 +)"); + { + // Try to complete globals that are added by the user's script + CompletionSet completions = getCompletionSet("myvar"); + + std::string prefix = ""; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "myvariable1")); + CHECK(checkCompletion(completions, prefix, "myvariable2")); + } + { + // Try completing some builtin functions + CompletionSet completions = getCompletionSet("math.m"); + + std::string prefix = "math."; + CHECK(completions.size() == 3); + CHECK(checkCompletion(completions, prefix, "max(")); + CHECK(checkCompletion(completions, prefix, "min(")); + CHECK(checkCompletion(completions, prefix, "modf(")); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "CompleteTableKeys") +{ + runCode(L, R"( + t = { color = "red", size = 1, shape = "circle" } +)"); + { + CompletionSet completions = getCompletionSet("t."); + + std::string prefix = "t."; + CHECK(completions.size() == 3); + CHECK(checkCompletion(completions, prefix, "color")); + CHECK(checkCompletion(completions, prefix, "size")); + CHECK(checkCompletion(completions, prefix, "shape")); + } + + { + CompletionSet completions = getCompletionSet("t.s"); + + std::string prefix = "t."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "size")); + CHECK(checkCompletion(completions, prefix, "shape")); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "StringMethods") +{ + runCode(L, R"( + s = "" +)"); + { + CompletionSet completions = getCompletionSet("s:l"); + + std::string prefix = "s:"; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "len(")); + CHECK(checkCompletion(completions, prefix, "lower(")); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithMetatableIndexTable") +{ + runCode(L, R"( + -- Create 't' which is a table with a metatable with an __index table + mt = {} + mt.__index = mt + + t = {} + setmetatable(t, mt) + + mt.mtkey1 = {x="x value", y="y value", 1, 2} + mt.mtkey2 = 2 + + t.tkey1 = {data1 = 2, data2 = "str", 3, 4} + t.tkey2 = 4 +)"); + { + CompletionSet completions = getCompletionSet("t.t"); + + std::string prefix = "t."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "tkey1")); + CHECK(checkCompletion(completions, prefix, "tkey2")); + } + { + CompletionSet completions = getCompletionSet("t.tkey1.data2:re"); + + std::string prefix = "t.tkey1.data2:"; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "rep(")); + CHECK(checkCompletion(completions, prefix, "reverse(")); + } + { + CompletionSet completions = getCompletionSet("t.mtk"); + + std::string prefix = "t."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "mtkey1")); + CHECK(checkCompletion(completions, prefix, "mtkey2")); + } + { + CompletionSet completions = getCompletionSet("t.mtkey1."); + + std::string prefix = "t.mtkey1."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "x")); + CHECK(checkCompletion(completions, prefix, "y")); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithMetatableIndexFunction") +{ + runCode(L, R"( + -- Create 't' which is a table with a metatable with an __index function + mt = {} + mt.__index = function(table, key) + print("mt.__index called") + if key == "foo" then + return "FOO" + elseif key == "bar" then + return "BAR" + else + return nil + end + end + + t = {} + setmetatable(t, mt) + t.tkey = 0 +)"); + { + CompletionSet completions = getCompletionSet("t.t"); + + std::string prefix = "t."; + CHECK(completions.size() == 1); + CHECK(checkCompletion(completions, prefix, "tkey")); + } + { + // t.foo is a valid key, but should not be completed because it requires calling an __index function + CompletionSet completions = getCompletionSet("t.foo"); + + CHECK(completions.size() == 0); + } + { + // t.foo is a valid key, but should not be found because it requires calling an __index function + CompletionSet completions = getCompletionSet("t.foo:"); + + CHECK(completions.size() == 0); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithMultipleMetatableIndexTables") +{ + runCode(L, R"( + -- Create a table with a chain of metatables + mt2 = {} + mt2.__index = mt2 + + mt = {} + mt.__index = mt + setmetatable(mt, mt2) + + t = {} + setmetatable(t, mt) + + mt2.mt2key = {x=1, y=2} + mt.mtkey = 2 + t.tkey = 3 +)"); + { + CompletionSet completions = getCompletionSet("t."); + + std::string prefix = "t."; + CHECK(completions.size() == 4); + CHECK(checkCompletion(completions, prefix, "__index")); + CHECK(checkCompletion(completions, prefix, "tkey")); + CHECK(checkCompletion(completions, prefix, "mtkey")); + CHECK(checkCompletion(completions, prefix, "mt2key")); + } + { + CompletionSet completions = getCompletionSet("t.__index."); + + std::string prefix = "t.__index."; + CHECK(completions.size() == 3); + CHECK(checkCompletion(completions, prefix, "__index")); + CHECK(checkCompletion(completions, prefix, "mtkey")); + CHECK(checkCompletion(completions, prefix, "mt2key")); + } + { + CompletionSet completions = getCompletionSet("t.mt2key."); + + std::string prefix = "t.mt2key."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "x")); + CHECK(checkCompletion(completions, prefix, "y")); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithDeepMetatableIndexTables") +{ + runCode(L, R"( +-- Creates a table with a chain of metatables of length `count` +function makeChainedTable(count) + local result = {} + result.__index = result + result[string.format("entry%d", count)] = { count = count } + if count == 0 then + return result + else + return setmetatable(result, makeChainedTable(count - 1)) + end +end + +t30 = makeChainedTable(30) +t60 = makeChainedTable(60) +)"); + { + // Check if entry0 exists + CompletionSet completions = getCompletionSet("t30.entry0"); + + std::string prefix = "t30."; + CHECK(checkCompletion(completions, prefix, "entry0")); + } + { + // Check if entry0.count exists + CompletionSet completions = getCompletionSet("t30.entry0.co"); + + std::string prefix = "t30.entry0."; + CHECK(checkCompletion(completions, prefix, "count")); + } + { + // Check if entry0 exists. With the max traversal limit of 50 in the repl, this should fail. + CompletionSet completions = getCompletionSet("t60.entry0"); + + CHECK(completions.size() == 0); + } + { + // Check if entry0.count exists. With the max traversal limit of 50 in the repl, this should fail. + CompletionSet completions = getCompletionSet("t60.entry0.co"); + + CHECK(completions.size() == 0); + } +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("RegressionTests"); + +TEST_CASE_FIXTURE(ReplFixture, "InfiniteRecursion") +{ + // If the infinite recrusion is not caught, test will fail + runCode(L, R"( +local NewProxyOne = newproxy(true) +local MetaTableOne = getmetatable(NewProxyOne) +MetaTableOne.__index = function() + return NewProxyOne.Game +end +print(NewProxyOne.HelloICauseACrash) +)"); +} + +TEST_SUITE_END(); diff --git a/tests/RequireTracer.test.cpp b/tests/RequireTracer.test.cpp index cbd4af29..ba03f363 100644 --- a/tests/RequireTracer.test.cpp +++ b/tests/RequireTracer.test.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/RequireTracer.h" +#include "Luau/Parser.h" #include "Fixture.h" @@ -57,6 +57,7 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local") { AstStatBlock* block = parse(R"( local m = workspace.Foo.Bar.Baz + require(m) )"); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); @@ -70,22 +71,22 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local") AstExprIndexName* value = loc->values.data[0]->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value]); + CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value].name); value = value->expr->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo/Bar", result.exprs[value]); + CHECK_EQ("workspace/Foo/Bar", result.exprs[value].name); value = value->expr->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo", result.exprs[value]); + CHECK_EQ("workspace/Foo", result.exprs[value].name); AstExprGlobal* workspace = value->expr->as(); REQUIRE(workspace); REQUIRE(result.exprs.contains(workspace)); - CHECK_EQ("workspace", result.exprs[workspace]); + CHECK_EQ("workspace", result.exprs[workspace].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") @@ -93,9 +94,10 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") AstStatBlock* block = parse(R"( local m = workspace.Foo.Bar.Baz local n = m.Quux + require(n) )"); - REQUIRE_EQ(2, block->body.size); + REQUIRE_EQ(3, block->body.size); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); @@ -104,13 +106,13 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") REQUIRE_EQ(1, local->vars.size); REQUIRE(result.exprs.contains(local->values.data[0])); - CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]]); + CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments") { AstStatBlock* block = parse(R"( - local M = require(workspace.Game.Thing, workspace.Something.Else) + local M = require(workspace.Game.Thing) )"); REQUIRE_EQ(1, block->body.size); @@ -124,52 +126,9 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments") AstExprCall* call = local->values.data[0]->as(); REQUIRE(call != nullptr); - REQUIRE_EQ(2, call->args.size); - - CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]]); - CHECK_EQ("workspace/Something/Else", result.exprs[call->args.data[1]]); -} - -TEST_CASE_FIXTURE(RequireTracerFixture, "follow_GetService_calls") -{ - AstStatBlock* block = parse(R"( - local R = game:GetService('ReplicatedStorage').Roact - local Roact = require(R) - )"); - REQUIRE_EQ(2, block->body.size); - - RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); - - AstStatLocal* local = block->body.data[0]->as(); - REQUIRE(local != nullptr); - - CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[local->values.data[0]]); - - AstStatLocal* local2 = block->body.data[1]->as(); - REQUIRE(local2 != nullptr); - REQUIRE_EQ(1, local2->values.size); - - AstExprCall* call = local2->values.data[0]->as(); - REQUIRE(call != nullptr); REQUIRE_EQ(1, call->args.size); - CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[call->args.data[0]]); -} - -TEST_CASE_FIXTURE(RequireTracerFixture, "follow_WaitForChild_calls") -{ - ScopedFastFlag luauTraceRequireLookupChild("LuauTraceRequireLookupChild", true); - - AstStatBlock* block = parse(R"( -local A = require(workspace:WaitForChild('ReplicatedStorage').Content) -local B = require(workspace:FindFirstChild('ReplicatedFirst').Data) - )"); - - RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); - - REQUIRE_EQ(2, result.requires.size()); - CHECK_EQ("workspace/ReplicatedStorage/Content", result.requires[0].first); - CHECK_EQ("workspace/ReplicatedFirst/Data", result.requires[1].first); + CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof") @@ -200,22 +159,23 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof") REQUIRE(call != nullptr); REQUIRE_EQ(1, call->args.size); - CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]]); + CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "follow_string_indexexpr") { AstStatBlock* block = parse(R"( local R = game["Test"] + require(R) )"); - REQUIRE_EQ(1, block->body.size); + REQUIRE_EQ(2, block->body.size); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); AstStatLocal* local = block->body.data[0]->as(); REQUIRE(local != nullptr); - CHECK_EQ("game/Test", result.exprs[local->values.data[0]]); + CHECK_EQ("game/Test", result.exprs[local->values.data[0]].name); } TEST_SUITE_END(); diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp new file mode 100644 index 00000000..7e50d5b6 --- /dev/null +++ b/tests/RuntimeLimits.test.cpp @@ -0,0 +1,271 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +/* Tests in this source file are meant to be a bellwether to verify that the numeric limits we've set are sufficient for + * most real-world scripts. + * + * If a change breaks a test in this source file, please don't adjust the flag values set in the fixture. Instead, + * consider it a latent performance problem by default. + * + * We should periodically revisit this to retest the limits. + */ + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +struct LimitFixture : BuiltinsFixture +{ +#if defined(_NOOPT) || defined(_DEBUG) + ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; +#endif +}; + +template +bool hasError(const CheckResult& result, T* = nullptr) +{ + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](const TypeError& a) { + return nullptr != get(a); + }); + return it != result.errors.end(); +} + +TEST_SUITE_BEGIN("RuntimeLimits"); + +TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") +{ + constexpr const char* src = R"LUA( + --!strict + + -- Big thanks to Dionysusnu by letting us use this code as part of our test suite! + -- https://github.com/Dionysusnu/rbxts-rust-classes + -- Licensed under the MPL 2.0: https://raw.githubusercontent.com/Dionysusnu/rbxts-rust-classes/master/LICENSE + + local TS = _G[script] + local lazyGet = TS.import(script, script.Parent.Parent, "util", "lazyLoad").lazyGet + local unit = TS.import(script, script.Parent.Parent, "util", "Unit").unit + local Iterator + lazyGet("Iterator", function(c) + Iterator = c + end) + local Option + lazyGet("Option", function(c) + Option = c + end) + local Vec + lazyGet("Vec", function(c) + Vec = c + end) + local Result + do + Result = setmetatable({}, { + __tostring = function() + return "Result" + end, + }) + Result.__index = Result + function Result.new(...) + local self = setmetatable({}, Result) + self:constructor(...) + return self + end + function Result:constructor(okValue, errValue) + self.okValue = okValue + self.errValue = errValue + end + function Result:ok(val) + return Result.new(val, nil) + end + function Result:err(val) + return Result.new(nil, val) + end + function Result:fromCallback(c) + local _0 = c + local _1, _2 = pcall(_0) + local result = _1 and { + success = true, + value = _2, + } or { + success = false, + error = _2, + } + return result.success and Result:ok(result.value) or Result:err(Option:wrap(result.error)) + end + function Result:fromVoidCallback(c) + local _0 = c + local _1, _2 = pcall(_0) + local result = _1 and { + success = true, + value = _2, + } or { + success = false, + error = _2, + } + return result.success and Result:ok(unit()) or Result:err(Option:wrap(result.error)) + end + Result.fromPromise = TS.async(function(self, p) + local _0, _1 = TS.try(function() + return TS.TRY_RETURN, { Result:ok(TS.await(p)) } + end, function(e) + return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } + end) + if _0 then + return unpack(_1) + end + end) + Result.fromVoidPromise = TS.async(function(self, p) + local _0, _1 = TS.try(function() + TS.await(p) + return TS.TRY_RETURN, { Result:ok(unit()) } + end, function(e) + return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } + end) + if _0 then + return unpack(_1) + end + end) + function Result:isOk() + return self.okValue ~= nil + end + function Result:isErr() + return self.errValue ~= nil + end + function Result:contains(x) + return self.okValue == x + end + function Result:containsErr(x) + return self.errValue == x + end + function Result:okOption() + return Option:wrap(self.okValue) + end + function Result:errOption() + return Option:wrap(self.errValue) + end + function Result:map(func) + return self:isOk() and Result:ok(func(self.okValue)) or Result:err(self.errValue) + end + function Result:mapOr(def, func) + local _0 + if self:isOk() then + _0 = func(self.okValue) + else + _0 = def + end + return _0 + end + function Result:mapOrElse(def, func) + local _0 + if self:isOk() then + _0 = func(self.okValue) + else + _0 = def(self.errValue) + end + return _0 + end + function Result:mapErr(func) + return self:isErr() and Result:err(func(self.errValue)) or Result:ok(self.okValue) + end + Result["and"] = function(self, other) + return self:isErr() and Result:err(self.errValue) or other + end + function Result:andThen(func) + return self:isErr() and Result:err(self.errValue) or func(self.okValue) + end + Result["or"] = function(self, other) + return self:isOk() and Result:ok(self.okValue) or other + end + function Result:orElse(other) + return self:isOk() and Result:ok(self.okValue) or other(self.errValue) + end + function Result:expect(msg) + if self:isOk() then + return self.okValue + else + error(msg) + end + end + function Result:unwrap() + return self:expect("called `Result.unwrap()` on an `Err` value: " .. tostring(self.errValue)) + end + function Result:unwrapOr(def) + local _0 + if self:isOk() then + _0 = self.okValue + else + _0 = def + end + return _0 + end + function Result:unwrapOrElse(gen) + local _0 + if self:isOk() then + _0 = self.okValue + else + _0 = gen(self.errValue) + end + return _0 + end + function Result:expectErr(msg) + if self:isErr() then + return self.errValue + else + error(msg) + end + end + function Result:unwrapErr() + return self:expectErr("called `Result.unwrapErr()` on an `Ok` value: " .. tostring(self.okValue)) + end + function Result:transpose() + return self:isOk() and self.okValue:map(function(some) + return Result:ok(some) + end) or Option:some(Result:err(self.errValue)) + end + function Result:flatten() + return self:isOk() and Result.new(self.okValue.okValue, self.okValue.errValue) or Result:err(self.errValue) + end + function Result:match(ifOk, ifErr) + local _0 + if self:isOk() then + _0 = ifOk(self.okValue) + else + _0 = ifErr(self.errValue) + end + return _0 + end + function Result:asPtr() + local _0 = (self.okValue) + if _0 == nil then + _0 = (self.errValue) + end + return _0 + end + end + local resultMeta = Result + resultMeta.__eq = function(a, b) + return b:match(function(ok) + return a:contains(ok) + end, function(err) + return a:containsErr(err) + end) + end + resultMeta.__tostring = function(result) + return result:match(function(ok) + return "Result.ok(" .. tostring(ok) .. ")" + end, function(err) + return "Result.err(" .. tostring(err) .. ")" + end) + end + return { + Result = Result, + } + )LUA"; + + CheckResult result = check(src); + CodeTooComplex ctc; + + CHECK(hasError(result, &ctc)); +} + +TEST_SUITE_END(); diff --git a/tests/Symbol.test.cpp b/tests/Symbol.test.cpp index 44fe3a3c..278c6ce2 100644 --- a/tests/Symbol.test.cpp +++ b/tests/Symbol.test.cpp @@ -10,7 +10,7 @@ using namespace Luau; TEST_SUITE_BEGIN("SymbolTests"); -TEST_CASE("hashing") +TEST_CASE("equality_and_hashing_of_globals") { std::string s1 = "name"; std::string s2 = "name"; @@ -31,10 +31,57 @@ TEST_CASE("hashing") CHECK_EQ(std::hash()(two), std::hash()(two)); std::unordered_map theMap; - theMap[AstName{s1.data()}] = 5; - theMap[AstName{s2.data()}] = 1; + theMap[n1] = 5; + theMap[n2] = 1; REQUIRE_EQ(1, theMap.size()); } +TEST_CASE("equality_and_hashing_of_locals") +{ + std::string s1 = "name"; + std::string s2 = "name"; + + // These two names point to distinct memory areas. + AstLocal one{AstName{s1.data()}, Location(), nullptr, 0, 0, nullptr}; + AstLocal two{AstName{s2.data()}, Location(), &one, 0, 0, nullptr}; + + Symbol n1{&one}; + Symbol n2{&two}; + + CHECK(n1 == n1); + CHECK(n1 != n2); + CHECK(n2 == n2); + + CHECK_EQ(std::hash()(&one), std::hash()(&one)); + CHECK_NE(std::hash()(&one), std::hash()(&two)); + CHECK_EQ(std::hash()(&two), std::hash()(&two)); + + std::unordered_map theMap; + theMap[n1] = 5; + theMap[n2] = 1; + + REQUIRE_EQ(2, theMap.size()); +} + +TEST_CASE("equality_of_empty_symbols") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + std::string s1 = "name"; + std::string s2 = "name"; + + AstName one{s1.data()}; + AstLocal two{AstName{s2.data()}, Location(), nullptr, 0, 0, nullptr}; + + Symbol global{one}; + Symbol local{&two}; + Symbol empty1{}; + Symbol empty2{}; + + CHECK(empty1 != global); + CHECK(empty1 != local); + CHECK(empty1 == empty2); +} + TEST_SUITE_END(); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp new file mode 100644 index 00000000..26c9a1ee --- /dev/null +++ b/tests/ToDot.test.cpp @@ -0,0 +1,395 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Scope.h" +#include "Luau/ToDot.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +struct ToDotClassFixture : Fixture +{ + ToDotClassFixture() + { + TypeArena& arena = typeChecker.globalTypes; + + unfreeze(arena); + + TypeId baseClassMetaType = arena.addType(TableTypeVar{}); + + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test"}); + getMutable(baseClassInstanceType)->props = { + {"BaseField", {typeChecker.numberType}}, + }; + typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test"}); + getMutable(childClassInstanceType)->props = { + {"ChildField", {typeChecker.stringType}}, + }; + typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + + for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + persist(ty.type); + + freeze(arena); + } +}; + +TEST_SUITE_BEGIN("ToDot"); + +TEST_CASE_FIXTURE(Fixture, "primitive") +{ + CheckResult result = check(R"( +local a: nil +local b: number +local c: any +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_NE("nil", toDot(requireType("a"))); + + CHECK_EQ(R"(digraph graphname { +n1 [label="number"]; +})", + toDot(requireType("b"))); + + CHECK_EQ(R"(digraph graphname { +n1 [label="any"]; +})", + toDot(requireType("c"))); + + ToDotOptions opts; + opts.showPointers = false; + opts.duplicatePrimitives = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="PrimitiveTypeVar number"]; +})", + toDot(requireType("b"), opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="AnyTypeVar 1"]; +})", + toDot(requireType("c"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound") +{ + CheckResult result = check(R"( +function a(): number return 444 end +local b = a() +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = getType("b"); + REQUIRE(bool(ty)); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="BoundTypeVar 1"]; +n1 -> n2; +n2 [label="number"]; +})", + toDot(*ty, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "function") +{ + CheckResult result = check(R"( +local function f(a, ...: string) return a end +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(a, ...string) -> a", toString(requireType("f"))); + + ToDotOptions opts; + opts.showPointers = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="FunctionTypeVar 1"]; +n1 -> n2 [label="arg"]; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="GenericTypeVar 3"]; +n2 -> n4 [label="tail"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n1 -> n6 [label="ret"]; +n6 [label="BoundTypePack 6"]; +n6 -> n7; +n7 [label="TypePack 7"]; +n7 -> n3; +})", + toDot(requireType("f"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "union") +{ + CheckResult result = check(R"( +local a: string | number +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="UnionTypeVar 1"]; +n1 -> n2; +n2 [label="string"]; +n1 -> n3; +n3 [label="number"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "intersection") +{ + CheckResult result = check(R"( +local a: string & number -- uninhabited +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="IntersectionTypeVar 1"]; +n1 -> n2; +n2 [label="string"]; +n1 -> n3; +n3 [label="number"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "table") +{ + CheckResult result = check(R"( +type A = { x: T, y: (U...) -> (), [string]: any } +local a: A +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="TableTypeVar A"]; +n1 -> n2 [label="x"]; +n2 [label="number"]; +n1 -> n3 [label="y"]; +n3 [label="FunctionTypeVar 3"]; +n3 -> n4 [label="arg"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n3 -> n6 [label="ret"]; +n6 [label="TypePack 6"]; +n1 -> n7 [label="[index]"]; +n7 [label="string"]; +n1 -> n8 [label="[value]"]; +n8 [label="any"]; +n1 -> n9 [label="typeParam"]; +n9 [label="number"]; +n1 -> n4 [label="typePackParam"]; +})", + toDot(requireType("a"), opts)); + + // Extra coverage with pointers (unstable values) + (void)toDot(requireType("a")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable") +{ + CheckResult result = check(R"( +local a: typeof(setmetatable({}, {})) +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="MetatableTypeVar 1"]; +n1 -> n2 [label="table"]; +n2 [label="TableTypeVar 2"]; +n1 -> n3 [label="metatable"]; +n3 [label="TableTypeVar 3"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "free") +{ + TypeVar type{TypeVariant{FreeTypeVar{TypeLevel{0, 0}}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FreeTypeVar 1"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "error") +{ + TypeVar type{TypeVariant{ErrorTypeVar{}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ErrorTypeVar 1"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "generic") +{ + TypeVar type{TypeVariant{GenericTypeVar{"T"}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypeVar T"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(ToDotClassFixture, "class") +{ + CheckResult result = check(R"( +local a: ChildClass +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ClassTypeVar ChildClass"]; +n1 -> n2 [label="ChildField"]; +n2 [label="string"]; +n1 -> n3 [label="[parent]"]; +n3 [label="ClassTypeVar BaseClass"]; +n3 -> n4 [label="BaseField"]; +n4 [label="number"]; +n3 -> n5 [label="[metatable]"]; +n5 [label="TableTypeVar 5"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "free_pack") +{ + TypePackVar pack{TypePackVariant{FreeTypePack{TypeLevel{0, 0}}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FreeTypePack 1"]; +})", + toDot(&pack, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "error_pack") +{ + TypePackVar pack{TypePackVariant{Unifiable::Error{}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ErrorTypePack 1"]; +})", + toDot(&pack, opts)); + + // Extra coverage with pointers (unstable values) + (void)toDot(&pack); +} + +TEST_CASE_FIXTURE(Fixture, "generic_pack") +{ + TypePackVar pack1{TypePackVariant{GenericTypePack{}}}; + TypePackVar pack2{TypePackVariant{GenericTypePack{"T"}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypePack 1"]; +})", + toDot(&pack1, opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypePack T"]; +})", + toDot(&pack2, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound_pack") +{ + TypePackVar pack{TypePackVariant{TypePack{{typeChecker.numberType}, {}}}}; + TypePackVar bound{TypePackVariant{BoundTypePack{&pack}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="BoundTypePack 1"]; +n1 -> n2; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="number"]; +})", + toDot(&bound, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound_table") +{ + CheckResult result = check(R"( +local a = {x=2} +local b +b.x = 2 +b = a +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = getType("b"); + REQUIRE(bool(ty)); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="TableTypeVar 1"]; +n1 -> n2 [label="boundTo"]; +n2 [label="TableTypeVar a"]; +n2 -> n3 [label="x"]; +n3 [label="number"]; +})", + toDot(*ty, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "singletontypes") +{ + CheckResult result = check(R"( + local x: "hi" | "\"hello\"" | true | false + )"); + + ToDotOptions opts; + opts.showPointers = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="UnionTypeVar 1"]; +n1 -> n2; +n2 [label="SingletonTypeVar string: hi"]; +n1 -> n3; +)" + "n3 [label=\"SingletonTypeVar string: \\\"hello\\\"\"];" + R"( +n1 -> n4; +n4 [label="SingletonTypeVar boolean: true"]; +n1 -> n5; +n5 [label="SingletonTypeVar boolean: false"]; +})", + toDot(requireType("x"), opts)); +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index d7d68c46..05f49422 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Fixture.h" @@ -8,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); + TEST_SUITE_BEGIN("ToString"); TEST_CASE_FIXTURE(Fixture, "primitive") @@ -57,7 +60,72 @@ TEST_CASE_FIXTURE(Fixture, "named_table") CHECK_EQ("TheTable", toString(&table)); } -TEST_CASE_FIXTURE(Fixture, "exhaustive_toString_of_cyclic_table") +TEST_CASE_FIXTURE(Fixture, "empty_table") +{ + CheckResult result = check(R"( + local a: {} + )"); + + CHECK_EQ("{| |}", toString(requireType("a"))); + + // Should stay the same with useLineBreaks enabled + ToStringOptions opts; + opts.useLineBreaks = true; + CHECK_EQ("{| |}", toString(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") +{ + CheckResult result = check(R"( + local a: { prop: string, anotherProp: number, thirdProp: boolean } + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + opts.DEPRECATED_indent = true; + + //clang-format off + CHECK_EQ("{|\n" + " anotherProp: number,\n" + " prop: string,\n" + " thirdProp: boolean\n" + "|}", + toString(requireType("a"), opts)); + //clang-format on +} + +TEST_CASE_FIXTURE(Fixture, "metatable") +{ + TypeVar table{TypeVariant(TableTypeVar())}; + TypeVar metatable{TypeVariant(TableTypeVar())}; + TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable})}; + CHECK_EQ("{ @metatable { }, { } }", toString(&mtv)); +} + +TEST_CASE_FIXTURE(Fixture, "named_metatable") +{ + TypeVar table{TypeVariant(TableTypeVar())}; + TypeVar metatable{TypeVariant(TableTypeVar())}; + TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable, "NamedMetatable"})}; + CHECK_EQ("NamedMetatable", toString(&mtv)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "named_metatable_toStringNamedFunction") +{ + CheckResult result = check(R"( + local function createTbl(): NamedMetatable + return setmetatable({}, {}) + end + type NamedMetatable = typeof(createTbl()) + )"); + + TypeId ty = requireType("createTbl"); + const FunctionTypeVar* ftv = get(follow(ty)); + REQUIRE(ftv); + CHECK_EQ("createTbl(): NamedMetatable", toStringNamedFunction("createTbl", *ftv)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") { CheckResult result = check(R"( --!strict @@ -123,6 +191,39 @@ TEST_CASE_FIXTURE(Fixture, "functions_are_always_parenthesized_in_unions_or_inte CHECK_EQ(toString(&itv), "((number, string) -> (string, number)) & ((string, number) -> (number, string))"); } +TEST_CASE_FIXTURE(Fixture, "intersections_respects_use_line_breaks") +{ + CheckResult result = check(R"( + local a: ((string) -> string) & ((number) -> number) + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + + //clang-format off + CHECK_EQ("((number) -> number)\n" + "& ((string) -> string)", + toString(requireType("a"), opts)); + //clang-format on +} + +TEST_CASE_FIXTURE(Fixture, "unions_respects_use_line_breaks") +{ + CheckResult result = check(R"( + local a: string | number | boolean + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + + //clang-format off + CHECK_EQ("boolean\n" + "| number\n" + "| string", + toString(requireType("a"), opts)); + //clang-format on +} + TEST_CASE_FIXTURE(Fixture, "quit_stringifying_table_type_when_length_is_exceeded") { TableTypeVar ttv{}; @@ -163,11 +264,23 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") ToStringOptions o; o.exhaustive = false; - o.maxTypeLength = 40; - CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + o.maxTypeLength = 30; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + } + else + { + o.maxTypeLength = 40; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + } } TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") @@ -182,11 +295,22 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") ToStringOptions o; o.exhaustive = true; - o.maxTypeLength = 40; - CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + o.maxTypeLength = 30; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + } + else + { + o.maxTypeLength = 40; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + } } TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table_state_braces") @@ -197,7 +321,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table TypeVar tv{ttv}; - ToStringOptions o{/* exhaustive= */ false, /* useLineBreaks= */ false, /* functionTypeArguments= */ false, /* hideTableKind= */ false, 40}; + ToStringOptions o; + o.maxTableLength = 40; CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}"); } @@ -258,9 +383,6 @@ TEST_CASE_FIXTURE(Fixture, "function_type_with_argument_names") TEST_CASE_FIXTURE(Fixture, "function_type_with_argument_names_generic") { - ScopedFastFlag luauGenericFunctions{"LuauGenericFunctions", true}; - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - CheckResult result = check("local function f(n: number, ...: a...): (a...) return ... end"); LUAU_REQUIRE_NO_ERRORS(result); @@ -317,30 +439,30 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") LUAU_REQUIRE_NO_ERRORS(result); - TypeId id3Type = requireType("id3"); - ToStringResult nameData = toStringDetailed(id3Type); - - REQUIRE_EQ(3, nameData.nameMap.typeVars.size()); - REQUIRE_EQ("(a, b, c) -> (a, b, c)", nameData.name); - ToStringOptions opts; - opts.nameMap = std::move(nameData.nameMap); + + TypeId id3Type = requireType("id3"); + ToStringResult nameData = toStringDetailed(id3Type, opts); + + REQUIRE(3 == opts.nameMap.typeVars.size()); + + REQUIRE_EQ("(a, b, c) -> (a, b, c)", nameData.name); const FunctionTypeVar* ftv = get(follow(id3Type)); REQUIRE(ftv != nullptr); auto params = flatten(ftv->argTypes).first; - REQUIRE_EQ(3, params.size()); + REQUIRE(3 == params.size()); - REQUIRE_EQ("a", toString(params[0], opts)); - REQUIRE_EQ("b", toString(params[1], opts)); - REQUIRE_EQ("c", toString(params[2], opts)); + CHECK("a" == toString(params[0], opts)); + CHECK("b" == toString(params[1], opts)); + CHECK("c" == toString(params[2], opts)); } -TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") +TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") { ScopedFastFlag sff[] = { - {"LuauGenericFunctions", true}, + {"DebugLuauSharedSelf", true}, }; CheckResult result = check(R"( @@ -356,13 +478,12 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") )"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId tType = requireType("inst"); - ToStringResult r = toStringDetailed(tType); - CHECK_EQ("{ @metatable {| __index: { @metatable {| __index: base |}, child } |}, inst }", r.name); - CHECK_EQ(0, r.nameMap.typeVars.size()); - ToStringOptions opts; - opts.nameMap = r.nameMap; + + TypeId tType = requireType("inst"); + ToStringResult r = toStringDetailed(tType, opts); + CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); + CHECK(0 == opts.nameMap.typeVars.size()); const MetatableTypeVar* tMeta = get(tType); REQUIRE(tMeta); @@ -380,20 +501,20 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") TableTypeVar* tMeta5 = getMutable(tMeta4->props["__index"].type); REQUIRE(tMeta5); + REQUIRE(tMeta5->props.count("one") > 0); TableTypeVar* tMeta6 = getMutable(tMeta3->table); REQUIRE(tMeta6); + REQUIRE(tMeta6->props.count("two") > 0); ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type, opts); - opts.nameMap = oneResult.nameMap; std::string twoResult = toString(tMeta6->props["two"].type, opts); - REQUIRE_EQ("(a) -> number", oneResult.name); - REQUIRE_EQ("(b) -> number", twoResult); + CHECK_EQ("(a) -> number", oneResult.name); + CHECK_EQ("(b) -> number", twoResult); } - TEST_CASE_FIXTURE(Fixture, "toStringErrorPack") { CheckResult result = check(R"( @@ -401,7 +522,7 @@ local function target(callback: nil) return callback(4, "hello") end )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(requireType("target")), "(nil) -> (*unknown*)"); + CHECK_EQ("(nil) -> (*error-type*)", toString(requireType("target"))); } TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") @@ -416,8 +537,6 @@ function foo(a, b) return a(b) end TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_TypePack") { - ScopedFastFlag sff{"LuauToStringFollowsBoundTo", true}; - TypeVar tv1{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tv1); ttv->state = TableState::Sealed; @@ -438,10 +557,21 @@ TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_T CHECK_EQ("{| hello: number, world: number |}", toString(&tpv2)); } +TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_return_type_if_pack_has_an_empty_head_link") +{ + TypeArena arena; + TypePackId realTail = arena.addTypePack({singletonTypes->stringType}); + TypePackId emptyTail = arena.addTypePack({}, realTail); + + TypePackId argList = arena.addTypePack({singletonTypes->stringType}); + + TypeId functionType = arena.addType(FunctionTypeVar{argList, emptyTail}); + + CHECK("(string) -> string" == toString(functionType)); +} + TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( type F = ((() -> number)?) -> F? local function f(p) return f end @@ -455,8 +585,6 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union" TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_intersection") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( function f() return f end local a: ((number) -> ()) & typeof(f) @@ -469,8 +597,6 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_inters TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { - ScopedFastFlag luauInstantiatedTypeParamRecursion{"LuauInstantiatedTypeParamRecursion", true}; - TypeVar tableTy{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tableTy); ttv->name = "Table"; @@ -479,4 +605,205 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") CHECK_EQ(toString(tableTy), "Table"); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") +{ + CheckResult result = check(R"( + local function id(x) return x end + )"); + + TypeId ty = requireType("id"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("id(x: a): a", toStringNamedFunction("id", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") +{ + CheckResult result = check(R"( + local function map(arr, fn) + local t = {} + for i = 0, #arr do + t[i] = fn(arr[i]) + end + return t + end + )"); + + TypeId ty = requireType("map"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") +{ + CheckResult result = check(R"( + local function f(a: number, b: string) end + local function test(...: T...): U... + f(...) + return 1, 2, 3 + end + )"); + + TypeId ty = requireType("test"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("test(...: T...): U...", toStringNamedFunction("test", *ftv)); +} + +TEST_CASE("toStringNamedFunction_unit_f") +{ + TypePackVar empty{TypePack{}}; + FunctionTypeVar ftv{&empty, &empty, {}, false}; + CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") +{ + CheckResult result = check(R"( + local function f(x: a, ...): (a, a, b...) + return x, x, ... + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(x: a, ...: any): (a, a, b...)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") +{ + CheckResult result = check(R"( + local function f(): ...number + return 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): ...number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") +{ + CheckResult result = check(R"( + local function f(): (string, ...number) + return 'a', 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): (string, ...number)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") +{ + CheckResult result = check(R"( + local f: (number, y: number) -> number + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(_: number, y: number): number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") +{ + CheckResult result = check(R"( + local function f(x: T, g: (T) -> U)): () + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + ToStringOptions opts; + opts.hideNamedFunctionTypeParameters = true; + CHECK_EQ("f(x: T, g: (T) -> U): ()", toStringNamedFunction("f", *ftv, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") +{ + CheckResult result = check(R"( + local function test(a, b : string, ... : number) return a end + )"); + + TypeId ty = requireType("test"); + const FunctionTypeVar* ftv = get(follow(ty)); + + ToStringOptions opts; + opts.namedFunctionOverrideArgNames = {"first", "second", "third"}; + CHECK_EQ("test(first: a, second: string, ...: number): a", toStringNamedFunction("test", *ftv, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_generics") +{ + CheckResult result = check(R"( + function foo(x: a, y) end + )"); + + CHECK("(a, b) -> ()" == toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") +{ + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + + CheckResult result = check(R"( + local foo = {} + function foo:method(arg: string): () + end + )"); + + TypeId parentTy = requireType("foo"); + auto ttv = get(follow(parentTy)); + auto ftv = get(ttv->props.at("method").type); + + CHECK_EQ("foo:method(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") +{ + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + + CheckResult result = check(R"( + local foo = {} + function foo:method(arg: string): () + end + )"); + + TypeId parentTy = requireType("foo"); + auto ttv = get(follow(parentTy)); + auto ftv = get(ttv->props.at("method").type); + + ToStringOptions opts; + opts.hideFunctionSelfArgument = true; + CHECK_EQ("foo:method(arg: string): ()", toStringNamedFunction("foo:method", *ftv, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "tostring_unsee_ttv_if_array") +{ + ScopedFastFlag sff("LuauUnseeArrayTtv", true); + + CheckResult result = check(R"( + local x: {string} + -- This code is constructed very specifically to use the same (by pointer + -- identity) type in the function twice. + local y: (typeof(x), typeof(x)) -> () + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK(toString(requireType("y")) == "({string}, {string}) -> ()"); +} + TEST_SUITE_END(); diff --git a/tests/TopoSort.test.cpp b/tests/TopoSort.test.cpp index 9b990866..1f14ae88 100644 --- a/tests/TopoSort.test.cpp +++ b/tests/TopoSort.test.cpp @@ -340,26 +340,28 @@ TEST_CASE_FIXTURE(Fixture, "nested_type_annotations_depends_on_later_typealiases TEST_CASE_FIXTURE(Fixture, "return_comes_last") { - CheckResult result = check(R"( -export type Module = { bar: (number) -> boolean, foo: () -> string } + AstStatBlock* program = parse(R"( + local module = {} -return function() : Module - local module = {} + local function confuseCompiler() return module.foo() end - local function confuseCompiler() return module.foo() end - - module.foo = function() return "" end + module.foo = function() return "" end - function module.bar(x:number) - confuseCompiler() - return true - end - - return module -end + function module.bar(x:number) + confuseCompiler() + return true + end + + return module )"); - LUAU_REQUIRE_NO_ERRORS(result); + auto sorted = toposort(*program); + + CHECK_EQ(sorted[0], program->body.data[0]); + CHECK_EQ(sorted[2], program->body.data[1]); + CHECK_EQ(sorted[1], program->body.data[2]); + CHECK_EQ(sorted[3], program->body.data[3]); + CHECK_EQ(sorted[4], program->body.data[4]); } TEST_CASE_FIXTURE(Fixture, "break_comes_last") diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index bfff60f9..e79bc9b7 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -6,6 +6,7 @@ #include "Luau/Transpiler.h" #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -21,7 +22,7 @@ local function isPortal(element) return false end - return element.component==Core.Portal + return element.component == Core.Portal end )"; @@ -223,12 +224,24 @@ TEST_CASE("escaped_strings") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("escaped_strings_2") +{ + const std::string code = R"( local s="\a\b\f\n\r\t\v\'\"\\" )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("need_a_space_between_number_literals_and_dots") { const std::string code = R"( return point and math.ceil(point* 100000* 100)/ 100000 .. '%'or '' )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("binary_keywords") +{ + const std::string code = "local c = a0 ._ or b0 ._"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("do_blocks") { const std::string code = R"( @@ -364,10 +377,10 @@ TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly") )"; std::string expected = R"( - local a:(string,number,...string)->(string,...number)=function(a:string,b:number,...:...string): (string,...number) + local a:(string,number,...string)->(string,...number)=function(a:string,b:number,...:string): (string,...number) end - local b:(...string)->(...number)=function(...:...string): ...number + local b:(...string)->(...number)=function(...:string): ...number end local c:()->()=function(): () @@ -376,7 +389,7 @@ TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly") std::string actual = decorateWithTypes(code); - CHECK_EQ(expected, decorateWithTypes(code)); + CHECK_EQ(expected, actual); } TEST_CASE_FIXTURE(Fixture, "function_type_location") @@ -400,4 +413,288 @@ TEST_CASE_FIXTURE(Fixture, "function_type_location") CHECK_EQ(expected, actual); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_assertion") +{ + std::string code = "local a = 5 :: number"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") +{ + std::string code = "local a = if 1 then 2 else 3"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_type_reference_import") +{ + fileResolver.source["game/A"] = R"( +export type Type = { a: number } +return {} + )"; + + std::string code = R"( +local Import = require(game.A) +local a: Import.Type + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_type_packs") +{ + std::string code = R"( +type Packed = (T...)->(T...) +local a: Packed<> +local b: Packed<(number, string)> + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested") +{ + std::string code = "local a: ((number)->(string))|((string)->(string))"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested_2") +{ + std::string code = "local a: (number&string)|(string&boolean)"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested_3") +{ + std::string code = "local a: nil | (string & number)"; + + CHECK_EQ("local a: ( string & number)?", transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_intersection_type_nested") +{ + std::string code = "local a: ((number)->(string))&((string)->(string))"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_intersection_type_nested_2") +{ + std::string code = "local a: (number|string)&(string|boolean)"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_varargs") +{ + std::string code = "local function f(...) return ... end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_index_expr") +{ + std::string code = "local a = {1, 2, 3} local b = a[2]"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_unary") +{ + std::string code = R"( +local a = 1 +local b = -1 +local c = true +local d = not c +local e = 'hello' +local d = #e + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_break_continue") +{ + std::string code = R"( +local a, b, c +repeat + if a then break end + if b then continue end +until c + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_compound_assignmenr") +{ + std::string code = R"( +local a = 1 +a += 2 +a -= 3 +a *= 4 +a /= 5 +a %= 6 +a ^= 7 +a ..= ' - result' + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") +{ + std::string code = "a, b, c = 1, 2, 3"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_generic_function") +{ + std::string code = R"( +local function foo(a: T, ...: S...) return 1 end +local f: (T, S...)->(number) = foo + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_reverse") +{ + std::string code = "local a: nil | number"; + + CHECK_EQ("local a: number?", transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple") +{ + std::string code = "for k,v in next,{}do print(k,v) end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_expr") +{ + std::string code = "local a = f:-"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("local a = (error-expr: f:%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_stat") +{ + std::string code = "-"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("(error-stat: (error-expr))", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_type") +{ + std::string code = "local a: "; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("local a:%error-type%", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_parse_error") +{ + std::string code = "local a = -"; + + auto result = transpile(code); + CHECK_EQ("", result.code); + CHECK_EQ("Expected identifier when parsing expression, got ", result.parseError); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_to_string") +{ + std::string code = "local a: string = 'hello'"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + REQUIRE(parseResult.root); + REQUIRE(parseResult.root->body.size == 1); + AstStatLocal* statLocal = parseResult.root->body.data[0]->as(); + REQUIRE(statLocal); + CHECK_EQ("local a: string = 'hello'", toString(statLocal)); + REQUIRE(statLocal->vars.size == 1); + AstLocal* local = statLocal->vars.data[0]; + REQUIRE(local->annotation); + CHECK_EQ("string", toString(local->annotation)); + REQUIRE(statLocal->values.size == 1); + AstExpr* expr = statLocal->values.data[0]; + CHECK_EQ("'hello'", toString(expr)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_type_alias_default_type_parameters") +{ + std::string code = R"( +type Packed = (T, U, V...)->(W...) +local a: Packed + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_singleton_types") +{ + std::string code = R"( +type t1 = 'hello' +type t2 = true +type t3 = '' +type t4 = false + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_array_types") +{ + std::string code = R"( +type t1 = {number} +type t2 = {[string]: number} + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple_types") +{ + std::string code = "for k:string,v:boolean in next,{}do end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_string_interp") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + std::string code = R"( local _ = `hello {name}` )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + std::string code = R"( local _ = ` bracket = \{, backtick = \` = {'ok'} ` )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp new file mode 100644 index 00000000..38e246c8 --- /dev/null +++ b/tests/TypeInfer.aliases.test.cpp @@ -0,0 +1,946 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(LuauNewLibraryTypeNames) + +TEST_SUITE_BEGIN("TypeAliases"); + +TEST_CASE_FIXTURE(Fixture, "basic_alias") +{ + CheckResult result = check(R"( + type T = number + local x: T = 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") +{ + CheckResult result = check(R"( + type F = () -> F? + local function f() + return f + end + + local g: F = f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); +} + +TEST_CASE_FIXTURE(Fixture, "names_are_ascribed") +{ + CheckResult result = check(R"( + type T = { x: number } + local x: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("T", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") +{ + // This is a tricky case. In order to support recursive type aliases, + // we first walk the block and generate free types as placeholders. + // We then walk the AST as normal. If we declare a type alias as below, + // we generate a free type. We then begin our normal walk, examining + // local x: T = "foo", which establishes two constraints: + // a <: b + // string <: a + // We then visit the type alias, and establish that + // b <: number + // Then, when solving these constraints, we dispatch them in the order + // they appear above. This means that a ~ b, and a ~ string, thus + // b ~ string. This means the b <: number constraint has no effect. + // Essentially we've "stolen" the alias's type out from under it. + // This test ensures that we don't actually do this. + CheckResult result = check(R"( + local x: T = "foo" + type T = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK(result.errors[0] == TypeError{ + Location{{1, 21}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + singletonTypes->numberType, + singletonTypes->stringType, + }, + }); + } + else + { + CHECK(result.errors[0] == TypeError{ + Location{{1, 8}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + singletonTypes->numberType, + singletonTypes->stringType, + }, + }); + } +} + +TEST_CASE_FIXTURE(Fixture, "mismatched_generic_type_param") +{ + CheckResult result = check(R"( + type T = (A...) -> () + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == + "Generic type 'A' is used as a variadic type parameter; consider changing 'A' to 'A...' in the generic argument list"); + CHECK(result.errors[0].location == Location{{1, 21}, {1, 25}}); +} + +TEST_CASE_FIXTURE(Fixture, "mismatched_generic_pack_type_param") +{ + CheckResult result = check(R"( + type T = (A) -> () + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == + "Variadic type parameter 'A...' is used as a regular generic type; consider changing 'A...' to 'A' in the generic argument list"); + CHECK(result.errors[0].location == Location{{1, 24}, {1, 25}}); +} + +TEST_CASE_FIXTURE(Fixture, "default_type_parameter") +{ + CheckResult result = check(R"( + type T = { a: A, b: B } + local x: T = { a = "foo", b = "bar" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("x")) == "T"); +} + +TEST_CASE_FIXTURE(Fixture, "default_pack_parameter") +{ + CheckResult result = check(R"( + type T = { fn: (A...) -> () } + local x: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("x")) == "T"); +} + +TEST_CASE_FIXTURE(Fixture, "saturate_to_first_type_pack") +{ + CheckResult result = check(R"( + type T = { fn: (A, B) -> C... } + local x: T + local f = x.fn + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("x")) == "T"); + CHECK(toString(requireType("f")) == "(string, number) -> (string, boolean)"); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") +{ + CheckResult result = check(R"( + --!strict + type Node = { Parent: Node?; } + local node: Node; + node.Parent = 1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Node?", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_aliases") +{ + CheckResult result = check(R"( + --!strict + type T = { f: number, g: U } + type U = { h: number, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = 3, g = { h = 5, i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_aliases") +{ + ScopedFastFlag sff_DebugLuauDeferredConstraintResolution{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + type T = { v: a } + local x: T = { v = 123 } + local y: T = { v = "foo" } + local bad: T = { v = "foo" } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + const char* expectedError; + if (FFlag::LuauTypeMismatchInvarianceInError) + expectedError = "Type '{ v: string }' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; + else + expectedError = "Type '{ v: string }' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; + + CHECK(result.errors[0].location == Location{{4, 31}, {4, 44}}); + CHECK(toString(result.errors[0]) == expectedError); +} + +TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") +{ + ScopedFastFlag sff_DebugLuauDeferredConstraintResolution{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + type T = { v: a } + type U = { t: T } + local x: U = { t = { v = 123 } } + local bad: U = { t = { v = "foo" } } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + const char* expectedError; + if (FFlag::LuauTypeMismatchInvarianceInError) + expectedError = "Type '{ t: { v: string } }' could not be converted into 'U'\n" + "caused by:\n" + " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; + else + expectedError = "Type '{ t: { v: string } }' could not be converted into 'U'\n" + "caused by:\n" + " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" + "caused by:\n" + " Property 'v' is not compatible. Type 'string' could not be converted into 'number'"; + + CHECK(result.errors[0].location == Location{{4, 31}, {4, 52}}); + CHECK(toString(result.errors[0]) == expectedError); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_generic_aliases") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: a, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = "lo", i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: b, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = 5, i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_ERRORS(result); + + // We had a UAF in this example caused by not cloning type function arguments + ModulePtr module = frontend.moduleResolver.getModule("MainModule"); + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes); + freeze(module->interfaceTypes); + module->internalTypes.clear(); + module->astTypes.clear(); + + // Make sure the error strings don't include "VALUELESS" + for (auto error : module->errors) + CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error)); +} + +TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors") +{ + CheckResult result = check(R"( + type Pair = {first: T, second: U} + local a: Pair + local b: Pair + + a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK_EQ("Pair", toString(tm->wantedType)); + CHECK_EQ("Pair", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition") +{ + CheckResult result = check(R"( + type A = number + type A = string -- Redefinition of type 'A', previously defined at line 1 + local foo: string = 1 -- "Type 'number' could not be converted into 'string'" + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = Table + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Wrapped", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = (Table) -> string + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +// Check that recursive intersection type doesn't generate an OOM +TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") +{ + CheckResult result = check(R"( + function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any + end + type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) + _(_) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") +{ + CheckResult result = check(R"( + local foo: Id = 1 + type Id = T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") +{ + const std::string code = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb = aa + )"; + + const std::string expected = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") +{ + CheckResult result = check(R"( + type A = () -> (number, B) + type B = () -> (string, A) + local a: A + local b: B + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); + CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "generic_param_remap") +{ + const std::string code = R"( + -- An example of a forwarded use of a type that has different type arguments than parameters + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb = aa + )"; + + const std::string expected = R"( + + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") +{ + CheckResult result = check(R"( + export type Foo = number + type Foo = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto dtd = get(result.errors[0]); + REQUIRE(dtd); + CHECK_EQ(dtd->name, "Foo"); +} + +TEST_CASE_FIXTURE(Fixture, "reported_location_is_correct_when_type_alias_are_duplicates") +{ + CheckResult result = check(R"( + type A = string + type B = number + type C = string + type B = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto dtd = get(result.errors[0]); + REQUIRE(dtd); + CHECK_EQ(dtd->name, "B"); + REQUIRE(dtd->previousLocation); + CHECK_EQ(dtd->previousLocation->begin.line + 1, 3); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") +{ + CheckResult result = check(R"( + type Node = { value: T, child: Node? } + + local function visitor(node: Node?) + local a: Node + + if node then + a = node.child -- Observe the output of the error message. + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto e = get(result.errors[0]); + REQUIRE(e != nullptr); + CHECK_EQ("Node?", toString(e->givenType)); + CHECK_EQ("Node", toString(e->wantedType)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_multi_assign") +{ + fileResolver.source["workspace/A"] = R"( + export type myvec2 = {x: number, y: number} + return {} + )"; + + fileResolver.source["workspace/B"] = R"( + export type myvec3 = {x: number, y: number, z: number} + return {} + )"; + + fileResolver.source["workspace/C"] = R"( + local Foo, Bar = require(workspace.A), require(workspace.B) + + local a: Foo.myvec2 + local b: Bar.myvec3 + )"; + + CheckResult result = frontend.check("workspace/C"); + LUAU_REQUIRE_NO_ERRORS(result); + ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; + + REQUIRE(m != nullptr); + + std::optional aTypeId = lookupName(m->getModuleScope(), "a"); + REQUIRE(aTypeId); + const Luau::TableTypeVar* aType = get(follow(*aTypeId)); + REQUIRE(aType); + REQUIRE(aType->props.size() == 2); + + std::optional bTypeId = lookupName(m->getModuleScope(), "b"); + REQUIRE(bTypeId); + const Luau::TableTypeVar* bType = get(follow(*bTypeId)); + REQUIRE(bType); + REQUIRE(bType->props.size() == 3); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") +{ + CheckResult result = check("type t10 = typeof(table)"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = getGlobalBinding(frontend, "table"); + + if (FFlag::LuauNewLibraryTypeNames) + CHECK(toString(ty) == "typeof(table)"); + else + CHECK(toString(ty) == "table"); + + const TableTypeVar* ttv = get(ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +local c: Cool = { a = 1, b = "s" } +type NotCool = Cool +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +type NotCool = Cool +local c: Cool = { a = 1, b = "s" } +local d: NotCool = { a = 1, b = "s" } +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + ty = requireType("d"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "NotCool"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation") +{ + CheckResult result = check(R"( +local c = { a = 1, b = "s" } +type Cool = typeof(c) +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK_EQ(ttv->name, "Cool"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_of_an_imported_recursive_type") +{ + fileResolver.source["game/A"] = R"( +export type X = { a: number, b: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(follow(*ty1), follow(*ty2)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_of_an_imported_recursive_generic_type") +{ + fileResolver.source["game/A"] = R"( +export type X = { a: T, b: U, C: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); + + bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); + CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return function(obj) return true end +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return {a = 1, b = function(obj) return true end} +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") +{ + CheckResult result = check(R"( + type Tree = { data: T, children: Forest } + type Forest = {Tree} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") +{ + CheckResult result = check(R"( + -- OK because forwarded types are used with their parameters. + type Tree = { data: T, children: Forest } + type Forest = {Tree<{T}>} + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") +{ + CheckResult result = check(R"( + -- Not OK because forwarded types are used with different types than their parameters. + type Forest = {Tree<{T}>} + type Tree = { data: T, children: Forest } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") +{ + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") +{ + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") +{ + CheckResult result = check(R"( + function f(x) return x[1] end + -- x has type X? for a free type variable X + local x = f ({}) + type ContainsFree = { this: a, that: typeof(x) } + type ContainsContainsFree = { that: ContainsFree } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") +{ + CheckResult result = check(R"( + type Array = { [number]: T } + type Tuple = Array + + local p: Tuple + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{number | string}", toString(requireType("p"), {true})); +} + +/* + * We had a problem where all type aliases would be prototyped into a child scope that happened + * to have the same level. This caused a problem where, if a sibling function referred to that + * type alias in its type signature, it would erroneously be quantified away, even though it doesn't + * actually belong to the function. + * + * We solved this by ascribing a unique subLevel to each prototyped alias. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_quantify_unresolved_aliases") +{ + CheckResult result = check(R"( + --!strict + + local KeyPool = {} + + local function newkey(pool: KeyPool, index) + return {} + end + + function newKeyPool() + local pool = { + available = {} :: {Key}, + } + + return setmetatable(pool, KeyPool) + end + + export type KeyPool = typeof(newKeyPool()) + export type Key = typeof(newkey(newKeyPool(), 1)) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * We keep a cache of type alias onto TypeVar to prevent infinite types from + * being constructed via recursive or corecursive aliases. We have to adjust + * the TypeLevels of those generic TypeVars so that the unifier doesn't think + * they have improperly leaked out of their scope. + */ +TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_scope_if_they_are_reused_in_multiple_aliases") +{ + CheckResult result = check(R"( + type Array = {T} + type Exclude = T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * The two-pass alias definition system starts by ascribing a free TypeVar to each alias. It then + * circles back to fill in the actual type later on. + * + * If this free type is unified with something degenerate like `any`, we need to take extra care + * to ensure that the alias actually binds to the type that the user expected. + */ +TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any") +{ + CheckResult result = check(R"( + local function x() + local y: FutureType = {}::any + return 1 + end + type FutureType = { foo: typeof(x()) } + local d: FutureType = { smth = true } -- missing error, 'd' is resolved to 'any' + )"); + + CHECK_EQ("{| foo: number |}", toString(requireType("d"), {true})); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2") +{ + ScopedFastFlag sff[] = { + {"DebugLuauSharedSelf", true}, + }; + + CheckResult result = check(R"( + local B = {} + B.bar = 4 + + function B:smth1() + local self: FutureIntersection = self + self.foo = 4 + return 4 + end + + function B:smth2() + local self: FutureIntersection = self + self.bar = 5 -- error, even though we should have B part with bar + end + + type A = { foo: typeof(B.smth1({foo=3})) } -- trick toposort into sorting functions before types + type B = typeof(B) + + type FutureIntersection = A & B + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // To be quite honest, I don't know exactly why DCR fixes this. + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + // TODO: shared self causes this test to break in bizarre ways. + LUAU_REQUIRE_ERRORS(result); + } +} + +TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") +{ + CheckResult result = check(R"( + type Tree = { data: T, children: {Tree} } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") +{ + CheckResult result = check(R"( + -- this would be an infinite type if we allowed it + type Tree = { data: T, children: {Tree<{T}>} } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "report_shadowed_aliases") +{ + ScopedFastFlag sff{"LuauReportShadowedTypeAlias", true}; + + // We allow a previous type alias to depend on a future type alias. That exact feature enables a confusing example, like the following snippet, + // which has the type alias FakeString point to the type alias `string` that which points to `number`. + CheckResult result = check(R"( + type MyString = string + type string = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Redefinition of type 'string'"); + + std::optional t1 = lookupType("MyString"); + REQUIRE(t1); + CHECK(isPrim(*t1, PrimitiveTypeVar::String)); + + std::optional t2 = lookupType("string"); + REQUIRE(t2); + CHECK(isPrim(*t2, PrimitiveTypeVar::String)); +} + +TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_shadow_user_defined_alias") +{ + CheckResult result = check(R"( + type T = number + + do + type T = string + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_create_cyclic_type_with_unknown_module") +{ + CheckResult result = check(R"( + type AAA = B.AAA + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Unknown type 'B.AAA'"); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 2e400164..b94e1df0 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -31,11 +30,21 @@ TEST_CASE_FIXTURE(Fixture, "successful_check") dumpErrors(result); } +TEST_CASE_FIXTURE(Fixture, "variable_type_is_supertype") +{ + CheckResult result = check(R"( + local x: number = 1 + local y: number? = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "function_parameters_can_have_annotations") { CheckResult result = check(R"( function double(x: number) - return x * 2 + return 2 end local four = double(2) @@ -48,7 +57,7 @@ TEST_CASE_FIXTURE(Fixture, "function_parameter_annotations_are_checked") { CheckResult result = check(R"( function double(x: number) - return x * 2 + return 2 end local four = double("two") @@ -71,13 +80,13 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") const FunctionTypeVar* ftv = get(fiftyType); REQUIRE(ftv != nullptr); - TypePackId retPack = ftv->retType; + TypePackId retPack = follow(ftv->retTypes); const TypePack* tp = get(retPack); REQUIRE(tp != nullptr); REQUIRE_EQ(1, tp->head.size()); - REQUIRE_EQ(typeChecker.anyType, tp->head[0]); + REQUIRE_EQ(typeChecker.anyType, follow(tp->head[0])); } TEST_CASE_FIXTURE(Fixture, "function_return_multret_annotations_are_checked") @@ -117,6 +126,23 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotation_should_continuously_parse LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "unknown_type_reference_generates_error") +{ + CheckResult result = check(R"( + local x: IDoNotExist + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{ + Location{{1, 17}, {1, 28}}, + getMainSourceModule()->name, + UnknownSymbol{ + "IDoNotExist", + UnknownSymbol::Context::Type, + }, + }); +} + TEST_CASE_FIXTURE(Fixture, "typeof_variable_type_annotation_should_return_its_type") { CheckResult result = check(R"( @@ -207,6 +233,31 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_does_not_propagate_type_info") CHECK_EQ("number", toString(requireType("b"))); } +TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") +{ + CheckResult result = check(R"( + local a = 55 :: number? + local b = a :: number + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number?", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "as_expr_warns_on_unrelated_cast") +{ + CheckResult result = check(R"( + local a = 55 :: string + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Cannot cast 'number' into 'string' because the types are unrelated", toString(result.errors[0])); + CHECK_EQ("string", toString(requireType("a"))); +} + TEST_CASE_FIXTURE(Fixture, "type_annotations_inside_function_bodies") { CheckResult result = check(R"( @@ -390,7 +441,7 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") )"); TypeId fType = requireType("aa"); - const ErrorTypeVar* ftv = get(follow(fType)); + const AnyTypeVar* ftv = get(follow(fType)); REQUIRE(ftv != nullptr); REQUIRE(!result.errors.empty()); } @@ -465,7 +516,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") CHECK(arrayTable->indexer); CHECK(isInArena(array.type, mod.interfaceTypes)); - CHECK_EQ(array.typeParams[0], arrayTable->indexer->indexResultType); + CHECK_EQ(array.typeParams[0].ty, arrayTable->indexer->indexResultType); } TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definitions") @@ -504,9 +555,9 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti CHECK_EQ(recordType, bType); } -TEST_CASE_FIXTURE(Fixture, "use_type_required_from_another_file") +TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") { - addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -530,9 +581,9 @@ TEST_CASE_FIXTURE(Fixture, "use_type_required_from_another_file") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "cannot_use_nonexported_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") { - addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -556,9 +607,9 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_nonexported_type") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "builtin_types_are_not_exported") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_are_not_exported") { - addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -588,7 +639,7 @@ struct AssertionCatcher { tripped = 0; oldhook = Luau::assertHandler(); - Luau::assertHandler() = [](const char* expr, const char* file, int line) -> int { + Luau::assertHandler() = [](const char* expr, const char* file, int line, const char* function) -> int { ++tripped; return 0; }; @@ -606,26 +657,24 @@ struct AssertionCatcher int AssertionCatcher::tripped; } // namespace -TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice") +TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag") { ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; AssertionCatcher ac; CHECK_THROWS_AS(check(R"( - local a: _luau_ice = 55 - )"), - std::runtime_error); + local a: _luau_ice = 55 + )"), + InternalCompilerError); LUAU_ASSERT(1 == AssertionCatcher::tripped); } -TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_handler") +TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag_handler") { ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; - AssertionCatcher ac; - bool caught = false; frontend.iceHandler.onInternalError = [&](const char*) { @@ -633,13 +682,11 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_handler") }; CHECK_THROWS_AS(check(R"( - local a: _luau_ice = 55 - )"), - std::runtime_error); + local a: _luau_ice = 55 + )"), + InternalCompilerError); CHECK_EQ(true, caught); - - frontend.iceHandler.onInternalError = {}; } TEST_CASE_FIXTURE(Fixture, "luau_ice_is_not_special_without_the_flag") @@ -652,9 +699,14 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_is_not_special_without_the_flag") )"); } -TEST_CASE_FIXTURE(Fixture, "luau_print_is_magic_if_the_flag_is_set") +TEST_CASE_FIXTURE(BuiltinsFixture, "luau_print_is_magic_if_the_flag_is_set") { - // Luau::resetPrintLine(); + static std::vector output; + output.clear(); + Luau::setPrintLine([](const std::string& s) { + output.push_back(s); + }); + ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; CheckResult result = check(R"( @@ -662,6 +714,8 @@ TEST_CASE_FIXTURE(Fixture, "luau_print_is_magic_if_the_flag_is_set") )"); LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE(1 == output.size()); } TEST_CASE_FIXTURE(Fixture, "luau_print_is_not_special_without_the_flag") @@ -725,4 +779,14 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_intersection_typevar") REQUIRE(ocf); } +TEST_CASE_FIXTURE(Fixture, "instantiation_clone_has_to_follow") +{ + CheckResult result = check(R"( + export type t8 = (t0)&(((true)|(any))->"") + export type t0 = ({})&({_:{[any]:number},}) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp new file mode 100644 index 00000000..91201812 --- /dev/null +++ b/tests/TypeInfer.anyerror.test.cpp @@ -0,0 +1,346 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferAnyError"); + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") +{ + CheckResult result = check(R"( + function bar(): any + return true + end + + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(typeChecker.anyType, requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") +{ + CheckResult result = check(R"( + function bar(): any + return true + end + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") +{ + CheckResult result = check(R"( + local bar: any + + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") +{ + CheckResult result = check(R"( + local bar: any + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") +{ + CheckResult result = check(R"( + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("*error-type*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") +{ + CheckResult result = check(R"( + function bar(c) return c end + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("*error-type*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local l = #this_is_not_defined + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local originalReward = unknown.Parent.Reward:GetChildren()[1] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "dot_on_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local foo = (true).x + foo.x = foo.y + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "any_type_propagates") +{ + CheckResult result = check(R"( + local foo: any + local bar = foo:method("argument") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "can_subscript_any") +{ + CheckResult result = check(R"( + local foo: any + local bar = foo[5] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("bar"))); +} + +// Not strictly correct: metatables permit overriding this +TEST_CASE_FIXTURE(Fixture, "can_get_length_of_any") +{ + CheckResult result = check(R"( + local foo: any = {} + local bar = #foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") +{ + CheckResult result = check(R"( + local f: any + local T = {} + + T.prop = f() + + return T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* ttv = getMutable(requireType("T")); + REQUIRE(ttv); + REQUIRE(ttv->props.count("prop")); + + REQUIRE_EQ("any", toString(ttv->props["prop"].type)); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") +{ + CheckResult result = check(R"( + local A : any + function A.B() end + A:C() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId aType = requireType("A"); + CHECK_EQ(aType, typeChecker.anyType); +} + +TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") +{ + CheckResult result = check(R"( + local a = unknown.Parent.Reward.GetChildren() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownSymbol* err = get(result.errors[0]); + REQUIRE(err != nullptr); + + CHECK_EQ("unknown", err->name); + + CHECK_EQ("*error-type*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") +{ + CheckResult result = check(R"( + local a = Utility.Create "Foo" {} + )"); + + CHECK_EQ("*error-type*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") +{ + CheckResult result = check(R"( + local a: any + local b + for _, i in pairs(a) do + b = i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "call_to_any_yields_any") +{ + CheckResult result = check(R"( + local a: any + local b = a() + )"); + + REQUIRE_EQ("any", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfAny") +{ + CheckResult result = check(R"( +local x: any = {} +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") +{ + CheckResult result = check(R"( +local x = (true).foo +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_of_any_can_be_a_table") +{ + CheckResult result = check(R"( +--!strict +local T: any +T = {} +T.__index = T +function T.new(...) + local self = {} + setmetatable(self, T) + self:construct(...) + return self +end +function T:construct(index) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_error_addition") +{ + CheckResult result = check(R"( +--!strict +local foo = makesandwich() +local bar = foo.nutrition + 100 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + // We should definitely get this error + CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0])); + // We get this error if makesandwich() returns a free type + // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") +{ + CheckResult result = check(R"( + local function f(thing: any | string) + local foo = thing.SomeRandomKey + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "union_of_types_regression_test") +{ + CheckResult result = check(R"( +--!strict +local stat +stat = stat and tonumber(stat) or stat + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index e8974776..32e31e16 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/BuiltinDefinitions.h" @@ -9,9 +8,11 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + TEST_SUITE_BEGIN("BuiltinTests"); -TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_things_are_defined") { CheckResult result = check(R"( local a00 = math.frexp @@ -49,44 +50,44 @@ TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "next_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "next_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( local a: string, b: number = next({ 1 }) local s = "foo" local t = { [s] = 1 } - local c: string, d: number = next(t) + local c: string?, d: number = next(t) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "pairs_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( type Map = { [K]: V } local map: Map = { ["foo"] = 1, ["bar"] = 2, ["baz"] = 3 } - local it: (Map, string | nil) -> (string, number), t: Map, i: nil = pairs(map) + local it: (Map, string | nil) -> (string?, number), t: Map, i: nil = pairs(map) )"); LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "ipairs_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( type Map = { [K]: V } local array: Map = { "foo", "bar", "baz" } - local it: (Map, number) -> (number, string), t: Map, i: number = ipairs(array) + local it: (Map, number) -> (number?, string), t: Map, i: number = ipairs(array) )"); LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_dot_remove_optionally_returns_generic") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_remove_optionally_returns_generic") { CheckResult result = check(R"( local t = { 1 } @@ -97,7 +98,7 @@ TEST_CASE_FIXTURE(Fixture, "table_dot_remove_optionally_returns_generic") CHECK_EQ(toString(requireType("n")), "number?"); } -TEST_CASE_FIXTURE(Fixture, "table_concat_returns_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_concat_returns_string") { CheckResult result = check(R"( local r = table.concat({1,2,3,4}, ",", 2); @@ -107,7 +108,7 @@ TEST_CASE_FIXTURE(Fixture, "table_concat_returns_string") CHECK_EQ(*typeChecker.stringType, *requireType("r")); } -TEST_CASE_FIXTURE(Fixture, "sort") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort") { CheckResult result = check(R"( local t = {1, 2, 3}; @@ -117,7 +118,7 @@ TEST_CASE_FIXTURE(Fixture, "sort") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sort_with_predicate") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_predicate") { CheckResult result = check(R"( --!strict @@ -129,7 +130,7 @@ TEST_CASE_FIXTURE(Fixture, "sort_with_predicate") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sort_with_bad_predicate") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_bad_predicate") { CheckResult result = check(R"( --!strict @@ -139,6 +140,12 @@ TEST_CASE_FIXTURE(Fixture, "sort_with_bad_predicate") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '(number, number) -> boolean' could not be converted into '((a, a) -> boolean)?' +caused by: + None of the union options are compatible. For example: Type '(number, number) -> boolean' could not be converted into '(a, a) -> boolean' +caused by: + Argument #1 type is not compatible. Type 'string' could not be converted into 'number')", + toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "strings_have_methods") @@ -151,7 +158,7 @@ TEST_CASE_FIXTURE(Fixture, "strings_have_methods") CHECK_EQ(*typeChecker.stringType, *requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "math_max_variatic") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_variatic") { CheckResult result = check(R"( local n = math.max(1,2,3,4,5,6,7,8,9,0) @@ -161,16 +168,17 @@ TEST_CASE_FIXTURE(Fixture, "math_max_variatic") CHECK_EQ(*typeChecker.numberType, *requireType("n")); } -TEST_CASE_FIXTURE(Fixture, "math_max_checks_for_numbers") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_checks_for_numbers") { CheckResult result = check(R"( local n = math.max(1,2,"3") )"); CHECK(!result.errors.empty()); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "builtin_tables_sealed") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_tables_sealed") { CheckResult result = check(R"LUA( local b = bit32 @@ -182,7 +190,7 @@ TEST_CASE_FIXTURE(Fixture, "builtin_tables_sealed") CHECK_EQ(bit32t->state, TableState::Sealed); } -TEST_CASE_FIXTURE(Fixture, "lua_51_exported_globals_all_exist") +TEST_CASE_FIXTURE(BuiltinsFixture, "lua_51_exported_globals_all_exist") { // Extracted from lua5.1 CheckResult result = check(R"( @@ -339,7 +347,7 @@ TEST_CASE_FIXTURE(Fixture, "lua_51_exported_globals_all_exist") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "setmetatable_unpacks_arg_types_correctly") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_unpacks_arg_types_correctly") { CheckResult result = check(R"( setmetatable({}, setmetatable({}, {})) @@ -347,7 +355,7 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_unpacks_arg_types_correctly") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_overload") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_2_args_overload") { CheckResult result = check(R"( local t = {} @@ -359,7 +367,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_o CHECK_EQ(typeChecker.stringType, requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "table_insert_corrrectly_infers_type_of_array_3_args_overload") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_3_args_overload") { CheckResult result = check(R"( local t = {} @@ -371,7 +379,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_corrrectly_infers_type_of_array_3_args_ CHECK_EQ("string", toString(requireType("s"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack") { CheckResult result = check(R"( local t = table.pack(1, "foo", true) @@ -381,7 +389,7 @@ TEST_CASE_FIXTURE(Fixture, "table_pack") CHECK_EQ("{| [number]: boolean | number | string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack_variadic") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack_variadic") { CheckResult result = check(R"( --!strict @@ -396,7 +404,7 @@ local t = table.pack(f()) CHECK_EQ("{| [number]: number | string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack_reduce") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack_reduce") { CheckResult result = check(R"( local t = table.pack(1, 2, true) @@ -413,7 +421,7 @@ TEST_CASE_FIXTURE(Fixture, "table_pack_reduce") CHECK_EQ("{| [number]: string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "gcinfo") +TEST_CASE_FIXTURE(BuiltinsFixture, "gcinfo") { CheckResult result = check(R"( local n = gcinfo() @@ -423,12 +431,12 @@ TEST_CASE_FIXTURE(Fixture, "gcinfo") CHECK_EQ(*typeChecker.numberType, *requireType("n")); } -TEST_CASE_FIXTURE(Fixture, "getfenv") +TEST_CASE_FIXTURE(BuiltinsFixture, "getfenv") { LUAU_REQUIRE_NO_ERRORS(check("getfenv(1)")); } -TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "os_time_takes_optional_date_table") { CheckResult result = check(R"( local n1 = os.time() @@ -442,22 +450,18 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") CHECK_EQ(*typeChecker.numberType, *requireType("n3")); } -TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "thread_is_a_type") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( local co = coroutine.create(function() end) )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.threadType, *requireType("co")); + CHECK("thread" == toString(requireType("co"))); } -TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") +TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_resume_anything_goes") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( local function nifty(x, y) print(x, y) @@ -474,10 +478,8 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") +TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_wrap_anything_goes") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( --!nonstrict local function nifty(x, y) @@ -495,7 +497,7 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_should_not_mutate_persisted_types") { CheckResult result = check(R"( local string = string @@ -510,7 +512,7 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") REQUIRE(ttv); } -TEST_CASE_FIXTURE(Fixture, "string_format_arg_types_inference") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_arg_types_inference") { CheckResult result = check(R"( --!strict @@ -523,7 +525,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_arg_types_inference") CHECK_EQ("(number, number, string) -> string", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "string_format_arg_count_mismatch") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_arg_count_mismatch") { CheckResult result = check(R"( --!strict @@ -539,7 +541,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_arg_count_mismatch") CHECK_EQ(result.errors[2].location.begin.line, 4); } -TEST_CASE_FIXTURE(Fixture, "string_format_correctly_ordered_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_correctly_ordered_types") { CheckResult result = check(R"( --!strict @@ -553,7 +555,30 @@ TEST_CASE_FIXTURE(Fixture, "string_format_correctly_ordered_types") CHECK_EQ(tm->givenType, typeChecker.numberType); } -TEST_CASE_FIXTURE(Fixture, "xpcall") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_tostring_specifier") +{ + CheckResult result = check(R"( + --!strict + string.format("%* %* %* %*", "string", 1, true, function() end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_tostring_specifier_type_constraint") +{ + CheckResult result = check(R"( + local function f(x): string + local _ = string.format("%*", x) + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(string) -> string", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall") { CheckResult result = check(R"( --!strict @@ -564,12 +589,21 @@ TEST_CASE_FIXTURE(Fixture, "xpcall") )"); LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("boolean", toString(requireType("a"))); - REQUIRE_EQ("number", toString(requireType("b"))); - REQUIRE_EQ("boolean", toString(requireType("c"))); + CHECK_EQ("boolean", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("boolean", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "see_thru_select") +TEST_CASE_FIXTURE(BuiltinsFixture, "trivial_select") +{ + CheckResult result = check(R"( + local a:number = select(1, 42) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "see_thru_select") { CheckResult result = check(R"( local a:number, b:boolean = select(2,"hi", 10, true) @@ -578,7 +612,7 @@ TEST_CASE_FIXTURE(Fixture, "see_thru_select") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "see_thru_select_count") +TEST_CASE_FIXTURE(BuiltinsFixture, "see_thru_select_count") { CheckResult result = check(R"( local a = select("#","hi", 10, true) @@ -588,7 +622,7 @@ TEST_CASE_FIXTURE(Fixture, "see_thru_select_count") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "select_with_decimal_argument_is_rounded_down") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_decimal_argument_is_rounded_down") { CheckResult result = check(R"( local a: number, b: boolean = select(2.9, "foo", 1, true) @@ -598,7 +632,7 @@ TEST_CASE_FIXTURE(Fixture, "select_with_decimal_argument_is_rounded_down") } // Could be flaky if the fix has regressed. -TEST_CASE_FIXTURE(Fixture, "bad_select_should_not_crash") +TEST_CASE_FIXTURE(BuiltinsFixture, "bad_select_should_not_crash") { CheckResult result = check(R"( do end @@ -610,10 +644,12 @@ TEST_CASE_FIXTURE(Fixture, "bad_select_should_not_crash") end )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Argument count mismatch. Function '_' expects at least 1 argument, but none are specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function 'select' expects 1 argument, but none are specified", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "select_way_out_of_range") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_way_out_of_range") { CheckResult result = check(R"( select(5432598430953240958) @@ -624,7 +660,7 @@ TEST_CASE_FIXTURE(Fixture, "select_way_out_of_range") CHECK_EQ("bad argument #1 to select (index out of range)", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "select_slightly_out_of_range") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_slightly_out_of_range") { CheckResult result = check(R"( select(3, "a", 1) @@ -635,7 +671,7 @@ TEST_CASE_FIXTURE(Fixture, "select_slightly_out_of_range") CHECK_EQ("bad argument #1 to select (index out of range)", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail") { CheckResult result = check(R"( --!nonstrict @@ -654,7 +690,7 @@ TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail") CHECK_EQ("any", toString(requireType("quux"))); } -TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail_and_string_head") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_string_head") { CheckResult result = check(R"( --!nonstrict @@ -667,7 +703,11 @@ TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail_and_string_head") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("foo"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("string", toString(requireType("foo"))); + else + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); CHECK_EQ("any", toString(requireType("baz"))); CHECK_EQ("any", toString(requireType("quux"))); @@ -693,7 +733,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function expects 2 arguments, but 3 are specified", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument2") @@ -708,7 +748,21 @@ TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument2") CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "debug_traceback_is_crazy") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_use_correct_argument3") +{ + CheckResult result = check(R"( + local s1 = string.format("%d") + local s2 = string.format("%d", 1) + local s3 = string.format("%d", 1, 2) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ("Argument count mismatch. Function expects 2 arguments, but only 1 is specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function expects 2 arguments, but 3 are specified", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "debug_traceback_is_crazy") { CheckResult result = check(R"( local co: thread = ... @@ -725,7 +779,7 @@ debug.traceback(co, "msg", 1) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "debug_info_is_crazy") +TEST_CASE_FIXTURE(BuiltinsFixture, "debug_info_is_crazy") { CheckResult result = check(R"( local co: thread, f: ()->() = ... @@ -739,7 +793,7 @@ debug.info(f, "n") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "aliased_string_format") +TEST_CASE_FIXTURE(BuiltinsFixture, "aliased_string_format") { CheckResult result = check(R"( local fmt = string.format @@ -750,7 +804,7 @@ TEST_CASE_FIXTURE(Fixture, "aliased_string_format") CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "string_lib_self_noself") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_lib_self_noself") { CheckResult result = check(R"( --!nonstrict @@ -769,20 +823,20 @@ TEST_CASE_FIXTURE(Fixture, "string_lib_self_noself") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "gmatch_definition") +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_definition") { CheckResult result = check(R"_( -local a, b, c = ("hey"):gmatch("(.)(.)(.)")() + local a, b, c = ("hey"):gmatch("(.)(.)(.)")() -for c in ("hey"):gmatch("(.)") do - print(c:upper()) -end -)_"); + for c in ("hey"):gmatch("(.)") do + print(c:upper()) + end + )_"); LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "select_on_variadic") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_on_variadic") { CheckResult result = check(R"( local function f(): (number, ...(boolean | number)) @@ -798,17 +852,18 @@ TEST_CASE_FIXTURE(Fixture, "select_on_variadic") CHECK_EQ("any", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_positions") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_report_all_type_errors_at_correct_positions") { CheckResult result = check(R"( ("%s%d%s"):format(1, "hello", true) + string.format("%s%d%s", 1, "hello", true) )"); TypeId stringType = typeChecker.stringType; TypeId numberType = typeChecker.numberType; TypeId booleanType = typeChecker.booleanType; - LUAU_REQUIRE_ERROR_COUNT(3, result); + LUAU_REQUIRE_ERROR_COUNT(6, result); CHECK_EQ(Location(Position{1, 26}, Position{1, 27}), result.errors[0].location); CHECK_EQ(TypeErrorData(TypeMismatch{stringType, numberType}), result.errors[0].data); @@ -818,12 +873,40 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi CHECK_EQ(Location(Position{1, 38}, Position{1, 42}), result.errors[2].location); CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[2].data); + + CHECK_EQ(Location(Position{2, 32}, Position{2, 33}), result.errors[3].location); + CHECK_EQ(TypeErrorData(TypeMismatch{stringType, numberType}), result.errors[3].data); + + CHECK_EQ(Location(Position{2, 35}, Position{2, 42}), result.errors[4].location); + CHECK_EQ(TypeErrorData(TypeMismatch{numberType, stringType}), result.errors[4].data); + + CHECK_EQ(Location(Position{2, 44}, Position{2, 48}), result.errors[5].location); + CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[5].data); } -TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "tonumber_returns_optional_number_type") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; + CheckResult result = check(R"( + --!strict + local b: number = tonumber('asdf') + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tonumber_returns_optional_number_type2") +{ + CheckResult result = check(R"( + --!strict + local b: number = tonumber('asdf') or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_add_definitions_to_persistent_types") +{ CheckResult result = check(R"( local f = math.sin local function g(x) return math.sin(x) end @@ -844,4 +927,365 @@ TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") REQUIRE(gtv->definition); } +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types") +{ + CheckResult result = check(R"( + local function f(x: (number | boolean)?) + return assert(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types2") +{ + CheckResult result = check(R"( + local function f(x: (number | boolean)?): number | true + return assert(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") +{ + CheckResult result = check(R"( + local function f(...: number?) + return assert(...) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(...number?) -> (number, ...number?)", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") +{ + CheckResult result = check(R"( + local function f(x: nil) + return assert(x, "hmm") + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(nil) -> (never, ...never)", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") +{ + CheckResult result = check(R"( + local t1: {a: number} = {a = 42} + local t2: {b: string} = {b = "hello"} + local t3: {boolean} = {false, true} + + local tf1 = table.freeze(t1) + local tf2 = table.freeze(t2) + local tf3 = table.freeze(t3) + + local a = tf1.a + local b = tf2.b + local c = tf3[2] + + local d = tf1.b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("string", toString(requireType("b"))); + CHECK_EQ("boolean", toString(requireType("c"))); + CHECK_EQ("*error-type*", toString(requireType("d"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") +{ + ScopedFastFlag sff{"LuauSetMetaTableArgsCheck", true}; + CheckResult result = check(R"( +local a = {b=setmetatable} +a.b() +a:b() +a:b({}) + )"); + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'a.b' expects 2 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function 'a.b' expects 2 arguments, but only 1 is specified"); +} + +TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") +{ + CheckResult result = check(R"( +local function f(a: typeof(f)) end +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Unknown global 'f'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "no_persistent_typelevel_change") +{ + TypeId mathTy = requireType(typeChecker.globalScope, "math"); + REQUIRE(mathTy); + TableTypeVar* ttv = getMutable(mathTy); + REQUIRE(ttv); + const FunctionTypeVar* ftv = get(ttv->props["frexp"].type); + REQUIRE(ftv); + auto original = ftv->level; + + CheckResult result = check("local a = math.frexp"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(ftv->level.level == original.level); + CHECK(ftv->level.subLevel == original.subLevel); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "global_singleton_types_are_sealed") +{ + CheckResult result = check(R"( +local function f(x: string) + local p = x:split('a') + p = table.pack(table.unpack(p, 1, #p - 1)) + return p +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") +{ + CheckResult result = check(R"END( + local a, b, c = string.gmatch("This is a string", "(.()(%a+))")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types2") +{ + CheckResult result = check(R"END( + local a, b, c = ("This is a string"):gmatch("(.()(%a+))")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") +{ + CheckResult result = check(R"END( + local a, b, c, d = string.gmatch("T(his)() is a string", ".")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 4); + + CHECK_EQ(toString(requireType("a")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens") +{ + CheckResult result = check(R"END( + local a, b, c, d = string.gmatch("T(his) is a string", "((.)%b()())")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); + CHECK_EQ(acm->expected, 3); + CHECK_EQ(acm->actual, 4); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "string"); + CHECK_EQ(toString(requireType("c")), "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_ignored") +{ + CheckResult result = check(R"END( + local a, b, c = string.gmatch("T(his)() is a string", "(T[()])()")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 3); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket") +{ + CheckResult result = check(R"END( + local a, b = string.gmatch("[[[", "()([[])")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "number"); + CHECK_EQ(toString(requireType("b")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_leading_end_bracket_is_part_of_set") +{ + CheckResult result = check(R"END( + -- An immediate right-bracket following a left-bracket is included within the set; + -- thus, '[]]'' is the set containing ']', and '[]' is an invalid set missing an enclosing + -- right-bracket. We detect an invalid set in this case and fall back to to default gmatch + -- typing. + local foo = string.gmatch("T[hi%]s]]]() is a string", "([]s)") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_invalid_pattern_fallback_to_builtin") +{ + CheckResult result = check(R"END( + local foo = string.gmatch("T(his)() is a string", ")") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_invalid_pattern_fallback_to_builtin2") +{ + CheckResult result = check(R"END( + local foo = string.gmatch("T(his)() is a string", "[") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types") +{ + CheckResult result = check(R"END( + local a, b, c = string.match("This is a string", "(.()(%a+))") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2") +{ + CheckResult result = check(R"END( + local a, b, c = string.match("This is a string", "(.()(%a+))", "this should be a number") + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(toString(tm->wantedType), "number?"); + CHECK_EQ(toString(tm->givenType), "string"); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types") +{ + CheckResult result = check(R"END( + local d, e, a, b, c = string.find("This is a string", "(.()(%a+))") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + CHECK_EQ(toString(requireType("d")), "number?"); + CHECK_EQ(toString(requireType("e")), "number?"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types2") +{ + CheckResult result = check(R"END( + local d, e, a, b, c = string.find("This is a string", "(.()(%a+))", "this should be a number") + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(toString(tm->wantedType), "number?"); + CHECK_EQ(toString(tm->givenType), "string"); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + CHECK_EQ(toString(requireType("d")), "number?"); + CHECK_EQ(toString(requireType("e")), "number?"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") +{ + CheckResult result = check(R"END( + local d, e, a, b, c = string.find("This is a string", "(.()(%a+))", 1, "this should be a bool") + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(toString(tm->wantedType), "boolean?"); + CHECK_EQ(toString(tm->givenType), "string"); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); + CHECK_EQ(toString(requireType("d")), "number?"); + CHECK_EQ(toString(requireType("e")), "number?"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") +{ + CheckResult result = check(R"END( + local d, e, a, b = string.find("This is a string", "(.()(%a+))", 1, true) + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 4); + + CHECK_EQ(toString(requireType("d")), "number?"); + CHECK_EQ(toString(requireType("e")), "number?"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 6da33a08..07dfc33f 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -1,100 +1,18 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" +#include "Luau/Common.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Fixture.h" +#include "ClassFixture.h" #include "doctest.h" using namespace Luau; using std::nullopt; -struct ClassFixture : Fixture -{ - ClassFixture() - { - TypeArena& arena = typeChecker.globalTypes; - TypeId numberType = typeChecker.numberType; - - unfreeze(arena); - - TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}}); - getMutable(baseClassInstanceType)->props = { - {"BaseMethod", {makeFunction(arena, baseClassInstanceType, {numberType}, {})}}, - {"BaseField", {numberType}}, - }; - - TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}}); - getMutable(baseClassType)->props = { - {"StaticMethod", {makeFunction(arena, nullopt, {}, {numberType})}}, - {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, - {"New", {makeFunction(arena, nullopt, {}, {baseClassInstanceType})}}, - }; - typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - addGlobalBinding(typeChecker, "BaseClass", baseClassType, "@test"); - - TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}}); - - getMutable(childClassInstanceType)->props = { - {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, - }; - - TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}}); - getMutable(childClassType)->props = { - {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, - }; - typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; - addGlobalBinding(typeChecker, "ChildClass", childClassType, "@test"); - - TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}}); - - getMutable(grandChildInstanceType)->props = { - {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, - }; - - TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}}); - getMutable(grandChildType)->props = { - {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, - }; - typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; - addGlobalBinding(typeChecker, "GrandChild", childClassType, "@test"); - - TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}}); - - getMutable(anotherChildInstanceType)->props = { - {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, - }; - - TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}}); - getMutable(anotherChildType)->props = { - {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, - }; - typeChecker.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; - addGlobalBinding(typeChecker, "AnotherChild", childClassType, "@test"); - - TypeId vector2MetaType = arena.addType(TableTypeVar{}); - - TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}}); - getMutable(vector2InstanceType)->props = { - {"X", {numberType}}, - {"Y", {numberType}}, - }; - - TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}}); - getMutable(vector2Type)->props = { - {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, - }; - getMutable(vector2MetaType)->props = { - {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, - }; - typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; - addGlobalBinding(typeChecker, "Vector2", vector2Type, "@test"); - - freeze(arena); - } -}; +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError); TEST_SUITE_BEGIN("TypeInferClasses"); @@ -232,8 +150,6 @@ TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class") TEST_CASE_FIXTURE(ClassFixture, "can_read_prop_of_base_class_using_string") { - ScopedFastFlag luauClassPropertyAccessAsString("LuauClassPropertyAccessAsString", true); - CheckResult result = check(R"( local c = ChildClass.New() local x = 1 + c["BaseField"] @@ -244,8 +160,6 @@ TEST_CASE_FIXTURE(ClassFixture, "can_read_prop_of_base_class_using_string") TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class_using_string") { - ScopedFastFlag luauClassPropertyAccessAsString("LuauClassPropertyAccessAsString", true); - CheckResult result = check(R"( local c = ChildClass.New() c["BaseField"] = 444 @@ -437,8 +351,6 @@ TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_orde TEST_CASE_FIXTURE(ClassFixture, "optional_class_field_access_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local b: Vector2? = nil local a = b.X + b.Z @@ -453,4 +365,132 @@ b.X = 2 -- real Vector2.X is also read-only CHECK_EQ("Value of type 'Vector2?' could be nil", toString(result.errors[3])); } +TEST_CASE_FIXTURE(ClassFixture, "detailed_class_unification_error") +{ + CheckResult result = check(R"( +local function foo(v) + return v.X :: number + string.len(v.Y) +end + +local a: Vector2 +local b = foo +b(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'Vector2' could not be converted into '{- X: a, Y: string -}' +caused by: + Property 'Y' is not compatible. Type 'number' could not be converted into 'string')", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(ClassFixture, "class_type_mismatch_with_name_conflict") +{ + CheckResult result = check(R"( +local i = ChildClass.New() +type ChildClass = { x: number } +local a: ChildClass = i + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'ChildClass' from 'Test' could not be converted into 'ChildClass' from 'MainModule'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(ClassFixture, "intersections_of_unions_of_classes") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (BaseClass | Vector2) & (ChildClass | AnotherChild) + local y : (ChildClass | AnotherChild) + x = y + y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "unions_of_intersections_of_classes") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (BaseClass & ChildClass) | (BaseClass & AnotherChild) | (BaseClass & Vector2) + local y : (ChildClass | AnotherChild) + x = y + y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "index_instance_property") +{ + ScopedFastFlag luauAllowIndexClassParameters{"LuauAllowIndexClassParameters", true}; + + CheckResult result = check(R"( + local function execute(object: BaseClass, name: string) + print(object[name]) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Attempting a dynamic property access on type 'BaseClass' is unsafe and may cause exceptions at runtime", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(ClassFixture, "index_instance_property_nonstrict") +{ + ScopedFastFlag luauAllowIndexClassParameters{"LuauAllowIndexClassParameters", true}; + + CheckResult result = check(R"( + --!nonstrict + + local function execute(object: BaseClass, name: string) + print(object[name]) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "type_mismatch_invariance_required_for_error") +{ + CheckResult result = check(R"( +type A = { x: ChildClass } +type B = { x: BaseClass } + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'x' is not compatible. Type 'ChildClass' could not be converted into 'BaseClass' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'x' is not compatible. Type 'ChildClass' could not be converted into 'BaseClass')"); +} + +TEST_CASE_FIXTURE(ClassFixture, "callable_classes") +{ + ScopedFastFlag luauCallableClasses{"LuauCallableClasses", true}; + + CheckResult result = check(R"( + local x : CallableClass + local y = x("testing") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("y"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 41e3e45a..26115046 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -12,6 +11,33 @@ using namespace Luau; TEST_SUITE_BEGIN("DefinitionTests"); +TEST_CASE_FIXTURE(Fixture, "definition_file_simple") +{ + loadDefinition(R"( + declare foo: number + declare function bar(x: number): string + declare foo2: typeof(foo) + )"); + + TypeId globalFooTy = getGlobalBinding(frontend, "foo"); + CHECK_EQ(toString(globalFooTy), "number"); + + TypeId globalBarTy = getGlobalBinding(frontend, "bar"); + CHECK_EQ(toString(globalBarTy), "(number) -> string"); + + TypeId globalFoo2Ty = getGlobalBinding(frontend, "foo2"); + CHECK_EQ(toString(globalFoo2Ty), "number"); + + CheckResult result = check(R"( + local x: number = foo - 1 + local y: string = bar(x) + local z: number | string = x + z = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "definition_file_loading") { loadDefinition(R"( @@ -22,20 +48,20 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_loading") declare function var(...: any): string )"); - TypeId globalFooTy = getGlobalBinding(frontend.typeChecker, "foo"); + TypeId globalFooTy = getGlobalBinding(frontend, "foo"); CHECK_EQ(toString(globalFooTy), "number"); - std::optional globalAsdfTy = frontend.typeChecker.globalScope->lookupType("Asdf"); + std::optional globalAsdfTy = frontend.getGlobalScope()->lookupType("Asdf"); REQUIRE(bool(globalAsdfTy)); CHECK_EQ(toString(globalAsdfTy->type), "number | string"); - TypeId globalBarTy = getGlobalBinding(frontend.typeChecker, "bar"); + TypeId globalBarTy = getGlobalBinding(frontend, "bar"); CHECK_EQ(toString(globalBarTy), "(number) -> string"); - TypeId globalFoo2Ty = getGlobalBinding(frontend.typeChecker, "foo2"); + TypeId globalFoo2Ty = getGlobalBinding(frontend, "foo2"); CHECK_EQ(toString(globalFoo2Ty), "number"); - TypeId globalVarTy = getGlobalBinding(frontend.typeChecker, "var"); + TypeId globalVarTy = getGlobalBinding(frontend, "var"); CHECK_EQ(toString(globalVarTy), "(...any) -> string"); @@ -59,7 +85,7 @@ TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_sc freeze(typeChecker.globalTypes); REQUIRE(!parseFailResult.success); - std::optional fooTy = tryGetGlobalBinding(typeChecker, "foo"); + std::optional fooTy = tryGetGlobalBinding(frontend, "foo"); CHECK(!fooTy.has_value()); LoadDefinitionFileResult checkFailResult = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( @@ -69,7 +95,7 @@ TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_sc "@test"); REQUIRE(!checkFailResult.success); - std::optional barTy = tryGetGlobalBinding(typeChecker, "bar"); + std::optional barTy = tryGetGlobalBinding(frontend, "bar"); CHECK(!barTy.has_value()); } @@ -171,9 +197,6 @@ TEST_CASE_FIXTURE(Fixture, "no_cyclic_defined_classes") TEST_CASE_FIXTURE(Fixture, "declaring_generic_functions") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - loadDefinition(R"( declare function f(a: a, b: b): string declare function g(...: a...): b... @@ -286,6 +309,21 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") CHECK_EQ(yTtv->props["x"].documentationSymbol, "@test/global/y.x"); } +TEST_CASE_FIXTURE(Fixture, "definitions_symbols_are_generated_for_recursively_referenced_types") +{ + loadDefinition(R"( + declare class MyClass + function myMethod(self) + end + + declare function myFunc(): MyClass + )"); + + std::optional myClassTy = typeChecker.globalScope->lookupType("MyClass"); + REQUIRE(bool(myClassTy)); + CHECK_EQ(myClassTy->type->documentationSymbol, "@test/globaltype/MyClass"); +} + TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_types") { loadDefinition(R"( @@ -297,4 +335,85 @@ TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_type CHECK_EQ(ty->type->documentationSymbol, std::nullopt); } +TEST_CASE_FIXTURE(Fixture, "single_class_type_identity_in_global_types") +{ + loadDefinition(R"( +declare class Cls +end + +declare GetCls: () -> (Cls) + )"); + + CheckResult result = check(R"( +local s : Cls = GetCls() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "class_definition_overload_metamethods") +{ + loadDefinition(R"( + declare class Vector3 + end + + declare class CFrame + function __mul(self, other: CFrame): CFrame + function __mul(self, other: Vector3): Vector3 + end + + declare function newVector3(): Vector3 + declare function newCFrame(): CFrame + )"); + + CheckResult result = check(R"( + local base = newCFrame() + local shouldBeCFrame = base * newCFrame() + local shouldBeVector = base * newVector3() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("shouldBeCFrame")), "CFrame"); + CHECK_EQ(toString(requireType("shouldBeVector")), "Vector3"); +} + +TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") +{ + loadDefinition(R"( + declare class Foo + ["a property"]: string + end + )"); + + CheckResult result = check(R"( + local x: Foo + local y = x["a property"] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("y")), "string"); +} + +TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") +{ + ScopedFastFlag LuauDeclareClassPrototype("LuauDeclareClassPrototype", true); + + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + declare class Channel + Messages: { Message } + OnMessage: (message: Message) -> () + end + + declare class Message + Text: string + Channel: Channel + end + )", + "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE(result.success); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp new file mode 100644 index 00000000..55218040 --- /dev/null +++ b/tests/TypeInfer.functions.test.cpp @@ -0,0 +1,1847 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Error.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauInstantiateInSubtyping); + +TEST_SUITE_BEGIN("TypeInferFunctions"); + +TEST_CASE_FIXTURE(Fixture, "tc_function") +{ + CheckResult result = check("function five() return 5 end"); + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* fiveType = get(requireType("five")); + REQUIRE(fiveType != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "check_function_bodies") +{ + CheckResult result = check("function myFunction() local a = 0 a = true end"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 44}, Position{0, 48}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.booleanType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_hoist_interior_defns_into_signature") +{ + // This test verifies that the signature does not have access to types + // declared within the body. Under DCR, if the function's inner scope + // encompasses the entire function expression, it would be possible for this + // to type check (but the solver output is somewhat undefined). This test + // ensures that this isn't the case. + CheckResult result = check(R"( + local function f(x: T) + type T = number + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{Location{{1, 28}, {1, 29}}, getMainSourceModule()->name, + UnknownSymbol{ + "T", + UnknownSymbol::Context::Type, + }}); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_type") +{ + CheckResult result = check("function take_five() return 5 end"); + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* takeFiveType = get(requireType("take_five")); + REQUIRE(takeFiveType != nullptr); + + std::vector retVec = flatten(takeFiveType->retTypes).first; + REQUIRE(!retVec.empty()); + + REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") +{ + CheckResult result = check("function take_five() return 5 end local five = take_five()"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *follow(requireType("five"))); +} + +TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") +{ + CheckResult result = check(R"( + function take_five() + return 5 + end + + take_five().prop = 888 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{typeChecker.numberType}})); +} + +TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and_size") +{ + CheckResult result = check(R"( + function f(...) end + + f(1) + f("foo", 2) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") +{ + CheckResult result = check(R"( + local T = {} + function T.f(...) + local result = {} + + for i = 1, select("#", ...) do + local dictionary = select(i, ...) + for key, value in pairs(dictionary) do + result[key] = value + end + end + + return result + end + + return T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto r = first(getMainModule()->getModuleScope()->returnType); + REQUIRE(r); + + TableTypeVar* ttv = getMutable(*r); + REQUIRE(ttv); + + REQUIRE(ttv->props.count("f")); + TypeId k = ttv->props["f"].type; + REQUIRE(k); +} + +TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_count") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply("") + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); + + ExtraInformation* ei = get(result.errors[1]); + REQUIRE(ei); + CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); +} + +TEST_CASE_FIXTURE(Fixture, "list_all_overloads_if_no_overload_takes_given_argument_count") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply() + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("No overload for function accepts 0 arguments.", ge->message); + + ExtraInformation* ei = get(result.errors[1]); + REQUIRE(ei); + CHECK_EQ("Available overloads: (number) -> number; (number) -> string; and (number, number) -> number", ei->message); +} + +TEST_CASE_FIXTURE(Fixture, "dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply(1, "") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") +{ + CheckResult result = check(R"( + type T = {method: ((T, number) -> number) & ((number) -> string)} + local T: T + + local a = T.method(T, 4) + local b = T.method(5) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("string", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_arguments") +{ + CheckResult result = check(R"( + --!nonstrict + + function g(a: number) end + + g() + + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(1, acm->expected); + CHECK_EQ(0, acm->actual); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_arguments_error_location") +{ + CheckResult result = check(R"( + --!strict + + function myfunction(a: number, b:number) end + myfunction(1) + + function getmyfunction() + return myfunction + end + getmyfunction()() + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + { + TypeError err = result.errors[0]; + + // Ensure the location matches the location of the function identifier + CHECK_EQ(err.location, Location(Position(4, 8), Position(4, 18))); + + auto acm = get(err); + REQUIRE(acm); + CHECK_EQ(2, acm->expected); + CHECK_EQ(1, acm->actual); + } + { + TypeError err = result.errors[1]; + + // Ensure the location matches the location of the expression returning the function + CHECK_EQ(err.location, Location(Position(9, 8), Position(9, 23))); + + auto acm = get(err); + REQUIRE(acm); + CHECK_EQ(2, acm->expected); + CHECK_EQ(0, acm->actual); + } +} + +TEST_CASE_FIXTURE(Fixture, "recursive_function") +{ + CheckResult result = check(R"( + function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "lambda_form_of_local_function_cannot_be_recursive") +{ + CheckResult result = check(R"( + local f = function() return f() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_local_function") +{ + CheckResult result = check(R"( + local function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// FIXME: This and the above case get handled very differently. It's pretty dumb. +// We really should unify the two code paths, probably by deleting AstStatFunction. +TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") +{ + CheckResult result = check(R"( + local count + function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") +{ + CheckResult result = check(R"( + function f() + return f + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = () -> t1", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") +{ + CheckResult result = check(R"( + local Get_des + function Get_des(func) + Get_des(func) + end + + local function f(d) + d:IsA("BasePart") + d.Parent:FindFirstChild("Humanoid") + d:IsA("Decal") + end + Get_des(f) + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_other_higher_order_function") +{ + CheckResult result = check(R"( + local d + d:foo() + d:foo() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "local_function") +{ + CheckResult result = check(R"( + function f() + return 8 + end + + function g() + local function f() + return 'hello' + end + return f + end + + local h = g() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId h = follow(requireType("h")); + + const FunctionTypeVar* ftv = get(h); + REQUIRE(ftv != nullptr); + + std::optional rt = first(ftv->retTypes); + REQUIRE(bool(rt)); + + TypeId retType = follow(*rt); + CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(retType)); +} + +TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") +{ + CheckResult result = check(R"( + local p = function(x) return x end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + const Luau::FunctionTypeVar* fn = get(requireType("p")); + REQUIRE(fn); + auto ret = first(fn->retTypes); + REQUIRE(ret); + REQUIRE(get(follow(*ret))); +} + +TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional") +{ + CheckResult result = check(R"( + local T = {} + function T.new(a: number?, b: number?, c: number?) return 5 end + local m = T.new() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "it_is_ok_not_to_supply_enough_retvals") +{ + CheckResult result = check(R"( + function get_two() return 5, 6 end + + local a = get_two() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions2") +{ + CheckResult result = check(R"( + function foo() end + + function bar() + local function foo() end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions_allowed_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + function foo() end + + function foo() end + + function bar() + local function foo() end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions_with_different_signatures_not_allowed_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + function foo(): number + return 1 + end + foo() + + function foo(n: number): number + return 2 + end + foo() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("() -> number", toString(tm->wantedType)); + CHECK_EQ("(number) -> number", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotation") +{ + CheckResult result = check(R"( + local i = 0 + function most_of_the_natural_numbers(): number? + if i < 10 then + i = i + 1 + return i + else + return nil + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = requireType("most_of_the_natural_numbers"); + const FunctionTypeVar* functionType = get(ty); + REQUIRE_MESSAGE(functionType, "Expected function but got " << toString(ty)); + + std::optional retType = first(functionType->retTypes); + REQUIRE(retType); + CHECK(get(*retType)); +} + +TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") +{ + CheckResult result = check(R"( + function apply(f, x) + return f(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("apply")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(2, argVec.size()); + + const FunctionTypeVar* fType = get(follow(argVec[0])); + REQUIRE_MESSAGE(fType != nullptr, "Expected a function but got " << toString(argVec[0])); + + std::vector fArgs = flatten(fType->argTypes).first; + + TypeId xType = follow(argVec[1]); + + CHECK_EQ(1, fArgs.size()); + CHECK_EQ(xType, follow(fArgs[0])); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") +{ + CheckResult result = check(R"( + function bottomupmerge(comp, a, b, left, mid, right) + local i, j = left, mid + for k = left, right do + if i < mid and (j > right or not comp(a[j], a[i])) then + b[k] = a[i] + i = i + 1 + else + b[k] = a[j] + j = j + 1 + end + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("bottomupmerge")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(6, argVec.size()); + + const FunctionTypeVar* fType = get(follow(argVec[0])); + REQUIRE(fType != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") +{ + CheckResult result = check(R"( + function swap(p) + local t = p[0] + p[0] = p[1] + p[1] = t + return nil + end + + function swapTwice(p) + swap(p) + swap(p) + return p + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("swapTwice")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(1, argVec.size()); + + const TableTypeVar* argType = get(follow(argVec[0])); + REQUIRE(argType != nullptr); + + CHECK(bool(argType->indexer)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") +{ + CheckResult result = check(R"( + function bottomupmerge(comp, a, b, left, mid, right) + local i, j = left, mid + for k = left, right do + if i < mid and (j > right or not comp(a[j], a[i])) then + b[k] = a[i] + i = i + 1 + else + b[k] = a[j] + j = j + 1 + end + end + end + + function mergesort(arr, comp) + local work = {} + for i = 1, #arr do + work[i] = arr[i] + end + local width = 1 + while width < #arr do + for i = 1, #arr, 2*width do + bottomupmerge(comp, arr, work, i, math.min(i+width, #arr), math.min(i+2*width-1, #arr)) + end + local temp = work + work = arr + arr = temp + width = width * 2 + end + return arr + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + /* + * mergesort takes two arguments: an array of some type T and a function that takes two Ts. + * We must assert that these two types are in fact the same type. + * In other words, comp(arr[x], arr[y]) is well-typed. + */ + + const FunctionTypeVar* ftv = get(requireType("mergesort")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(2, argVec.size()); + + const TableTypeVar* arg0 = get(follow(argVec[0])); + REQUIRE(arg0 != nullptr); + REQUIRE(bool(arg0->indexer)); + + const FunctionTypeVar* arg1 = get(follow(argVec[1])); + REQUIRE(arg1 != nullptr); + REQUIRE_EQ(2, size(arg1->argTypes)); + + std::vector arg1Args = flatten(arg1->argTypes).first; + + CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[0]); + CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "mutual_recursion") +{ + CheckResult result = check(R"( + --!strict + + function newPlayerCharacter() + startGui() -- Unknown symbol 'startGui' + end + + local characterAddedConnection: any + function startGui() + characterAddedConnection = game:GetService("Players").LocalPlayer.CharacterAdded:connect(newPlayerCharacter) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "toposort_doesnt_break_mutual_recursion") +{ + CheckResult result = check(R"( + --!strict + local x = nil + function f() g() end + -- make sure print(x) doesn't get toposorted here, breaking the mutual block + function g() x = f end + print(x) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") +{ + CheckResult result = check(R"( + --!nonstrict + + function f() + return 114 + end + + return function() + return f():andThen() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") +{ + CheckResult result = check(R"( + function onerror() end + function foo() end + xpcall(foo, onerror) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_indirect_function_case_where_it_is_ok_to_provide_too_many_arguments") +{ + CheckResult result = check(R"( + local mycb: (number, number) -> () + + function f() end + + mycb = f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + + local function f1(v): number? + if v then + return 1 + end + end + + local function f2(v) + if v then + return 1 + end + end + + local function f3(v): () + if v then + return + end + end + + local function f4(v) + if v then + return + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + FunctionExitsWithoutReturning* err = get(result.errors[0]); + CHECK(err); +} + +TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_strict") +{ + CheckResult result = check(R"( + --!strict + + local function f1(v): number? + if v then + return 1 + end + end + + local function f2(v) + if v then + return 1 + end + end + + local function f3(v): () + if v then + return + end + end + + local function f4(v) + if v then + return + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + FunctionExitsWithoutReturning* annotatedErr = get(result.errors[0]); + CHECK(annotatedErr); + + FunctionExitsWithoutReturning* inferredErr = get(result.errors[1]); + CHECK(inferredErr); +} + +TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields_errors_spanning_argument") +{ + CheckResult result = check(R"( + function foo(a: number, b: string) end + + foo("Test", 123) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ(result.errors[0], (TypeError{Location{Position{3, 12}, Position{3, 18}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.stringType, + }})); + + CHECK_EQ(result.errors[1], (TypeError{Location{Position{3, 20}, Position{3, 23}}, TypeMismatch{ + typeChecker.stringType, + typeChecker.numberType, + }})); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "calling_function_with_anytypepack_doesnt_leak_free_types") +{ + CheckResult result = check(R"( + --!nonstrict + + function Test(a) + return 1, "" + end + + + local tab = {} + table.insert(tab, Test(1)); + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + opts.maxTableLength = 0; + + CHECK_EQ("{any}", toString(requireType("tab"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_return_values") +{ + ScopedFastFlag sff{"LuauBetterMessagingOnCountMismatch", true}; + + CheckResult result = check(R"( + --!strict + + function f() + return 55 + end + + local a, b = f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 2); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_return_values_in_parentheses") +{ + ScopedFastFlag sff{"LuauBetterMessagingOnCountMismatch", true}; + + CheckResult result = check(R"( + --!strict + + function f() + return 55 + end + + local a, b = (f()) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::FunctionResult); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 2); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_return_values_no_function") +{ + ScopedFastFlag sff{"LuauBetterMessagingOnCountMismatch", true}; + + CheckResult result = check(R"( + --!strict + + local a, b = 55 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::ExprListResult); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 2); +} + +TEST_CASE_FIXTURE(Fixture, "ignored_return_values") +{ + CheckResult result = check(R"( + --!strict + + function f() + return 55, "" + end + + local a = f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") +{ + CheckResult result = check(R"( + --!strict + + function f(): (number, string) + return 55 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Return); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 1); +} + +TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") +{ + CheckResult result = check(R"( + function foo(a, b): number + return 0 + end + + local a: (string)->number = foo + local b: (number, number)->(number, number) = foo + + local c: (string, number)->number = foo -- no error + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + auto tm1 = get(result.errors[0]); + REQUIRE(tm1); + + CHECK_EQ("(string) -> number", toString(tm1->wantedType)); + CHECK_EQ("(string, *error-type*) -> number", toString(tm1->givenType)); + + auto tm2 = get(result.errors[1]); + REQUIRE(tm2); + + CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); + CHECK_EQ("(string, *error-type*) -> number", toString(tm2->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") +{ + CheckResult result = check(R"( + --!strict + local tbl = {} + function tbl:abc(a: number, b: number) + return a + end + tbl:abc(1, 2) -- Line 6 + -- | Column 14 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + TypeId type = requireTypeAtPosition(Position(6, 14)); + CHECK_EQ("(tbl, number, number) -> number", toString(type)); + auto ftv = get(type); + REQUIRE(ftv); + CHECK(ftv->hasSelf); +} + +TEST_CASE_FIXTURE(Fixture, "record_matching_overload") +{ + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number) -> number) + local abc: Overload + abc(1) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // AstExprCall is the node that has the overload stored on it. + // findTypeAtPosition will look at the AstExprLocal, but this is not what + // we want to look at. + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), Position(3, 10)); + REQUIRE_GE(ancestry.size(), 2); + AstExpr* parentExpr = ancestry[ancestry.size() - 2]->asExpr(); + REQUIRE(bool(parentExpr)); + REQUIRE(parentExpr->is()); + + ModulePtr module = getMainModule(); + auto it = module->astOverloadResolvedTypes.find(parentExpr); + REQUIRE(it); + CHECK_EQ(toString(*it), "(number) -> number"); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") +{ + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number, number) -> number) + local abc: Overload + local x = abc(true) + local y = abc(true,true) + local z = abc(true,true,true) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("number", toString(requireType("y"))); + // Should this be string|number? + CHECK_EQ("string", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") +{ + // Simple direct arg to arg propagation + CheckResult result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // An optional function is accepted, but since we already provide a function, nil can be ignored + result = check(R"( +type Table = { x: number, y: number } +local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Make sure self calls match correct index + result = check(R"( +type Table = { x: number, y: number } +local x = {} +x.b = {x = 1, y = 2} +function x:f(a: (Table) -> number) return a(self.b) end +x:f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Mix inferred and explicit argument types + result = check(R"( +function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end +f(function(a: number, b, c) return c and a + b or b - a end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Anonymous function has a variadic pack + result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(...) return select(1, ...).z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Can't accept more arguments than provided + result = check(R"( +function f(a: (a: number, b: number) -> number) return a(1, 2) end +f(function(a, b, c, ...) return a + b end) + )"); + + LUAU_REQUIRE_ERRORS(result); + + if (FFlag::LuauInstantiateInSubtyping) + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } + + // Infer from variadic packs into elements + result = check(R"( +function f(a: (...number) -> number) return a(1, 2) end +f(function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Infer from variadic packs into variadic packs + result = check(R"( +type Table = { x: number, y: number } +function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end +f(function(a, ...) local b = ... return b.z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Return type inference + result = check(R"( +type Table = { x: number, y: number } +function f(a: (number) -> Table) return a(4) end +f(function(x) return x * 2 end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); + + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") +{ + // Simple direct arg to arg propagation + CheckResult result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // An optional function is accepted, but since we already provide a function, nil can be ignored + result = check(R"( +type Table = { x: number, y: number } +local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Make sure self calls match correct index + result = check(R"( +type Table = { x: number, y: number } +local x = {} +x.b = {x = 1, y = 2} +function x:f(a: (Table) -> number) return a(self.b) end +x:f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Mix inferred and explicit argument types + result = check(R"( +function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end +f(function(a: number, b, c) return c and a + b or b - a end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Anonymous function has a variadic pack + result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(...) return select(1, ...).z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Can't accept more arguments than provided + result = check(R"( +function f(a: (a: number, b: number) -> number) return a(1, 2) end +f(function(a, b, c, ...) return a + b end) + )"); + + LUAU_REQUIRE_ERRORS(result); + + if (FFlag::LuauInstantiateInSubtyping) + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } + + // Infer from variadic packs into elements + result = check(R"( +function f(a: (...number) -> number) return a(1, 2) end +f(function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Infer from variadic packs into variadic packs + result = check(R"( +type Table = { x: number, y: number } +function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end +f(function(a, ...) local b = ... return b.z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Return type inference + result = check(R"( +type Table = { x: number, y: number } +function f(a: (number) -> Table) return a(4) end +f(function(x) return x * 2 end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); + + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") +{ + CheckResult result = check(R"( +type Table = { x: number, y: number } +local f: (Table) -> number = function(t) return t.x + t.y end + +type TableWithFunc = { x: number, y: number, f: (number, number) -> number } +local a: TableWithFunc = { x = 3, y = 4, f = function(a, b) return a + b end } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") +{ + CheckResult result = check(R"( +local function f(): {string|number} + return {1, "b", 3} +end + +local function g(): (number, {string|number}) + return 4, {1, "b", 3} +end + +local function h(): ...{string|number} + return {4}, {1, "b", 3}, {"s"} +end + +local function i(): ...{string|number} + return {1, "b", 3}, h() +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") +{ + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") +{ + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, string) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, string) -> string' +caused by: + Argument #2 type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") +{ + CheckResult result = check(R"( +type A = (number, number) -> (number) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Function only returns 1 value, but 2 are required here)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") +{ + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, number) -> number + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, number) -> number' +caused by: + Return type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") +{ + CheckResult result = check(R"( +type A = (number, number) -> (number, string) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_quantify_right_type") +{ + fileResolver.source["game/isAMagicMock"] = R"( +--!nonstrict +return function(value) + return false +end + )"; + + CheckResult result = check(R"( +--!nonstrict +local MagicMock = {} +MagicMock.is = require(game.isAMagicMock) + +function MagicMock.is(value) + return false +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite") +{ + CheckResult result = check(R"( +function string.len(): number + return 1 +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // if 'string' library property was replaced with an internal module type, it will be freed and the next check will crash + frontend.clear(); + + result = check(R"( +print(string.len('hello')) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite_2") +{ + CheckResult result = check(R"( +local t: { f: ((x: number) -> number)? } = {} + +function t.f(x) + print(x + 5) + return x .. "asd" -- 1st error: we know that return type is a number, not a string +end + +t.f = function(x) + print(x + 5) + return x .. "asd" -- 2nd error: we know that return type is a number, not a string +end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'string' could not be converted into 'number')"); + CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time2") +{ + CheckResult result = check(R"( + --!strict + + local function resolveDispatcher() + return (nil :: any) :: {useContext: (number?) -> any} + end + + local useContext + useContext = function(unstable_observedBits: number?) + resolveDispatcher().useContext(unstable_observedBits) + end + )"); + + // LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken. + // You get a TypeMismatch error where both types stringify the same. + + CHECK(result.errors.empty()); + if (!result.errors.empty()) + { + for (const auto& e : result.errors) + printf("%s %s: %s\n", e.moduleName.c_str(), toString(e.location).c_str(), toString(e).c_str()); + } +} + +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time3") +{ + CheckResult result = check(R"( + local foo + + foo():bar(function() + return foo() + end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_unsealed_overwrite") +{ + CheckResult result = check(R"( +local t = { f = nil :: ((x: number) -> number)? } + +function t.f(x: string): string -- 1st error: new function value type is incompatible + return x .. "asd" +end + +t.f = function(x) + print(x + 5) + return x .. "asd" -- 2nd error: we know that return type is a number, not a string +end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string) -> string' could not be converted into '((number) -> number)?' +caused by: + None of the union options are compatible. For example: Type '(string) -> string' could not be converted into '(number) -> number' +caused by: + Argument #1 type is not compatible. Type 'number' could not be converted into 'string')"); + CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") +{ + CheckResult result = check(R"( + local function f(x: any) end + f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") +{ + CheckResult result = check(R"( +local t: {[string]: () -> number} = {} + +function t.a() return 1 end -- OK +function t:b() return 2 end -- not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '(*error-type*) -> number' could not be converted into '() -> number' +caused by: + Argument count mismatch. Function expects 1 argument, but none are specified)", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") +{ + CheckResult result = check(R"( + function test(a: number, b: string, ...) + end + + test(1) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(2, acm->expected); + CHECK_EQ(1, acm->actual); + CHECK_EQ(CountMismatch::Context::Arg, acm->context); + CHECK(acm->isVariadic); +} + +TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic") +{ + CheckResult result = check(R"( +function test(a: number, b: string, ...) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +wrapper(test) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(3, acm->expected); + CHECK_EQ(1, acm->actual); + CHECK_EQ(CountMismatch::Context::Arg, acm->context); + CHECK(acm->isVariadic); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "too_few_arguments_variadic_generic2") +{ + CheckResult result = check(R"( +function test(a: number, b: string, ...) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +pcall(wrapper, test) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(4, acm->expected); + CHECK_EQ(2, acm->actual); + CHECK_EQ(CountMismatch::Context::Arg, acm->context); + CHECK(acm->isVariadic); +} + +TEST_CASE_FIXTURE(Fixture, "occurs_check_failure_in_function_return_type") +{ + CheckResult result = check(R"( + function f() + return 5, f() + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(nullptr != get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_unknown") +{ + CheckResult result = check(R"( + local function foo(f: (unknown) -> (), x) + f(x) + end + )"); + + CHECK_EQ("((unknown) -> (), a) -> ()", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_infer_parameter_types_for_functions_from_their_call_site") +{ + CheckResult result = check(R"( + local t = {} + + function t.f(x) + return x + end + + t.__index = t + + function g(s) + local q = s.p and s.p.q or nil + return q and t.f(q) or nil + end + + local f = t.f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(a) -> a", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_mutate_the_underlying_head_of_typepack_when_calling_with_self") +{ + CheckResult result = check(R"( + local t = {} + function t:m(x) end + function f(): never return 5 :: never end + t:m(f()) + t:m(f()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "improved_function_arg_mismatch_errors") +{ + CheckResult result = check(R"( +local function foo1(a: number) end +foo1() + +local function foo2(a: number, b: string?) end +foo2() + +local function foo3(a: number, b: string?, c: any) end -- any is optional +foo3() + +string.find() + +local t = {} +function t.foo(x: number, y: string?, ...: any) return 1 end +function t:bar(x: number, y: string?) end +t.foo() + +t:bar() + +local u = { a = t, b = function() return t end } +u.a.foo() +local x = (u.a).foo() + +u.b().foo() + )"); + + LUAU_REQUIRE_ERROR_COUNT(9, result); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo1' expects 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function 'foo2' expects 1 to 2 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[2]), "Argument count mismatch. Function 'foo3' expects 1 to 3 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[3]), "Argument count mismatch. Function 'string.find' expects 2 to 4 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[4]), "Argument count mismatch. Function 't.foo' expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[5]), "Argument count mismatch. Function 't.bar' expects 2 to 3 arguments, but only 1 is specified"); + CHECK_EQ(toString(result.errors[6]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[7]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[8]), "Argument count mismatch. Function expects at least 1 argument, but none are specified"); +} + +// This might be surprising, but since 'any' became optional, unannotated functions in non-strict 'expect' 0 arguments +TEST_CASE_FIXTURE(BuiltinsFixture, "improved_function_arg_mismatch_error_nonstrict") +{ + CheckResult result = check(R"( +--!nonstrict +local function foo(a, b) end +foo(string.find("hello", "e")) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo' expects 0 to 2 arguments, but 3 are specified"); +} + +TEST_CASE_FIXTURE(Fixture, "luau_subtyping_is_np_hard") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + {"LuauOverloadedFunctionSubtypingPerf", true}, + }; + + CheckResult result = check(R"( +--!strict + +-- An example of coding up graph coloring in the Luau type system. +-- This codes a three-node, two color problem. +-- A three-node triangle is uncolorable, +-- but a three-node line is colorable. + +type Red = "red" +type Blue = "blue" +type Color = Red | Blue +type Coloring = (Color) -> (Color) -> (Color) -> boolean +type Uncolorable = (Color) -> (Color) -> (Color) -> false + +type Line = Coloring + & ((Red) -> (Red) -> (Color) -> false) + & ((Blue) -> (Blue) -> (Color) -> false) + & ((Color) -> (Red) -> (Red) -> false) + & ((Color) -> (Blue) -> (Blue) -> false) + +type Triangle = Line + & ((Red) -> (Color) -> (Red) -> false) + & ((Blue) -> (Color) -> (Blue) -> false) + +local x : Triangle +local y : Line +local z : Uncolorable +z = x -- OK, so the triangle is uncolorable +z = y -- Not OK, so the line is colorable + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + "Type '((\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> boolean) & ((\"blue\" | \"red\") -> (\"blue\") -> (\"blue\") " + "-> false) & ((\"blue\" | \"red\") -> (\"red\") -> (\"red\") -> false) & ((\"blue\") -> (\"blue\") -> (\"blue\" | \"red\") -> false) & " + "((\"red\") -> (\"red\") -> (\"blue\" | \"red\") -> false)' could not be converted into '(\"blue\" | \"red\") -> (\"blue\" | \"red\") -> " + "(\"blue\" | \"red\") -> false'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions") +{ + ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; + registerHiddenTypes(*this, frontend.globalTypes); + + CheckResult result = check(R"( + function foo(f: fun) end + + function a() end + function id(x) return x end + + foo(a) + foo(id) + foo(foo) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "concrete_functions_are_not_supertypes_of_function") +{ + ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; + registerHiddenTypes(*this, frontend.globalTypes); + + CheckResult result = check(R"( + local a: fun = function() end + + function one(arg: () -> ()) end + function two(arg: (T) -> T) end + + one(a) + two(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(6 == result.errors[0].location.begin.line); + CHECK(7 == result.errors[1].location.begin.line); +} + +TEST_CASE_FIXTURE(Fixture, "other_things_are_not_related_to_function") +{ + ScopedFastFlag sff{"LuauNegatedFunctionTypes", true}; + registerHiddenTypes(*this, frontend.globalTypes); + + CheckResult result = check(R"( + local a: fun = function() end + local b: {} = a + local c: boolean = a + local d: fun = true + local e: fun = {} + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + + CHECK(2 == result.errors[0].location.begin.line); + CHECK(3 == result.errors[1].location.begin.line); + CHECK(4 == result.errors[2].location.begin.line); + CHECK(5 == result.errors[3].location.begin.line); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_must_follow_in_overload_resolution") +{ + ScopedFastFlag luauTypeInferMissingFollows{"LuauTypeInferMissingFollows", true}; + + CheckResult result = check(R"( +for _ in function():(t0)&((()->())&(()->())) +end do +_(_(_,_,_),_) +end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 3a04a18f..c25f8e5f 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -1,20 +1,23 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" +#include "Luau/Scope.h" + +#include #include "Fixture.h" #include "doctest.h" +LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) + using namespace Luau; TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function id(x:a): a return x @@ -27,8 +30,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_function") TEST_CASE_FIXTURE(Fixture, "check_generic_local_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a): a return x @@ -41,10 +42,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_local_function") TEST_CASE_FIXTURE(Fixture, "check_generic_typepack_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauGenericVariadicsUnification", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function id(...: a...): (a...) return ... end local x: string, y: boolean = id("hi", true) @@ -56,8 +53,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_typepack_function") TEST_CASE_FIXTURE(Fixture, "types_before_typepacks") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -66,8 +61,6 @@ TEST_CASE_FIXTURE(Fixture, "types_before_typepacks") TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a):a return x end local f: (a)->a = id @@ -77,9 +70,8 @@ TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") +TEST_CASE_FIXTURE(BuiltinsFixture, "inferred_local_vars_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x end print("This is bogus") -- TODO: CLI-39916 @@ -90,9 +82,8 @@ TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") +TEST_CASE_FIXTURE(BuiltinsFixture, "local_vars_can_be_instantiated_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x end print("This is bogus") -- TODO: CLI-39916 @@ -104,8 +95,6 @@ TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") TEST_CASE_FIXTURE(Fixture, "properties_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local t = {} t.m = function(x: a):a return x end @@ -117,8 +106,6 @@ TEST_CASE_FIXTURE(Fixture, "properties_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "properties_can_be_instantiated_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local t: { m: (number)->number } = { m = function(x:number) return x+1 end } local function id(x:a):a return x end @@ -129,8 +116,6 @@ TEST_CASE_FIXTURE(Fixture, "properties_can_be_instantiated_polytypes") TEST_CASE_FIXTURE(Fixture, "check_nested_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function f() local function id(x:a): a @@ -145,8 +130,6 @@ TEST_CASE_FIXTURE(Fixture, "check_nested_generic_function") TEST_CASE_FIXTURE(Fixture, "check_recursive_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a):a local y: string = id("hi") @@ -159,8 +142,6 @@ TEST_CASE_FIXTURE(Fixture, "check_recursive_generic_function") TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local id2 local function id1(x:a):a @@ -179,8 +160,6 @@ TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions") TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( type T = { id: (a) -> a } local x: T = { id = function(x:a):a return x end } @@ -192,8 +171,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types") TEST_CASE_FIXTURE(Fixture, "generic_factories") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( type T = { id: (a) -> a } type Factory = { build: () -> T } @@ -215,10 +192,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_factories") TEST_CASE_FIXTURE(Fixture, "factories_of_generics") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; - CheckResult result = check(R"( type T = { id: (a) -> a } type Factory = { build: () -> T } @@ -241,7 +214,6 @@ TEST_CASE_FIXTURE(Fixture, "factories_of_generics") TEST_CASE_FIXTURE(Fixture, "infer_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function id(x) return x @@ -255,17 +227,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") const FunctionTypeVar* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); - auto [rets, varrets] = flatten(idFun->retType); + auto [rets, varrets] = flatten(idFun->retTypes); CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); - CHECK_EQ(args[0], idFun->generics[0]); - CHECK_EQ(rets[0], idFun->generics[0]); + CHECK_EQ(follow(args[0]), follow(idFun->generics[0])); + CHECK_EQ(follow(rets[0]), follow(idFun->generics[0])); } TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x @@ -279,17 +250,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") const FunctionTypeVar* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); - auto [rets, varrets] = flatten(idFun->retType); + auto [rets, varrets] = flatten(idFun->retTypes); CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); - CHECK_EQ(args[0], idFun->generics[0]); - CHECK_EQ(rets[0], idFun->generics[0]); + CHECK_EQ(follow(args[0]), follow(idFun->generics[0])); + CHECK_EQ(follow(rets[0]), follow(idFun->generics[0])); } TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f() local function id(x) @@ -304,19 +274,20 @@ TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + CheckResult result = check(R"( local x = {} function x:id(x) return x end function x:f(): string return self:id("hello") end function x:g(): number return self:id(37) end )"); - LUAU_REQUIRE_NO_ERRORS(result); + // TODO: Quantification should be doing the conversion, not normalization. + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local x = {} function x:id(x) return x end @@ -326,13 +297,11 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") end )"); // TODO: Should typecheck but currently errors CLI-39916 - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_generic_property") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauRankNTypes", true}; CheckResult result = check(R"( local t = {} t.m = function(x) return x end @@ -344,9 +313,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_property") TEST_CASE_FIXTURE(Fixture, "function_arguments_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function f(g: (a)->a) local x: number = g(37) @@ -358,9 +324,6 @@ TEST_CASE_FIXTURE(Fixture, "function_arguments_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "function_results_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function f() : (a)->a local function id(x:a):a return x end @@ -372,9 +335,6 @@ TEST_CASE_FIXTURE(Fixture, "function_results_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "type_parameters_can_be_polytypes") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function id(x:a):a return x end local f: (a)->a = id(id) @@ -384,7 +344,6 @@ TEST_CASE_FIXTURE(Fixture, "type_parameters_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f(y) -- this will only typecheck if we infer z: any @@ -401,12 +360,11 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") -- so this assignment should fail local b: boolean = f(true) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f(y) local z = y @@ -418,17 +376,11 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") local y: number = id(37) end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") { - ScopedFastFlag sffs[] = { - {"LuauGenericFunctions", true}, - {"LuauParseGenericFunctions", true}, - {"LuauRankNTypes", true}, - }; - CheckResult result = check(R"( type T = { m: (a) -> T } function f(t : T) @@ -440,10 +392,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") TEST_CASE_FIXTURE(Fixture, "dont_unify_bound_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; - CheckResult result = check(R"( type F = () -> (a, b) -> a type G = (b, b) -> b @@ -470,7 +418,6 @@ TEST_CASE_FIXTURE(Fixture, "mutable_state_polymorphism") // Replaying the classic problem with polymorphism and mutable state in Luau // See, e.g. Tofte (1990) // https://www.sciencedirect.com/science/article/pii/089054019090018D. - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( --!strict -- Our old friend the polymorphic identity function @@ -508,7 +455,6 @@ TEST_CASE_FIXTURE(Fixture, "mutable_state_polymorphism") TEST_CASE_FIXTURE(Fixture, "rank_N_types_via_typeof") { - ScopedFastFlag sffs{"LuauGenericFunctions", false}; CheckResult result = check(R"( --!strict local function id(x) return x end @@ -531,8 +477,6 @@ TEST_CASE_FIXTURE(Fixture, "rank_N_types_via_typeof") TEST_CASE_FIXTURE(Fixture, "duplicate_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function f(x:a):a return x end )"); @@ -541,7 +485,6 @@ TEST_CASE_FIXTURE(Fixture, "duplicate_generic_types") TEST_CASE_FIXTURE(Fixture, "duplicate_generic_type_packs") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -550,7 +493,6 @@ TEST_CASE_FIXTURE(Fixture, "duplicate_generic_type_packs") TEST_CASE_FIXTURE(Fixture, "typepacks_before_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -559,9 +501,6 @@ TEST_CASE_FIXTURE(Fixture, "typepacks_before_types") TEST_CASE_FIXTURE(Fixture, "variadic_generics") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a) end @@ -573,9 +512,6 @@ TEST_CASE_FIXTURE(Fixture, "variadic_generics") TEST_CASE_FIXTURE(Fixture, "generic_type_pack_syntax") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a...): (a...) return ... end )"); @@ -586,10 +522,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_type_pack_syntax") TEST_CASE_FIXTURE(Fixture, "generic_type_pack_parentheses") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauGenericVariadicsUnification", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a...): any return (...) end )"); @@ -599,9 +531,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_type_pack_parentheses") TEST_CASE_FIXTURE(Fixture, "better_mismatch_error_messages") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: T...) return ... @@ -626,9 +555,6 @@ TEST_CASE_FIXTURE(Fixture, "better_mismatch_error_messages") TEST_CASE_FIXTURE(Fixture, "reject_clashing_generic_and_pack_names") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f() end )"); @@ -641,8 +567,6 @@ TEST_CASE_FIXTURE(Fixture, "reject_clashing_generic_and_pack_names") TEST_CASE_FIXTURE(Fixture, "instantiation_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( function f(z) local o = {} @@ -665,8 +589,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiation_sharing_types") TEST_CASE_FIXTURE(Fixture, "quantification_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( function f(x) return {5} end function g(x, y) return f(x) end @@ -680,8 +602,6 @@ TEST_CASE_FIXTURE(Fixture, "quantification_sharing_types") TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( type T = { x: {a}, y: {number} } local o1: T = { x = {true}, y = {5} } @@ -695,4 +615,668 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") CHECK(requireType("y1") == requireType("y2")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "bound_tables_do_not_clone_original_fields") +{ + CheckResult result = check(R"( +local exports = {} +local nested = {} + +nested.name = function(t, k) + local a = t.x.y + return rawget(t, k) +end + +exports.nested = nested +return exports + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "instantiated_function_argument_names") +{ + CheckResult result = check(R"( +local function f(a: T, ...: U...) end + +f(1, 2, 3) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto ty = findTypeAtPosition(Position(3, 0)); + REQUIRE(ty); + ToStringOptions opts; + opts.functionTypeArguments = true; + CHECK_EQ(toString(*ty, opts), "(a: number, number, number) -> ()"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_types") +{ + CheckResult result = check(R"( +type C = () -> () +type D = () -> () + +local c: C +local d: D = c + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type parameters)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_pack") +{ + CheckResult result = check(R"( +type C = () -> () +type D = () -> () + +local c: C +local d: D = c + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(toString(result.errors[0]), + R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_functions_dont_cache_type_parameters") +{ + CheckResult result = check(R"( +-- See https://github.com/Roblox/luau/issues/332 +-- This function has a type parameter with the same name as clones, +-- so if we cache type parameter names for functions these get confused. +-- function id(x : Z) : Z +function id(x : X) : X + return x +end + +function clone(dict: {[X]:Y}): {[X]:Y} + local copy = {} + for k, v in pairs(dict) do + copy[k] = v + end + return copy +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") +{ + CheckResult result = check(R"( +--!strict +-- At one point this produced a UAF +type T = { a: U, b: a } +type U = { c: T?, d : a } +local x: T = { a = { c = nil, d = 5 }, b = 37 } +x.a.c = x +local y: T = { a = { c = nil, d = 5 }, b = 37 } +y.a.c = y + )"); + + LUAU_REQUIRE_ERRORS(result); + if (FFlag::LuauTypeMismatchInvarianceInError) + { + CHECK_EQ(toString(result.errors[0]), + R"(Type 'y' could not be converted into 'T' +caused by: + Property 'a' is not compatible. Type '{ c: T?, d: number }' could not be converted into 'U' +caused by: + Property 'd' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + } + else + { + CHECK_EQ(toString(result.errors[0]), + R"(Type 'y' could not be converted into 'T' +caused by: + Property 'a' is not compatible. Type '{ c: T?, d: number }' could not be converted into 'U' +caused by: + Property 'd' is not compatible. Type 'number' could not be converted into 'string')"); + } +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification1") +{ + CheckResult result = check(R"( +--!strict +type Dispatcher = { + useMemo: (create: () -> T...) -> T... +} + +local TheDispatcher: Dispatcher = { + useMemo = function(create: () -> U...): U... + return create() + end +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification2") +{ + CheckResult result = check(R"( +--!strict +type Dispatcher = { + useMemo: (create: () -> T...) -> T... +} + +local TheDispatcher: Dispatcher = { + useMemo = function(create) + return create() + end +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification3") +{ + CheckResult result = check(R"( +--!strict +type Dispatcher = { + useMemo: (arg: S, create: (S) -> T...) -> T... +} + +local TheDispatcher: Dispatcher = { + useMemo = function(arg: T, create: (T) -> U...): U... + return create(arg) + end +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_few") +{ + CheckResult result = check(R"( +function test(a: number) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +wrapper(test) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function 'wrapper' expects 2 arguments, but only 1 is specified)"); +} + +TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_many") +{ + CheckResult result = check(R"( +function test2(a: number, b: string) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +wrapper(test2, 1, "", 3) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function 'wrapper' expects 3 arguments, but 4 are specified)"); +} + +TEST_CASE_FIXTURE(Fixture, "generic_function") +{ + CheckResult result = check(R"( + function id(x) return x end + local a = id(55) + local b = id(nil) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(a) -> a", toString(requireType("id"))); + CHECK_EQ(*typeChecker.numberType, *requireType("a")); + CHECK_EQ(*typeChecker.nilType, *requireType("b")); +} + +TEST_CASE_FIXTURE(Fixture, "generic_table_method") +{ + CheckResult result = check(R"( + local T = {} + + function T:bar(i) + return i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId tType = requireType("T"); + TableTypeVar* tTable = getMutable(tType); + REQUIRE(tTable != nullptr); + + REQUIRE(tTable->props.count("bar")); + TypeId barType = tTable->props["bar"].type; + REQUIRE(barType != nullptr); + + const FunctionTypeVar* ftv = get(follow(barType)); + REQUIRE_MESSAGE(ftv != nullptr, "Should be a function: " << *barType); + + std::vector args = flatten(ftv->argTypes).first; + TypeId argType = args.at(1); + + CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") +{ + CheckResult result = check(R"( + local T = {} + + function T:foo() + return T:bar(5) + end + + function T:bar(i) + return i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + const TableTypeVar* t = get(requireType("T")); + REQUIRE(t != nullptr); + + std::optional fooProp = get(t->props, "foo"); + REQUIRE(bool(fooProp)); + + const FunctionTypeVar* foo = get(follow(fooProp->type)); + REQUIRE(bool(foo)); + + std::optional ret_ = first(foo->retTypes); + REQUIRE(bool(ret_)); + TypeId ret = follow(*ret_); + + REQUIRE_EQ(getPrimitiveType(ret), PrimitiveTypeVar::Number); +} + +/* + * We had a bug in instantiation where the argument types of 'f' and 'g' would be inferred as + * f {+ method: function(): (t2, T3...) +} + * g {+ method: function({+ method: function(): (t2, T3...) +}): (t5, T6...) +} + * + * The type of 'g' is totally wrong as t2 and t5 should be unified, as should T3 with T6. + * + * The correct unification of the argument to 'g' is + * + * {+ method: function(): (t5, T6...) +} + */ +TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") +{ + auto result = check(R"( + function f(o) + o:method() + end + + function g(o) + f(o) + end + )"); + + TypeId g = requireType("g"); + const FunctionTypeVar* gFun = get(g); + REQUIRE(gFun != nullptr); + + auto optionArg = first(gFun->argTypes); + REQUIRE(bool(optionArg)); + + TypeId arg = follow(*optionArg); + const TableTypeVar* argTable = get(arg); + REQUIRE(argTable != nullptr); + + std::optional methodProp = get(argTable->props, "method"); + REQUIRE(bool(methodProp)); + + const FunctionTypeVar* methodFunction = get(methodProp->type); + REQUIRE(methodFunction != nullptr); + + std::optional methodArg = first(methodFunction->argTypes); + REQUIRE(bool(methodArg)); + + REQUIRE_EQ(follow(*methodArg), follow(arg)); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") +{ + CheckResult result = check(R"( + function foo(a, b) + return a(b) + end + + function bar() + local c: ((number)->number, number)->number = foo -- no error + c = foo -- no error + local d: ((number)->number, string)->number = foo -- error from arg 2 (string) not being convertable to number from the call a(b) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("((number) -> number, string) -> number", toString(tm->wantedType)); + if (FFlag::LuauInstantiateInSubtyping) + CHECK_EQ("((a) -> (b...), a) -> (b...)", toString(tm->givenType)); + else + CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") +{ + CheckResult result = check(R"( + function foo(a, b) + return a(b) + end + + function bar() + local _: (string, string)->number = foo -- string cannot be converted to (string)->number + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("(string, string) -> number", toString(tm->wantedType)); + if (FFlag::LuauInstantiateInSubtyping) + CHECK_EQ("((a) -> (b...), a) -> (b...)", toString(tm->givenType)); + else + CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") +{ + // Mutability in type function application right now can create strange recursive types + CheckResult result = check(R"( +type Table = { a: number } +type Self = T +local a: Self
+ )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a")), "Table"); +} + +TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") +{ + CheckResult result = check(R"( + function _(l0:t0): (any, ()->()) + end + + type t0 = t0 | {} + )"); + + LUAU_REQUIRE_ERRORS(result); + + std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); + REQUIRE(t0); + CHECK_EQ("*error-type*", toString(t0->type)); + + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { + return get(err); + }); + CHECK(it != result.errors.end()); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") +{ + CheckResult result = check(R"( + local function sum(x: a, y: a, f: (a, a) -> a) + return f(x, y) + end + return sum(2, 3, function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( + local function map(arr: {a}, f: (a) -> b) + local r = {} + for i,v in ipairs(arr) do + table.insert(r, f(v)) + end + return r + end + local a = {1, 2, 3} + local r = map(a, function(a) return a + a > 100 end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{boolean}", toString(requireType("r"))); + + check(R"( + local function foldl(arr: {a}, init: b, f: (b, a) -> b) + local r = init + for i,v in ipairs(arr) do + r = f(r, v) + end + return r + end + local a = {1, 2, 3} + local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") +{ + CheckResult result = check(R"( + local g12: ((T, (T) -> T) -> T) & ((T, T, (T, T) -> T) -> T) + + g12(1, function(x) return x + x end) + g12(1, 2, function(x, y) return x + y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( + local g12: ((T, (T) -> T) -> T) & ((T, T, (T, T) -> T) -> T) + + g12({x=1}, function(x) return {x=-x.x} end) + g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_lib_function_function_argument") +{ + CheckResult result = check(R"( +local a = {{x=4}, {x=7}, {x=1}} +table.sort(a, function(x, y) return x.x < y.x end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_infer_generic_functions") +{ + CheckResult result = check(R"( +local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end + +local function sumrec(f: typeof(sum)) + return sum(2, 3, function(a, b) return a + b end) +end + +local b = sumrec(sum) -- ok +local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") +{ + CheckResult result = check(R"( + type A = { x: number } + local a: A = { x = 1 } + local b = a + type B = typeof(b) + type X = T + local c: X + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics1") +{ + // https://github.com/Roblox/luau/issues/484 + CheckResult result = check(R"( +--!strict +type MyObject = { + getReturnValue: (cb: () -> V) -> V +} +local object: MyObject = { + getReturnValue = function(cb: () -> U): U + return cb() + end, +} + +type ComplexObject = { + id: T, + nested: MyObject +} + +local complex: ComplexObject = { + id = "Foo", + nested = object, +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics2") +{ + // https://github.com/Roblox/luau/issues/484 + CheckResult result = check(R"( +--!strict +type MyObject = { + getReturnValue: (cb: () -> V) -> V +} +type ComplexObject = { + id: T, + nested: MyObject +} + +local complex2: ComplexObject = nil + +local x = complex2.nested.getReturnValue(function(): string + return "" +end) + +local y = complex2.nested.getReturnValue(function() + return 3 +end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_functions_even_if_they_have_an_explicit_generic") +{ + CheckResult result = check(R"( + function foo(f, x: X) + return f(x) + end + )"); + + CHECK("((X) -> (a...), X) -> (a...)" == toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_always_instantiate_generic_intersection_types") +{ + ScopedFastFlag sff[] = { + {"LuauMaybeGenericIntersectionTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type Array = { [number]: T } + + type Array_Statics = { + new: () -> Array, + } + + local _Arr : Array & Array_Statics = {} :: Array_Statics + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "hof_subtype_instantiation_regression") +{ + CheckResult result = check(R"( +--!strict + +local function defaultSort(a: T, b: T) + return true +end +type A = any +return function(array: {T}): {T} + table.sort(array, defaultSort) + return array +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "higher_rank_polymorphism_should_not_accept_instantiated_arguments") +{ + ScopedFastFlag sffs[] = { + {"LuauInstantiateInSubtyping", true}, + }; + + CheckResult result = check(R"( +--!strict + +local function instantiate(f: (a) -> a): (number) -> number + return f +end + +instantiate(function(x: string) return "foo" end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto tm1 = get(result.errors[0]); + REQUIRE(tm1); + + CHECK_EQ("(a) -> a", toString(tm1->wantedType)); + CHECK_EQ("(string) -> string", toString(tm1->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_and_generalization_play_nice") +{ + CheckResult result = check(R"( + local foo = function(a) + return a() + end + + local a = foo(function() return 1 end) + local b = foo(function() return "bar" end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("number" == toString(requireType("a"))); + CHECK("string" == toString(requireType("b"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 9685f4f3..188be63c 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -9,6 +8,7 @@ using namespace Luau; + TEST_SUITE_BEGIN("IntersectionTypes"); TEST_CASE_FIXTURE(Fixture, "select_correct_union_fn") @@ -185,7 +185,15 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_dep )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("r")); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("string", toString(requireType("r"))); + } + else + { + CHECK_EQ("string & string", toString(requireType("r"))); + } } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_mixed_types") @@ -199,7 +207,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_mixed_types") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number & string", toString(requireType("r"))); // TODO(amccord): This should be an error. + CHECK_EQ("number & string", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_part_missing_the_property") @@ -219,7 +227,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_part_missing_ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_property_of_type_any") { CheckResult result = check(R"( - type A = {x: number} + type A = {y: number} type B = {x: any} local t: A & B @@ -311,10 +319,34 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") { CheckResult result = check(R"( - type X = { x: (number) -> number } - type Y = { y: (string) -> string } + type X = { x: (number) -> number } + type Y = { y: (string) -> string } - type XY = X & Y + type XY = X & Y + + local xy : XY = { + x = function(a: number) return -a end, + y = function(a: string) return a .. "b" end + } + function xy.z(a:number) return a * 10 end + function xy:y(a:number) return a * 10 end + function xy:w(a:number) return a * 10 end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); + CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); +} + +TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") +{ + // After normalization, previous 'table_intersection_write_sealed_indirect' is identical to this one + CheckResult result = check(R"( + type XY = { x: (number) -> number, y: (string) -> string } local xy : XY = { x = function(a: number) return -a end, @@ -325,13 +357,16 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") function xy:w(a:number) return a * 10 end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); - CHECK_EQ(toString(result.errors[1]), "Cannot add property 'y' to table 'X & Y'"); - CHECK_EQ(toString(result.errors[2]), "Cannot add property 'w' to table 'X & Y'"); + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'XY'"); + CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'XY'"); } -TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_intersection_setmetatable") { CheckResult result = check(R"( local t: {} & {} @@ -341,4 +376,569 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") +{ + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } +type XYZ = X & Y & Z +local a: XYZ = 3 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' +caused by: + Not all intersection parts are compatible. Type 'number' could not be converted into 'X')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") +{ + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } +type XYZ = X & Y & Z +local a: XYZ +local b: number = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); +} + +TEST_CASE_FIXTURE(Fixture, "overload_is_not_a_function") +{ + check(R"( +--!nonstrict +function _(...):((typeof(not _))&(typeof(not _)))&((typeof(not _))&(typeof(not _))) +_(...)(setfenv,_,not _,"")[_] = nil +end +do end +_(...)(...,setfenv,_):_G() +)"); +} + +TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") +{ + CheckResult result = check(R"( + local l0,l0 + repeat + type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0) + function _(l0):(t0)&(t0) + while nil do + end + end + until _(_)(_)._ + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "intersect_bool_and_false") +{ + CheckResult result = check(R"( + local x : (boolean & false) + local y : false = x -- OK + local z : true = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type 'boolean & false' could not be converted into 'true'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") +{ + CheckResult result = check(R"( + local x : false & (boolean & false) + local y : false = x -- OK + local z : true = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + // TODO: odd stringification of `false & (boolean & false)`.) + CHECK_EQ(toString(result.errors[0]), + "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number?) -> number?) & ((string?) -> string?) + local y : (nil) -> nil = x -- OK + local z : (number) -> number = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> number?) & ((string?) -> string?)' could not be converted into '(number) -> number'; " + "none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_saturate_overloaded_functions") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number) -> number) & ((string) -> string) + local y : ((number | string) -> (number | string)) = x -- OK + local z : ((number | boolean) -> (number | boolean)) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number) & ((string) -> string)' could not be converted into '(boolean | number) -> " + "boolean | number'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : { p : number?, q : string? } & { p : number?, q : number?, r : number? } + local y : { p : number?, q : nil, r : number? } = x -- OK + local z : { p : nil } = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into " + "'{| p: nil |}'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") +{ + CheckResult result = check(R"( + local x : { p : number?, q : any } & { p : unknown, q : string? } + local y : { p : number?, q : string? } = x -- OK + local z : { p : string?, q : number? } = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into '{| p: string?, " + "q: number? |}'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + {"LuauUninhabitedSubAnything2", true}, + }; + + CheckResult result = check(R"( + local x : { p : number?, q : never } & { p : never, q : string? } -- OK + local y : { p : never, q : never } = x -- OK + local z : never = x -- OK + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number?) -> ({ p : number } & { q : number })) & ((string?) -> ({ p : number } & { r : number })) + local y : (nil) -> { p : number, q : number, r : number} = x -- OK + local z : (number?) -> { p : number, q : number, r : number} = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into " + "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number?) -> (a | number)) & ((string?) -> (a | string)) + local y : (nil) -> a = x -- OK + local z : (number?) -> a = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> a | number) & ((string?) -> a | string)' could not be converted into '(number?) -> a'; " + "none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generics") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((a?) -> (a | b)) & ((c?) -> (b | c)) + local y : (nil) -> ((a & c) | b) = x -- OK + local z : (a?) -> ((a & c) | b) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + "Type '((a?) -> a | b) & ((c?) -> b | c)' could not be converted into '(a?) -> (a & c) | b'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic_packs") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...)) + local y : ((nil, a...) -> (nil, b...)) = x -- OK + local z : ((nil, b...) -> (nil, a...)) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))' could not be converted " + "into '(nil, b...) -> (nil, a...)'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_result") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number) & ((nil) -> unknown) + local y : (number?) -> unknown = x -- OK + local z : (number?) -> number? = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> unknown) & ((number) -> number)' could not be converted into '(number?) -> number?'; none " + "of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_arguments") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number?) & ((unknown) -> string?) + local y : (number) -> nil = x -- OK + local z : (number?) -> nil = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number?) & ((unknown) -> string?)' could not be converted into '(number?) -> nil'; none " + "of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_result") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number) & ((nil) -> never) + local y : (number?) -> number = x -- OK + local z : (number?) -> never = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> never) & ((number) -> number)' could not be converted into '(number?) -> never'; none of " + "the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_arguments") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number?) & ((never) -> string?) + local y : (never) -> nil = x -- OK + local z : (number?) -> nil = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((never) -> string?) & ((number) -> number?)' could not be converted into '(number?) -> nil'; none " + "of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_overlapping_results_and_variadics") +{ + CheckResult result = check(R"( + local x : ((string?) -> (string | number)) & ((number?) -> ...number) + local y : ((nil) -> (number, number?)) = x -- OK + local z : ((string | number) -> (number, number?)) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> (...number)) & ((string?) -> number | string)' could not be converted into '(number | " + "string) -> (number, number?)'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_1") +{ + CheckResult result = check(R"( + function f() + local x : (() -> a...) & (() -> b...) + local y : (() -> b...) & (() -> a...) = x -- OK + local z : () -> () = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + "Type '(() -> (a...)) & (() -> (b...))' could not be converted into '() -> ()'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_2") +{ + CheckResult result = check(R"( + function f() + local x : ((a...) -> ()) & ((b...) -> ()) + local y : ((b...) -> ()) & ((a...) -> ()) = x -- OK + local z : () -> () = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + "Type '((a...) -> ()) & ((b...) -> ())' could not be converted into '() -> ()'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_3") +{ + CheckResult result = check(R"( + function f() + local x : (() -> a...) & (() -> (number?,a...)) + local y : (() -> (number?,a...)) & (() -> a...) = x -- OK + local z : () -> (number) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + "Type '(() -> (a...)) & (() -> (number?, a...))' could not be converted into '() -> number'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_4") +{ + CheckResult result = check(R"( + function f() + local x : ((a...) -> ()) & ((number,a...) -> number) + local y : ((number,a...) -> number) & ((a...) -> ()) = x -- OK + local z : (number?) -> () = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((a...) -> ()) & ((number, a...) -> number)' could not be converted into '(number?) -> ()'; none of " + "the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local a : string? = nil + local b : number? = nil + + local x = setmetatable({}, { p = 5, q = a }); + local y = setmetatable({}, { q = b, r = "hi" }); + local z = setmetatable({}, { p = 5, q = nil, r = "hi" }); + + type X = typeof(x) + type Y = typeof(y) + type Z = typeof(z) + + local xy : X&Y = z; + local yx : Y&X = z; + z = xy; + z = yx; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_subtypes") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x = setmetatable({ a = 5 }, { p = 5 }); + local y = setmetatable({ b = "hi" }, { p = 5, q = "hi" }); + local z = setmetatable({ a = 5, b = "hi" }, { p = 5, q = "hi" }); + + type X = typeof(x) + type Y = typeof(y) + type Z = typeof(z) + + local xy : X&Y = z; + local yx : Y&X = z; + z = xy; + z = yx; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables_with_properties") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x = setmetatable({ a = 5 }, { p = 5 }); + local y = setmetatable({ b = "hi" }, { q = "hi" }); + local z = setmetatable({ a = 5, b = "hi" }, { p = 5, q = "hi" }); + + type X = typeof(x) + type Y = typeof(y) + type Z = typeof(z) + + local xy : X&Y = z; + z = xy; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_with_table") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x = setmetatable({ a = 5 }, { p = 5 }); + local z = setmetatable({ a = 5, b = "hi" }, { p = 5 }); + + type X = typeof(x) + type Y = { b : string } + type Z = typeof(z) + + -- TODO: once we have shape types, we should be able to initialize these with z + local xy : X&Y; + local yx : Y&X; + z = xy; + z = yx; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "CLI-44817") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + type X = {x: number} + type Y = {y: number} + type Z = {z: number} + + type XY = {x: number, y: number} + type XYZ = {x:number, y: number, z: number} + + local xy: XY = {x = 0, y = 0} + local xyz: XYZ = {x = 0, y = 0, z = 0} + + local xNy: X&Y = xy + local xNyNz: X&Y&Z = xyz + + local t1: XY = xNy -- Type 'X & Y' could not be converted into 'XY' + local t2: XY = xNyNz -- Type 'X & Y & Z' could not be converted into 'XY' + local t3: XYZ = xNyNz -- Type 'X & Y & Z' could not be converted into 'XYZ' + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function f(t): { x: number } & { x: string } + local x = t.x + return t + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({| x: number |} & {| x: string |}) -> {| x: number |} & {| x: string |}", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types_2") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function f(t: { x: number } & { x: string }) + return t.x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({| x: number |} & {| x: string |}) -> number & string", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp new file mode 100644 index 00000000..40912a95 --- /dev/null +++ b/tests/TypeInfer.loops.test.cpp @@ -0,0 +1,664 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + +TEST_SUITE_BEGIN("TypeInferLoops"); + +TEST_CASE_FIXTURE(Fixture, "for_loop") +{ + CheckResult result = check(R"( + local q + for i=0, 50, 2 do + q = i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("q")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") +{ + CheckResult result = check(R"( + local n + local s + for i, v in pairs({ "foo" }) do + n = i + s = v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*typeChecker.stringType, *requireType("s")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_next") +{ + CheckResult result = check(R"( + local n + local s + for i, v in next, { "foo", "bar" } do + n = i + s = v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*typeChecker.stringType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") +{ + CheckResult result = check(R"( + local it: any + local a, b + for i, v in it do + a, b = i, v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") +{ + CheckResult result = check(R"( + local foo = "bar" + for i, v in foo do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Cannot call non-function string", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_just_one_iterator_is_ok") +{ + CheckResult result = check(R"( + local function keys(dictionary) + local new = {} + local index = 1 + + for key in pairs(dictionary) do + new[index] = key + index = index + 1 + end + + return new + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators_dcr") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + function no_iter() end + for key in no_iter() do end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_a_custom_iterator_should_type_check") +{ + CheckResult result = check(R"( + local function range(l, h): () -> number + return function() + return l + end + end + + for n: string in range(1, 10) do + print(n) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") +{ + CheckResult result = check(R"( + function f(x) + gobble.prop = x.otherprop + end + + local p + for _, part in i_am_not_defined do + p = part + f(part) + part.thirdprop = false + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + TypeId p = requireType("p"); + CHECK_EQ("*error-type*", toString(p)); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") +{ + CheckResult result = check(R"( + local bad_iter = 5 + + for a in bad_iter() do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") +{ + CheckResult result = check(R"( + local function hasDivisors(value: number, table) + return false + end + + function prime_iter(state, index) + while hasDivisors(index, state) do + index += 1 + end + + state[index] = true + return index + end + + function primes1() + return prime_iter, {} + end + + function primes2() + return prime_iter, {}, "" + end + + function primes3() + return prime_iter, {}, 2 + end + + for p in primes1() do print(p) end -- mismatch in argument count + + for p in primes2() do print(p) end -- mismatch in argument types, prime_iter takes {}, number, we are given {}, string + + for p in primes3() do print(p) end -- no error + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Arg); + CHECK_EQ(2, acm->expected); + CHECK_EQ(1, acm->actual); + + TypeMismatch* tm = get(result.errors[1]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") +{ + CheckResult result = check(R"( + function prime_iter(state, index) + return 1 + end + + for p in prime_iter do print(p) end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Arg); + CHECK_EQ(2, acm->expected); + CHECK_EQ(0, acm->actual); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_incompatible_args_to_iterator") +{ + CheckResult result = check(R"( + function my_iter(state: string, index: number) + return state, index + end + + local my_state = {} + local first_index = "first" + + -- Type errors here. my_state and first_index cannot be passed to my_iter + for a, b in my_iter, my_state, first_index do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(get(result.errors[1])); + CHECK(Location{{9, 29}, {9, 37}} == result.errors[0].location); + + CHECK(get(result.errors[1])); + CHECK(Location{{9, 39}, {9, 50}} == result.errors[1].location); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") +{ + CheckResult result = check(R"( + function primes() + return function (state: number) end, 2 + end + + for p, q in primes do + q = "" + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "while_loop") +{ + CheckResult result = check(R"( + local i + while true do + i = 8 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("i")); +} + +TEST_CASE_FIXTURE(Fixture, "repeat_loop") +{ + CheckResult result = check(R"( + local i + repeat + i = 'hi' + until true + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.stringType, *requireType("i")); +} + +TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") +{ + CheckResult result = check(R"( + repeat + local x = true + until x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") +{ + CheckResult result = check(R"( + repeat + local x = true + until x + + print(x) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "varlist_declared_by_for_in_loop_should_be_free") +{ + CheckResult result = check(R"( + local T = {} + + function T.f(p) + for i, v in pairs(p) do + T.f(v) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "properly_infer_iteratee_is_a_free_table") +{ + // In this case, we cannot know the element type of the table {}. It could be anything. + // We therefore must initially ascribe a free typevar to iter. + CheckResult result = check(R"( + for iter in pairs({}) do + iter:g().p = true + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_while") +{ + CheckResult result = check(R"( + while true do + local a = 1 + end + + print(a) -- oops! + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownSymbol* us = get(result.errors[0]); + REQUIRE(us); + CHECK_EQ(us->name, "a"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_produces_integral_indices") +{ + CheckResult result = check(R"( + local key + for i, e in ipairs({}) do key = i end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("number", toString(requireType("key"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") +{ + // This code doesn't pass typechecking. We just care that it doesn't crash. + (void)check(R"( + --!nonstrict + function _:_(...) + end + + repeat + if _ then + else + _ = ... + end + until _ + + for _ in _() do + end + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "unreachable_code_after_infinite_loop") +{ + { + CheckResult result = check(R"( + function unreachablecodepath(a): number + while true do + if a then return 10 end + end + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } + + { + CheckResult result = check(R"( + function reachablecodepath(a): number + while true do + if a then break end + return 10 + end + + print("x") -- correct error + end + reachablecodepath(4) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK(get(result.errors[0])); + } + + { + CheckResult result = check(R"( + function unreachablecodepath(a): number + repeat + if a then return 10 end + until false + + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } + + { + CheckResult result = check(R"( + function reachablecodepath(a, b): number + repeat + if a then break end + + if b then return 10 end + until false + + print("x") -- correct error + end + reachablecodepath(4) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK(get(result.errors[0])); + } + + { + CheckResult result = check(R"( + function unreachablecodepath(a: number?): number + repeat + return 10 + until a ~= nil + + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional") +{ + CheckResult result = check(R"( + local t = {} + for _ in t do + for _ in assert(missing()) do + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") +{ + // Just check that this doesn't assert + check(R"( + --!nonstrict + function _(l0:number) + return _ + end + for _ in _(8) do + end + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_generic_next") +{ + CheckResult result = check(R"( + for k: number, v: number in next, {1, 2, 3} do + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") +{ + CheckResult result = check(R"( + local t: {string} = {} + local key + for k: number in t do + end + for k: number, v: string in t do + end + for k, v in t do + key = k + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.numberType, *requireType("key")); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") +{ + CheckResult result = check(R"( + local t: {string} = {} + local extra + for k, v, e in t do + extra = e + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + CHECK_EQ(*typeChecker.nilType, *requireType("extra")); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_strict") +{ + CheckResult result = check(R"( + local t = {} + for k, v in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("Cannot iterate over a table without indexer", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_nonstrict") +{ + CheckResult result = check(Mode::Nonstrict, R"( + local t = {} + for k, v in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_nil") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local t = setmetatable({}, { __iter = function(o) return next, nil end, }) + for k: number, v: string in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Type 'nil' could not be converted into '{- [a]: b -}'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_not_enough_returns") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local t = setmetatable({}, { __iter = function(o) end }) + for k: number, v: string in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{ + Location{{2, 36}, {2, 37}}, + GenericError{"__iter must return at least one value"}, + }); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_ok") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local t = setmetatable({ + children = {"foo"} + }, { __iter = function(o) return next, o.children end }) + for k: number, v: string in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_metamethod_ok_with_inference") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local t = setmetatable({ + children = {"foo"} + }, { __iter = function(o) return next, o.children end }) + + local a, b + for k, v in t do + a = k + b = v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("a")) == "number"); + CHECK(toString(requireType("b")) == "string"); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp new file mode 100644 index 00000000..b06c80e9 --- /dev/null +++ b/tests/TypeInfer.modules.test.cpp @@ -0,0 +1,487 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferModules"); + +TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_require_basic") +{ + fileResolver.source["game/A"] = R"( + --!strict + return { + a = 1, + } + )"; + + fileResolver.source["game/B"] = R"( + --!strict + local A = require(game.A) + + local b = A.a + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ModulePtr b = frontend.moduleResolver.modules["game/B"]; + REQUIRE(b != nullptr); + std::optional bType = requireType(b, "b"); + REQUIRE(bType); + CHECK(toString(*bType) == "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "require") +{ + fileResolver.source["game/A"] = R"( + local function hooty(x: number): string + return "Hi there!" + end + + return {hooty=hooty} + )"; + + fileResolver.source["game/B"] = R"( + local Hooty = require(game.A) + + local h -- free! + local i = Hooty.hooty(h) + )"; + + CheckResult aResult = frontend.check("game/A"); + dumpErrors(aResult); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = frontend.check("game/B"); + dumpErrors(bResult); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ModulePtr b = frontend.moduleResolver.modules["game/B"]; + + REQUIRE(b != nullptr); + + dumpErrors(bResult); + + std::optional iType = requireType(b, "i"); + REQUIRE_EQ("string", toString(*iType)); + + std::optional hType = requireType(b, "h"); + REQUIRE_EQ("number", toString(*hType)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "require_types") +{ + fileResolver.source["workspace/A"] = R"( + export type Point = {x: number, y: number} + + return {} + )"; + + fileResolver.source["workspace/B"] = R"( + local Hooty = require(workspace.A) + + local h: Hooty.Point + )"; + + CheckResult bResult = frontend.check("workspace/B"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ModulePtr b = frontend.moduleResolver.modules["workspace/B"]; + REQUIRE(b != nullptr); + + TypeId hType = requireType(b, "h"); + REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "require_a_variadic_function") +{ + fileResolver.source["game/A"] = R"( + local T = {} + function T.f(...) end + return T + )"; + + fileResolver.source["game/B"] = R"( + local A = require(game.A) + local f = A.f + )"; + + CheckResult result = frontend.check("game/B"); + + ModulePtr bModule = frontend.moduleResolver.getModule("game/B"); + REQUIRE(bModule != nullptr); + + TypeId f = follow(requireType(bModule, "f")); + + const FunctionTypeVar* ftv = get(f); + REQUIRE(ftv); + + auto iter = begin(ftv->argTypes); + auto endIter = end(ftv->argTypes); + + REQUIRE(iter == endIter); + REQUIRE(iter.tail()); + + CHECK(get(*iter.tail())); +} + +TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") +{ + CheckResult result = check(R"( + local p: SomeModule.DoesNotExist + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE_EQ(result.errors[0], (TypeError{Location{{1, 17}, {1, 40}}, UnknownSymbol{"SomeModule.DoesNotExist"}})); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "require_module_that_does_not_export") +{ + const std::string sourceA = R"( + )"; + + const std::string sourceB = R"( + local Hooty = require(script.Parent.A) + )"; + + fileResolver.source["game/Workspace/A"] = sourceA; + fileResolver.source["game/Workspace/B"] = sourceB; + + frontend.check("game/Workspace/A"); + frontend.check("game/Workspace/B"); + + ModulePtr aModule = frontend.moduleResolver.modules["game/Workspace/A"]; + ModulePtr bModule = frontend.moduleResolver.modules["game/Workspace/B"]; + + CHECK(aModule->errors.empty()); + REQUIRE_EQ(1, bModule->errors.size()); + CHECK_MESSAGE(get(bModule->errors[0]), "Should be IllegalRequire: " << toString(bModule->errors[0])); + + auto hootyType = requireType(bModule, "Hooty"); + + CHECK_EQ("*error-type*", toString(hootyType)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "warn_if_you_try_to_require_a_non_modulescript") +{ + fileResolver.source["Modules/A"] = ""; + fileResolver.sourceTypes["Modules/A"] = SourceCode::Local; + + fileResolver.source["Modules/B"] = R"( + local M = require(script.Parent.A) + )"; + + CheckResult result = frontend.check("Modules/B"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_call_expression") +{ + fileResolver.source["game/A"] = R"( +--!strict +return { def = 4 } + )"; + + fileResolver.source["game/B"] = R"( +--!strict +local tbl = { abc = require(game.A) } +local a : string = "" +a = tbl.abc.def + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_type_mismatch") +{ + fileResolver.source["game/A"] = R"( +return { def = 4 } + )"; + + fileResolver.source["game/B"] = R"( +local tbl: string = require(game.A) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "bound_free_table_export_is_ok") +{ + CheckResult result = check(R"( +local n = {} +function n:Clone() end + +local m = {} + +function m.a(x) + x:Clone() +end + +function m.b() + m.a(n) +end + +return m +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global") +{ + CheckResult result = check(R"( +--!nonstrict +require = function(a) end + +local crash = require(game.A) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "require_failed_module") +{ + fileResolver.source["game/A"] = R"( +return unfortunately() + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(aResult); + + CheckResult result = check(R"( +local ModuleA = require(game.A) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional oty = requireType("ModuleA"); + + CHECK_EQ("*error-type*", toString(*oty)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types") +{ + fileResolver.source["game/A"] = R"( +export type Type = { unrelated: boolean } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = {} +function x:Destroy(): () end + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_2") +{ + fileResolver.source["game/A"] = R"( +export type Type = { x: { a: number } } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = { x = { a = 2 } } +type Rename = typeof(x.x) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_3") +{ + fileResolver.source["game/A"] = R"( +local y = setmetatable({}, {}) +export type Type = { x: typeof(y) } +return { x = y } + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = types +type Rename = typeof(x.x) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_4") +{ + fileResolver.source["game/A"] = R"( +export type Array = {T} +local arrayops = {} +function arrayops.foo(x: Array) end +return arrayops + )"; + + CheckResult result = check(R"( +local arrayops = require(game.A) + +local tbl = {} +tbl.a = 2 +function tbl:foo(b: number, c: number) + -- introduce BoundTypeVar to imported type + arrayops.foo(self._regions) +end +-- this alias decreases function type level and causes a demotion of its type +type Table = typeof(tbl) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_5") +{ + fileResolver.source["game/A"] = R"( +export type Type = {x: number, y: number} +local arrayops = {} +function arrayops.foo(x: Type) end +return arrayops + )"; + + CheckResult result = check(R"( +local arrayops = require(game.A) + +local tbl = {} +tbl.a = 2 +function tbl:foo(b: number, c: number) + -- introduce boundTo TableTypeVar to imported type + self.x.a = 2 + arrayops.foo(self.x) +end +-- this alias decreases function type level and causes a demotion of its type +type Table = typeof(tbl) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict") +{ + fileResolver.source["game/A"] = R"( +export type T = { x: number } +return {} + )"; + + fileResolver.source["game/B"] = R"( +export type T = { x: string } +return {} + )"; + + fileResolver.source["game/C"] = R"( +local A = require(game.A) +local B = require(game.B) +local a: A.T = { x = 2 } +local b: B.T = a + )"; + + CheckResult result = frontend.check("game/C"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' +caused by: + Property 'x' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' +caused by: + Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict_instantiated") +{ + fileResolver.source["game/A"] = R"( +export type Wrap = { x: T } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local A = require(game.A) +export type T = A.Wrap +return {} + )"; + + fileResolver.source["game/C"] = R"( +local A = require(game.A) +export type T = A.Wrap +return {} + )"; + + fileResolver.source["game/D"] = R"( +local A = require(game.B) +local B = require(game.C) +local a: A.T = { x = 2 } +local b: B.T = a + )"; + + CheckResult result = frontend.check("game/D"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' +caused by: + Property 'x' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' +caused by: + Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "constrained_anyification_clone_immutable_types") +{ + fileResolver.source["game/A"] = R"( +return function(...) end + )"; + + fileResolver.source["game/B"] = R"( +local l0 = require(game.A) +return l0 + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_anyify_variadic_return_must_follow") +{ + ScopedFastFlag luauTypeInferMissingFollows{"LuauTypeInferMissingFollows", true}; + + CheckResult result = check(R"( +return unpack(l0[_]) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.negations.test.cpp b/tests/TypeInfer.negations.test.cpp new file mode 100644 index 00000000..0e7fb03d --- /dev/null +++ b/tests/TypeInfer.negations.test.cpp @@ -0,0 +1,51 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/Common.h" +#include "ScopedFlags.h" + +using namespace Luau; + +namespace +{ +struct NegationFixture : Fixture +{ + TypeArena arena; + ScopedFastFlag sff[1]{ + {"LuauSubtypeNormalizer", true}, + }; + + NegationFixture() + { + registerHiddenTypes(*this, arena); + } +}; +} // namespace + +TEST_SUITE_BEGIN("Negations"); + +TEST_CASE_FIXTURE(NegationFixture, "negated_string_is_a_subtype_of_string") +{ + CheckResult result = check(R"( + function foo(arg: string) end + local a: string & Not<"Hello"> + foo(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(NegationFixture, "string_is_not_a_subtype_of_negated_string") +{ + CheckResult result = check(R"( + function foo(arg: string & Not<"hello">) end + local a: string + foo(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp new file mode 100644 index 00000000..41690704 --- /dev/null +++ b/tests/TypeInfer.oop.test.cpp @@ -0,0 +1,275 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferOOP"); + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon") +{ + CheckResult result = check(R"( + local someTable = {} + + someTable.Function1 = function(Arg1) + end + + someTable.Function1() -- Argument count mismatch + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") +{ + CheckResult result = check(R"( + local someTable = {} + + someTable.Function2 = function(Arg1, Arg2) + end + + someTable.Function2() -- Argument count mismatch + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works") +{ + CheckResult result = check(R"( + type T = {method: ((T, number) -> number) & ((number) -> number)} + local T: T + + T.method(4) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "method_depends_on_table") +{ + CheckResult result = check(R"( + -- This catches a bug where x:m didn't count as a use of x + -- so toposort would happily reorder a definition of + -- function x:m before the definition of x. + function g() f() end + local x = {} + function x:m() end + function f() x:m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "methods_are_topologically_sorted") +{ + CheckResult result = check(R"( + local T = {} + + function T:foo() + return T:bar(999), T:bar("hi") + end + + function T:bar(i) + return i + end + + local a, b = T:foo() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("a"))); + CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_methods_defined_using_dot_syntax_and_explicit_self_parameter") +{ + check(R"( + local T = {} + + function T.method(self) + self:method() + end + + function T.method2(self) + self:method() + end + + T:method2() + )"); +} + +TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocate_memory") +{ + CheckResult result = check(R"( + ("foo") + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + )"); + + ModulePtr module = getMainModule(); + CHECK_GE(50, module->internalTypes.typeVars.size()); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "object_constructor_can_refer_to_method_of_self") +{ + // CLI-30902 + CheckResult result = check(R"( + --!strict + + type Foo = { + fooConn: () -> () | nil + } + + local Foo = {} + Foo.__index = Foo + + function Foo.new() + local self: Foo = { + fooConn = nil, + } + setmetatable(self, Foo) + + self.fooConn = function() + self:method() -- Key 'method' not found in table self + end + + return self + end + + function Foo:method() + print("foo") + end + + local foo = Foo.new() + + -- TODO This is the best our current refinement support can offer :( + local bar = foo.fooConn + if bar then bar() end + + -- foo.fooConn() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfSealed") +{ + CheckResult result = check(R"( +local x: {prop: number} = {prop=9999} +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") +{ + CheckResult result = check(R"( + --!nonstrict + local f = {} + function f:foo(a: number, b: number) end + + function bar(...) + f.foo(f, 1, ...) + end + + bar(2) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table") +{ + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + + Base64FileReader() + + function ReadMidiEvents(data) + + local reader = Base64FileReader(data) + + while reader:HasMore() do + (reader:Byte() % 128) + end + end + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_oop") +{ + CheckResult result = check(R"( + --!strict +local Class = {} +Class.__index = Class + +type Class = typeof(setmetatable({} :: { x: number }, Class)) + +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end + +function Class.getx(self: Class) + return self.x +end + +function test() + local c = Class.new(42) + local n = c:getx() + local nn = c.x + + print(string.format("%d %d", n, nn)) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp new file mode 100644 index 00000000..93d7361b --- /dev/null +++ b/tests/TypeInfer.operators.test.cpp @@ -0,0 +1,1136 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" +#include "ClassFixture.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + +TEST_SUITE_BEGIN("TypeInferOperators"); + +TEST_CASE_FIXTURE(Fixture, "or_joins_types") +{ + CheckResult result = check(R"( + local s = "a" or 10 + local x:string|number = s + )"); + LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(*requireType("s")), "(string & ~(false?)) | number"); + CHECK_EQ(toString(*requireType("x")), "number | string"); + } + else + { + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("x")), "number | string"); + } +} + +TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") +{ + CheckResult result = check(R"( + local s = "a" or 10 + local x:number|string = s + local y = x or "s" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(*requireType("s")), "(string & ~(false?)) | number"); + CHECK_EQ(toString(*requireType("y")), "((number | string) & ~(false?)) | string"); + } + else + { + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("y")), "number | string"); + } +} + +TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") +{ + CheckResult result = check(R"( + local s = "a" or "b" + local x:string = s + )"); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*requireType("s"), *typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "and_does_not_always_add_boolean") +{ + ScopedFastFlag sff[]{ + {"LuauTryhardAnd", true}, + }; + + CheckResult result = check(R"( + local s = "a" and 10 + local x:boolean|number = s + )"); + LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(*requireType("s")), "((false?) & string) | number"); + } + else + { + CHECK_EQ(toString(*requireType("s")), "number"); + } +} + +TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") +{ + CheckResult result = check(R"( + local s = "a" and true + local x:boolean = s + )"); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*requireType("x"), *typeChecker.booleanType); +} + +TEST_CASE_FIXTURE(Fixture, "and_or_ternary") +{ + CheckResult result = check(R"( + local s = (1/2) > 0.5 and "a" or 10 + )"); + LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(toString(*requireType("s")), "((((false?) & boolean) | string) & ~(false?)) | number"); + } + else + { + CHECK_EQ(toString(*requireType("s")), "number | string"); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") +{ + CheckResult result = check(R"( + function add(a: number, b: string) + return a + (tonumber(b) :: number), tostring(a) .. b + end + local n, s = add(2,"3") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* functionType = get(requireType("add")); + + std::optional retType = first(functionType->retTypes); + REQUIRE(retType.has_value()); + CHECK_EQ(typeChecker.numberType, follow(*retType)); + CHECK_EQ(requireType("n"), typeChecker.numberType); + CHECK_EQ(requireType("s"), typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") +{ + CheckResult result = check(R"( + local PI=3.1415926535897931 + local SOLAR_MASS=4*PI * PI + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(requireType("SOLAR_MASS"), typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "primitive_arith_possible_metatable") +{ + CheckResult result = check(R"( + function add(a: number, b: any) + return a + b + end + local t = add(1,2) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") +{ + CheckResult result = check(R"( + local a = 4 + 8 + local b = a + 9 + local s = 'hotdogs' + local t = s .. s + local c = b - a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("string", toString(requireType("s"))); + CHECK_EQ("string", toString(requireType("t"))); + CHECK_EQ("number", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_intersection") +{ + CheckResult result = check(R"( + --!strict + local Vec3 = {} + Vec3.__index = Vec3 + function Vec3.new() + return setmetatable({x=0, y=0, z=0}, Vec3) + end + + export type Vec3 = typeof(Vec3.new()) + + local thefun: any = function(self, o) return self end + + local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun + + Vec3.__mul = multiply + + local a = Vec3.new() + local b = Vec3.new() + local c = a * b + local d = a * 2 + local e = a * 'cabbage' + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Vec3", toString(requireType("a"))); + CHECK_EQ("Vec3", toString(requireType("b"))); + CHECK_EQ("Vec3", toString(requireType("c"))); + CHECK_EQ("Vec3", toString(requireType("d"))); + CHECK_EQ("Vec3", toString(requireType("e"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") +{ + CheckResult result = check(R"( + --!strict + local Vec3 = {} + Vec3.__index = Vec3 + function Vec3.new() + return setmetatable({x=0, y=0, z=0}, Vec3) + end + + export type Vec3 = typeof(Vec3.new()) + + local thefun: any = function(self, o) return self end + + local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun + + Vec3.__mul = multiply + + local a = Vec3.new() + local b = Vec3.new() + local c = b * a + local d = 2 * a + local e = 'cabbage' * a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Vec3", toString(requireType("a"))); + CHECK_EQ("Vec3", toString(requireType("b"))); + CHECK_EQ("Vec3", toString(requireType("c"))); + CHECK_EQ("Vec3", toString(requireType("d"))); + CHECK_EQ("Vec3", toString(requireType("e"))); +} + +TEST_CASE_FIXTURE(Fixture, "compare_numbers") +{ + CheckResult result = check(R"( + local a = 441 + local b = 0 + local c = a < b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "compare_strings") +{ + CheckResult result = check(R"( + local a = '441' + local b = '0' + local c = a < b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_metatable") +{ + CheckResult result = check(R"( + local a = {} + local b = {} + local c = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* gen = get(result.errors[0]); + + REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") +{ + CheckResult result = check(R"( + local M = {} + function M.new() + return setmetatable({}, M) + end + type M = typeof(M.new()) + + local a = M.new() + local b = M.new() + local c = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* gen = get(result.errors[0]); + REQUIRE(gen != nullptr); + REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") +{ + CheckResult result = check(R"( + --!strict + local M = {} + function M.new() + return setmetatable({}, M) + end + function M.__lt(left, right) return true end + + local a = M.new() + local b = {} + local c = a < b -- line 10 + local d = b < a -- line 11 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + REQUIRE_EQ((Location{{10, 18}, {10, 23}}), result.errors[0].location); + + REQUIRE_EQ((Location{{11, 18}, {11, 23}}), result.errors[1].location); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") +{ + CheckResult result = check(R"( + --!strict + local M = {} + function M.new() + return setmetatable({}, M) + end + function M.__lt(left, right) return true end + type M = typeof(M.new()) + + local a = M.new() + local b = {} + local c = a < b -- line 10 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = get(result.errors[0]); + REQUIRE(err != nullptr); + + // Frail. :| + REQUIRE_EQ("Types M and b cannot be compared with < because they do not have the same metatable", err->message); +} + +TEST_CASE_FIXTURE(Fixture, "in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators") +{ + CheckResult result = check(R"( + --!nonstrict + + function maybe_a_number(): number? + return 50 + end + + local a = maybe_a_number() < maybe_a_number() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_basic") +{ + CheckResult result = check(R"( + local s = 10 + s += 20 + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "number"); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_op") +{ + CheckResult result = check(R"( + local s = 10 + s += true + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.booleanType}})); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") +{ + CheckResult result = check(R"( + local s = 'hello' + s += 10 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") +{ + CheckResult result = check(R"( + --!strict + type V2B = { x: number, y: number } + local v2b: V2B = { x = 0, y = 0 } + local VMT = {} + type V2 = typeof(setmetatable(v2b, VMT)) + + function VMT.__add(a: V2, b: V2): V2 + return setmetatable({ x = a.x + b.x, y = a.y + b.y }, VMT) + end + + local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) + local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) + v1 += v2 + )"); + CHECK_EQ(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_mismatch_metatable") +{ + CheckResult result = check(R"( + --!strict + type V2B = { x: number, y: number } + local v2b: V2B = { x = 0, y = 0 } + local VMT = {} + type V2 = typeof(setmetatable(v2b, VMT)) + + function VMT.__mod(a: V2, b: V2): number + return a.x * b.x + a.y * b.y + end + + local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) + local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) + v1 %= v2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK("Type 'number' could not be converted into 'V2'" == toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") +{ + CheckResult result = check(R"( +function f() return 1; end +function g() return 2; end +(f or g)() +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "CallAndOrOfFunctions") +{ + CheckResult result = check(R"( +function f() return 1; end +function g() return 2; end +local x = false +(x and f or g)() +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") +{ + CheckResult result = check(R"( + --!strict + local foo + local mt = {} + + mt.__unm = function(val: typeof(foo)): string + return tostring(val.value) .. "test" + end + + foo = setmetatable({ + value = 10 + }, mt) + + local a = -foo + + local b = 1+-1 + + local bar = { + value = 10 + } + local c = -bar -- disallowed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("string", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK(toString(result.errors[0]) == "Type '{ value: number }' could not be converted into 'number'"); + } + else + { + GenericError* gen = get(result.errors[0]); + REQUIRE(gen); + REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") +{ + CheckResult result = check(R"( + --!strict + local mt = {} + + mt.__unm = function(val: boolean): string + return "test" + end + + local foo = setmetatable({ + value = 10 + }, mt) + + local a = -foo + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("string", toString(requireType("a"))); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *typeChecker.booleanType); + // given type is the typeof(foo) which is complex to compare against +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") +{ + CheckResult result = check(R"( + --!strict + local mt = {} + + mt.__len = function(val): string + return "test" + end + + local foo = setmetatable({ + value = 10, + }, mt) + + local a = #foo + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("number", toString(requireType("a"))); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); + REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "unary_not_is_boolean") +{ + CheckResult result = check(R"( + local b = not "string" + local c = not (math.random() > 0.5 and "string" or 7) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("boolean", toString(requireType("b"))); + REQUIRE_EQ("boolean", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") +{ + CheckResult result = check(R"( + --!strict + local a = "1.24" + 123 -- not allowed + + local foo = { + value = 10 + } + + local b = foo + 1 -- not allowed + + local bar = { + value = 1 + } + + local mt = {} + + setmetatable(bar, mt) + + mt.__add = function(a: typeof(bar), b: number): number + return a.value + b + end + + local c = bar + 1 -- allowed + + local d = bar + foo -- not allowed + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(*tm->wantedType, *typeChecker.numberType); + CHECK_EQ(*tm->givenType, *typeChecker.stringType); + + GenericError* gen1 = get(result.errors[1]); + REQUIRE(gen1); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(gen1->message, "Operator + is not applicable for '{ value: number }' and 'number' because neither type has a metatable"); + else + CHECK_EQ(gen1->message, "Binary operator '+' not supported by types 'foo' and 'number'"); + + TypeMismatch* tm2 = get(result.errors[2]); + REQUIRE(tm2); + CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); + CHECK_EQ(*tm2->givenType, *requireType("foo")); +} + +// CLI-29033 +TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") +{ + CheckResult result = check(R"( + function merge(lower, greater) + if lower.y == greater.y then + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "concat_op_on_free_lhs_and_string_rhs") +{ + CheckResult result = check(R"( + local function f(x) + return x .. "y" + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") +{ + CheckResult result = check(R"( + local function f(x) + return "foo" .. x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(string) -> string", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") +{ + std::vector ops = {"+", "-", "*", "/", "%", "^", ".."}; + + std::string src = "function foo(a, b)\n"; + + for (const auto& op : ops) + src += "local _ = a " + op + " b\n"; + + src += "end"; + + CheckResult result = check(src); + LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); + + CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") +{ + CheckResult result = check(R"( + --!strict + local t = {} + while true and t[1] do + print(t[1].test) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") +{ + CheckResult result = check(R"( + local a: boolean = true + local b: boolean = false + local foo = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Types 'boolean' and 'boolean' cannot be compared with relational operator <", ge->message); + else + CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2") +{ + CheckResult result = check(R"( + local a: number | string = "" + local b: number | string = 1 + local foo = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Types 'number | string' and 'number | string' cannot be compared with relational operator <", ge->message); + else + CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") +{ + CheckResult result = check(R"( + --!strict + local _ + _ += _ and _ or _ and _ or _ and _ + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "UnknownGlobalCompoundAssign") +{ + // In non-strict mode, global definition is still allowed + { + CheckResult result = check(R"( + --!nonstrict + a = a + 1 + print(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } + + // In strict mode we no longer generate two errors from lhs + { + CheckResult result = check(R"( + --!strict + a += 1 + print(a) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } + + // In non-strict mode, compound assignment is not a definition, it's a modification + { + CheckResult result = check(R"( + --!nonstrict + a += 1 + print(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator") +{ + CheckResult result = check(R"( +--!strict +local a: number? = nil +local b: number = a or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator2") +{ + CheckResult result = check(R"( +--!nonstrict +local a: number? = nil +local b: number = a or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_strip_nil_from_rhs_or_operator") +{ + CheckResult result = check(R"( +--!strict +local a: number? = nil +local b: number = 1 or a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("((number & ~(false?)) | number)?", toString(tm->givenType)); + } + else + { + CHECK_EQ("number?", toString(tm->givenType)); + } +} + +TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") +{ + CheckResult result = check(R"( + type Array = { [number]: T } + type Fiber = { id: number } + type null = {} + + local fiberStack: Array = {} + local index = 0 + + local function f(fiber: Fiber) + local a = fiber ~= fiberStack[index] + local b = fiberStack[index] ~= fiber + end + + return f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap") +{ + CheckResult result = check(R"( + local function f(a: string | number, b: boolean | number) + return a == b + end + )"); + + // This doesn't produce any errors but for the wrong reasons. + // This unit test serves as a reminder to not try and unify the operands on `==`/`~=`. + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") +{ + ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; + + CheckResult result = check(R"( + local a: string | number = "hi" + local b: {x: string}? = {x = "bye"} + + local r1 = a == b + local r2 = b == a + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "refine_and_or") +{ + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t and t.x or 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("((((false?) & ({| x: number? |}?)) | a) & ~(false?)) | number", toString(requireType("u"))); + } + else + { + CHECK_EQ("number", toString(requireType("u"))); + } +} + +TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") +{ + CheckResult result = check(Mode::Strict, R"( + local function f(x, y) + return x + y + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in + operation; consider adding a type annotation to 'x'"); + + result = check(Mode::Nonstrict, R"( + local function f(x, y) + return x + y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // When type inference is unified, we could add an assertion that + // the strict and nonstrict types are equivalent. This isn't actually + // the case right now, though. +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "equality_operations_succeed_if_any_union_branch_succeeds") +{ + CheckResult result = check(R"( + local mm = {} + type Foo = typeof(setmetatable({}, mm)) + local x: Foo + local y: Foo? + + local v1 = x == y + local v2 = y == x + local v3 = x ~= y + local v4 = y ~= x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CheckResult result2 = check(R"( + local mm1 = { + x = "foo", + } + + local mm2 = { + y = "bar", + } + + type Foo = typeof(setmetatable({}, mm1)) + type Bar = typeof(setmetatable({}, mm2)) + + local x1: Foo + local x2: Foo? + local y1: Bar + local y2: Bar? + + local v1 = x1 == y1 + local v2 = x2 == y2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result2); + CHECK(toString(result2.errors[0]) == "Types Foo and Bar cannot be compared with == because they do not have the same metatable"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_and") +{ + ScopedFastFlag sff{"LuauBinaryNeedsExpectedTypesToo", true}; + + CheckResult result = check(R"( + local x: "a" | "b" | boolean = math.random() > 0.5 and "a" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_or") +{ + ScopedFastFlag sff{"LuauBinaryNeedsExpectedTypesToo", true}; + + CheckResult result = check(R"( + local x: "a" | "b" | boolean = math.random() > 0.5 or "b" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "unrelated_classes_cannot_be_compared") +{ + ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; + + CheckResult result = check(R"( + local a = BaseClass.New() + local b = UnrelatedClass.New() + + local c = a == b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "unrelated_primitives_cannot_be_compared") +{ + ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; + + CheckResult result = check(R"( + local c = 5 == true + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "mm_ops_must_return_a_value") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local mm = { + __add = function(self, other) + return + end, + } + + local x = setmetatable({}, mm) + local y = x + 123 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(requireType("y") == singletonTypes->errorRecoveryType()); + + const GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK(ge->message == "Metamethod '__add' must return a value"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local mm1 = { + __lt = function(self, other) + return 123 + end, + } + + local mm2 = { + __lt = function(self, other) + return + end, + } + + local o1 = setmetatable({}, mm1) + local v1 = o1 < o1 + + local o2 = setmetatable({}, mm2) + local v2 = o2 < o2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(requireType("v1") == singletonTypes->booleanType); + CHECK(requireType("v2") == singletonTypes->booleanType); + + CHECK(toString(result.errors[0]) == "Metamethod '__lt' must return type 'boolean'"); + CHECK(toString(result.errors[1]) == "Metamethod '__lt' must return type 'boolean'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_and") +{ + ScopedFastFlag sff[]{ + {"LuauTryhardAnd", true}, + }; + + CheckResult result = check(R"( +local a: number? = 5 +local b: boolean = (a or 1) > 10 +local c -- free + +local x = a and 1 +local y = 'a' and 1 +local z = b and 1 +local w = c and 1 + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK("((false?) & (number?)) | number" == toString(requireType("x"))); + CHECK("((false?) & string) | number" == toString(requireType("y"))); + CHECK("((false?) & boolean) | number" == toString(requireType("z"))); + CHECK("((false?) & a) | number" == toString(requireType("w"))); + } + else + { + CHECK("number?" == toString(requireType("x"))); + CHECK("number" == toString(requireType("y"))); + CHECK("boolean | number" == toString(requireType("z"))); // 'false' widened to boolean + CHECK("(boolean | number)?" == toString(requireType("w"))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_or") +{ + ScopedFastFlag sff[]{ + {"LuauTryhardAnd", true}, + }; + + CheckResult result = check(R"( +local a: number | false = 5 +local b: number? = 6 +local c: boolean = true +local d: true = true +local e: false = false +local f: nil = false + +local a1 = a or 'a' +local b1 = b or 4 +local c1 = c or 'c' +local d1 = d or 'd' +local e1 = e or 'e' +local f1 = f or 'f' + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK("((false | number) & ~(false?)) | string" == toString(requireType("a1"))); + CHECK("((number?) & ~(false?)) | number" == toString(requireType("b1"))); + CHECK("(boolean & ~(false?)) | string" == toString(requireType("c1"))); + CHECK("(true & ~(false?)) | string" == toString(requireType("d1"))); + CHECK("(false & ~(false?)) | string" == toString(requireType("e1"))); + CHECK("(nil & ~(false?)) | string" == toString(requireType("f1"))); + } + else + { + CHECK("number | string" == toString(requireType("a1"))); + CHECK("number" == toString(requireType("b1"))); + CHECK("boolean | string" == toString(requireType("c1"))); // 'true' widened to boolean + CHECK("boolean | string" == toString(requireType("d1"))); // 'true' widened to boolean + CHECK("string" == toString(requireType("e1"))); + CHECK("string" == toString(requireType("f1"))); + } +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp new file mode 100644 index 00000000..3c2c8781 --- /dev/null +++ b/tests/TypeInfer.primitives.test.cpp @@ -0,0 +1,114 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferPrimitives"); + +TEST_CASE_FIXTURE(Fixture, "cannot_call_primitives") +{ + CheckResult result = check("local foo = 5 foo()"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0]) != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "string_length") +{ + CheckResult result = check(R"( + local s = "Hello, World!" + local t = #s + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "string_index") +{ + CheckResult result = check(R"( + local s = "Hello, World!" + local t = s[4] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + NotATable* nat = get(result.errors[0]); + REQUIRE(nat); + CHECK_EQ("string", toString(nat->ty)); + + CHECK_EQ("*error-type*", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "string_method") +{ + CheckResult result = check(R"( + local p = ("tacos"):len() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*requireType("p"), *typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "string_function_indirect") +{ + CheckResult result = check(R"( + local s:string + local l = s.lower + local p = l(s) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*requireType("p"), *typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "string_function_other") +{ + CheckResult result = check(R"( + local s:string + local p = s:match("foo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("p")), "string"); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") +{ + CheckResult result = check(R"( +local x: number = 9999 +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Cannot add method to non-table type 'number'"); + CHECK_EQ(toString(result.errors[1]), "Type 'number' could not be converted into 'string'"); +} + +TEST_CASE("singleton_types") +{ + BuiltinsFixture a; + + { + BuiltinsFixture b; + } + + // Check that Frontend 'a' environment wasn't modified by 'b' + CheckResult result = a.check("local s: string = 'hello' local t = s:lower()"); + + CHECK(result.errors.empty()); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 46496fdb..b7408f87 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Fixture.h" @@ -8,13 +7,13 @@ #include -LUAU_FASTFLAG(LuauEqConstraint) - using namespace Luau; +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) + TEST_SUITE_BEGIN("ProvisionalTests"); -// These tests check for behavior that differes from the final behavior we'd +// These tests check for behavior that differs from the final behavior we'd // like to have. They serve to document the current state of the typechecker. // When making future improvements, its very likely these tests will break and // will need to be replaced. @@ -41,18 +40,19 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") )"; const std::string expected = R"( - function f(a:{fn:()->(free)}): () + function f(a:{fn:()->(a,b...)}): () if type(a) == 'boolean'then local a1:boolean=a elseif a.fn()then - local a2:{fn:()->(free)}=a + local a2:{fn:()->(a,b...)}=a end end )"; + CHECK_EQ(expected, decorateWithTypes(code)); } -TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") +TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") { const std::string code = R"( local a, b, c = xpcall(function() return 1, "foo" end, function() return "foo", 1 end) @@ -62,7 +62,15 @@ TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") local a:boolean,b:number,c:string=xpcall(function(): (number,string)return 1,'foo'end,function(): (string,number)return'foo',1 end) )"; - CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + CHECK("boolean" == toString(requireType("a"))); + CHECK("number" == toString(requireType("b"))); + CHECK("string" == toString(requireType("c"))); + + CHECK(expected == decorateWithTypes(code)); + + LUAU_REQUIRE_NO_ERRORS(result); } // We had a bug where if you have two type packs that looks like: @@ -104,7 +112,7 @@ TEST_CASE_FIXTURE(Fixture, "it_should_be_agnostic_of_actual_size") // Ideally setmetatable's second argument would be an optional free table. // For now, infer it as just a free table. -TEST_CASE_FIXTURE(Fixture, "setmetatable_constrains_free_type_into_free_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_constrains_free_type_into_free_table") { CheckResult result = check(R"( local a = {} @@ -121,65 +129,9 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_constrains_free_type_into_free_table") CHECK_EQ("number", toString(tm->givenType)); } -TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") -{ - CheckResult result = check(R"( - local a: {x: number, y: number, [any]: any} | {y: number} - - function f(t) - t.y = 1 - return t - end - - local b = f(a) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // :( - // Should be the same as the type of a - REQUIRE_EQ("{| y: number |}", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") -{ - CheckResult result = check(R"( - local a: {y: number} | {x: number, y: number, [any]: any} - - function f(t) - t.y = 1 - return t - end - - local b = f(a) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // :( - // Should be the same as the type of a - REQUIRE_EQ("{| [any]: any, x: number, y: number |}", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "normal_conditional_expression_has_refinements") -{ - CheckResult result = check(R"( - local foo: {x: number}? = nil - local bar = foo and foo.x -- TODO: Geez. We are inferring the wrong types here. Should be 'number?'. - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Binary and/or return types are straight up wrong. JIRA: CLI-40300 - CHECK_EQ("boolean | number", toString(requireType("bar"))); -} - // Luau currently doesn't yet know how to allow assignments when the binding was refined. TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") { - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( type Node = { value: T, child: Node? } @@ -201,7 +153,7 @@ TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") // Originally from TypeInfer.test.cpp. // I dont think type checking the metamethod at every site of == is the correct thing to do. // We should be type checking the metamethod at the call site of setmetatable. -TEST_CASE_FIXTURE(Fixture, "error_on_eq_metamethod_returning_a_type_other_than_boolean") +TEST_CASE_FIXTURE(BuiltinsFixture, "error_on_eq_metamethod_returning_a_type_other_than_boolean") { CheckResult result = check(R"( local tab = {a = 1} @@ -220,34 +172,11 @@ TEST_CASE_FIXTURE(Fixture, "error_on_eq_metamethod_returning_a_type_other_than_b CHECK_EQ("Metamethod '__eq' must return type 'boolean'", ge->message); } -// Requires success typing to confidently determine that this expression has no overlap. -TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") -{ - CheckResult result = check(R"( - local a: string | number = "hi" - local b: {x: string}? = {x = "bye"} - - local r1 = a == b - local r2 = b == a - )"); - - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| x: string |}?' could not be converted into 'number | string'"); - CHECK_EQ(toString(result.errors[1]), "Type 'number | string' could not be converted into '{| x: string |}?'"); - } -} - // Belongs in TypeInfer.refinements.test.cpp. -// We'll need to not only report an error on `a == b`, but also to refine both operands as `never` in the `==` branch. +// We need refine both operands as `never` in the `==` branch. TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; + ScopedFastFlag sff{"LuauIntersectionTestForEquality", true}; CheckResult result = check(R"( local function f(a: string, b: boolean?) @@ -259,7 +188,7 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "string"); // a == b CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b @@ -268,10 +197,34 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b } -TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) +// Also belongs in TypeInfer.refinements.test.cpp. +// Just needs to fully support equality refinement. Which is annoying without type states. +TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") { - ScopedFastInt sffi{"LuauTarjanChildLimit", 50}; - ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 50}; + CheckResult result = check(R"( + type T = {x: string, y: number} | {x: nil, y: nil} + + local function f(t: T) + if t.x ~= nil then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{| x: string, y: number |}", toString(requireTypeAtPosition({5, 28}))); + + // Should be {| x: nil, y: nil |} + CHECK_EQ("{| x: nil, y: nil |} | {| x: string, y: number |}", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) +{ + ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; + ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1}; CheckResult result = check(R"LUA( local Result @@ -306,283 +259,551 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doct } } -TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doctest::timeout(1.0)) +// FIXME: Move this test to another source file when removing FFlag::LuauLowerBoundsCalculation +TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { - ScopedFastInt sffi{"LuauTarjanChildLimit", 400}; + ScopedFastFlag sff[]{ + {"LuauReturnAnyInsteadOfICE", true}, + }; - CheckResult result = check(R"LUA( + // In-place quantification causes these types to have the wrong types but only because of nasty interaction with prototyping. + // The type of f is initially () -> free1... + // Then the prototype iterator advances, and checks the function expression assigned to g, which has the type () -> free2... + // In the body it calls f and returns what f() returns. This binds free2... with free1..., causing f and g to have same types. + // We then quantify g, leaving it with the final type () -> a... + // Because free1... and free2... were bound, in combination with in-place quantification, f's return type was also turned into a... + // Then the check iterator catches up, and checks the body of f, and attempts to quantify it too. + // Alas, one of the requirements for quantification is that a type must contain free types. () -> a... has no free types. + // Thus the quantification for f was no-op, which explains why f does not have any type parameters. + // Calling f() will attempt to instantiate the function type, which turns generics in type binders into to free types. + // However, instantiations only converts generics contained within the type binders of a function, so instantiation was also no-op. + // Which means that calling f() simply returned a... rather than an instantiation of it. And since the call site was not in tail position, + // picking first element in a... triggers an ICE because calls returning generic packs are unexpected. + CheckResult result = check(R"( + local function f() end + + local g = function() return f() end + + local x = (f()) -- should error: no return values to assign from the call to f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // f and g should have the type () -> () + CHECK_EQ("() -> (a...)", toString(requireType("f"))); + CHECK_EQ("() -> (a...)", toString(requireType("g"))); + CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now +} + +TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") +{ + CheckResult result = check(R"( + local function id(x) return x end + local n2n: (number) -> number = id + local s2s: (string) -> string = id + )"); + + LUAU_REQUIRE_ERRORS(result); // Should not have any errors. +} + +TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") +{ + ScopedFastFlag sff[] = { + // I'm not sure why this is broken without DCR, but it seems to be fixed + // when DCR is enabled. + {"DebugLuauDeferredConstraintResolution", false}, + }; + + CheckResult result = check(R"( + local function f() return end + local g = function() return f() end + )"); + + LUAU_REQUIRE_ERRORS(result); // Should not have any errors. +} + +TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") +{ + ScopedFastFlag sff[] = { + // I'm not sure why this is broken without DCR, but it seems to be fixed + // when DCR is enabled. + {"DebugLuauDeferredConstraintResolution", false}, + }; + + CheckResult result = check(R"( --!strict - local TS = _G[script] - local lazyGet = TS.import(script, script.Parent.Parent, "util", "lazyLoad").lazyGet - local unit = TS.import(script, script.Parent.Parent, "util", "Unit").unit - local Iterator - lazyGet("Iterator", function(c) - Iterator = c - end) - local Option - lazyGet("Option", function(c) - Option = c - end) - local Vec - lazyGet("Vec", function(c) - Vec = c - end) - local Result - do - Result = setmetatable({}, { - __tostring = function() - return "Result" - end, - }) - Result.__index = Result - function Result.new(...) - local self = setmetatable({}, Result) - self:constructor(...) - return self - end - function Result:constructor(okValue, errValue) - self.okValue = okValue - self.errValue = errValue - end - function Result:ok(val) - return Result.new(val, nil) - end - function Result:err(val) - return Result.new(nil, val) - end - function Result:fromCallback(c) - local _0 = c - local _1, _2 = pcall(_0) - local result = _1 and { - success = true, - value = _2, - } or { - success = false, - error = _2, - } - return result.success and Result:ok(result.value) or Result:err(Option:wrap(result.error)) - end - function Result:fromVoidCallback(c) - local _0 = c - local _1, _2 = pcall(_0) - local result = _1 and { - success = true, - value = _2, - } or { - success = false, - error = _2, - } - return result.success and Result:ok(unit()) or Result:err(Option:wrap(result.error)) - end - Result.fromPromise = TS.async(function(self, p) - local _0, _1 = TS.try(function() - return TS.TRY_RETURN, { Result:ok(TS.await(p)) } - end, function(e) - return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } - end) - if _0 then - return unpack(_1) - end - end) - Result.fromVoidPromise = TS.async(function(self, p) - local _0, _1 = TS.try(function() - TS.await(p) - return TS.TRY_RETURN, { Result:ok(unit()) } - end, function(e) - return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } - end) - if _0 then - return unpack(_1) - end - end) - function Result:isOk() - return self.okValue ~= nil - end - function Result:isErr() - return self.errValue ~= nil - end - function Result:contains(x) - return self.okValue == x - end - function Result:containsErr(x) - return self.errValue == x - end - function Result:okOption() - return Option:wrap(self.okValue) - end - function Result:errOption() - return Option:wrap(self.errValue) - end - function Result:map(func) - return self:isOk() and Result:ok(func(self.okValue)) or Result:err(self.errValue) - end - function Result:mapOr(def, func) - local _0 - if self:isOk() then - _0 = func(self.okValue) - else - _0 = def - end - return _0 - end - function Result:mapOrElse(def, func) - local _0 - if self:isOk() then - _0 = func(self.okValue) - else - _0 = def(self.errValue) - end - return _0 - end - function Result:mapErr(func) - return self:isErr() and Result:err(func(self.errValue)) or Result:ok(self.okValue) - end - Result["and"] = function(self, other) - return self:isErr() and Result:err(self.errValue) or other - end - function Result:andThen(func) - return self:isErr() and Result:err(self.errValue) or func(self.okValue) - end - Result["or"] = function(self, other) - return self:isOk() and Result:ok(self.okValue) or other - end - function Result:orElse(other) - return self:isOk() and Result:ok(self.okValue) or other(self.errValue) - end - function Result:expect(msg) - if self:isOk() then - return self.okValue - else - error(msg) - end - end - function Result:unwrap() - return self:expect("called `Result.unwrap()` on an `Err` value: " .. tostring(self.errValue)) - end - function Result:unwrapOr(def) - local _0 - if self:isOk() then - _0 = self.okValue - else - _0 = def - end - return _0 - end - function Result:unwrapOrElse(gen) - local _0 - if self:isOk() then - _0 = self.okValue - else - _0 = gen(self.errValue) - end - return _0 - end - function Result:expectErr(msg) - if self:isErr() then - return self.errValue - else - error(msg) - end - end - function Result:unwrapErr() - return self:expectErr("called `Result.unwrapErr()` on an `Ok` value: " .. tostring(self.okValue)) - end - function Result:transpose() - return self:isOk() and self.okValue:map(function(some) - return Result:ok(some) - end) or Option:some(Result:err(self.errValue)) - end - function Result:flatten() - return self:isOk() and Result.new(self.okValue.okValue, self.okValue.errValue) or Result:err(self.errValue) - end - function Result:match(ifOk, ifErr) - local _0 - if self:isOk() then - _0 = ifOk(self.okValue) - else - _0 = ifErr(self.errValue) - end - return _0 - end - function Result:asPtr() - local _0 = (self.okValue) - if _0 == nil then - _0 = (self.errValue) - end - return _0 - end - end - local resultMeta = Result - resultMeta.__eq = function(a, b) - return b:match(function(ok) - return a:contains(ok) - end, function(err) - return a:containsErr(err) - end) - end - resultMeta.__tostring = function(result) - return result:match(function(ok) - return "Result.ok(" .. tostring(ok) .. ")" - end, function(err) - return "Result.err(" .. tostring(err) .. ")" - end) - end - return { - Result = Result, - } - )LUA"); + local function f(...) return ... end + local g = function(...) return f(...) end + )"); - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& a) { - return nullptr != get(a); - }); - if (it == result.errors.end()) + LUAU_REQUIRE_ERRORS(result); // Should not have any errors. +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(BuiltinsFixture, "pcall_returns_at_least_two_value_but_function_returns_nothing") +{ + ScopedFastFlag sff{"LuauBetterMessagingOnCountMismatch", true}; + + CheckResult result = check(R"( + local function f(): () end + local ok, res = pcall(f) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Function only returns 1 value, but 2 are required here", toString(result.errors[0])); + // LUAU_REQUIRE_NO_ERRORS(result); + // CHECK_EQ("boolean", toString(requireType("ok"))); + // CHECK_EQ("any", toString(requireType("res"))); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(BuiltinsFixture, "choose_the_right_overload_for_pcall") +{ + CheckResult result = check(R"( + local function f(): number + if math.random() > 0.5 then + return 5 + else + error("something") + end + end + + local ok, res = pcall(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("boolean", toString(requireType("ok"))); + CHECK_EQ("number", toString(requireType("res"))); + // CHECK_EQ("any", toString(requireType("res"))); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_many_things_but_first_of_it_is_forgotten") +{ + CheckResult result = check(R"( + local function f(): (number, string, boolean) + if math.random() > 0.5 then + return 5, "hello", true + else + error("something") + end + end + + local ok, res, s, b = pcall(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("boolean", toString(requireType("ok"))); + CHECK_EQ("number", toString(requireType("res"))); + // CHECK_EQ("any", toString(requireType("res"))); + CHECK_EQ("string", toString(requireType("s"))); + CHECK_EQ("boolean", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_any") +{ + CheckResult result = check(R"( + local function foo(f: (any) -> (), x) + f(x) + end + )"); + + CHECK_EQ("((any) -> (), any) -> ()", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_function_with_no_returns") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local T = {} + T.__index = T + + function T.new() + local self = setmetatable({}, T) + return self:ctor() or self + end + + function T:ctor() + -- oops, no return! + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Not all codepaths in this function return 'self, a...'.", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + }; + + CheckResult result = check(R"( + local function hasDivisors(value: number) + end + + function prime_iter(state, index) + hasDivisors(index) + index += 1 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Solving this requires recognizing that we can't dispatch a constraint + // like this without doing further work: + // + // (*blocked*) -> () <: (number) -> (b...) + // + // We solve this by searching both types for BlockedTypeVars and block the + // constraint on any we find. It also gets the job done, but I'm worried + // about the efficiency of doing so many deep type traversals and it may + // make us more prone to getting stuck on constraint cycles. + // + // If this doesn't pan out, a possible solution is to go further down the + // path of supporting partial constraint dispatch. The way it would work is + // that we'd dispatch the above constraint by binding b... to (), but we + // would append a new constraint number <: *blocked* to the constraint set + // to be solved later. This should be faster and theoretically less prone + // to cyclic constraint dependencies. + CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); +} + +TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") +{ + TypeArena arena; + TypeId nilType = singletonTypes->nilType; + + std::unique_ptr scope = std::make_unique(singletonTypes->anyTypePack); + + TypeId free1 = arena.addType(FreeTypePack{scope.get()}); + TypeId option1 = arena.addType(UnionTypeVar{{nilType, free1}}); + + TypeId free2 = arena.addType(FreeTypePack{scope.get()}); + TypeId option2 = arena.addType(UnionTypeVar{{nilType, free2}}); + + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant}; + + u.tryUnify(option1, option2); + + CHECK(u.errors.empty()); + + u.log.commit(); + + ToStringOptions opts; + CHECK("a?" == toString(option1, opts)); + + // CHECK("a?" == toString(option2, opts)); // This should hold, but does not. + CHECK("b?" == toString(option2, opts)); // This should not hold. +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", false}; + + CheckResult result = check(R"( + function no_iter() end + for key in no_iter() do end -- This should not be ok + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// Ideally, we would not try to export a function type with generic types from incorrect scope +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface") +{ + fileResolver.source["game/A"] = R"( +local wrapStrictTable + +local metatable = { + __index = function(self, key) + local value = self.__tbl[key] + if type(value) == "table" then + -- unification of the free 'wrapStrictTable' with this function type causes generics of this function to leak out of scope + return wrapStrictTable(value, self.__name .. "." .. key) + end + return value + end, +} + +return wrapStrictTable + )"; + + frontend.check("game/A"); + + fileResolver.source["game/B"] = R"( +local wrapStrictTable = require(game.A) + +local Constants = {} + +return wrapStrictTable(Constants, "Constants") + )"; + + frontend.check("game/B"); + + ModulePtr m = frontend.moduleResolver.modules["game/B"]; + REQUIRE(m); + + std::optional result = first(m->getModuleScope()->returnType); + REQUIRE(result); + CHECK(get(*result)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") +{ + fileResolver.source["game/A"] = R"( +local wrapStrictTable + +local metatable = { + __index = function(self, key, ...: T) + local value = self.__tbl[key] + if type(value) == "table" then + -- unification of the free 'wrapStrictTable' with this function type causes generics of this function to leak out of scope + return wrapStrictTable(value, self.__name .. "." .. key) + end + return ... + end, +} + +return wrapStrictTable + )"; + + frontend.check("game/A"); + + fileResolver.source["game/B"] = R"( +local wrapStrictTable = require(game.A) + +local Constants = {} + +return wrapStrictTable(Constants, "Constants") + )"; + + frontend.check("game/B"); + + ModulePtr m = frontend.moduleResolver.modules["game/B"]; + REQUIRE(m); + + std::optional result = first(m->getModuleScope()->returnType); + REQUIRE(result); + CHECK(get(*result)); +} + +// We need a simplification step to make this do the right thing. ("normalization-lite") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") +{ + CheckResult result = check(R"( + local function foo(t, x) + if x == "hi" or x == "bye" then + table.insert(t, x) + end + + return t + end + + local t = foo({}, "hi") + table.insert(t, "totally_unrelated_type" :: "totally_unrelated_type") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // We'd really like for this to be {string} + CHECK_EQ("{string | string}", toString(requireType("t"))); +} + +namespace +{ +struct IsSubtypeFixture : Fixture +{ + bool isSubtype(TypeId a, TypeId b) { - dumpErrors(result); - FAIL("Expected a UnificationTooComplex error"); + return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, singletonTypes, ice); + } +}; +} // namespace + +TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection_of_functions_of_different_arities") +{ + check(R"( + type A = (any) -> () + type B = (any, any) -> () + type T = A & B + + local a: A + local b: B + local t: T + )"); + + [[maybe_unused]] TypeId a = requireType("a"); + [[maybe_unused]] TypeId b = requireType("b"); + + // CHECK(!isSubtype(a, b)); // !! + // CHECK(!isSubtype(b, a)); + + CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity") +{ + check(R"( + local a: (number) -> () + local b: () -> () + + local c: () -> number + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + // CHECK(!isSubtype(b, a)); + // CHECK(!isSubtype(c, a)); + + CHECK(!isSubtype(a, b)); + // CHECK(!isSubtype(c, b)); + + CHECK(!isSubtype(a, c)); + CHECK(!isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_with_mismatching_arity_but_optional_parameters") +{ + /* + * (T0..TN) <: (T0..TN, A?) + * (T0..TN) <: (T0..TN, any) + * (T0..TN, A?) R <: U -> S if U <: T and R <: S + * A | B <: T if A <: T and B <: T + * T <: A | B if T <: A or T <: B + */ + check(R"( + local a: (number?) -> () + local b: (number) -> () + local c: (number, number?) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, number?) -> () <: (number) -> (number) + * The packs have inequal lengths, but (number) <: (number, number?) + * and number <: number + */ + // CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * because (number, number?) () () + * because (number, number?) () + local b: (number) -> () + local c: (number, any) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, any) -> () (number) + * The packs have inequal lengths + */ + // CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * The packs have inequal lengths + */ + // CHECK(!isSubtype(a, c)); + + /* + * (number) -> () () + * The packs have inequal lengths + */ + // CHECK(!isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") +{ + CheckResult result = check(R"( + local t: {x: number?} = {x = nil} + + if t.x then + local u: {x: number} = t + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauTypeMismatchInvarianceInError) + { + CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' +caused by: + Property 'x' is not compatible. Type 'number?' could not be converted into 'number' in an invariant context)", + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' +caused by: + Property 'x' is not compatible. Type 'number?' could not be converted into 'number')", + toString(result.errors[0])); } } -// Should be in TypeInfer.tables.test.cpp -// It's unsound to instantiate tables containing generic methods, -// since mutating properties means table properties should be invariant. -// We currently allow this but we shouldn't! -TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_tables_in_call_is_unsound") -{ - CheckResult result = check(R"( - --!strict - local t = {} - function t.m(x) return x end - local a : string = t.m("hi") - local b : number = t.m(5) - function f(x : { m : (number)->number }) - x.m = function(x) return 1+x end - end - f(t) -- This shouldn't typecheck - local c : string = t.m("hi") - )"); - - // TODO: this should error! - // This should be fixed by replacing generic tables by generics with type bounds. - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") -{ - ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; - ScopedFastFlag luauFollowInTypeFunApply{"LuauFollowInTypeFunApply", true}; - ScopedFastFlag luauInstantiatedTypeParamRecursion{"LuauInstantiatedTypeParamRecursion", true}; - - // Mutability in type function application right now can create strange recursive types - // TODO: instantiation right now is problematic, it this example should either leave the Table type alone - // or it should rename the type to 'Self' so that the result will be 'Self
' - CheckResult result = check(R"( -type Table = { a: number } -type Self = T -local a: Self
- )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a")), "Table
"); -} - TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index f2ba0ddc..5a7c8432 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1,16 +1,108 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Normalize.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Fixture.h" #include "doctest.h" -LUAU_FASTFLAG(LuauWeakEqConstraint) -LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) -LUAU_FASTFLAG(LuauOrPredicate) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) using namespace Luau; +namespace +{ +std::optional> magicFunctionInstanceIsA( + TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + if (expr.args.size != 1) + return std::nullopt; + + auto index = expr.func->as(); + auto str = expr.args.data[0]->as(); + if (!index || !str) + return std::nullopt; + + std::optional lvalue = tryGetLValue(*index->expr); + std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); + if (!lvalue || !tfun) + return std::nullopt; + + unfreeze(typeChecker.globalTypes); + TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); + freeze(typeChecker.globalTypes); + return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; +} + +std::vector dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) +{ + if (ctx.callSite->args.size != 1) + return {}; + + auto index = ctx.callSite->func->as(); + auto str = ctx.callSite->args.data[0]->as(); + if (!index || !str) + return {}; + + std::optional def = ctx.dfg->getDef(index->expr); + if (!def) + return {}; + + std::optional tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size)); + if (!tfun) + return {}; + + return {ctx.connectiveArena->proposition(*def, tfun->type)}; +} + +struct RefinementClassFixture : BuiltinsFixture +{ + RefinementClassFixture() + { + TypeArena& arena = typeChecker.globalTypes; + NotNull scope{typeChecker.globalScope.get()}; + + unfreeze(arena); + TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); + getMutable(vec3)->props = { + {"X", Property{typeChecker.numberType}}, + {"Y", Property{typeChecker.numberType}}, + {"Z", Property{typeChecker.numberType}}, + }; + + TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); + + TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); + TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); + TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); + getMutable(isA)->magicFunction = magicFunctionInstanceIsA; + getMutable(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA; + + getMutable(inst)->props = { + {"Name", Property{typeChecker.stringType}}, + {"IsA", Property{isA}}, + }; + + TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); + TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); + getMutable(part)->props = { + {"Position", Property{vec3}}, + }; + + typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; + typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; + typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; + typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; + + for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + persist(ty.type); + + freeze(typeChecker.globalTypes); + } +}; +} // namespace + TEST_SUITE_BEGIN("RefinementTest"); TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint") @@ -27,8 +119,16 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({5, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({5, 26}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({5, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint") @@ -45,8 +145,16 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({5, 26}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through") @@ -63,8 +171,16 @@ TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({5, 26}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "and_constraint") @@ -83,8 +199,16 @@ TEST_CASE_FIXTURE(Fixture, "and_constraint") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("number", toString(requireTypeAtPosition({4, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("(number?) & ~(false?)", toString(requireTypeAtPosition({4, 26}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("number", toString(requireTypeAtPosition({4, 26}))); + } CHECK_EQ("string?", toString(requireTypeAtPosition({6, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({7, 26}))); @@ -109,8 +233,16 @@ TEST_CASE_FIXTURE(Fixture, "not_and_constraint") CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({6, 26}))); - CHECK_EQ("number", toString(requireTypeAtPosition({7, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("(number?) & ~(false?)", toString(requireTypeAtPosition({7, 26}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("number", toString(requireTypeAtPosition({7, 26}))); + } } TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates") @@ -132,13 +264,58 @@ TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates") CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); - if (FFlag::LuauOrPredicate) + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({7, 26}))); + } + else { CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); } } +TEST_CASE_FIXTURE(Fixture, "a_and_b_or_a_and_c") +{ + CheckResult result = check(R"( + function f(a: string?, b: number?, c: boolean) + if (a and b) or (a and c) then + local foo = a + local bar = b + local baz = c + else + local foo = a + local bar = b + local baz = c + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & (~(false?) | ~(false?))", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("boolean", toString(requireTypeAtPosition({5, 28}))); + + CHECK_EQ("string?", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({8, 28}))); + CHECK_EQ("boolean", toString(requireTypeAtPosition({9, 28}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("true", toString(requireTypeAtPosition({5, 28}))); // oh no! :( + + CHECK_EQ("string?", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({8, 28}))); + CHECK_EQ("boolean", toString(requireTypeAtPosition({9, 28}))); + } +} + TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") { CheckResult result = check(R"( @@ -152,11 +329,20 @@ TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); - CHECK_EQ("string", toString(requireTypeAtPosition({4, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("number?", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({4, 26}))); + } + else + { + // We're going to drop support for type refinements through type assertions. + CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({4, 26}))); + } } -TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_if_condition_position") { CheckResult result = check(R"( function f(s: any) @@ -168,10 +354,17 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("any & number", toString(requireTypeAtPosition({3, 26}))); + } + else + { + CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + } } -TEST_CASE_FIXTURE(Fixture, "typeguard_in_assert_position") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") { CheckResult result = check(R"( local a @@ -180,38 +373,11 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_in_assert_position") )"); LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("number", toString(requireType("b"))); } -TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") -{ - CheckResult result = check(R"( - type ActuallyString = string - - do -- Necessary. Otherwise toposort has ActuallyString come after string type alias. - type string = number - local foo: string = 1 - - if type(foo) == "string" then - local bar: ActuallyString = foo - local baz: boolean = foo - end - end - )"); - - if (FFlag::LuauImprovedTypeGuardPredicate2) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'string' could not be converted into 'boolean'", toString(result.errors[0])); - } -} - -TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") +TEST_CASE_FIXTURE(BuiltinsFixture, "call_an_incompatible_function_after_using_typeguard") { CheckResult result = check(R"( local function f(x: number) @@ -226,10 +392,11 @@ TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "impossible_type_narrow_is_not_an_error") +TEST_CASE_FIXTURE(BuiltinsFixture, "impossible_type_narrow_is_not_an_error") { // This unit test serves as a reminder to not implement this warning until Luau is intelligent enough. // For instance, getting a value out of the indexer and checking whether the value exists is not an error. @@ -248,24 +415,30 @@ TEST_CASE_FIXTURE(Fixture, "impossible_type_narrow_is_not_an_error") TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") { - CheckResult result = check(R"( local t: {x: number?} = {x = 1} if t.x then - local foo: number = t.x + local t2 = t + local foo = t.x end local bar = t.x )"); LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK("{| x: number? |} & {| x: ~(false?) |}" == toString(requireTypeAtPosition({4, 23}))); + CHECK("(number?) & ~(false?)" == toString(requireTypeAtPosition({5, 26}))); + } + CHECK_EQ("number?", toString(requireType("bar"))); } -TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") +TEST_CASE_FIXTURE(BuiltinsFixture, "index_on_a_refined_property") { - CheckResult result = check(R"( local t: {x: {y: string}?} = {x = {y = "hello!"}} @@ -277,7 +450,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_constraints") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_non_binary_expressions_actually_resolve_constraints") { CheckResult result = check(R"( local foo: string? = "hello" @@ -288,25 +461,8 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") -{ - - CheckResult result = check(R"( - local t: {x: number?} = {x = nil} - - if t.x then - local u: {x: number} = t - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type '{| x: number? |}' could not be converted into '{| x: number |}'", toString(result.errors[0])); -} - TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?, b: boolean?) if a == b then @@ -319,18 +475,18 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) + if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "((number | string)?) & unknown"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "(boolean?) & unknown"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "((number | string)?) & unknown"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "(boolean?) & unknown"); // a ~= b } else { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "nil"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "nil"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b @@ -339,8 +495,6 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?) if a == 1 then @@ -353,22 +507,20 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) + if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1 - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & unknown"); // a == 1 + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & unknown"); // a ~= 1 } else { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number"); // a == 1 + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1; CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 } } TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?) if "hello" == a then @@ -381,22 +533,20 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) + if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == "hello" - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello" & ((number | string)?))"); // a == "hello" + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), R"(((number | string)?) & ~"hello")"); // a ~= "hello" } else { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "string"); // a == "hello" - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello" + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), R"((number | string)?)"); // a ~= "hello" } } TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?) if a ~= nil then @@ -409,22 +559,20 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) + if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "((number | string)?) & ~nil"); // a ~= nil + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "((number | string)?) & unknown"); // a == nil } else { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "nil"); // a == nil + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil } } TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a, b: string?) if a == b then @@ -435,22 +583,21 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) + if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b + ToStringOptions opts; + CHECK_EQ(toString(requireTypeAtPosition({3, 33}), opts), "a & unknown"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36}), opts), "(string?) & unknown"); // a == b } else { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "string?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b } } TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: any, b: {x: number}?) if a ~= b then @@ -461,22 +608,20 @@ TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_e LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) + if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any & unknown"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "({| x: number |}?) & unknown"); // a ~= b } else { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b } } TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local t: {string} = {"hello"} @@ -491,20 +636,21 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b - - if (FFlag::LuauWeakEqConstraint) + if (FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string & unknown"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "(string?) & unknown"); // a ~= b + + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string & unknown"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "(string?) & unknown"); // a == b } else { - // This is technically not wrong, but it's also wrong at the same time. - // The refinement code is none the wiser about the fact we pulled a string out of an array, so it has no choice but to narrow as just string. - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b + + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b } } @@ -524,10 +670,8 @@ TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_narrow_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x) if type(x) == "vector" then @@ -536,16 +680,13 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") end )"); - // This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type. - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0])); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28}))); } -TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") +TEST_CASE_FIXTURE(BuiltinsFixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local t = {"hello"} local v = t[2] @@ -571,10 +712,8 @@ TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true" CHECK_EQ("string", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" } -TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_not_to_be_string") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | number | boolean) if type(x) ~= "string" then @@ -587,14 +726,20 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("boolean | number", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(boolean | number | string) & ~string", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" + CHECK_EQ("(boolean | number | string) & string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" + } + else + { + CHECK_EQ("boolean | number", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" + } } -TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_narrows_for_table") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | {x: number} | {y: boolean}) if type(x) == "table" then @@ -611,10 +756,8 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "table" } -TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_narrows_for_functions") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function weird(x: string | ((number) -> string)) if type(x) == "function" then @@ -627,79 +770,602 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(number) -> string", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" -} - -namespace -{ -std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) -{ - if (expr.args.size != 1) - return std::nullopt; - - auto index = expr.func->as(); - auto str = expr.args.data[0]->as(); - if (!index || !str) - return std::nullopt; - - std::optional lvalue = tryGetLValue(*index->expr); - std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); - if (!lvalue || !tfun) - return std::nullopt; - - unfreeze(typeChecker.globalTypes); - TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); - freeze(typeChecker.globalTypes); - return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; -} - -struct RefinementClassFixture : Fixture -{ - RefinementClassFixture() + if (FFlag::DebugLuauDeferredConstraintResolution) { - TypeArena& arena = typeChecker.globalTypes; - - unfreeze(arena); - TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr}); - getMutable(vec3)->props = { - {"X", Property{typeChecker.numberType}}, - {"Y", Property{typeChecker.numberType}}, - {"Z", Property{typeChecker.numberType}}, - }; - - TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); - - TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); - TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); - TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); - getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - - getMutable(inst)->props = { - {"Name", Property{typeChecker.stringType}}, - {"IsA", Property{isA}}, - }; - - TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); - TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); - getMutable(part)->props = { - {"Position", Property{vec3}}, - }; - - typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; - typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; - typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; - typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; - freeze(typeChecker.globalTypes); + CHECK_EQ("(((number) -> string) | string) & function", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" + CHECK_EQ("(((number) -> string) | string) & ~function", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" } -}; -} // namespace + else + { + CHECK_EQ("(number) -> string", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_can_filter_for_intersection_of_tables") +{ + CheckResult result = check(R"( + type XYCoord = {x: number} & {y: number} + local function f(t: XYCoord?) + if type(t) == "table" then + local foo = t + else + local foo = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_can_filter_for_overloaded_function") +{ + CheckResult result = check(R"( + type SomeOverloadedFunction = ((number) -> string) & ((string) -> number) + local function f(g: SomeOverloadedFunction?) + if type(g) == "function" then + local foo = g + else + local foo = g + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("((((number) -> string) & ((string) -> number))?) & function", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("((((number) -> string) & ((string) -> number))?) & ~function", toString(requireTypeAtPosition({6, 28}))); + } + else + { + CHECK_EQ("((number) -> string) & ((string) -> number)", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_narrowed_into_nothingness") +{ + CheckResult result = check(R"( + local function f(t: {x: number}) + if type(t) ~= "table" then + local foo = t + error(("Expected a table, got %s"):format(type(t))) + end + + return t.x + 1 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") +{ + CheckResult result = check(R"( + local function f(a: number?, b: number?) + if (not a) or (not b) then + local foo = a + local bar = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number?", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b2") +{ + CheckResult result = check(R"( + local function f(a: number?, b: number?) + if not (a and b) then + local foo = a + local bar = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number?", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b") +{ + CheckResult result = check(R"( + local function f(a: number?, b: number?) + if (not a) and (not b) then + local foo = a + local bar = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({4, 28}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") +{ + CheckResult result = check(R"( + local function f(a: number?, b: number?) + if not (a or b) then + local foo = a + local bar = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(number?) & ~~(false?)", toString(requireTypeAtPosition({4, 28}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "either_number_or_string") +{ + CheckResult result = check(R"( + local function f(x: any) + if type(x) == "number" or type(x) == "string" then + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(number | string) & any", toString(requireTypeAtPosition({3, 28}))); + } + else + { + CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") +{ + CheckResult result = check(R"( + local function f(t: {x: boolean}?) + if not t or t.x then + local foo = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{| x: boolean |}?", toString(requireTypeAtPosition({3, 28}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") +{ + CheckResult result = check(R"( + local a: (number | string)? + assert(a) + local b = a + assert(type(a) == "number") + local c = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("((number | string)?) & ~(false?)", toString(requireTypeAtPosition({3, 18}))); + CHECK_EQ("((number | string)?) & ~(false?) & number", toString(requireTypeAtPosition({5, 18}))); + } + else + { + CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 18}))); + CHECK_EQ("number", toString(requireTypeAtPosition({5, 18}))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") +{ + // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. + CheckResult result = check(R"( + local function f(b: string | { x: string }, a) + assert(type(a) == "string") + assert(type(b) == "string" or type(b) == "table") + + if type(b) == "string" then + local foo = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(never | string) & (string | {| x: string |}) & string", toString(requireTypeAtPosition({6, 28}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({6, 28}))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") +{ + CheckResult result = check(R"( + local function f(a: string | number | boolean) + if type(a) ~= "number" and type(a) ~= "string" then + local foo = a + else + local foo = a + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(boolean | number | string) & ~number & ~string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(boolean | number | string) & (number | string)", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_ifelse_expression") +{ + CheckResult result = check(R"( + function f(v:string?) + return if v then v else tostring(v) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({2, 29}))); + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({2, 45}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({2, 29}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({2, 45}))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_ifelse_expression") +{ + CheckResult result = check(R"( + function f(v:string?) + return if not v then tostring(v) else v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string?) & ~~(false?)", toString(requireTypeAtPosition({2, 42}))); + CHECK_EQ("(string?) & ~(false?)", toString(requireTypeAtPosition({2, 50}))); + } + else + { + CHECK_EQ("nil", toString(requireTypeAtPosition({2, 42}))); + CHECK_EQ("string", toString(requireTypeAtPosition({2, 50}))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression") +{ + CheckResult result = check(R"( + function returnOne(x) + return 1 + end + + function f(v:any) + return if typeof(v) == "number" then v else returnOne(v) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("any & number", toString(requireTypeAtPosition({6, 49}))); + CHECK_EQ("any & ~number", toString(requireTypeAtPosition({6, 66}))); + } + else + { + CHECK_EQ("number", toString(requireTypeAtPosition({6, 49}))); + CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") +{ + CheckResult result = check(R"( + local foo: string? = "hi" + assert(foo) + local foo: number = 5 + print(foo:sub(1, 1)) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'number' does not have key 'sub'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_property_whose_base_was_previously_refined") +{ + CheckResult result = check(R"( + type T = {x: string | number} + local t: T? = {x = "hi"} + if t then + if type(t.x) == "string" then + local foo = t.x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({5, 30}))); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined2") +{ + CheckResult result = check(R"( + type T = { x: { y: number }? } + + local function f(t: T?) + if t and t.x then + local foo = t.x.y + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireTypeAtPosition({5, 32}))); +} + +TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") +{ + CheckResult result = check(R"( + type T = { [string]: { prop: number }? } + local t: T = {} + + if t["hello"] then + local foo = t["hello"].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") +{ + CheckResult result = check(R"( + type T = {tag: "missing", x: nil} | {tag: "exists", x: string} + + local function f(t: T) + if t.x then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(R"(({| tag: "exists", x: string |} | {| tag: "missing", x: nil |}) & {| x: ~(false?) |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ( + R"(({| tag: "exists", x: string |} | {| tag: "missing", x: nil |}) & {| x: ~~(false?) |})", toString(requireTypeAtPosition({7, 28}))); + } + else + { + CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "exists", x: string |} | {| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "discriminate_tag") +{ + CheckResult result = check(R"( + type Cat = {tag: "Cat", name: string, catfood: string} + type Dog = {tag: "Dog", name: string, dogfood: string} + type Animal = Cat | Dog + + local function f(animal: Animal) + if animal.tag == "Cat" then + local cat = animal + elseif animal.tag == "Dog" then + local dog = animal + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(R"((Cat | Dog) & {| tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ(R"((Cat | Dog) & {| tag: ~"Cat" |} & {| tag: "Dog" |})", toString(requireTypeAtPosition({9, 33}))); + } + else + { + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "discriminate_tag_with_implicit_else") +{ + ScopedFastFlag sff{"LuauImplicitElseRefinement", true}; + + CheckResult result = check(R"( + type Cat = {tag: "Cat", name: string, catfood: string} + type Dog = {tag: "Dog", name: string, dogfood: string} + type Animal = Cat | Dog + + local function f(animal: Animal) + if animal.tag == "Cat" then + local cat = animal + else + local dog = animal + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(R"((Cat | Dog) & {| tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ(R"((Cat | Dog) & {| tag: ~"Cat" |})", toString(requireTypeAtPosition({9, 33}))); + } + else + { + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") +{ + CheckResult result = check(R"( + local function len(a: {any}) + return a and #a or nil + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") +{ + CheckResult result = check(R"( + local function f(x: boolean) + if x then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("boolean & ~(false?)", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("boolean & ~~(false?)", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("true", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("false", toString(requireTypeAtPosition({5, 28}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") +{ + CheckResult result = check(R"( + type Ok = { ok: true, value: T } + type Err = { ok: false, error: E } + type Result = Ok | Err + + local function apply(t: Result, f: (T) -> (), g: (E) -> ()) + if t.ok then + f(t.value) + else + g(t.error) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "refine_a_property_not_to_be_nil_through_an_intersection_table") +{ + CheckResult result = check(R"( + type T = {} & {f: ((string) -> string)?} + local function f(t: T, x) + if t.f then + t.f(x) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") +{ + CheckResult result = check(R"( + type T = {tag: "Part", x: Part} | {tag: "Folder", x: Folder} + + local function f(t: T) + if t.x:IsA("Part") then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ(R"(({| tag: "Folder", x: Folder |} | {| tag: "Part", x: Part |}) & {| x: Part |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"(({| tag: "Folder", x: Folder |} | {| tag: "Part", x: Part |}) & {| x: ~Part |})", toString(requireTypeAtPosition({7, 28}))); + } + else + { + CHECK_EQ(R"({| tag: "Part", x: Part |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "Folder", x: Folder |})", toString(requireTypeAtPosition({7, 28}))); + } +} TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(vec) local X, Y, Z = vec.X, vec.Y, vec.Z @@ -714,20 +1380,17 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" - CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" + CHECK_EQ("never", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" } TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: Instance | Vector3) if typeof(x) == "Vector3" then @@ -740,14 +1403,20 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Instance | Vector3) & Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Instance | Vector3) & ~Vector3", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | number | Instance | Vector3) if type(x) == "userdata" then @@ -766,11 +1435,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") { - ScopedFastFlag sffs[] = { - {"LuauImprovedTypeGuardPredicate2", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; - CheckResult result = check(R"( local function f(x: Part | Folder | string) if typeof(x) == "Instance" then @@ -783,19 +1447,22 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Folder | Part | string) & Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Folder | Part | string) & ~Instance", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); + } } -TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") +TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_from_subclasses_of_instance_or_string_or_vector3") { - ScopedFastFlag sffs[] = { - {"LuauImprovedTypeGuardPredicate2", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; - CheckResult result = check(R"( - local function f(x: Part | Folder | Instance | string | Vector3 | any) + local function f(x: Part | Folder | string | Vector3) if typeof(x) == "Instance" then local foo = x else @@ -806,17 +1473,20 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Folder | Part | Vector3 | string) & Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Folder | Part | Vector3 | string) & ~Instance", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Vector3 | string", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; - CheckResult result = check(R"( --!nonstrict @@ -831,18 +1501,100 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ("any", toString(requireTypeAtPosition({7, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("Folder & Instance & {- -}", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("(~Folder | ~Instance) & {- -} & never", toString(requireTypeAtPosition({7, 28}))); + } + else + { + CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("any", toString(requireTypeAtPosition({7, 28}))); + } +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "refine_param_of_type_instance_without_using_typeof") +{ + CheckResult result = check(R"( + local function f(x: Instance) + if x:IsA("Folder") then + local foo = x + elseif typeof(x) == "table" then + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("Folder & Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance & ~Folder & never", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Folder", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("never", toString(requireTypeAtPosition({5, 28}))); + } +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "refine_param_of_type_folder_or_part_without_using_typeof") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder) + if x:IsA("Folder") then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Folder | Part) & Folder", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Folder | Part) & ~Folder", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Folder", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); + } +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "isa_type_refinement_must_be_known_ahead_of_time") +{ + CheckResult result = check(R"( + local function f(x): Instance + if x:IsA("Folder") then + local foo = x + else + local foo = x + end + + return x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; - CheckResult result = check(R"( local function f(x: Part | Folder | string) if typeof(x) ~= "Instance" or not x:IsA("Part") then @@ -859,154 +1611,73 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); } -TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_doesnt_leak_to_elseif") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( - type XYCoord = {x: number} & {y: number} - local function f(t: XYCoord?) - if type(t) == "table" then - local foo = t + function f(a) + if type(a) == "boolean" then + local a1 = a + elseif a.fn() then + local a2 = a else - local foo = t + local a3 = a end end )"); LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } -TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( - type SomeOverloadedFunction = ((number) -> string) & ((string) -> number) - local function f(g: SomeOverloadedFunction?) - if type(g) == "function" then - local foo = g + local function f(x: unknown) + if type(x) == "string" then + local foo = x else - local foo = g + local bar = x end end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("((number) -> string) & ((string) -> number)", toString(requireTypeAtPosition({4, 28}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("unknown & string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("unknown & ~string", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("unknown", toString(requireTypeAtPosition({5, 28}))); + } } -TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") +TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( - local function f(t: {x: number}) - if type(t) ~= "table" then - local foo = t - error(("Expected a table, got %s"):format(type(t))) - end - - return t.x + 1 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); -} - -TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") -{ - ScopedFastFlag sff{"LuauOrPredicate", true}; - - CheckResult result = check(R"( - local function f(a: number?, b: number?) - if (not a) or (not b) then - local foo = a - local bar = b + local function f(t: {number}) + local x = t[1] + if not x then + local foo = x + else + local bar = x end end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number?", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); -} - -TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b2") -{ - ScopedFastFlag sff{"LuauOrPredicate", true}; - - CheckResult result = check(R"( - local function f(a: number?, b: number?) - if not (a and b) then - local foo = a - local bar = b - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("number?", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); -} - -TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b") -{ - ScopedFastFlag sff{"LuauOrPredicate", true}; - - CheckResult result = check(R"( - local function f(a: number?, b: number?) - if (not a) and (not b) then - local foo = a - local bar = b - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("number", toString(requireTypeAtPosition({6, 28}))); } -TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") +TEST_CASE_FIXTURE(BuiltinsFixture, "what_nonsensical_condition") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( - local function f(a: number?, b: number?) - if not (a or b) then - local foo = a - local bar = b - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); -} - -TEST_CASE_FIXTURE(Fixture, "either_number_or_string") -{ - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; - - CheckResult result = check(R"( - local function f(x: any) - if type(x) == "number" or type(x) == "string" then + local function f(x) + if type(x) == "string" and type(x) == "number" then local foo = x end end @@ -1014,147 +1685,49 @@ TEST_CASE_FIXTURE(Fixture, "either_number_or_string") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("a & number & string", toString(requireTypeAtPosition({3, 28}))); + } + else + { + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); + } } -TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") +TEST_CASE_FIXTURE(Fixture, "else_with_no_explicit_expression_should_also_refine_the_tagged_union") { - ScopedFastFlag sff{"LuauOrPredicate", true}; + ScopedFastFlag sff{"LuauImplicitElseRefinement", true}; CheckResult result = check(R"( - local function f(t: {x: boolean}?) - if not t or t.x then - local foo = t - end - end - )"); + type Ok = { tag: "ok", value: T } + type Err = { tag: "err", err: E } + type Result = Ok | Err - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("{| x: boolean |}?", toString(requireTypeAtPosition({3, 28}))); -} - -TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") -{ - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; - - CheckResult result = check(R"( - local a: (number | string)? - assert(a) - local b = a - assert(type(a) == "number") - local c = a - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 18}))); - CHECK_EQ("number", toString(requireTypeAtPosition({5, 18}))); -} - -TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") -{ - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; - - // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. - CheckResult result = check(R"( - local function f(b: string | { x: string }, a) - assert(type(a) == "string") - assert(type(b) == "string" or type(b) == "table") - - if type(b) == "string" then - local foo = b - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("string", toString(requireTypeAtPosition({6, 28}))); -} - -TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") -{ - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; - - CheckResult result = check(R"( - local function f(a: string | number | boolean) - if type(a) ~= "number" and type(a) ~= "string" then - local foo = a + function and_then(r: Result, f: (T) -> U): Result + if r.tag == "ok" then + return { tag = "ok", value = f(r.value) } else - local foo = a + return r end end )"); LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); } -TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") +TEST_CASE_FIXTURE(Fixture, "fuzz_filtered_refined_types_are_followed") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + ScopedFastFlag luauTypeInferMissingFollows{"LuauTypeInferMissingFollows", true}; CheckResult result = check(R"( - function f(v:string?) - return if v then v else tostring(v) - end +local _ +do +local _ = _ ~= _ or _ or _ +end )"); LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("string", toString(requireTypeAtPosition({2, 29}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({2, 45}))); -} - -TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") -{ - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - CheckResult result = check(R"( - function f(v:string?) - return if not v then tostring(v) else v - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("nil", toString(requireTypeAtPosition({2, 42}))); - CHECK_EQ("string", toString(requireTypeAtPosition({2, 50}))); -} - -TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") -{ - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - CheckResult result = check(R"( - function returnOne(x) - return 1 - end - - function f(v:any) - return if typeof(v) == "number" then v else returnOne(v) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("number", toString(requireTypeAtPosition({6, 49}))); - CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp new file mode 100644 index 00000000..e37880d0 --- /dev/null +++ b/tests/TypeInfer.singletons.test.cpp @@ -0,0 +1,518 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeSingletons"); + +TEST_CASE_FIXTURE(Fixture, "bool_singletons") +{ + CheckResult result = check(R"( + local a: true = true + local b: false = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons") +{ + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: "bar" = "bar" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") +{ + CheckResult result = check(R"( + local a: true = false + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'false' could not be converted into 'true'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch") +{ + CheckResult result = check(R"( + local a: "foo" = "bar" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars") +{ + CheckResult result = check(R"( + local a: "\n" = "\000\r" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '"\000\r"' could not be converted into '"\n"')", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype") +{ + CheckResult result = check(R"( + local a: true = true + local b: boolean = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype") +{ + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: string = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype_multi_assignment") +{ + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: string, c: number = a, 10 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") +{ + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "foo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch") +{ + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "bar") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") +{ + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, "foo") + g(false, 37) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") +{ + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, 37) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") +{ + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "foo" + local b : MyEnum = "bar" + local c : MyEnum = "baz" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") +{ + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "bang" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bang\"' could not be converted into '\"bar\" | \"baz\" | \"foo\"'; none of the union options are compatible", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") +{ + CheckResult result = check(R"( + type MyEnum1 = "foo" | "bar" + type MyEnum2 = MyEnum1 | "baz" + local a : MyEnum1 = "foo" + local b : MyEnum2 = a + local c : string = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") +{ + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Dog = { tag = "Dog", howls = true } + local b : Animal = { tag = "Cat", meows = true } + local c : Animal = a + c = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch") +{ + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", howls = true } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") +{ + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", meows = true } + a.tag = "Dog" + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_has_a_boolean") +{ + CheckResult result = check(R"( + local t={a=1,b=false} + )"); + + CHECK("{ a: number, b: boolean }" == toString(requireType("t"), {true})); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") +{ + CheckResult result = check(R"( + --!strict + type T = { + ["foo"] : number, + ["$$bar"] : string, + baz : boolean + } + local t: T = { + ["foo"] = 37, + ["$$bar"] = "hi", + baz = true + } + local a: number = t.foo + local b: string = t["$$bar"] + local c: boolean = t.baz + t.foo = 5 + t["$$bar"] = "lo" + t.baz = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch") +{ + CheckResult result = check(R"( + --!strict + type T = { + ["$$bar"] : string, + } + local t: T = { + ["$$bar"] = "hi", + } + t["$$bar"] = 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") +{ + CheckResult result = check(R"( + --!strict + type S = "bar" + type T = { + [("foo")] : number, + [S] : string, + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") +{ + CheckResult result = check(R"( + --!strict + local x: { ["<>"] : number } + x = { ["\n"] = 5 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Table type '{ ["\n"]: number }' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") +{ + CheckResult result = check(R"( +type Cat = { tag: 'cat', catfood: string } +type Dog = { tag: 'dog', dogfood: string } +type Animal = Cat | Dog + +local a: Animal = { tag = 'cat', cafood = 'something' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'a' could not be converted into 'Cat | Dog' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'Cat' because the former is missing field 'catfood')", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") +{ + CheckResult result = check(R"( +type Good = { success: true, result: string } +type Bad = { success: false, error: string } +type Result = Good | Bad + +local a: Result = { success = false, result = 'something' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'a' could not be converted into 'Bad | Good' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'Bad' because the former is missing field 'error')", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "parametric_tagged_union_alias") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + }; + + CheckResult result = check(R"( + type Ok = {success: true, result: T} + type Err = {success: false, error: T} + type Result = Ok | Err + + local a : Result = {success = false, result = "hotdogs"} + local b : Result = {success = true, result = "hotdogs"} + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + const std::string expectedError = "Type '{ result: string, success: false }' could not be converted into 'Err | Ok'\n" + "caused by:\n" + " None of the union options are compatible. For example: Table type '{ result: string, success: false }'" + " not compatible with type 'Err' because the former is missing field 'error'"; + + CHECK(toString(result.errors[0]) == expectedError); +} + +TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") +{ + CheckResult result = check(R"( +type Cat = { tag: 'cat', catfood: string } +type Dog = { tag: 'dog', dogfood: string } +type Animal = Cat | Dog + +local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag = 'dog', dogfood = 'other' } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_singleton") +{ + CheckResult result = check(R"( + local function foo(f, x) + if x == "hi" then + f(x) + f("foo") + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 18}))); + // should be ((string) -> a..., string) -> () but needs lower bounds calculation + CHECK_EQ("((string) -> (b...), a) -> ()", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") +{ + CheckResult result = check(R"( + local function foo(f, x): "hello"? -- anyone there? + return if x == "hi" + then f(x) + else nil + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); + CHECK_EQ(R"(((string) -> (a, c...), b) -> "hello"?)", toString(requireType("foo"))); + // CHECK_EQ(R"(((string) -> ("hello"?, b...), a) -> "hello"?)", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") +{ + CheckResult result = check(R"( + local foo: "foo" = "foo" + local copy = foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireType("copy"))); +} + +TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables") +{ + CheckResult result = check(R"( + type Cat = {tag: "Cat", meows: boolean} + type Dog = {tag: "Dog", barks: boolean} + type Animal = Cat | Dog + + local function f(tag: "Cat" | "Dog"): Animal? + if tag == "Cat" then + local result = {tag = tag, meows = true} + return result + elseif tag == "Dog" then + local result = {tag = tag, barks = true} + return result + else + return nil + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "functions_are_not_to_be_widened") +{ + CheckResult result = check(R"( + local function foo(my_enum: "A" | "B") end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"(("A" | "B") -> ())", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") +{ + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" then + local x = a:byte() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 22}))); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") +{ + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" or a == "bye" then + local x = a:byte() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 22}))); +} + +TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") +{ + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" then + local x = #a + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); +} + +TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") +{ + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" or a == "bye" then + local x = #a + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 23}))); +} + +TEST_CASE_FIXTURE(Fixture, "no_widening_from_callsites") +{ + ScopedFastFlag sff{"LuauReturnsFromCallsitesAreNotWidened", true}; + + CheckResult result = check(R"( + type Direction = "North" | "East" | "West" | "South" + + local function direction(): Direction + return "North" + end + + local d: Direction = direction() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 1d1b2fae..c379559d 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" +#include "Luau/Common.h" +#include "Luau/Frontend.h" +#include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -12,6 +14,12 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(LuauNewLibraryTypeNames) + TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -43,12 +51,27 @@ TEST_CASE_FIXTURE(Fixture, "augment_table") const TableTypeVar* tType = get(requireType("t")); REQUIRE(tType != nullptr); - CHECK(tType->props.find("foo") != tType->props.end()); + CHECK("{ foo: string }" == toString(requireType("t"), {true})); +} + +TEST_CASE_FIXTURE(Fixture, "augment_nested_table") +{ + CheckResult result = check("local t = { p = {} } t.p.foo = 'bar'"); + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* tType = getMutable(requireType("t")); + REQUIRE(tType != nullptr); + + REQUIRE(tType->props.find("p") != tType->props.end()); + const TableTypeVar* pType = get(tType->props["p"].type); + REQUIRE(pType != nullptr); + + CHECK("{ p: { foo: string } }" == toString(requireType("t"), {true})); } TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") { - CheckResult result = check("local t = {prop=999} t.foo = 'bar'"); + CheckResult result = check("function mkt() return {prop=999} end local t = mkt() t.foo = 'bar'"); LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; @@ -60,7 +83,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") CHECK_EQ(s, "{| prop: number |}"); CHECK_EQ(error->prop, "foo"); CHECK_EQ(error->context, CannotExtendTable::Property); - CHECK_EQ(err.location, (Location{Position{0, 24}, Position{0, 29}})); + CHECK_EQ(err.location, (Location{Position{0, 59}, Position{0, 64}})); } TEST_CASE_FIXTURE(Fixture, "dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table") @@ -85,7 +108,11 @@ TEST_CASE_FIXTURE(Fixture, "updating_sealed_table_prop_is_ok") TEST_CASE_FIXTURE(Fixture, "cannot_change_type_of_unsealed_table_prop") { - CheckResult result = check("local t = {} t.prop = 999 t.prop = 'hello'"); + CheckResult result = check(R"( + local t = {} + t.prop = 999 + t.prop = 'hello' + )"); LUAU_REQUIRE_ERROR_COUNT(1, result); } @@ -185,7 +212,7 @@ TEST_CASE_FIXTURE(Fixture, "used_dot_instead_of_colon") REQUIRE(it != result.errors.end()); } -TEST_CASE_FIXTURE(Fixture, "used_colon_correctly") +TEST_CASE_FIXTURE(BuiltinsFixture, "used_colon_correctly") { CheckResult result = check(R"( --!nonstrict @@ -272,10 +299,11 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; - UnknownProperty* error = get(err); + MissingProperties* error = get(err); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); - CHECK_EQ(error->key, "y"); + CHECK_EQ("y", error->properties[0]); // TODO(rblanckaert): Revist when we can bind self at function creation time // CHECK_EQ(err.location, Location(Position{5, 19}, Position{5, 25})); @@ -340,10 +368,11 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") LUAU_REQUIRE_ERROR_COUNT(1, result); - UnknownProperty* error = get(result.errors[0]); + MissingProperties* error = get(result.errors[0]); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); - CHECK_EQ("baz", error->key); + CHECK_EQ("baz", error->properties[0]); } TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") @@ -359,8 +388,11 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; - UnknownProperty* error = get(err); + MissingProperties* error = get(err); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); + + CHECK_EQ("baz", error->properties[0]); // TODO(rblanckaert): Revist when we can bind self at function creation time /* @@ -448,6 +480,69 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") dumpErrors(result); } +TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") +{ + CheckResult result = check(R"( + --!strict + local t = { u = {} } + t = { u = { p = 37 } } + t = { u = { q = "hi" } } + local x = t.u.p + local y = t.u.q + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number?", toString(requireType("x"))); + CHECK_EQ("string?", toString(requireType("y"))); +} + +TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_call") +{ + CheckResult result = check(R"( + --!strict + function get(x) return x.opts["MYOPT"] end + function set(x,y) x.opts["MYOPT"] = y end + local t = { opts = {} } + set(t,37) + local x = get(t) + )"); + + // Currently this errors but it shouldn't, since set only needs write access + // TODO: file a JIRA for this + LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ("number?", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "width_subtyping") +{ + CheckResult result = check(R"( + --!strict + function f(x : { q : number }) + x.q = 8 + end + local t : { q : number, r : string } = { q = 8, r = "hi" } + f(t) + local x : string = t.r + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "width_subtyping_needs_covariance") +{ + CheckResult result = check(R"( + --!strict + function f(x : { p : { q : number }}) + x.p = { q = 8, r = 5 } + end + local t : { p : { q : number, r : string } } = { p = { q = 8, r = "hi" } } + f(t) -- Shouldn't typecheck + local x : string = t.p.r -- x is 5 + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "infer_array") { CheckResult result = check(R"( @@ -524,7 +619,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") REQUIRE_EQ(indexer.indexType, typeChecker.numberType); - REQUIRE(nullptr != get(indexer.indexResultType)); + REQUIRE(nullptr != get(follow(indexer.indexResultType))); } TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") @@ -548,7 +643,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") const TableTypeVar* argType = get(follow(argVec[0])); REQUIRE(argType != nullptr); - std::vector retVec = flatten(ftv->retType).first; + std::vector retVec = flatten(ftv->retTypes).first; const TableTypeVar* retType = get(follow(retVec[0])); REQUIRE(retType != nullptr); @@ -597,7 +692,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") const FunctionTypeVar* fType = get(requireType("f")); REQUIRE(fType != nullptr); - auto retType_ = first(fType->retType); + auto retType_ = first(fType->retTypes); REQUIRE(bool(retType_)); auto retType = get(follow(*retType_)); @@ -676,16 +771,23 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sealed_table_value_must_not_infer_an_indexer") +TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_NO_ERRORS(result); +} - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm != nullptr); +TEST_CASE_FIXTURE(Fixture, "array_factory_function") +{ + CheckResult result = check(R"( + function empty() return {} end + local array: {string} = empty() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "sealed_table_indexers_must_unify") @@ -756,37 +858,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_ CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); } -TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_string") -{ - ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); - - CheckResult result = check(R"( - local t: { a: string } - function f(x: string) return t[x] end - local a = f("a") - local b = f("b") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.anyType, *requireType("a")); - CHECK_EQ(*typeChecker.anyType, *requireType("b")); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_number") -{ - ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); - - CheckResult result = check(R"( - local t = { a = true } - function f(x: number) return t[x] end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); -} - TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer") { CheckResult result = check(R"( @@ -798,18 +869,19 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("a")); + CHECK("string" == toString(*typeChecker.stringType)); TableTypeVar* tableType = getMutable(requireType("t")); REQUIRE(tableType != nullptr); REQUIRE(tableType->indexer == std::nullopt); + REQUIRE(0 != tableType->props.count("a")); TypeId propertyA = tableType->props["a"].type; REQUIRE(propertyA != nullptr); CHECK_EQ(*typeChecker.stringType, *propertyA); } -TEST_CASE_FIXTURE(Fixture, "oop_indexer_works") +TEST_CASE_FIXTURE(BuiltinsFixture, "oop_indexer_works") { CheckResult result = check(R"( local clazz = {} @@ -832,7 +904,7 @@ TEST_CASE_FIXTURE(Fixture, "oop_indexer_works") CHECK_EQ(*typeChecker.stringType, *requireType("words")); } -TEST_CASE_FIXTURE(Fixture, "indexer_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_table") { CheckResult result = check(R"( local clazz = {a="hello"} @@ -845,7 +917,7 @@ TEST_CASE_FIXTURE(Fixture, "indexer_table") CHECK_EQ(*typeChecker.stringType, *requireType("b")); } -TEST_CASE_FIXTURE(Fixture, "indexer_fn") +TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_fn") { CheckResult result = check(R"( local instanace = setmetatable({}, {__index=function() return 10 end}) @@ -856,7 +928,7 @@ TEST_CASE_FIXTURE(Fixture, "indexer_fn") CHECK_EQ(*typeChecker.numberType, *requireType("b")); } -TEST_CASE_FIXTURE(Fixture, "meta_add") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add") { // Note: meta_add_inferred and this unit test are currently the same exact thing. // We'll want to change this one in particular when we add real syntax for metatables. @@ -873,7 +945,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add") CHECK_EQ(follow(requireType("a")), follow(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "meta_add_inferred") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_inferred") { CheckResult result = check(R"( local a = {} @@ -886,7 +958,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add_inferred") CHECK_EQ(*requireType("a"), *requireType("c")); } -TEST_CASE_FIXTURE(Fixture, "meta_add_both_ways") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_both_ways") { CheckResult result = check(R"( type VectorMt = { __add: (Vector, number) -> Vector } @@ -906,7 +978,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add_both_ways") // This test exposed a bug where we let go of the "seen" stack while unifying table types // As a result, type inference crashed with a stack overflow. -TEST_CASE_FIXTURE(Fixture, "unification_of_unions_in_a_self_referential_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "unification_of_unions_in_a_self_referential_type") { CheckResult result = check(R"( type A = {} @@ -935,7 +1007,7 @@ TEST_CASE_FIXTURE(Fixture, "unification_of_unions_in_a_self_referential_type") CHECK_EQ(bmtv->metatable, requireType("bmt")); } -TEST_CASE_FIXTURE(Fixture, "oop_polymorphic") +TEST_CASE_FIXTURE(BuiltinsFixture, "oop_polymorphic") { CheckResult result = check(R"( local animal = {} @@ -986,7 +1058,7 @@ TEST_CASE_FIXTURE(Fixture, "user_defined_table_types_are_named") CHECK_EQ("Vector3", toString(requireType("v"))); } -TEST_CASE_FIXTURE(Fixture, "result_is_always_any_if_lhs_is_any") +TEST_CASE_FIXTURE(BuiltinsFixture, "result_is_always_any_if_lhs_is_any") { CheckResult result = check(R"( type Vector3MT = { @@ -1059,7 +1131,7 @@ TEST_CASE_FIXTURE(Fixture, "nice_error_when_trying_to_fetch_property_of_boolean" CHECK_EQ("Type 'boolean' does not have key 'some_prop'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_builtin_sealed_table_must_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "defining_a_method_for_a_builtin_sealed_table_must_fail") { CheckResult result = check(R"( function string.m() end @@ -1068,7 +1140,7 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_builtin_sealed_table_must_fa LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_builtin_sealed_table_must_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "defining_a_self_method_for_a_builtin_sealed_table_must_fail") { CheckResult result = check(R"( function string:m() end @@ -1080,7 +1152,8 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_builtin_sealed_table_mu TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail") { CheckResult result = check(R"( - local t = {x = 1} + function mkt() return {x = 1} end + local t = mkt() function t.m() end )"); @@ -1090,13 +1163,34 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_sealed_table_must_fail") { CheckResult result = check(R"( - local t = {x = 1} + function mkt() return {x = 1} end + local t = mkt() function t:m() end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok") +{ + CheckResult result = check(R"( + local t = {x = 1} + function t.m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_unsealed_table_is_ok") +{ + CheckResult result = check(R"( + local t = {x = 1} + function t:m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + // This unit test could be flaky if the fix has regressed. TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_crashing") { @@ -1121,13 +1215,12 @@ TEST_CASE_FIXTURE(Fixture, "passing_compatible_unions_to_a_generic_table_without { CheckResult result = check(R"( type A = {x: number, y: number, [any]: any} | {y: number} - local a: A function f(t) t.y = 1 end - f(a) + f({y = 5} :: A) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1159,7 +1252,7 @@ TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_function_call") CHECK_EQ(toString(te), "Key 'fOo' not found in table 't'. Did you mean 'Foo'?"); } -TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_property_access") +TEST_CASE_FIXTURE(BuiltinsFixture, "found_like_key_in_table_property_access") { CheckResult result = check(R"( local t = {X = 1} @@ -1184,7 +1277,7 @@ TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_property_access") CHECK_EQ(toString(te), "Key 'x' not found in table 't'. Did you mean 'X'?"); } -TEST_CASE_FIXTURE(Fixture, "found_multiple_like_keys") +TEST_CASE_FIXTURE(BuiltinsFixture, "found_multiple_like_keys") { CheckResult result = check(R"( local t = {Foo = 1, foO = 2} @@ -1210,7 +1303,7 @@ TEST_CASE_FIXTURE(Fixture, "found_multiple_like_keys") CHECK_EQ(toString(te), "Key 'foo' not found in table 't'. Did you mean one of 'Foo', 'foO'?"); } -TEST_CASE_FIXTURE(Fixture, "dont_suggest_exact_match_keys") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_suggest_exact_match_keys") { CheckResult result = check(R"( local t = {} @@ -1237,7 +1330,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_exact_match_keys") CHECK_EQ(toString(te), "Key 'Foo' not found in table 't'. Did you mean 'foO'?"); } -TEST_CASE_FIXTURE(Fixture, "getmetatable_returns_pointer_to_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_pointer_to_metatable") { CheckResult result = check(R"( local t = {x = 1} @@ -1250,7 +1343,7 @@ TEST_CASE_FIXTURE(Fixture, "getmetatable_returns_pointer_to_metatable") CHECK_EQ(*requireType("mt"), *requireType("returnedMT")); } -TEST_CASE_FIXTURE(Fixture, "metatable_mismatch_should_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_mismatch_should_fail") { CheckResult result = check(R"( local t1 = {x = 1} @@ -1272,7 +1365,7 @@ TEST_CASE_FIXTURE(Fixture, "metatable_mismatch_should_fail") CHECK_EQ(*tm->givenType, *requireType("t2")); } -TEST_CASE_FIXTURE(Fixture, "property_lookup_through_tabletypevar_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "property_lookup_through_tabletypevar_metatable") { CheckResult result = check(R"( local t = {x = 1} @@ -1291,7 +1384,7 @@ TEST_CASE_FIXTURE(Fixture, "property_lookup_through_tabletypevar_metatable") CHECK_EQ(up->key, "z"); } -TEST_CASE_FIXTURE(Fixture, "missing_metatable_for_sealed_tables_do_not_get_inferred") +TEST_CASE_FIXTURE(BuiltinsFixture, "missing_metatable_for_sealed_tables_do_not_get_inferred") { CheckResult result = check(R"( local t = {x = 1} @@ -1364,7 +1457,7 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") CHECK_EQ("{| |}", toString(mp->subType)); } -TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") +TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_indexer") { CheckResult result = check(R"( type StringToStringMap = { [string]: string } @@ -1373,6 +1466,25 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") LUAU_REQUIRE_ERROR_COUNT(1, result); + ToStringOptions o{/* exhaustive= */ true}; + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); + // Should t now have an indexer? + // It would if the assignment to rt was correctly typed. + CHECK_EQ("{ [string]: string, foo: number }", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_sealed_tables_with_props_into_table_with_indexer") +{ + CheckResult result = check(R"( + type StringToStringMap = { [string]: string } + function mkrt() return { ["foo"] = 1 } end + local rt: StringToStringMap = mkrt() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); @@ -1402,8 +1514,21 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("string", toString(tm->wantedType, o)); - CHECK_EQ("number", toString(tm->givenType, o)); + CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); + CHECK_EQ("{ a: number }", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") +{ + CheckResult result = check(R"( + local function foo(a: {[string]: number, a: string}, i: string) + return a[i] + end + local hi: number = foo({ a = "hi" }, "a") -- shouldn't typecheck since at runtime hi is "hi" + )"); + + // This typechecks but shouldn't + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors") @@ -1446,22 +1571,33 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") { CheckResult result = check(R"( - local vec3 = {x = 1, y = 2, z = 3} - local vec1 = {x = 1} + function mkvec3() return {x = 1, y = 2, z = 3} end + function mkvec1() return {x = 1} end + + local vec3 = {mkvec3()} + local vec1 = {mkvec1()} vec1 = vec3 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - MissingProperties* mp = get(result.errors[0]); - REQUIRE(mp); - CHECK_EQ(mp->context, MissingProperties::Extra); - REQUIRE_EQ(2, mp->properties.size()); - CHECK_EQ(mp->properties[0], "y"); - CHECK_EQ(mp->properties[1], "z"); - CHECK_EQ("vec1", toString(mp->superType)); - CHECK_EQ("vec3", toString(mp->subType)); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("vec1", toString(tm->wantedType)); + CHECK_EQ("vec3", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") +{ + CheckResult result = check(R"( + local vec3 = {x = 1, y = 2, z = 3} + local vec1 = {x = 1} + + vec1 = vec3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short") @@ -1520,7 +1656,8 @@ TEST_CASE_FIXTURE(Fixture, "reasonable_error_when_adding_a_nonexistent_property_ { CheckResult result = check(R"( --!strict - local A = {"value"} + function mkA() return {"value"} end + local A = mkA() A.B = "Hello" )"); @@ -1568,7 +1705,8 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") --!strict local function f() - local t = { x = 1 } + local function mkt() return { x = 1 } end + local t = mkt() function t.a() end function t.b() end @@ -1583,7 +1721,7 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") CHECK_EQ("Cannot add property 'b' to table '{| x: number |}'", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "builtin_table_names") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") { CheckResult result = check(R"( os.h = 2 @@ -1592,11 +1730,19 @@ TEST_CASE_FIXTURE(Fixture, "builtin_table_names") LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("Cannot add property 'h' to table 'os'", toString(result.errors[0])); - CHECK_EQ("Cannot add property 'k' to table 'string'", toString(result.errors[1])); + if (FFlag::LuauNewLibraryTypeNames) + { + CHECK_EQ("Cannot add property 'h' to table 'typeof(os)'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'k' to table 'typeof(string)'", toString(result.errors[1])); + } + else + { + CHECK_EQ("Cannot add property 'h' to table 'os'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'k' to table 'string'", toString(result.errors[1])); + } } -TEST_CASE_FIXTURE(Fixture, "persistent_sealed_table_is_immutable") +TEST_CASE_FIXTURE(BuiltinsFixture, "persistent_sealed_table_is_immutable") { CheckResult result = check(R"( --!nonstrict @@ -1604,7 +1750,10 @@ TEST_CASE_FIXTURE(Fixture, "persistent_sealed_table_is_immutable") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Cannot add property 'bad' to table 'os'", toString(result.errors[0])); + if (FFlag::LuauNewLibraryTypeNames) + CHECK_EQ("Cannot add property 'bad' to table 'typeof(os)'", toString(result.errors[0])); + else + CHECK_EQ("Cannot add property 'bad' to table 'os'", toString(result.errors[0])); const TableTypeVar* osType = get(requireType("os")); REQUIRE(osType != nullptr); @@ -1699,7 +1848,7 @@ local foos: {Foo} = { LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "quantifying_a_bound_var_works") +TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") { CheckResult result = check(R"( local clazz = {} @@ -1718,11 +1867,12 @@ TEST_CASE_FIXTURE(Fixture, "quantifying_a_bound_var_works") TypeId ty = requireType("clazz"); TableTypeVar* ttv = getMutable(ty); REQUIRE(ttv); + REQUIRE(ttv->props.count("new")); Property& prop = ttv->props["new"]; REQUIRE(prop.type); const FunctionTypeVar* ftv = get(follow(prop.type)); REQUIRE(ftv); - const TypePack* res = get(follow(ftv->retType)); + const TypePack* res = get(follow(ftv->retTypes)); REQUIRE(res); REQUIRE(res->head.size() == 1); const MetatableTypeVar* mtv = get(follow(res->head[0])); @@ -1732,8 +1882,10 @@ TEST_CASE_FIXTURE(Fixture, "quantifying_a_bound_var_works") REQUIRE_EQ(ttv->state, TableState::Sealed); } -TEST_CASE_FIXTURE(Fixture, "less_exponential_blowup_please") +TEST_CASE_FIXTURE(BuiltinsFixture, "less_exponential_blowup_please") { + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + CheckResult result = check(R"( --!strict @@ -1761,7 +1913,7 @@ TEST_CASE_FIXTURE(Fixture, "less_exponential_blowup_please") newData:First() )"); - LUAU_REQUIRE_ERRORS(result); + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_call") @@ -1821,6 +1973,1475 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table local c : string = t.m("hi") )"); + // TODO: test behavior is wrong with LuauInstantiateInSubtyping until we can re-enable the covariant requirement for instantiation in subtyping + if (FFlag::LuauInstantiateInSubtyping) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + local buttons = {} + table.insert(buttons, { a = 1 }) + table.insert(buttons, { a = 2, b = true }) + table.insert(buttons, { a = 3 }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_strict") +{ + CheckResult result = check(R"( + --!strict + local buttons = {} + table.insert(buttons, { a = 1 }) + table.insert(buttons, { a = 2, b = true }) + table.insert(buttons, { a = 3 }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") +{ + CheckResult result = check(R"( +type A = { x: number, y: number } +type B = { x: number, y: string } + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") +{ + CheckResult result = check(R"( +type AS = { x: number, y: number } +type BS = { x: number, y: string } + +type A = { a: boolean, b: AS } +type B = { a: boolean, b: BS } + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "error_detailed_metatable_prop") +{ + CheckResult result = check(R"( +local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); +local b1 = setmetatable({ x = 2, y = "hello" }, { __call = function(s) end }); +local c1: typeof(a1) = b1 + +local a2 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); +local b2 = setmetatable({ x = 2, y = 4 }, { __call = function(s, t) end }); +local c2: typeof(a2) = b2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' +caused by: + Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' +caused by: + Property 'y' is not compatible. Type 'string' could not be converted into 'number' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' +caused by: + Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' +caused by: + Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); + + if (FFlag::LuauInstantiateInSubtyping) + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' +caused by: + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' +caused by: + Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + } + else + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' +caused by: + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' +caused by: + Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + } +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") +{ + CheckResult result = check(R"( + type A = { [number]: string } + type B = { [string]: string } + + local a: A = { 'a', 'b' } + local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property '[indexer key]' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property '[indexer key]' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") +{ + CheckResult result = check(R"( + type A = { [number]: number } + type B = { [number]: string } + + local a: A = { 1, 2, 3 } + local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + if (FFlag::LuauTypeMismatchInvarianceInError) + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") +{ + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { p : Super } +type HasSub = { p : Sub } +local a: HasSuper = { p = { x = 5, y = 7 }} +a.p = { x = 9 } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") +{ + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { p : Super } +type HasSub = { p : Sub } +local tmp = { p = { x = 5, y = 7 }} +local a: HasSuper = tmp +a.p = { x = 9 } +-- needs to be an error because +local y: number = tmp.p.y + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper' +caused by: + Property 'p' is not compatible. Table type '{ x: number, y: number }' not compatible with type 'Super' because the former has extra field 'y')"); +} + +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") +{ + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { [string] : Super } +type HasSub = { [string] : Sub } +local a: HasSuper = { p = { x = 5, y = 7 }} +a.p = { x = 9 } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_type_call") +{ + CheckResult result = check(R"( +local b +b = setmetatable({}, {__call = b}) +b() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })"); +} + +TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") +{ + CheckResult result = check(R"( + --!strict + local function setNumber(t: { p: number? }, x:number) t.p = x end + local function getString(t: { p: string? }):string return t.p or "" end + -- This shouldn't type-check! + local function oh(x:number): string + local t: {} = {} + setNumber(t, x) + return getString(t) + end + local s: string = oh(37) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "top_table_type") +{ + CheckResult result = check(R"( + --!strict + type Table = { [any] : any } + type HasTable = { p: Table? } + type HasHasTable = { p: HasTable? } + local t : Table = { p = 5 } + local u : HasTable = { p = { p = 5 } } + local v : HasHasTable = { p = { p = { p = 5 } } } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "length_operator_union") +{ + CheckResult result = check(R"( +local x: {number} | {string} +local y = #x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "length_operator_intersection") +{ + CheckResult result = check(R"( +local x: {number} & {z:string} -- mixed tables are evil +local y = #x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "length_operator_non_table_union") +{ + CheckResult result = check(R"( +local x: {number} | any | string +local y = #x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "length_operator_union_errors") +{ + CheckResult result = check(R"( +local x: {number} | number | string +local y = #x + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") +{ + // t :: t1 where t1 = {metatable {__index: t1, __tostring: (t1) -> string}} + CheckResult result = check(R"( + local mt = {} + local t = setmetatable({}, mt) + mt.__index = t + + function mt:__tostring() + return t.p + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 't' does not have key 'p'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "give_up_after_one_metatable_index_look_up") +{ + CheckResult result = check(R"( + local data = { x = 5 } + local t1 = setmetatable({}, { __index = data }) + local t2 = setmetatable({}, t1) -- note: must be t1, not a new table + + local x1 = t1.x -- ok + local x2 = t2.x -- nope + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 't2' does not have key 'x'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "confusing_indexing") +{ + CheckResult result = check(R"( + type T = {} & {p: number | string} + local function f(t: T) + return t.p + end + + local foo = f({p = "string"}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number | string", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") +{ + CheckResult result = check(R"( + local a: {x: number, y: number, [any]: any} | {y: number} + + function f(t) + t.y = 1 + return t + end + + local b = f(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") +{ + CheckResult result = check(R"( + local a: {y: number} | {x: number, y: number, [any]: any} + + function f(t) + t.y = 1 + return t + end + + local b = f(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf1") +{ + CheckResult result = check(R"( +-- This example produced a UAF at one point, caused by pointers to table types becoming +-- invalidated by child unifiers. (Calling log.concat can cause pointers to become invalid.) +type _Entry = { + a: number, + + middle: (self: _Entry) -> (), + + z: number +} + +export type AnyEntry = _Entry + +local Entry = {} +Entry.__index = Entry + +function Entry:dispose() + self:middle() + forgetChildren(self) -- unify free with sealed AnyEntry +end + +function forgetChildren(parent: AnyEntry) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf2") +{ + CheckResult result = check(R"( +-- Another example that UAFd, this time found by fuzzing. +local _ +do +_._ *= (_[{n0=_[{[{[_]=_,}]=_,}],}])[_] +_ = (_.n0) +end +_._ *= (_[false])[_] +_ = (_.cos) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_call_tables") +{ + CheckResult result = check("local foo = {} foo()"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0]) != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "table_length") +{ + CheckResult result = check(R"( + local t = {} + local s = #t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK(nullptr != get(requireType("t"))); + CHECK_EQ(*typeChecker.numberType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") +{ + CheckResult result = check("local a = {} a[0] = 7 a[0] = nil"); + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") +{ + CheckResult result = check("local a = {} a[0] = 7 a[0] = 't'"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.stringType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") +{ + CheckResult result = check(R"( + local a = {a=1, b=2} + a['a'] = nil + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{2, 17}, Position{2, 20}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.nilType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "free_rhs_table_can_also_be_bound") +{ + check(R"( + local o + local v = o:i() + + function g(u) + v = u + end + + o:f(g) + o:h() + o:h() + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_unifies_into_map") +{ + CheckResult result = check(R"( + local Instance: any + local UDim2: any + + function Create(instanceType) + return function(data) + local obj = Instance.new(instanceType) + for k, v in pairs(data) do + if type(k) == 'number' then + --v.Parent = obj + else + obj[k] = v + end + end + return obj + end + end + + local topbarShadow = Create'ImageLabel'{ + Name = "TopBarShadow"; + Size = UDim2.new(1, 0, 0, 3); + Position = UDim2.new(0, 0, 1, 0); + Image = "rbxasset://textures/ui/TopBar/dropshadow.png"; + BackgroundTransparency = 1; + Active = false; + Visible = false; + }; + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tables_get_names_from_their_locals") +{ + CheckResult result = check(R"( + local T = {} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("T", toString(requireType("T"))); +} + +TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") +{ + CheckResult result = check(R"( + function foo(arr) + local work = {} + for i = 1, #arr do + work[i] = arr[i] + end + + return arr + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + const FunctionTypeVar* fooType = get(requireType("foo")); + REQUIRE(fooType); + + std::optional fooArg1 = first(fooType->argTypes); + REQUIRE(fooArg1); + + const TableTypeVar* fooArg1Table = get(*fooArg1); + REQUIRE(fooArg1Table); + + CHECK_EQ(fooArg1Table->state, TableState::Generic); +} + +/* + * This test case exposed an oversight in the treatment of free tables. + * Free tables, like free TypeVars, need to record the scope depth where they were created so that + * we do not erroneously let-generalize them when they are used in a nested lambda. + * + * For more information about let-generalization, see + * + * The important idea here is that the return type of Counter.new is a table with some metatable. + * That metatable *must* be the same TypeVar as the type of Counter. If it is a copy (produced by + * the generalization process), then it loses the knowledge that its metatable will have an :incr() + * method. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_scope") +{ + CheckResult result = check(R"( + local Counter = {} + Counter.__index = Counter + + function Counter.new() + local self = setmetatable({count=0}, Counter) + return self + end + + function Counter:incr() + self.count = 1 + return self.count + end + + local self = Counter.new() + print(self:incr()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* counterType = getMutable(requireType("Counter")); + REQUIRE(counterType); + + REQUIRE(counterType->props.count("new")); + const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); + REQUIRE(newType); + + std::optional newRetType = *first(newType->retTypes); + REQUIRE(newRetType); + + const MetatableTypeVar* newRet = get(follow(*newRetType)); + REQUIRE(newRet); + + const TableTypeVar* newRetMeta = get(newRet->metatable); + REQUIRE(newRetMeta); + + CHECK(newRetMeta->props.count("incr")); + CHECK_EQ(follow(newRet->metatable), follow(requireType("Counter"))); +} + +// TODO: CLI-39624 +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_tables_at_scope_level") +{ + CheckResult result = check(R"( + --!strict + local Option = {} + Option.__index = Option + function Option.Is(obj) + return (type(obj) == "table" and getmetatable(obj) == Option) + end + return Option + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") +{ + CheckResult result = check(R"( + --!strict + function f(U) + U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() + end + )"); + + ModulePtr module = getMainModule(); + CHECK_GE(100, module->internalTypes.typeVars.size()); +} + +TEST_CASE_FIXTURE(Fixture, "MixedPropertiesAndIndexers") +{ + CheckResult result = check(R"( +local x = {} +x.a = "a" +x[0] = true +x.b = 37 +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "setmetatable_cant_be_used_to_mutate_global_types") +{ + { + Fixture fix; + + // inherit env from parent fixture checker + fix.typeChecker.globalScope = typeChecker.globalScope; + + fix.check(R"( +--!nonstrict +type MT = typeof(setmetatable) +function wtf(arg: {MT}): typeof(table) + arg = wtf(arg) +end +)"); + } + + // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down + // note: it's important for typeck to be destroyed at this point! + { + for (auto& p : typeChecker.globalScope->bindings) + { + toString(p.second.typeId); // toString walks the entire type, making sure ASAN catches access to destroyed type arenas + } + } +} + +TEST_CASE_FIXTURE(Fixture, "evil_table_unification") +{ + // this code re-infers the type of _ while processing fields of _, which can cause use-after-free + check(R"( +--!nonstrict +_ = ... +_:table(_,string)[_:gsub(_,...,n0)],_,_:gsub(_,string)[""],_:split(_,...,table)._,n0 = nil +do end +)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") +{ + CheckResult result = check("local x = setmetatable({})"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Argument count mismatch. Function 'setmetatable' expects 2 arguments, but only 1 is specified", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning") +{ + CheckResult result = check(R"( +--!nonstrict +local l0:any,l61:t0 = _,math +while _ do +_() +end +function _():t0 +end +type t0 = any +)"); + + std::optional ty = requireType("math"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning_2") +{ + CheckResult result = check(R"( +type X = T +type K = X +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("math"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_3") +{ + CheckResult result = check(R"( +type X = T +local a = {} +a.x = 4 +local b: X +a.y = 5 +local c: X +c = b +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("a"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "table_indexing_error_location") +{ + CheckResult result = check(R"( +local foo = {42} +local bar: number? +local baz = foo[bar] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_basic") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local a = setmetatable({ + a = 1, + }, { + __call = function(self, b: number) + return self.a * b + end, + }) + + local foo = a(12) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(requireType("foo") == singletonTypes->numberType); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_must_be_callable") +{ + CheckResult result = check(R"( + local a = setmetatable({}, { + __call = 123, + }) + + local foo = a() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{ + Location{{5, 20}, {5, 21}}, + CannotCallNonFunction{singletonTypes->numberType}, + }); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_generic") +{ + CheckResult result = check(R"( + local a = setmetatable({}, { + __call = function(self, b: T) + return b + end, + }) + + local foo = a(12) + local bar = a("bar") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(requireType("foo") == singletonTypes->numberType); + CHECK(requireType("bar") == singletonTypes->stringType); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") +{ + CheckResult result = check(R"( +local a = setmetatable({ x = 2 }, { + __call = function(self) + return (self.x :: number) * 2 -- should work without annotation in the future + end +}) +local b = a() +local c = a(2) -- too many arguments + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Argument count mismatch. Function 'a' expects 1 argument, but 2 are specified", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "access_index_metamethod_that_returns_variadic") +{ + CheckResult result = check(R"( + type Foo = {x: string} + local t = {} + setmetatable(t, { + __index = function(x: string): ...Foo + return {x = x} + end + }) + + local foo = t.bar + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = true; + CHECK_EQ("{| x: string |}", toString(requireType("foo"), o)); +} + +TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") +{ + fileResolver.source["Module/Backend/Types"] = R"( + export type Fiber = { + return_: Fiber? + } + return {} + )"; + + fileResolver.source["Module/Backend"] = R"( + local Types = require(script.Types) + type Fiber = Types.Fiber + type ReactRenderer = { findFiberByHostInstance: () -> Fiber? } + + local function attach(renderer): () + local function getPrimaryFiber(fiber) + local alternate = fiber.alternate + return fiber + end + + local function getFiberIDForNative() + local fiber = renderer.findFiberByHostInstance() + fiber = fiber.return_ + return getPrimaryFiber(fiber) + end + end + + function culprit(renderer: ReactRenderer): () + attach(renderer) + end + + return culprit + )"; + + CheckResult result = frontend.check("Module/Backend"); +} + +TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") +{ + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t.x and t or 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); + CHECK_EQ("number | {| x: number? |}", toString(requireType("u"))); +} + +TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") +{ + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t and t.x == 5 or t.x == 31337 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); + CHECK_EQ("boolean", toString(requireType("u"))); +} + +/* + * We had an issue where part of the type of pairs() was an unsealed table. + * This test depends on FFlagDebugLuauFreezeArena to trigger it. + */ +TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") +{ + check(R"( + function _(l0:{n0:any}) + _ = pairs + end + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_function_check_use_after_free") +{ + CheckResult result = check(R"( +local t = {} + +function t.x(value) + for k,v in pairs(t) do end +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * When we add new properties to an unsealed table, we should do a level check and promote the property type to be at + * the level of the table. + */ +TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the_same_TypeLevel_of_that_table") +{ + CheckResult result = check(R"( + --!strict + local T = {} + + local function f(prop) + T[1] = { + prop = prop, + } + end + + local function g() + local l = T[1].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// The real bug here was that we weren't always uncondionally typechecking a trailing return statement last. +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") +{ + CheckResult result = check(R"( + local function a(state) + print(state.blah) + end + + local function b(state) -- The bug was that we inferred state: {blah: any, gwar: any} + print(state.gwar) + end + + return function() + return function(state) + a(state) + b(state) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({+ blah: a +}) -> ()", toString(requireType("a"))); + CHECK_EQ("({+ gwar: a +}) -> ()", toString(requireType("b"))); + CHECK_EQ("() -> ({+ blah: a, gwar: b +}) -> ()", toString(getMainModule()->getModuleScope()->returnType)); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") +{ + ScopedFastFlag sff[] = { + // {"LuauLowerBoundsCalculation", true}, + {"DebugLuauSharedSelf", true}, + }; + + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + )"); + + CHECK_EQ("(t1) -> {| Byte: (a) -> (b...), PeekByte: (a) -> (b...) |} where t1 = {+ byte: (t1, number) -> (b...) +}", + toString(requireType("Base64FileReader"))); +} + +TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") +{ + CheckResult result = check(R"( + local t: { [string]: number } = { 5, 6, 7 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2])); +} + +TEST_CASE_FIXTURE(Fixture, "shared_selfs") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local t = {} + t.x = 5 + + function t:m1() return self.x end + function t:m2() return self.y end + + return t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ("{| m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b, x: number |}", toString(requireType("t"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "shared_selfs_from_free_param") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local function f(t) + function t:m1() return self.x end + function t:m2() return self.y end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({+ m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b +}) -> ()", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "shared_selfs_through_metatables") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local t = {} + t.__index = t + setmetatable({}, t) + + function t:m1() return self.x end + function t:m2() return self.y end + + return t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ( + toString(requireType("t"), opts), "t1 where t1 = {| __index: t1, m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b |}"); +} + +TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") +{ + CheckResult result = check(R"( + type X = { { x: boolean?, y: boolean? } } + + local l1: {[string]: X} = { key = { { x = true }, { y = true } } } + local l2: {[any]: X} = { key = { { x = true }, { y = true } } } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra_2") +{ + CheckResult result = check(R"( + type X = {[any]: string | boolean} + + local x: X = { key = "str" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "expected_indexer_from_table_union") +{ + LUAU_REQUIRE_NO_ERRORS(check(R"(local a: {[string]: {number | string}} = {a = {2, 's'}})")); + LUAU_REQUIRE_NO_ERRORS(check(R"(local a: {[string]: {number | string}}? = {a = {2, 's'}})")); + LUAU_REQUIRE_NO_ERRORS(check(R"(local a: {[string]: {[string]: {string?}}?} = {["a"] = {["b"] = {"a", "b"}}})")); +} + +TEST_CASE_FIXTURE(Fixture, "prop_access_on_key_whose_types_mismatches") +{ + CheckResult result = check(R"( + local t: {number} = {} + local x = t.x + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Key 'x' not found in table '{number}'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "prop_access_on_unions_of_indexers_where_key_whose_types_mismatches") +{ + CheckResult result = check(R"( + local t: { [number]: number } | { [boolean]: number } = {} + local u = t.x + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '{number} | {| [boolean]: number |}' does not have key 'x'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "quantify_metatables_of_metatables_of_table") +{ + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + + CheckResult result = check(R"( + local T = {} + + function T:m() + return self.x, self.y + end + + function T:n() + end + + local U = setmetatable({}, {__index = T}) + + local V = setmetatable({}, {__index = U}) + + return V + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ(toString(requireType("V"), opts), "{ @metatable { __index: { @metatable { __index: {| m: ({+ x: a, y: b +}) -> (a, b), n: ({+ x: a, y: b +}) -> () |} }, { } } }, { } }"); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_even_that_table_was_never_exported_at_all") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local T = {} + + function T:m() + return self.x + end + + function T:n() + return self.y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ("{| m: ({+ x: a, y: b +}) -> a, n: ({+ x: a, y: b +}) -> b |}", toString(requireType("T"), opts)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "leaking_bad_metatable_errors") +{ + CheckResult result = check(R"( +local a = setmetatable({}, 1) +local b = a.x + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Metatable was not a table", toString(result.errors[0])); + CHECK_EQ("Type 'a' does not have key 'x'", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "scalar_is_a_subtype_of_a_compatible_polymorphic_shape_type") +{ + ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + + CheckResult result = check(R"( + local function f(s) + return s:lower() + end + + f("foo" :: string) + f("bar" :: "bar") + f("baz" :: "bar" | "baz") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type") +{ + ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + if (!FFlag::LuauNewLibraryTypeNames) + return; + + CheckResult result = check(R"( + local function f(s) + return s:absolutely_no_scalar_has_this_method() + end + + f("foo" :: string) + f("bar" :: "bar") + f("baz" :: "bar" | "baz") + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + + CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[0])); + CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[1])); + CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[2])); +} + +TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible") +{ + ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; // Changes argument from table type to primitive + + CheckResult result = check(R"( + local function f(s): string + local foo = s:lower() + return s + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(string) -> string", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible") +{ + ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + if (!FFlag::LuauNewLibraryTypeNames) + return; + + CheckResult result = check(R"( + local function f(s): string + local foo = s:absolutely_no_scalar_has_this_method() + return s + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[0])); + CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "a_free_shape_can_turn_into_a_scalar_directly") +{ + ScopedFastFlag luauScalarShapeSubtyping{"LuauScalarShapeSubtyping", true}; + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; + + CheckResult result = check(R"( + local function stringByteList(str) + local out = {} + for i = 1, #str do + table.insert(out, string.byte(str, i)) + end + return table.concat(out, ",") + end + + local x = stringByteList("xoo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_tables_in_call_is_unsound") +{ + ScopedFastFlag sff[]{ + {"LuauInstantiateInSubtyping", true}, + }; + + CheckResult result = check(R"( + --!strict + local t = {} + function t.m(x) return x end + local a : string = t.m("hi") + local b : number = t.m(5) + function f(x : { m : (number)->number }) + x.m = function(x) return 1+x end + end + f(t) -- This shouldn't typecheck + local c : string = t.m("hi") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // TODO: test behavior is wrong until we can re-enable the covariant requirement for instantiation in subtyping + // LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' + // caused by: + // Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type + // parameters)"); + // // this error message is not great since the underlying issue is that the context is invariant, + // and `(number) -> number` cannot be a subtype of `(a) -> a`. +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_table_instantiation_potential_regression") +{ + CheckResult result = check(R"( +--!strict + +function f(x) + x.p = 5 + return x +end +local g : ({ p : number, q : string }) -> ({ p : number, r : boolean }) = f + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + MissingProperties* error = get(result.errors[0]); + REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); + + CHECK_EQ("r", error->properties[0]); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_has_a_side_effect") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local mt = { + __add = function(x, y) + return 123 + end, + } + + local foo = {} + setmetatable(foo, mt) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("foo")) == "{ @metatable { __add: (a, b) -> number }, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated") +{ + CheckResult result = check(R"( + local t = { + x = 5 :: NonexistingTypeWhichEndsUpReturningAnErrorType, + y = 5 + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ("{ x: *error-type*, y: number }", toString(requireType("t"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_table_indexer_unification_can_bound_owner_to_string") +{ + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; + + CheckResult result = check(R"( +sin,_ = nil +_ = _[_.sin][_._][_][_]._ +_[_] = _ + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_extra_prop_unification_can_bound_owner_to_string") +{ + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; + + CheckResult result = check(R"( +l0,_ = nil +_ = _,_[_.n5]._[_][_][_]._ +_._.foreach[_],_ = _[_],_._ + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_typelevel_promote_on_changed_table_type") +{ + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; + + CheckResult result = check(R"( +_._,_ = nil +_ = _.foreach[_]._,_[_.n5]._[_.foreach][_][_]._ +_ = _._ + )"); + LUAU_REQUIRE_ERRORS(result); } diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 37333b19..e42cea63 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -2,19 +2,21 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" #include -LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) -LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauInstantiateInSubtyping); using namespace Luau; @@ -43,10 +45,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_error") CheckResult result = check("local a = 7 local b = 'hi' a = b"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{ - requireType("a"), - requireType("b"), - }})); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); } TEST_CASE_FIXTURE(Fixture, "tc_error_2") @@ -60,15 +59,6 @@ TEST_CASE_FIXTURE(Fixture, "tc_error_2") }})); } -TEST_CASE_FIXTURE(Fixture, "tc_function") -{ - CheckResult result = check("function five() return 5 end"); - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* fiveType = get(requireType("five")); - REQUIRE(fiveType != nullptr); -} - TEST_CASE_FIXTURE(Fixture, "infer_locals_with_nil_value") { CheckResult result = check("local f = nil; f = 'hello world'"); @@ -94,6 +84,10 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") { + ScopedFastFlag sff[]{ + {"DebugLuauDeferredConstraintResolution", false}, + }; + CheckResult result = check(R"( --!nocheck function f(x) @@ -108,457 +102,12 @@ TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "check_function_bodies") -{ - CheckResult result = check("function myFunction() local a = 0 a = true end"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 44}, Position{0, 48}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.booleanType, - }})); -} - -TEST_CASE_FIXTURE(Fixture, "infer_return_type") -{ - CheckResult result = check("function take_five() return 5 end"); - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* takeFiveType = get(requireType("take_five")); - REQUIRE(takeFiveType != nullptr); - - std::vector retVec = flatten(takeFiveType->retType).first; - REQUIRE(!retVec.empty()); - - REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); -} - -TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") -{ - CheckResult result = check("function take_five() return 5 end local five = take_five()"); - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *follow(requireType("five"))); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_call_primitives") -{ - CheckResult result = check("local foo = 5 foo()"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0]) != nullptr); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_call_tables") -{ - CheckResult result = check("local foo = {} foo()"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK(get(result.errors[0]) != nullptr); -} - -TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") -{ - CheckResult result = check(R"( - function take_five() - return 5 - end - - take_five().prop = 888 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{typeChecker.numberType}})); -} - TEST_CASE_FIXTURE(Fixture, "expr_statement") { CheckResult result = check("local foo = 5 foo()"); LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "generic_function") -{ - CheckResult result = check("function id(x) return x end local a = id(55) local b = id(nil)"); - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("a")); - CHECK_EQ(*typeChecker.nilType, *requireType("b")); -} - -TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and_size") -{ - CheckResult result = check(R"( - function f(...) end - - f(1) - f("foo", 2) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") -{ - CheckResult result = check(R"( - local T = {} - function T.f(...) - local result = {} - - for i = 1, select("#", ...) do - local dictionary = select(i, ...) - for key, value in pairs(dictionary) do - result[key] = value - end - end - - return result - end - - return T - )"); - - auto r = first(getMainModule()->getModuleScope()->returnType); - REQUIRE(r); - - TableTypeVar* ttv = getMutable(*r); - REQUIRE(ttv); - - TypeId k = ttv->props["f"].type; - REQUIRE(k); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "for_loop") -{ - CheckResult result = check(R"( - local q - for i=0, 50, 2 do - q = i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("q")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop") -{ - CheckResult result = check(R"( - local n - local s - for i, v in pairs({ "foo" }) do - n = i - s = v - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_next") -{ - CheckResult result = check(R"( - local n - local s - for i, v in next, { "foo", "bar" } do - n = i - s = v - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") -{ - CheckResult result = check(R"( - local it: any - local a, b - for i, v in it do - a, b = i, v - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") -{ - CheckResult result = check(R"( - local foo = "bar" - for i, v in foo do - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") -{ - CheckResult result = check(R"( - local function keys(dictionary) - local new = {} - local index = 1 - - for key in pairs(dictionary) do - new[index] = key - index = index + 1 - end - - return new - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_with_a_custom_iterator_should_type_check") -{ - CheckResult result = check(R"( - local function range(l, h): () -> number - return function() - return l - end - end - - for n: string in range(1, 10) do - print(n) - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") -{ - CheckResult result = check(R"( - function f(x) - gobble.prop = x.otherprop - end - - local p - for _, part in i_am_not_defined do - p = part - f(part) - part.thirdprop = false - end - )"); - - CHECK_EQ(2, result.errors.size()); - - TypeId p = requireType("p"); - CHECK_EQ(*p, *typeChecker.errorType); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") -{ - CheckResult result = check(R"( - local bad_iter = 5 - - for a in bad_iter() do - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") -{ - CheckResult result = check(R"( - local function hasDivisors(value: number, table) - return false - end - - function prime_iter(state, index) - while hasDivisors(index, state) do - index += 1 - end - - state[index] = true - return index - end - - function primes1() - return prime_iter, {} - end - - function primes2() - return prime_iter, {}, "" - end - - function primes3() - return prime_iter, {}, 2 - end - - for p in primes1() do print(p) end -- mismatch in argument count - - for p in primes2() do print(p) end -- mismatch in argument types, prime_iter takes {}, number, we are given {}, string - - for p in primes3() do print(p) end -- no errror - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Arg); - CHECK_EQ(2, acm->expected); - CHECK_EQ(1, acm->actual); - - TypeMismatch* tm = get(result.errors[1]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") -{ - CheckResult result = check(R"( - function prime_iter(state, index) - return 1 - end - - for p in prime_iter do print(p) end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Arg); - CHECK_EQ(2, acm->expected); - CHECK_EQ(0, acm->actual); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") -{ - CheckResult result = check(R"( - function bar(): any - return true - end - - local a - for b in bar do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(typeChecker.anyType, requireType("a")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") -{ - CheckResult result = check(R"( - function bar(): any - return true - end - - local a - for b in bar() do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(typeChecker.anyType, requireType("a")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") -{ - CheckResult result = check(R"( - local bar: any - - local a - for b in bar do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(typeChecker.anyType, requireType("a")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") -{ - CheckResult result = check(R"( - local bar: any - - local a - for b in bar() do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(typeChecker.anyType, requireType("a")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") -{ - CheckResult result = check(R"( - local a - for b in bar do - a = b - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(typeChecker.errorType, requireType("a")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") -{ - CheckResult result = check(R"( - function bar(c) return c end - - local a - for b in bar() do - a = b - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(typeChecker.errorType, requireType("a")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") -{ - CheckResult result = check(R"( - function primes() - return function (state: number) end, 2 - end - - for p, q in primes do - q = "" - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); -} - TEST_CASE_FIXTURE(Fixture, "if_statement") { CheckResult result = check(R"( @@ -578,495 +127,6 @@ TEST_CASE_FIXTURE(Fixture, "if_statement") CHECK_EQ(*typeChecker.numberType, *requireType("b")); } -TEST_CASE_FIXTURE(Fixture, "while_loop") -{ - CheckResult result = check(R"( - local i - while true do - i = 8 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("i")); -} - -TEST_CASE_FIXTURE(Fixture, "repeat_loop") -{ - CheckResult result = check(R"( - local i - repeat - i = 'hi' - until true - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.stringType, *requireType("i")); -} - -TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") -{ - CheckResult result = check(R"( - repeat - local x = true - until x - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") -{ - CheckResult result = check(R"( - repeat - local x = true - until x - - print(x) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "table_length") -{ - CheckResult result = check(R"( - local t = {} - local s = #t - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK(nullptr != get(requireType("t"))); - CHECK_EQ(*typeChecker.numberType, *requireType("s")); -} - -TEST_CASE_FIXTURE(Fixture, "string_length") -{ - CheckResult result = check(R"( - local s = "Hello, World!" - local t = #s - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireType("t"))); -} - -TEST_CASE_FIXTURE(Fixture, "string_index") -{ - CheckResult result = check(R"( - local s = "Hello, World!" - local t = s[4] - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - NotATable* nat = get(result.errors[0]); - REQUIRE(nat); - CHECK_EQ("string", toString(nat->ty)); - - CHECK(get(requireType("t"))); -} - -TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") -{ - CheckResult result = check(R"( - local l = #this_is_not_defined - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_error_type_does_not_produce_an_error") -{ - CheckResult result = check(R"( - local originalReward = unknown.Parent.Reward:GetChildren()[1] - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") -{ - CheckResult result = check("local a = {} a[0] = 7 a[0] = nil"); - LUAU_REQUIRE_ERROR_COUNT(0, result); -} - -TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") -{ - CheckResult result = check("local a = {} a[0] = 7 a[0] = 't'"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.stringType, - }})); -} - -TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") -{ - CheckResult result = check("local a = {a=1, b=2} a['a'] = nil"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.nilType, - }})); -} - -TEST_CASE_FIXTURE(Fixture, "dot_on_error_type_does_not_produce_an_error") -{ - CheckResult result = check(R"( - local foo = (true).x - foo.x = foo.y - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon") -{ - CheckResult result = check(R"( - local someTable = {} - - someTable.Function1 = function(Arg1) - end - - someTable.Function1() -- Argument count mismatch - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") -{ - CheckResult result = check(R"( - local someTable = {} - - someTable.Function2 = function(Arg1, Arg2) - end - - someTable.Function2() -- Argument count mismatch - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works") -{ - CheckResult result = check(R"( - type T = {method: ((T, number) -> number) & ((number) -> number)} - local T: T - - T.method(4) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_count") -{ - CheckResult result = check(R"( - local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) - multiply("") - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); - - ExtraInformation* ei = get(result.errors[1]); - REQUIRE(ei); - CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); -} - -TEST_CASE_FIXTURE(Fixture, "list_all_overloads_if_no_overload_takes_given_argument_count") -{ - CheckResult result = check(R"( - local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) - multiply() - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - CHECK_EQ("No overload for function accepts 0 arguments.", ge->message); - - ExtraInformation* ei = get(result.errors[1]); - REQUIRE(ei); - CHECK_EQ("Available overloads: (number) -> number; (number) -> string; and (number, number) -> number", ei->message); -} - -TEST_CASE_FIXTURE(Fixture, "dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists") -{ - CheckResult result = check(R"( - local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) - multiply(1, "") - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") -{ - CheckResult result = check(R"( - type T = {method: ((T, number) -> number) & ((number) -> string)} - local T: T - - local a = T.method(T, 4) - local b = T.method(5) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireType("a"))); - CHECK_EQ("string", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "too_many_arguments") -{ - CheckResult result = check(R"( - --!nonstrict - - function g(a: number) end - - g() - - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto err = result.errors[0]; - auto acm = get(err); - REQUIRE(acm); - - CHECK_EQ(1, acm->expected); - CHECK_EQ(0, acm->actual); -} - -TEST_CASE_FIXTURE(Fixture, "any_type_propagates") -{ - CheckResult result = check(R"( - local foo: any - local bar = foo:method("argument") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("bar"))); -} - -TEST_CASE_FIXTURE(Fixture, "can_subscript_any") -{ - CheckResult result = check(R"( - local foo: any - local bar = foo[5] - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("bar"))); -} - -// Not strictly correct: metatables permit overriding this -TEST_CASE_FIXTURE(Fixture, "can_get_length_of_any") -{ - CheckResult result = check(R"( - local foo: any = {} - local bar = #foo - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("bar"))); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_function") -{ - CheckResult result = check(R"( - function count(n: number) - if n == 0 then - return 0 - else - return count(n - 1) - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "lambda_form_of_local_function_cannot_be_recursive") -{ - CheckResult result = check(R"( - local f = function() return f() end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_local_function") -{ - CheckResult result = check(R"( - local function count(n: number) - if n == 0 then - return 0 - else - return count(n - 1) - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -// FIXME: This and the above case get handled very differently. It's pretty dumb. -// We really should unify the two code paths, probably by deleting AstStatFunction. -TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") -{ - CheckResult result = check(R"( - local count - function count(n: number) - if n == 0 then - return 0 - else - return count(n - 1) - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - function f() - return f - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = () -> t1", toString(requireType("f"))); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - function f(g) - return f(f) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - type F = () -> F? - local function f() - return f - end - - local g: F = f - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); -} - -// TODO: File a Jira about this -/* -TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack") -{ - CheckResult result = check(R"( - function a(x) return 1 end - a(...) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - REQUIRE(bool(getMainModule()->getModuleScope()->varargPack)); - - TypePackId varargPack = *getMainModule()->getModuleScope()->varargPack; - - auto iter = begin(varargPack); - auto endIter = end(varargPack); - - CHECK(iter != endIter); - ++iter; - CHECK(iter == endIter); - - CHECK(!iter.tail()); -} -*/ - -TEST_CASE_FIXTURE(Fixture, "method_depends_on_table") -{ - CheckResult result = check(R"( - -- This catches a bug where x:m didn't count as a use of x - -- so toposort would happily reorder a definition of - -- function x:m before the definition of x. - function g() f() end - local x = {} - function x:m() end - function f() x:m() end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") -{ - CheckResult result = check(R"( - local Get_des - function Get_des(func) - Get_des(func) - end - - local function f(d) - d:IsA("BasePart") - d.Parent:FindFirstChild("Humanoid") - d:IsA("Decal") - end - Get_des(f) - - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "another_other_higher_order_function") -{ - CheckResult result = check(R"( - local d - d:foo() - d:foo() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") { CheckResult result = check(R"( @@ -1083,121 +143,6 @@ TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "generic_table_method") -{ - CheckResult result = check(R"( - local T = {} - - function T:bar(i) - return i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId tType = requireType("T"); - TableTypeVar* tTable = getMutable(tType); - REQUIRE(tTable != nullptr); - - TypeId barType = tTable->props["bar"].type; - REQUIRE(barType != nullptr); - - const FunctionTypeVar* ftv = get(follow(barType)); - REQUIRE_MESSAGE(ftv != nullptr, "Should be a function: " << *barType); - - std::vector args = flatten(ftv->argTypes).first; - TypeId argType = args.at(1); - - CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); -} - -TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") -{ - CheckResult result = check(R"( - local T = {} - - function T:foo() - return T:bar(5) - end - - function T:bar(i) - return i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - const TableTypeVar* t = get(requireType("T")); - REQUIRE(t != nullptr); - - std::optional fooProp = get(t->props, "foo"); - REQUIRE(bool(fooProp)); - - const FunctionTypeVar* foo = get(follow(fooProp->type)); - REQUIRE(bool(foo)); - - std::optional ret_ = first(foo->retType); - REQUIRE(bool(ret_)); - TypeId ret = follow(*ret_); - - REQUIRE_EQ(getPrimitiveType(ret), PrimitiveTypeVar::Number); -} - -TEST_CASE_FIXTURE(Fixture, "methods_are_topologically_sorted") -{ - CheckResult result = check(R"( - local T = {} - - function T:foo() - return T:bar(999), T:bar("hi") - end - - function T:bar(i) - return i - end - - local a, b = T:foo() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("a"))); - CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "local_function") -{ - CheckResult result = check(R"( - function f() - return 8 - end - - function g() - local function f() - return 'hello' - end - return f - end - - local h = g() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId h = follow(requireType("h")); - - const FunctionTypeVar* ftv = get(h); - REQUIRE(ftv != nullptr); - - std::optional rt = first(ftv->retType); - REQUIRE(bool(rt)); - - TypeId retType = follow(*rt); - CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(retType)); -} - TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") { CheckResult result = check(R"( @@ -1209,285 +154,11 @@ TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") o = p )"); -} - -/* - * We had a bug in instantiation where the argument types of 'f' and 'g' would be inferred as - * f {+ method: function(): (t2, T3...) +} - * g {+ method: function({+ method: function(): (t2, T3...) +}): (t5, T6...) +} - * - * The type of 'g' is totally wrong as t2 and t5 should be unified, as should T3 with T6. - * - * The correct unification of the argument to 'g' is - * - * {+ method: function(): (t5, T6...) +} - */ -TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") -{ - auto result = check(R"( - function f(o) - o:method() - end - - function g(o) - f(o) - end - )"); - - TypeId g = requireType("g"); - const FunctionTypeVar* gFun = get(g); - REQUIRE(gFun != nullptr); - - auto optionArg = first(gFun->argTypes); - REQUIRE(bool(optionArg)); - - TypeId arg = follow(*optionArg); - const TableTypeVar* argTable = get(arg); - REQUIRE(argTable != nullptr); - - std::optional methodProp = get(argTable->props, "method"); - REQUIRE(bool(methodProp)); - - const FunctionTypeVar* methodFunction = get(methodProp->type); - REQUIRE(methodFunction != nullptr); - - std::optional methodArg = first(methodFunction->argTypes); - REQUIRE(bool(methodArg)); - - REQUIRE_EQ(follow(*methodArg), follow(arg)); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") -{ - CheckResult result = check(R"( - --!strict - type Node = { Parent: Node?; } - local node: Node; - node.Parent = 1 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("Node?", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") -{ - CheckResult result = check(R"( - local T = {} - - function T.f(p) - for i, v in pairs(p) do - T.f(v) - end - end - )"); LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") -{ - // In this case, we cannot know the element type of the table {}. It could be anything. - // We therefore must initially ascribe a free typevar to iter. - CheckResult result = check(R"( - for iter in pairs({}) do - iter:g().p = true - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "quantify_methods_defined_using_dot_syntax_and_explicit_self_parameter") -{ - check(R"( - local T = {} - - function T.method(self) - self:method() - end - - function T.method2(self) - self:method() - end - - T:method2() - )"); -} - -TEST_CASE_FIXTURE(Fixture, "free_rhs_table_can_also_be_bound") -{ - check(R"( - local o - local v = o:i() - - function g(u) - v = u - end - - o:f(g) - o:h() - o:h() - )"); -} - -TEST_CASE_FIXTURE(Fixture, "require") -{ - fileResolver.source["game/A"] = R"( - local function hooty(x: number): string - return "Hi there!" - end - - return {hooty=hooty} - )"; - - fileResolver.source["game/B"] = R"( - local Hooty = require(game.A) - - local h -- free! - local i = Hooty.hooty(h) - )"; - - CheckResult aResult = frontend.check("game/A"); - dumpErrors(aResult); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = frontend.check("game/B"); - dumpErrors(bResult); - LUAU_REQUIRE_NO_ERRORS(bResult); - - ModulePtr b = frontend.moduleResolver.modules["game/B"]; - - REQUIRE(b != nullptr); - - dumpErrors(bResult); - - std::optional iType = requireType(b, "i"); - REQUIRE_EQ("string", toString(*iType)); - - std::optional hType = requireType(b, "h"); - REQUIRE_EQ("number", toString(*hType)); -} - -TEST_CASE_FIXTURE(Fixture, "require_types") -{ - fileResolver.source["workspace/A"] = R"( - export type Point = {x: number, y: number} - - return {} - )"; - - fileResolver.source["workspace/B"] = R"( - local Hooty = require(workspace.A) - - local h: Hooty.Point - )"; - - CheckResult bResult = frontend.check("workspace/B"); - dumpErrors(bResult); - - ModulePtr b = frontend.moduleResolver.modules["workspace/B"]; - REQUIRE(b != nullptr); - - TypeId hType = requireType(b, "h"); - REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); -} - -TEST_CASE_FIXTURE(Fixture, "require_a_variadic_function") -{ - fileResolver.source["game/A"] = R"( - local T = {} - function T.f(...) end - return T - )"; - - fileResolver.source["game/B"] = R"( - local A = require(game.A) - local f = A.f - )"; - - CheckResult result = frontend.check("game/B"); - - ModulePtr bModule = frontend.moduleResolver.getModule("game/B"); - REQUIRE(bModule != nullptr); - - TypeId f = follow(requireType(bModule, "f")); - - const FunctionTypeVar* ftv = get(f); - REQUIRE(ftv); - - auto iter = begin(ftv->argTypes); - auto endIter = end(ftv->argTypes); - - REQUIRE(iter == endIter); - REQUIRE(iter.tail()); - - CHECK(get(*iter.tail())); -} - -TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") -{ - CheckResult result = check(R"( - local f: any - local T = {} - - T.prop = f() - - return T - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TableTypeVar* ttv = getMutable(requireType("T")); - REQUIRE(ttv); - REQUIRE(ttv->props.count("prop")); - - REQUIRE_EQ("any", toString(ttv->props["prop"].type)); -} - -TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") -{ - CheckResult result = check(R"( - local p: SomeModule.DoesNotExist - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE_EQ(result.errors[0], (TypeError{Location{{1, 17}, {1, 40}}, UnknownSymbol{"SomeModule.DoesNotExist"}})); -} - -TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") -{ - const std::string sourceA = R"( - )"; - - const std::string sourceB = R"( - local Hooty = require(script.Parent.A) - )"; - - fileResolver.source["game/Workspace/A"] = sourceA; - fileResolver.source["game/Workspace/B"] = sourceB; - - frontend.check("game/Workspace/A"); - frontend.check("game/Workspace/B"); - - ModulePtr aModule = frontend.moduleResolver.modules["game/Workspace/A"]; - ModulePtr bModule = frontend.moduleResolver.modules["game/Workspace/B"]; - - CHECK(aModule->errors.empty()); - REQUIRE_EQ(1, bModule->errors.size()); - CHECK_MESSAGE(get(bModule->errors[0]), "Should be IllegalRequire: " << toString(bModule->errors[0])); - - auto hootyType = requireType(bModule, "Hooty"); - - CHECK_MESSAGE(get(follow(hootyType)) != nullptr, "Should be an error: " << toString(hootyType)); -} - -TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") +TEST_CASE_FIXTURE(BuiltinsFixture, "warn_on_lowercase_parent_property") { CheckResult result = check(R"( local M = require(script.parent.DoesNotMatter) @@ -1501,145 +172,7 @@ TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") REQUIRE_EQ("parent", ed->symbol); } -TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") -{ - CheckResult result = check(R"( - local A : any - function A.B() end - A:C() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId aType = requireType("A"); - CHECK_EQ(aType, typeChecker.anyType); -} - -TEST_CASE_FIXTURE(Fixture, "table_unifies_into_map") -{ - CheckResult result = check(R"( - local Instance: any - local UDim2: any - - function Create(instanceType) - return function(data) - local obj = Instance.new(instanceType) - for k, v in pairs(data) do - if type(k) == 'number' then - --v.Parent = obj - else - obj[k] = v - end - end - return obj - end - end - - local topbarShadow = Create'ImageLabel'{ - Name = "TopBarShadow"; - Size = UDim2.new(1, 0, 0, 3); - Position = UDim2.new(0, 0, 1, 0); - Image = "rbxasset://textures/ui/TopBar/dropshadow.png"; - BackgroundTransparency = 1; - Active = false; - Visible = false; - }; - - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") -{ - CheckResult result = check(R"( - local p = function(x) return x end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - const Luau::FunctionTypeVar* fn = get(requireType("p")); - REQUIRE(fn); - auto ret = first(fn->retType); - REQUIRE(ret); - REQUIRE(get(follow(*ret))); -} - -TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") -{ - CheckResult result = check(R"( - function foo(a, b) - return a(b) - end - - function bar() - local c: ((number)->number, number)->number = foo -- no error - c = foo -- no error - local d: ((number)->number, string)->number = foo -- error from arg 2 (string) not being convertable to number from the call a(b) - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("((number) -> number, string) -> number", toString(tm->wantedType)); - CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") -{ - CheckResult result = check(R"( - function foo(a, b) - return a(b) - end - - function bar() - local _: (string, string)->number = foo -- string cannot be converted to (string)->number - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("(string, string) -> number", toString(tm->wantedType)); - CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "string_method") -{ - CheckResult result = check(R"( - local p = ("tacos"):len() - )"); - CHECK_EQ(0, result.errors.size()); - - CHECK_EQ(*requireType("p"), *typeChecker.numberType); -} - -TEST_CASE_FIXTURE(Fixture, "string_function_indirect") -{ - CheckResult result = check(R"( - local s:string - local l = s.lower - local p = l(s) - )"); - CHECK_EQ(0, result.errors.size()); - - CHECK_EQ(*requireType("p"), *typeChecker.stringType); -} - -TEST_CASE_FIXTURE(Fixture, "string_function_other") -{ - CheckResult result = check(R"( - local s:string - local p = s:match("foo") - )"); - CHECK_EQ(0, result.errors.size()); - - CHECK_EQ(toString(requireType("p")), "string?"); -} - -TEST_CASE_FIXTURE(Fixture, "weird_case") +TEST_CASE_FIXTURE(BuiltinsFixture, "weird_case") { CheckResult result = check(R"( local function f() return 4 end @@ -1647,87 +180,10 @@ TEST_CASE_FIXTURE(Fixture, "weird_case") )"); LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "or_joins_types") -{ - CheckResult result = check(R"( - local s = "a" or 10 - local x:string|number = s - )"); - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(*requireType("s")), "number | string"); - CHECK_EQ(toString(*requireType("x")), "number | string"); -} - -TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") -{ - CheckResult result = check(R"( - local s = "a" or 10 - local x:number|string = s - local y = x or "s" - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "number | string"); - CHECK_EQ(toString(*requireType("y")), "number | string"); -} - -TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") -{ - CheckResult result = check(R"( - local s = "a" or "b" - local x:string = s - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(*requireType("s"), *typeChecker.stringType); -} - -TEST_CASE_FIXTURE(Fixture, "and_adds_boolean") -{ - CheckResult result = check(R"( - local s = "a" and 10 - local x:boolean|number = s - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "boolean | number"); -} - -TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") -{ - CheckResult result = check(R"( - local s = "a" and true - local x:boolean = s - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(*requireType("x"), *typeChecker.booleanType); -} - -TEST_CASE_FIXTURE(Fixture, "and_or_ternary") -{ - CheckResult result = check(R"( - local s = (1/2) > 0.5 and "a" or 10 - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "number | string"); -} - -TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional") -{ - CheckResult result = check(R"( - local T = {} - function T.new(a: number?, b: number?, c: number?) return 5 end - local m = T.new() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); } TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( --!strict local s @@ -1738,8 +194,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") TEST_CASE_FIXTURE(Fixture, "occurs_check_does_not_recurse_forever_if_asked_to_traverse_a_cyclic_type") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( --!strict function u(t, w) @@ -1763,304 +217,7 @@ TEST_CASE_FIXTURE(Fixture, "crazy_complexity") } #endif -// We had a bug where a cyclic union caused a stack overflow. -// ex type U = number | U -TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") -{ - CheckResult result = check(R"( - --!strict - - function f(a, b) - a:g(b or {}) - a:g(b) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "it_is_ok_not_to_supply_enough_retvals") -{ - CheckResult result = check(R"( - function get_two() return 5, 6 end - - local a = get_two() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "duplicate_functions2") -{ - CheckResult result = check(R"( - function foo() end - - function bar() - local function foo() end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); -} - -TEST_CASE_FIXTURE(Fixture, "duplicate_functions_allowed_in_nonstrict") -{ - CheckResult result = check(R"( - --!nonstrict - function foo() end - - function foo() end - - function bar() - local function foo() end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "duplicate_functions_with_different_signatures_not_allowed_in_nonstrict") -{ - CheckResult result = check(R"( - --!nonstrict - function foo(): number - return 1 - end - foo() - - function foo(n: number): number - return 2 - end - foo() - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("() -> number", toString(tm->wantedType)); - CHECK_EQ("(number) -> number", toString(tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "tables_get_names_from_their_locals") -{ - CheckResult result = check(R"( - local T = {} - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("T", toString(requireType("T"))); -} - -TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") -{ - CheckResult result = check(R"( - function foo(arr) - local work = {} - for i = 1, #arr do - work[i] = arr[i] - end - - return arr - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - const FunctionTypeVar* fooType = get(requireType("foo")); - REQUIRE(fooType); - - std::optional fooArg1 = first(fooType->argTypes); - REQUIRE(fooArg1); - - const TableTypeVar* fooArg1Table = get(*fooArg1); - REQUIRE(fooArg1Table); - - CHECK_EQ(fooArg1Table->state, TableState::Generic); -} - -TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotation") -{ - CheckResult result = check(R"( - local i = 0 - function most_of_the_natural_numbers(): number? - if i < 10 then - i = i + 1 - return i - else - return nil - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* functionType = get(requireType("most_of_the_natural_numbers")); - - std::optional retType = first(functionType->retType); - REQUIRE(retType); - CHECK(get(*retType)); -} - -TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") -{ - CheckResult result = check(R"( - function apply(f, x) - return f(x) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* ftv = get(requireType("apply")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(2, argVec.size()); - - const FunctionTypeVar* fType = get(argVec[0]); - REQUIRE(fType != nullptr); - - std::vector fArgs = flatten(fType->argTypes).first; - - TypeId xType = argVec[1]; - - CHECK_EQ(1, fArgs.size()); - CHECK_EQ(xType, fArgs[0]); -} - -TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") -{ - CheckResult result = check(R"( - function bottomupmerge(comp, a, b, left, mid, right) - local i, j = left, mid - for k = left, right do - if i < mid and (j > right or not comp(a[j], a[i])) then - b[k] = a[i] - i = i + 1 - else - b[k] = a[j] - j = j + 1 - end - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* ftv = get(requireType("bottomupmerge")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(6, argVec.size()); - - const FunctionTypeVar* fType = get(argVec[0]); - REQUIRE(fType != nullptr); -} - -TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") -{ - CheckResult result = check(R"( - function swap(p) - local t = p[0] - p[0] = p[1] - p[1] = t - return nil - end - - function swapTwice(p) - swap(p) - swap(p) - return p - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* ftv = get(requireType("swapTwice")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(1, argVec.size()); - - const TableTypeVar* argType = get(follow(argVec[0])); - REQUIRE(argType != nullptr); - - CHECK(bool(argType->indexer)); -} - -TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") -{ - CheckResult result = check(R"( - function bottomupmerge(comp, a, b, left, mid, right) - local i, j = left, mid - for k = left, right do - if i < mid and (j > right or not comp(a[j], a[i])) then - b[k] = a[i] - i = i + 1 - else - b[k] = a[j] - j = j + 1 - end - end - end - - function mergesort(arr, comp) - local work = {} - for i = 1, #arr do - work[i] = arr[i] - end - local width = 1 - while width < #arr do - for i = 1, #arr, 2*width do - bottomupmerge(comp, arr, work, i, math.min(i+width, #arr), math.min(i+2*width-1, #arr)) - end - local temp = work - work = arr - arr = temp - width = width * 2 - end - return arr - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - /* - * mergesort takes two arguments: an array of some type T and a function that takes two Ts. - * We must assert that these two types are in fact the same type. - * In other words, comp(arr[x], arr[y]) is well-typed. - */ - - const FunctionTypeVar* ftv = get(requireType("mergesort")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(2, argVec.size()); - - const TableTypeVar* arg0 = get(follow(argVec[0])); - REQUIRE(arg0 != nullptr); - REQUIRE(bool(arg0->indexer)); - - const FunctionTypeVar* arg1 = get(follow(argVec[1])); - REQUIRE(arg1 != nullptr); - REQUIRE_EQ(2, size(arg1->argTypes)); - - std::vector arg1Args = flatten(arg1->argTypes).first; - - CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[0]); - CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); -} - -TEST_CASE_FIXTURE(Fixture, "error_types_propagate") +TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") { CheckResult result = check(R"( local err = (true).x @@ -2077,375 +234,14 @@ TEST_CASE_FIXTURE(Fixture, "error_types_propagate") CHECK_EQ("boolean", toString(err->table)); CHECK_EQ("x", err->key); - CHECK(nullptr != get(requireType("c"))); - CHECK(nullptr != get(requireType("d"))); - CHECK(nullptr != get(requireType("e"))); - CHECK(nullptr != get(requireType("f"))); -} - -TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") -{ - CheckResult result = check(R"( - local a = unknown.Parent.Reward.GetChildren() - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - UnknownSymbol* err = get(result.errors[0]); - REQUIRE(err != nullptr); - - CHECK_EQ("unknown", err->name); - - CHECK(nullptr != get(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") -{ - CheckResult result = check(R"( - local a = Utility.Create "Foo" {} - )"); - - TypeId aType = requireType("a"); - - REQUIRE_MESSAGE(nullptr != get(aType), "Not an error: " << toString(aType)); -} - -TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") -{ - CheckResult result = check(R"( - function add(a: number, b: string) - return a + tonumber(b), a .. b - end - local n, s = add(2,"3") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* functionType = get(requireType("add")); - - std::optional retType = first(functionType->retType); - CHECK_EQ(std::optional(typeChecker.numberType), retType); - CHECK_EQ(requireType("n"), typeChecker.numberType); - CHECK_EQ(requireType("s"), typeChecker.stringType); -} - -TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") -{ - CheckResult result = check(R"( - local PI=3.1415926535897931 - local SOLAR_MASS=4*PI * PI - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(requireType("SOLAR_MASS"), typeChecker.numberType); -} - -TEST_CASE_FIXTURE(Fixture, "primitive_arith_possible_metatable") -{ - CheckResult result = check(R"( - function add(a: number, b: any) - return a + b - end - local t = add(1,2) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("t"))); -} - -TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") -{ - CheckResult result = check(R"( - local a = 4 + 8 - local b = a + 9 - local s = 'hotdogs' - local t = s .. s - local c = b - a - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("number", toString(requireType("a"))); - CHECK_EQ("number", toString(requireType("b"))); - CHECK_EQ("string", toString(requireType("s"))); - CHECK_EQ("string", toString(requireType("t"))); - CHECK_EQ("number", toString(requireType("c"))); -} - -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") -{ - CheckResult result = check(R"( - --!strict - local Vec3 = {} - Vec3.__index = Vec3 - function Vec3.new() - return setmetatable({x=0, y=0, z=0}, Vec3) - end - - export type Vec3 = typeof(Vec3.new()) - - local thefun: any = function(self, o) return self end - - local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun - - Vec3.__mul = multiply - - local a = Vec3.new() - local b = Vec3.new() - local c = a * b - local d = a * 2 - local e = a * 'cabbage' - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Vec3", toString(requireType("a"))); - CHECK_EQ("Vec3", toString(requireType("b"))); - CHECK_EQ("Vec3", toString(requireType("c"))); - CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); -} - -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") -{ - CheckResult result = check(R"( - --!strict - local Vec3 = {} - Vec3.__index = Vec3 - function Vec3.new() - return setmetatable({x=0, y=0, z=0}, Vec3) - end - - export type Vec3 = typeof(Vec3.new()) - - local thefun: any = function(self, o) return self end - - local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun - - Vec3.__mul = multiply - - local a = Vec3.new() - local b = Vec3.new() - local c = b * a - local d = 2 * a - local e = 'cabbage' * a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Vec3", toString(requireType("a"))); - CHECK_EQ("Vec3", toString(requireType("b"))); - CHECK_EQ("Vec3", toString(requireType("c"))); - CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); -} - -TEST_CASE_FIXTURE(Fixture, "compare_numbers") -{ - CheckResult result = check(R"( - local a = 441 - local b = 0 - local c = a < b - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "compare_strings") -{ - CheckResult result = check(R"( - local a = '441' - local b = '0' - local c = a < b - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_metatable") -{ - CheckResult result = check(R"( - local a = {} - local b = {} - local c = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* gen = get(result.errors[0]); - - REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") -{ - CheckResult result = check(R"( - local M = {} - function M.new() - return setmetatable({}, M) - end - type M = typeof(M.new()) - - local a = M.new() - local b = M.new() - local c = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* gen = get(result.errors[0]); - REQUIRE(gen != nullptr); - REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") -{ - CheckResult result = check(R"( - --!strict - local M = {} - function M.new() - return setmetatable({}, M) - end - function M.__lt(left, right) return true end - - local a = M.new() - local b = {} - local c = a < b -- line 10 - local d = b < a -- line 11 - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - REQUIRE_EQ((Location{{10, 18}, {10, 23}}), result.errors[0].location); - - REQUIRE_EQ((Location{{11, 18}, {11, 23}}), result.errors[1].location); -} - -TEST_CASE_FIXTURE(Fixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") -{ - CheckResult result = check(R"( - --!strict - local M = {} - function M.new() - return setmetatable({}, M) - end - function M.__lt(left, right) return true end - type M = typeof(M.new()) - - local a = M.new() - local b = {} - local c = a < b -- line 10 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto err = get(result.errors[0]); - REQUIRE(err != nullptr); - - // Frail. :| - REQUIRE_EQ("Types M and b cannot be compared with < because they do not have the same metatable", err->message); -} - -TEST_CASE_FIXTURE(Fixture, "in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators") -{ - CheckResult result = check(R"( - --!nonstrict - - function maybe_a_number(): number? - return 50 - end - - local a = maybe_a_number() < maybe_a_number() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -/* - * This test case exposed an oversight in the treatment of free tables. - * Free tables, like free TypeVars, need to record the scope depth where they were created so that - * we do not erroneously let-generalize them when they are used in a nested lambda. - * - * For more information about let-generalization, see - * - * The important idea here is that the return type of Counter.new is a table with some metatable. - * That metatable *must* be the same TypeVar as the type of Counter. If it is a copy (produced by - * the generalization process), then it loses the knowledge that its metatable will have an :incr() - * method. - */ -TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") -{ - CheckResult result = check(R"( - local Counter = {} - Counter.__index = Counter - - function Counter.new() - local self = setmetatable({count=0}, Counter) - return self - end - - function Counter:incr() - self.count = 1 - return self.count - end - - local self = Counter.new() - print(self:incr()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TableTypeVar* counterType = getMutable(requireType("Counter")); - REQUIRE(counterType); - - const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); - REQUIRE(newType); - - std::optional newRetType = *first(newType->retType); - REQUIRE(newRetType); - - const MetatableTypeVar* newRet = get(follow(*newRetType)); - REQUIRE(newRet); - - const TableTypeVar* newRetMeta = get(newRet->metatable); - REQUIRE(newRetMeta); - - CHECK(newRetMeta->props.count("incr")); - CHECK_EQ(follow(newRet->metatable), follow(requireType("Counter"))); -} - -// TODO: CLI-39624 -TEST_CASE_FIXTURE(Fixture, "instantiate_tables_at_scope_level") -{ - CheckResult result = check(R"( - --!strict - local Option = {} - Option.__index = Option - function Option.Is(obj) - return (type(obj) == "table" and getmetatable(obj) == Option) - end - return Option - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") -{ - const std::string code = R"( - function f(a) - if type(a) == "boolean" then - local a1 = a - elseif a.fn() then - local a2 = a - else - local a3 = a - end - end - )"; - CheckResult result = check(code); - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); + // TODO: Should we assert anything about these tests when DCR is being used? + if (!FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("*error-type*", toString(requireType("c"))); + CHECK_EQ("*error-type*", toString(requireType("d"))); + CHECK_EQ("*error-type*", toString(requireType("e"))); + CHECK_EQ("*error-type*", toString(requireType("f"))); + } } TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowing") @@ -2463,62 +259,6 @@ TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowi LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "x_or_y_forces_both_x_and_y_to_be_of_same_type_if_either_is_free") -{ - CheckResult result = check(R"( - local function f(x, y) return x or y end - - local x = f(1, 2) - local y = f(3, "foo") - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*requireType("x"), *typeChecker.numberType); - - CHECK_EQ(result.errors[0], (TypeError{Location{{4, 23}, {4, 28}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); -} - -TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocate_memory") -{ - CheckResult result = check(R"( - ("foo") - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - )"); - - ModulePtr module = getMainModule(); - CHECK_GE(50, module->internalTypes.typeVars.size()); -} - -TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") -{ - CheckResult result = check(R"( - --!strict - function f(U) - U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() - end - )"); - - ModulePtr module = getMainModule(); - CHECK_GE(100, module->internalTypes.typeVars.size()); -} - TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") { CheckResult result = check(R"( @@ -2557,141 +297,9 @@ TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") CHECK_GE(5, module->interfaceTypes.typeVars.size()); } -TEST_CASE_FIXTURE(Fixture, "mutual_recursion") -{ - CheckResult result = check(R"( - --!strict - - function newPlayerCharacter() - startGui() -- Unknown symbol 'startGui' - end - - local characterAddedConnection: any - function startGui() - characterAddedConnection = game:GetService("Players").LocalPlayer.CharacterAdded:connect(newPlayerCharacter) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") -{ - CheckResult result = check(R"( - --!strict - local x = nil - function f() g() end - -- make sure print(x) doen't get toposorted here, breaking the mutual block - function g() x = f end - print(x) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") -{ - CheckResult result = check(R"( - --!strict - type T = { f: a, g: U } - type U = { h: a, i: T? } - local x: T = { f = 37, g = { h = 5, i = nil } } - x.g.i = x - local y: T = { f = "hi", g = { h = "lo", i = nil } } - y.g.i = y - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") -{ - CheckResult result = check(R"( - --!strict - type T = { f: a, g: U } - type U = { h: b, i: T? } - local x: T = { f = 37, g = { h = 5, i = nil } } - x.g.i = x - local y: T = { f = "hi", g = { h = 5, i = nil } } - y.g.i = y - )"); - - LUAU_REQUIRE_ERRORS(result); - - // We had a UAF in this example caused by not cloning type function arguments - ModulePtr module = frontend.moduleResolver.getModule("MainModule"); - unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); - freeze(module->interfaceTypes); - module->internalTypes.clear(); - module->astTypes.clear(); - - // Make sure the error strings don't include "VALUELESS" - for (auto error : module->errors) - CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error)); -} - -TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") -{ - // CLI-30902 - CheckResult result = check(R"( - --!strict - - type Foo = { - fooConn: () -> () | nil - } - - local Foo = {} - Foo.__index = Foo - - function Foo.new() - local self: Foo = { - fooConn = nil, - } - setmetatable(self, Foo) - - self.fooConn = function() - self:method() -- Key 'method' not found in table self - end - - return self - end - - function Foo:method() - print("foo") - end - - local foo = Foo.new() - - -- TODO This is the best our current refinement support can offer :( - local bar = foo.fooConn - if bar then bar() end - - -- foo.fooConn() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") -{ - CheckResult result = check(R"( - local a: any - local b - for _, i in pairs(a) do - b = i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("b"))); -} - // In these tests, a successful parse is required, so we need the parser to return the AST and then we can test the recursion depth limit in type // checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead. -TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") +TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_count") { #if defined(LUAU_ENABLE_ASAN) int limit = 250; @@ -2700,12 +308,13 @@ TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") #else int limit = 600; #endif - ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; - ScopedFastInt luauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", limit - 100}; - ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", 0}; - CHECK_NOTHROW(check("print('Hello!')")); - CHECK_THROWS_AS(check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"), std::runtime_error); + ScopedFastInt sfi{"LuauCheckRecursionLimit", limit}; + + CheckResult result = check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(nullptr != get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "check_block_recursion_limit") @@ -2732,7 +341,7 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") #if defined(LUAU_ENABLE_ASAN) int limit = 250; #elif defined(_DEBUG) || defined(_NOOPT) - int limit = 350; + int limit = 300; #else int limit = 600; #endif @@ -2742,147 +351,7 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") CheckResult result = check(R"(("foo"))" + rep(":lower()", limit)); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(nullptr != get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_basic") -{ - CheckResult result = check(R"( - local s = 10 - s += 20 - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "number"); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_op") -{ - CheckResult result = check(R"( - local s = 10 - s += true - )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.booleanType}})); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") -{ - CheckResult result = check(R"( - local s = 'hello' - s += 10 - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") -{ - CheckResult result = check(R"( - --!strict - type V2B = { x: number, y: number } - local v2b: V2B = { x = 0, y = 0 } - local VMT = {} - type V2 = typeof(setmetatable(v2b, VMT)) - - function VMT.__add(a: V2, b: V2): V2 - return setmetatable({ x = a.x + b.x, y = a.y + b.y }, VMT) - end - - local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) - local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) - v1 += v2 - )"); - CHECK_EQ(0, result.errors.size()); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_metatable") -{ - CheckResult result = check(R"( - --!strict - type V2B = { x: number, y: number } - local v2b: V2B = { x = 0, y = 0 } - local VMT = {} - type V2 = typeof(setmetatable(v2b, VMT)) - - function VMT.__mod(a: V2, b: V2): number - return a.x * b.x + a.y * b.y - end - - local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) - local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) - v1 %= v2 - )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - CHECK_EQ(*tm->wantedType, *requireType("v2")); - CHECK_EQ(*tm->givenType, *typeChecker.numberType); -} - -TEST_CASE_FIXTURE(Fixture, "dont_ice_if_a_TypePack_is_an_error") -{ - CheckResult result = check(R"( - --!strict - function f(s) - print(s) - return f - end - - f("foo")("bar") - )"); -} - -TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") -{ - CheckResult result = check(R"( - --!nonstrict - - function f() - return 114 - end - - return function() - return f():andThen() - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") -{ - CheckResult result = check(R"( - function onerror() end - function foo() end - xpcall(foo, onerror) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "another_indirect_function_case_where_it_is_ok_to_provide_too_many_arguments") -{ - CheckResult result = check(R"( - local mycb: (number, number) -> () - - function f() end - - mycb = f - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "call_to_any_yields_any") -{ - CheckResult result = check(R"( - local a: any - local b = a() - )"); - - REQUIRE_EQ("any", toString(requireType("b"))); + CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected CodeTooComplex but got " << toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "globals") @@ -2928,106 +397,6 @@ TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") CHECK_EQ("foo", us->name); } -TEST_CASE_FIXTURE(Fixture, "globals_everywhere") -{ - CheckResult result = check(R"( - --!nonstrict - foo = 1 - - if true then - bar = 2 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); -} - -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfAny") -{ - CheckResult result = check(R"( -local x: any = {} -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfSealed") -{ - CheckResult result = check(R"( -local x: {prop: number} = {prop=9999} -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") -{ - CheckResult result = check(R"( -local x: number = 9999 -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERROR_COUNT(3, result); -} - -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") -{ - CheckResult result = check(R"( -local x = (true).foo -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - -TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") -{ - CheckResult result = check(R"( -function f() return 1; end -function g() return 2; end -(f or g)() -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "CallAndOrOfFunctions") -{ - CheckResult result = check(R"( -function f() return 1; end -function g() return 2; end -local x = false -(x and f or g)() -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "MixedPropertiesAndIndexers") -{ - CheckResult result = check(R"( -local x = {} -x.a = "a" -x[0] = true -x.b = 37 -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") { CheckResult result = check(R"( @@ -3035,7 +404,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") local a = 1 end - print(a) -- oops! + local b = a -- oops! )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -3045,35 +414,6 @@ TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") CHECK_EQ(us->name, "a"); } -TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") -{ - CheckResult result = check(R"( - while true do - local a = 1 - end - - print(a) -- oops! - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - UnknownSymbol* us = get(result.errors[0]); - REQUIRE(us); - CHECK_EQ(us->name, "a"); -} - -TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indeces") -{ - CheckResult result = check(R"( - local key - for i, e in ipairs({}) do key = i end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - REQUIRE_EQ("number", toString(requireType("key"))); -} - TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") { CHECK_NOTHROW(check(R"( @@ -3087,468 +427,6 @@ TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") CHECK_EQ("any", toString(requireType("value"))); } -TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_nonstrict") -{ - CheckResult result = check(R"( - --!nonstrict - - local function f1(v): number? - if v then - return 1 - end - end - - local function f2(v) - if v then - return 1 - end - end - - local function f3(v): () - if v then - return - end - end - - local function f4(v) - if v then - return - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - FunctionExitsWithoutReturning* err = get(result.errors[0]); - CHECK(err); -} - -TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_strict") -{ - CheckResult result = check(R"( - --!strict - - local function f1(v): number? - if v then - return 1 - end - end - - local function f2(v) - if v then - return 1 - end - end - - local function f3(v): () - if v then - return - end - end - - local function f4(v) - if v then - return - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - FunctionExitsWithoutReturning* annotatedErr = get(result.errors[0]); - CHECK(annotatedErr); - - FunctionExitsWithoutReturning* inferredErr = get(result.errors[1]); - CHECK(inferredErr); -} - -// TEST_CASE_FIXTURE(Fixture, "infer_method_signature_of_argument") -// { -// CheckResult result = check(R"( -// function f(a) -// if a.cond then -// return a.method() -// end -// end -// )"); - -// LUAU_REQUIRE_NO_ERRORS(result); - -// CHECK_EQ("A", toString(requireType("f"))); -// } - -TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") -{ - fileResolver.source["Modules/A"] = ""; - fileResolver.sourceTypes["Modules/A"] = SourceCode::Local; - - fileResolver.source["Modules/B"] = R"( - local M = require(script.Parent.A) - )"; - - CheckResult result = frontend.check("Modules/B"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields_errors_spanning_argument") -{ - CheckResult result = check(R"( - function foo(a: number, b: string) end - - foo("Test", 123) - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK_EQ(result.errors[0], (TypeError{Location{Position{3, 12}, Position{3, 18}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.stringType, - }})); - - CHECK_EQ(result.errors[1], (TypeError{Location{Position{3, 20}, Position{3, 23}}, TypeMismatch{ - typeChecker.stringType, - typeChecker.numberType, - }})); -} - -TEST_CASE_FIXTURE(Fixture, "calling_function_with_anytypepack_doesnt_leak_free_types") -{ - CheckResult result = check(R"( - --!nonstrict - - function Test(a) - return 1, "" - end - - - local tab = {} - table.insert(tab, Test(1)); - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions opts; - opts.exhaustive = true; - opts.maxTableLength = 0; - - CHECK_EQ("{any}", toString(requireType("tab"), opts)); -} - -TEST_CASE_FIXTURE(Fixture, "too_many_return_values") -{ - CheckResult result = check(R"( - --!strict - - function f() - return 55 - end - - local a, b = f() - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK(acm->context == CountMismatch::Result); -} - -TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") -{ - CheckResult result = check(R"( - --!strict - - function f(): (number, string) - return 55 - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Return); -} - -TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") -{ - CheckResult result = check(R"( - --!strict - local foo = { - value = 10 - } - local mt = {} - setmetatable(foo, mt) - - mt.__unm = function(val: typeof(foo)): string - return val.value .. "test" - end - - local a = -foo - - local b = 1+-1 - - local bar = { - value = 10 - } - local c = -bar -- disallowed - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("string", toString(requireType("a"))); - CHECK_EQ("number", toString(requireType("b"))); - - GenericError* gen = get(result.errors[0]); - REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); -} - -TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") -{ - CheckResult result = check(R"( - local b = not "string" - local c = not (math.random() > 0.5 and "string" or 7) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("boolean", toString(requireType("b"))); - REQUIRE_EQ("boolean", toString(requireType("c"))); -} - -TEST_CASE_FIXTURE(Fixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") -{ - CheckResult result = check(R"( - --!strict - local a = "1.24" + 123 -- not allowed - - local foo = { - value = 10 - } - - local b = foo + 1 -- not allowed - - local bar = { - value = 1 - } - - local mt = {} - - setmetatable(bar, mt) - - mt.__add = function(a: typeof(bar), b: number): number - return a.value + b - end - - local c = bar + 1 -- allowed - - local d = bar + foo -- not allowed - )"); - - LUAU_REQUIRE_ERROR_COUNT(3, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); - REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); - - TypeMismatch* tm2 = get(result.errors[2]); - CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); - CHECK_EQ(*tm2->givenType, *requireType("foo")); - - GenericError* gen2 = get(result.errors[1]); - REQUIRE_EQ(gen2->message, "Binary operator '+' not supported by types 'foo' and 'number'"); -} - -// CLI-29033 -TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") -{ - CheckResult result = check(R"( - function merge(lower, greater) - if lower.y == greater.y then - end - end - )"); - - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); - } -} - -TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable") -{ - CheckResult result = check(R"( - local x - print((x == true and (x .. "y")) .. 1) - )"); - - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK_EQ("Type 'boolean' could not be converted into 'number | string'", toString(result.errors[0])); - CHECK_EQ("Type 'boolean | string' could not be converted into 'number | string'", toString(result.errors[1])); - } -} - -TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") -{ - CheckResult result = check(R"( - local x - print("foo" .. x) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("string", toString(requireType("x"))); -} - -TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") -{ - std::vector ops = {"+", "-", "*", "/", "%", "^", ".."}; - - std::string src = R"( - function foo(a, b) - )"; - - for (const auto& op : ops) - src += "local _ = a " + op + "b\n"; - - src += "end"; - - CheckResult result = check(src); - LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); - - CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") -{ - CheckResult result = check(R"( - function foo(a, b): number - return 0 - end - - local a: (string)->number = foo - local b: (number, number)->(number, number) = foo - - local c: (string, number)->number = foo -- no error - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - auto tm1 = get(result.errors[0]); - REQUIRE(tm1); - - CHECK_EQ("(string) -> number", toString(tm1->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm1->givenType)); - - auto tm2 = get(result.errors[1]); - REQUIRE(tm2); - - CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm2->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "setmetatable_cant_be_used_to_mutate_global_types") -{ - { - Fixture fix; - - // inherit env from parent fixture checker - fix.typeChecker.globalScope = typeChecker.globalScope; - - fix.check(R"( ---!nonstrict -type MT = typeof(setmetatable) -function wtf(arg: {MT}): typeof(table) - arg = wtf(arg) -end -)"); - } - - // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down - // note: it's important for typeck to be destroyed at this point! - { - for (auto& p : typeChecker.globalScope->bindings) - { - toString(p.second.typeId); // toString walks the entire type, making sure ASAN catches access to destroyed type arenas - } - } -} - -TEST_CASE_FIXTURE(Fixture, "evil_table_unification") -{ - // this code re-infers the type of _ while processing fields of _, which can cause use-after-free - check(R"( ---!nonstrict -_ = ... -_:table(_,string)[_:gsub(_,...,n0)],_,_:gsub(_,string)[""],_:split(_,...,table)._,n0 = nil -do end -)"); -} - -TEST_CASE_FIXTURE(Fixture, "overload_is_not_a_function") -{ - check(R"( ---!nonstrict -function _(...):((typeof(not _))&(typeof(not _)))&((typeof(not _))&(typeof(not _))) -_(...)(setfenv,_,not _,"")[_] = nil -end -do end -_(...)(...,setfenv,_):_G() -)"); -} - -TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors") -{ - CheckResult result = check(R"( - type Pair = {first: T, second: U} - local a: Pair - local b: Pair - - a = b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - - CHECK_EQ("Pair", toString(tm->wantedType)); - CHECK_EQ("Pair", toString(tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs") -{ - // this has a risk of creating cyclic type packs, causing infinite loops / OOMs - check(R"( ---!nonstrict -_ += _(_,...) -repeat -_ += _(...) -until ... + _ -)"); - - check(R"( ---!nonstrict -_ += _(_(...,...),_(...)) -repeat -until _ -)"); -} - TEST_CASE_FIXTURE(Fixture, "cyclic_follow") { check(R"( @@ -3577,19 +455,6 @@ end )"); } -TEST_CASE_FIXTURE(Fixture, "and_binexps_dont_unify") -{ - CheckResult result = check(R"( - --!strict - local t = {} - while true and t[1] do - print(t[1].test) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - struct FindFreeTypeVars { bool foundOne = false; @@ -3613,43 +478,6 @@ struct FindFreeTypeVars } }; -TEST_CASE_FIXTURE(Fixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") -{ - CheckResult result = check("local x = setmetatable({})"); - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") -{ - // This code doesn't pass typechecking. We just care that it doesn't crash. - (void)check(R"( - --!nonstrict - function _:_(...) - end - - repeat - if _ then - else - _ = ... - end - until _ - - for _ in _() do - end - )"); -} - -TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition") -{ - CheckResult result = check(R"( - type A = number - type A = string -- Redefinition of type 'A', previously defined at line 1 - local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery") { CheckResult result = check(R"( @@ -3669,7 +497,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_assert") LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error") +TEST_CASE_FIXTURE(BuiltinsFixture, "tc_after_error_recovery_no_replacement_name_in_error") { { CheckResult result = check(R"( @@ -3722,69 +550,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error } } -TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") -{ - CheckResult result = check(R"( - local a: boolean = true - local b: boolean = false - local foo = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); -} - -TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2") -{ - CheckResult result = check(R"( - local a: number | string = "" - local b: number | string = 1 - local foo = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); -} - -TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type") -{ - CheckResult result = check(R"( - type Table = { a: T } - type Wrapped = Table - local l: Wrapped = 2 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("Wrapped", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") -{ - CheckResult result = check(R"( - type Table = { a: T } - type Wrapped = (Table) -> string - local l: Wrapped = 2 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") +TEST_CASE_FIXTURE(BuiltinsFixture, "index_expr_should_be_checked") { CheckResult result = check(R"( local foo: any @@ -3800,100 +566,6 @@ TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") CHECK_EQ("x", up->key); } -TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") -{ - { - CheckResult result = check(R"( - function unreachablecodepath(a): number - while true do - if a then return 10 end - end - -- unreachable - end - unreachablecodepath(4) - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); - } - - { - CheckResult result = check(R"( - function reachablecodepath(a): number - while true do - if a then break end - return 10 - end - - print("x") -- correct error - end - reachablecodepath(4) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK(get(result.errors[0])); - } - - { - CheckResult result = check(R"( - function unreachablecodepath(a): number - repeat - if a then return 10 end - until false - - -- unreachable - end - unreachablecodepath(4) - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); - } - - { - CheckResult result = check(R"( - function reachablecodepath(a, b): number - repeat - if a then break end - - if b then return 10 end - until false - - print("x") -- correct error - end - reachablecodepath(4) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK(get(result.errors[0])); - } - - { - CheckResult result = check(R"( - function unreachablecodepath(a: number?): number - repeat - return 10 - until a ~= nil - - -- unreachable - end - unreachablecodepath(4) - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); - } -} - -TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") -{ - CheckResult result = check(R"( - --!strict - local _ - _ += _ and _ or _ and _ or _ and _ - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); -} - TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") { CheckResult result = check(R"( @@ -3909,81 +581,6 @@ TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); } -// Check that recursive intersection type doesn't generate an OOM -TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") -{ - CheckResult result = check(R"( - function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any - end - type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) - _(_) - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") -{ - // In non-strict mode, global definition is still allowed - { - CheckResult result = check(R"( - --!nonstrict - a = a + 1 - print(a) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); - } - - // In strict mode we no longer generate two errors from lhs - { - CheckResult result = check(R"( - --!strict - a += 1 - print(a) - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); - } - - // In non-strict mode, compound assignment is not a definition, it's a modification - { - CheckResult result = check(R"( - --!nonstrict - a += 1 - print(a) - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); - } -} - -TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") -{ - CheckResult result = check(R"( - local t = {} - for _ in t do - for _ in assert(missing()) do - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") -{ - CheckResult result = check(R"( - local foo: Id = 1 - type Id = T - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") { CheckResult result = check(R"( @@ -3994,101 +591,6 @@ TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "metatable_of_any_can_be_a_table") -{ - CheckResult result = check(R"( ---!strict -local T: any -T = {} -T.__index = T -function T.new(...) - local self = {} - setmetatable(self, T) - self:construct(...) - return self -end -function T:construct(index) -end -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") -{ - const std::string code = R"( - type A = {v:T, b:B} - type B = {v:T, a:A} - local aa:A - local bb = aa - )"; - - const std::string expected = R"( - type A = {v:T, b:B} - type B = {v:T, a:A} - local aa:A - local bb:A=aa - )"; - - CHECK_EQ(expected, decorateWithTypes(code)); - CheckResult result = check(code); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - type A = () -> (number, B) - type B = () -> (string, A) - local a: A - local b: B - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); - CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "generic_param_remap") -{ - const std::string code = R"( - -- An example of a forwarded use of a type that has different type arguments than parameters - type A = {t:T, u:U, next:A?} - local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } - local bb = aa - )"; - - const std::string expected = R"( - - type A = {t:T, u:U, next:A?} - local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } - local bb:A=aa - )"; - - CHECK_EQ(expected, decorateWithTypes(code)); - CheckResult result = check(code); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") -{ - CheckResult result = check(R"( - export type Foo = number - type Foo = number - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto dtd = get(result.errors[0]); - REQUIRE(dtd); - CHECK_EQ(dtd->name, "Foo"); -} - TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") { CheckResult result = check(R"( @@ -4116,66 +618,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_ice_on_astexprerror") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator") -{ - CheckResult result = check(R"( ---!strict -local a: number? = nil -local b: number = a or 1 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator2") -{ - CheckResult result = check(R"( ---!nonstrict -local a: number? = nil -local b: number = a or 1 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "dont_strip_nil_from_rhs_or_operator") -{ - CheckResult result = check(R"( ---!strict -local a: number? = nil -local b: number = 1 or a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ("number?", toString(tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") -{ - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - - CheckResult result = check(R"( - --!strict - local tbl = {} - function tbl:abc(a: number, b: number) - return a - end - tbl:abc(1, 2) -- Line 6 - -- | Column 14 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - TypeId type = requireTypeAtPosition(Position(6, 14)); - CHECK_EQ("(tbl, number, number) -> number", toString(type)); - auto ftv = get(type); - REQUIRE(ftv); - CHECK(ftv->hasSelf); -} - TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does") { CheckResult result = check(R"( @@ -4193,493 +635,6 @@ TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does") REQUIRE_MESSAGE(get(e) != nullptr, "Expected UnknownSymbol, but got " << e); } -TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") -{ - ScopedFastFlag sffs3{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - - CheckResult result = check(R"( - type Node = { value: T, child: Node? } - - local function visitor(node: Node?) - local a: Node - - if node then - a = node.child -- Observe the output of the error message. - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto e = get(result.errors[0]); - CHECK_EQ("Node?", toString(e->givenType)); - CHECK_EQ("Node", toString(e->wantedType)); -} - -TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") -{ - CheckResult result = check(R"( - type Array = { [number]: T } - type Fiber = { id: number } - type null = {} - - local fiberStack: Array = {} - local index = 0 - - local function f(fiber: Fiber) - local a = fiber ~= fiberStack[index] - local b = fiberStack[index] ~= fiber - end - - return f - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "general_require_call_expression") -{ - fileResolver.source["game/A"] = R"( ---!strict -return { def = 4 } - )"; - - fileResolver.source["game/B"] = R"( ---!strict -local tbl = { abc = require(game.A) } -local a : string = "" -a = tbl.abc.def - )"; - - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "general_require_type_mismatch") -{ - fileResolver.source["game/A"] = R"( -return { def = 4 } - )"; - - fileResolver.source["game/B"] = R"( -local tbl: string = require(game.A) - )"; - - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") -{ - fileResolver.source["workspace/A"] = R"( - export type myvec2 = {x: number, y: number} - return {} - )"; - - fileResolver.source["workspace/B"] = R"( - export type myvec3 = {x: number, y: number, z: number} - return {} - )"; - - fileResolver.source["workspace/C"] = R"( - local Foo, Bar = require(workspace.A), require(workspace.B) - - local a: Foo.myvec2 - local b: Bar.myvec3 - )"; - - CheckResult result = frontend.check("workspace/C"); - LUAU_REQUIRE_NO_ERRORS(result); - ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; - - REQUIRE(m != nullptr); - - std::optional aTypeId = lookupName(m->getModuleScope(), "a"); - REQUIRE(aTypeId); - const Luau::TableTypeVar* aType = get(follow(*aTypeId)); - REQUIRE(aType); - REQUIRE(aType->props.size() == 2); - - std::optional bTypeId = lookupName(m->getModuleScope(), "b"); - REQUIRE(bTypeId); - const Luau::TableTypeVar* bType = get(follow(*bTypeId)); - REQUIRE(bType); - REQUIRE(bType->props.size() == 3); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") -{ - CheckResult result = check("type t10 = typeof(table)"); - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId ty = getGlobalBinding(frontend.typeChecker, "table"); - CHECK_EQ(toString(ty), "table"); - - const TableTypeVar* ttv = get(ty); - REQUIRE(ttv); - - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") -{ - CheckResult result = check(R"( -type Cool = { a: number, b: string } -local c: Cool = { a = 1, b = "s" } -type NotCool = Cool -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "Cool"); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") -{ - CheckResult result = check(R"( -type Cool = { a: number, b: string } -type NotCool = Cool -local c: Cool = { a = 1, b = "s" } -local d: NotCool = { a = 1, b = "s" } -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "Cool"); - - ty = requireType("d"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "NotCool"); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation") -{ - CheckResult result = check(R"( -local c = { a = 1, b = "s" } -type Cool = typeof(c) -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK_EQ(ttv->name, "Cool"); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") -{ - ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; - - fileResolver.source["game/A"] = R"( -export type X = { a: number, b: X? } -return {} - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - std::optional ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - std::optional ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(follow(*ty1), follow(*ty2)); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") -{ - ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; - - fileResolver.source["game/A"] = R"( -export type X = { a: T, b: U, C: X? } -return {} - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - std::optional ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - std::optional ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); - - bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); - CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); -} - -TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") -{ - CheckResult result = check(R"( ---!nonstrict -local f = {} -function f:foo(a: number, b: number) end - -function bar(...) - f.foo(f, 1, ...) -end - -bar(2) -)"); - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") -{ - CheckResult result = check(R"( -local function f(a: typeof(f)) end -)"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Unknown global 'f'", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning") -{ - CheckResult result = check(R"( ---!nonstrict -local l0:any,l61:t0 = _,math -while _ do -_() -end -function _():t0 -end -type t0 = any -)"); - - std::optional ty = requireType("math"); - REQUIRE(ty); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "bound_free_table_export_is_ok") -{ - CheckResult result = check(R"( -local n = {} -function n:Clone() end - -local m = {} - -function m.a(x) - x:Clone() -end - -function m.b() - m.a(n) -end - -return m -)"); - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") -{ - TypeId mathTy = requireType(typeChecker.globalScope, "math"); - REQUIRE(mathTy); - TableTypeVar* ttv = getMutable(mathTy); - REQUIRE(ttv); - const FunctionTypeVar* ftv = get(ttv->props["frexp"].type); - REQUIRE(ftv); - auto original = ftv->level; - - CheckResult result = check("local a = math.frexp"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK(ftv->level.level == original.level); - CHECK(ftv->level.subLevel == original.subLevel); -} - -TEST_CASE_FIXTURE(Fixture, "table_indexing_error_location") -{ - CheckResult result = check(R"( -local foo = {42} -local bar: number? -local baz = foo[bar] - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); -} - -TEST_CASE_FIXTURE(Fixture, "table_simple_call") -{ - CheckResult result = check(R"( -local a = setmetatable({ x = 2 }, { - __call = function(self) - return (self.x :: number) * 2 -- should work without annotation in the future - end -}) -local b = a() -local c = a(2) -- too many arguments - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") -{ - CheckResult result = check(R"( -function get() - return function(obj) return true end -end - -export type f = typeof(get()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak") -{ - CheckResult result = check(R"( -function get() - return {a = 1, b = function(obj) return true end} -end - -export type f = typeof(get()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "custom_require_global") -{ - CheckResult result = check(R"( ---!nonstrict -require = function(a) end - -local crash = require(game.A) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap") -{ - ScopedFastFlag sff1{"LuauEqConstraint", true}; - - CheckResult result = check(R"( - local function f(a: string | number, b: boolean | number) - return a == b - end - )"); - - // This doesn't produce any errors but for the wrong reasons. - // This unit test serves as a reminder to not try and unify the operands on `==`/`~=`. - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "access_index_metamethod_that_returns_variadic") -{ - CheckResult result = check(R"( - type Foo = {x: string} - local t = {} - setmetatable(t, { - __index = function(x: string): ...Foo - return {x = x} - end - }) - - local foo = t.bar - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions o; - o.exhaustive = true; - CHECK_EQ("{| x: string |}", toString(requireType("foo"), o)); -} - -TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks") -{ - CheckResult result = check(R"( - type ( ... ) ( ) ; - ( ... ) ( - - ... ) ( - ... ) - type = ( ... ) ; - ( ... ) ( ) ( ... ) ; - ( ... ) "" - )"); - - CHECK_LE(0, result.errors.size()); -} - -TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks2") -{ - CheckResult result = check(R"( - function _(l0:((typeof((pcall)))|((((t0)->())|(typeof(-67108864)))|(any)))|(any),...):(((typeof(0))|(any))|(any),typeof(-67108864),any) - xpcall(_,_,_) - _(_,_,_) - end - )"); - - CHECK_LE(0, result.errors.size()); -} - -TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") -{ - CheckResult result = check(R"( - function _(l0:t0): (any, ()->()) - end - - type t0 = t0 | {} - )"); - - CHECK_LE(0, result.errors.size()); - - std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); - REQUIRE(t0); - CHECK(get(t0->type)); - - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { - return get(err); - }); - CHECK(it != result.errors.end()); -} - TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") { CheckResult result = check(R"( @@ -4691,11 +646,12 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") _(nil) )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERRORS(result); std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK(get(t0->type)); + + CHECK_EQ("*error-type*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -4703,7 +659,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") CHECK(it != result.errors.end()); } -TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional2") +TEST_CASE_FIXTURE(BuiltinsFixture, "no_stack_overflow_from_isoptional2") { CheckResult result = check(R"( function _(l0:({})|(t0)):((((typeof((xpcall)))|(t96))|(t13))&(t96),()->typeof(...)) @@ -4730,26 +686,10 @@ TEST_CASE_FIXTURE(Fixture, "no_infinite_loop_when_trying_to_unify_uh_this") _() )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") -{ - CheckResult result = check(R"( - local l0,l0 - repeat - type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0) - function _(l0):(t0)&(t0) - while nil do - end - end - until _(_)(_)._ - )"); - - CHECK_LE(0, result.errors.size()); -} - -TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") +TEST_CASE_FIXTURE(BuiltinsFixture, "no_heap_use_after_free_error") { CheckResult result = check(R"( --!nonstrict @@ -4757,405 +697,17 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") local l0 do end while _ do - function _:_() - _ += _(_._(_:n0(xpcall,_))) - end - end - )"); - - CHECK_LE(0, result.errors.size()); -} - -TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") -{ - ScopedFastFlag sff{"LuauLogTableTypeVarBoundTo", true}; - - fileResolver.source["Module/Backend/Types"] = R"( - export type Fiber = { - return_: Fiber? - } - return {} - )"; - - fileResolver.source["Module/Backend"] = R"( - local Types = require(script.Types) - type Fiber = Types.Fiber - type ReactRenderer = { findFiberByHostInstance: () -> Fiber? } - - local function attach(renderer): () - local function getPrimaryFiber(fiber) - local alternate = fiber.alternate - return fiber - end - - local function getFiberIDForNative() - local fiber = renderer.findFiberByHostInstance() - fiber = fiber.return_ - return getPrimaryFiber(fiber) + function _:_() + _ += _(_._(_:n0(xpcall,_))) end end - - function culprit(renderer: ReactRenderer): () - attach(renderer) - end - - return culprit - )"; - - CheckResult result = frontend.check("Module/Backend"); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") -{ - CheckResult result = check(R"( - type Tree = { data: T, children: {Tree} } - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- this would be an infinite type if we allowed it - type Tree = { data: T, children: {Tree<{T}>} } )"); LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "record_matching_overload") -{ - ScopedFastFlag sffs("LuauStoreMatchingOverloadFnType", true); - - CheckResult result = check(R"( - type Overload = ((string) -> string) & ((number) -> number) - local abc: Overload - abc(1) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // AstExprCall is the node that has the overload stored on it. - // findTypeAtPosition will look at the AstExprLocal, but this is not what - // we want to look at. - std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), Position(3, 10)); - REQUIRE_GE(ancestry.size(), 2); - AstExpr* parentExpr = ancestry[ancestry.size() - 2]->asExpr(); - REQUIRE(bool(parentExpr)); - REQUIRE(parentExpr->is()); - - ModulePtr module = getMainModule(); - auto it = module->astOverloadResolvedTypes.find(parentExpr); - REQUIRE(it != module->astOverloadResolvedTypes.end()); - CHECK_EQ(toString(it->second), "(number) -> number"); -} - -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") -{ - ScopedFastFlag luauInferFunctionArgsFix("LuauInferFunctionArgsFix", true); - - // Simple direct arg to arg propagation - CheckResult result = check(R"( -type Table = { x: number, y: number } -local function f(a: (Table) -> number) return a({x = 1, y = 2}) end -f(function(a) return a.x + a.y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // An optional funciton is accepted, but since we already provide a function, nil can be ignored - result = check(R"( -type Table = { x: number, y: number } -local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end -f(function(a) return a.x + a.y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Make sure self calls match correct index - result = check(R"( -type Table = { x: number, y: number } -local x = {} -x.b = {x = 1, y = 2} -function x:f(a: (Table) -> number) return a(self.b) end -x:f(function(a) return a.x + a.y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Mix inferred and explicit argument types - result = check(R"( -function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end -f(function(a: number, b, c) return c and a + b or b - a end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Anonymous function has a varyadic pack - result = check(R"( -type Table = { x: number, y: number } -local function f(a: (Table) -> number) return a({x = 1, y = 2}) end -f(function(...) return select(1, ...).z end) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); - - // Can't accept more arguments than provided - result = check(R"( -function f(a: (a: number, b: number) -> number) return a(1, 2) end -f(function(a, b, c, ...) return a + b end) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(number, number, a) -> number' could not be converted into '(number, number) -> number'", toString(result.errors[0])); - - // Infer from varyadic packs into elements - result = check(R"( -function f(a: (...number) -> number) return a(1, 2) end -f(function(a, b) return a + b end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Infer from varyadic packs into varyadic packs - result = check(R"( -type Table = { x: number, y: number } -function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end -f(function(a, ...) local b = ... return b.z end) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); - - // Return type inference - result = check(R"( -type Table = { x: number, y: number } -function f(a: (number) -> Table) return a(4) end -f(function(x) return x * 2 end) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - - // Return type doesn't inference 'nil' - result = check(R"( -function f(a: (number) -> nil) return a(4) end -f(function(x) print(x) end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") -{ - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - - CheckResult result = check(R"( -local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end -return sum(2, 3, function(a, b) return a + b end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - result = check(R"( -local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end -local a = {1, 2, 3} -local r = map(a, function(a) return a + a > 100 end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{boolean}", toString(requireType("r"))); - - check(R"( -local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end -local a = {1, 2, 3} -local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{| c: number, s: number |}", toString(requireType("r"))); -} - -TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") -{ - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - - CheckResult result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end - -local g12: typeof(g1) & typeof(g2) - -g12(1, function(x) return x + x end) -g12(1, 2, function(x, y) return x + y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end - -local g12: typeof(g1) & typeof(g2) - -g12({x=1}, function(x) return {x=-x.x} end) -g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") -{ - CheckResult result = check(R"( - type Tree = { data: T, children: Forest } - type Forest = {Tree} - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- OK because forwarded types are used with their parameters. - type Tree = { data: T, children: Forest } - type Forest = {Tree<{T}>} - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- Not OK because forwarded types are used with different types than their parameters. - type Forest = {Tree<{T}>} - type Tree = { data: T, children: Forest } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") -{ - CheckResult result = check(R"( - type Tree1 = { data: T, children: {Tree2} } - type Tree2 = { data: U, children: {Tree1} } - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - type Tree1 = { data: T, children: {Tree2} } - type Tree2 = { data: U, children: {Tree1} } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") -{ - CheckResult result = check(R"( - function f(x) return x[1] end - -- x has type X? for a free type variable X - local x = f ({}) - type ContainsFree = { this: a, that: typeof(x) } - type ContainsContainsFree = { that: ContainsFree } - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") -{ - CheckResult result = check(R"( -local a = {{x=4}, {x=7}, {x=1}} -table.sort(a, function(x, y) return x.x < y.x end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") -{ - CheckResult result = check(R"( -type Table = { x: number, y: number } -local f: (Table) -> number = function(t) return t.x + t.y end - -type TableWithFunc = { x: number, y: number, f: (number, number) -> number } -local a: TableWithFunc = { x = 3, y = 4, f = function(a, b) return a + b end } - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "do_not_infer_generic_functions") -{ - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - - CheckResult result = check(R"( -local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end - -local function sumrec(f: typeof(sum)) - return sum(2, 3, function(a, b) return a + b end) -end - -local b = sumrec(sum) -- ok -local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") -{ - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - - CheckResult result = check(R"( -local function f(): {string|number} - return {1, "b", 3} -end - -local function g(): (number, {string|number}) - return 4, {1, "b", 3} -end - -local function h(): ...{string|number} - return {4}, {1, "b", 3}, {"s"} -end - -local function i(): ...{string|number} - return {1, "b", 3}, h() -end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "infer_type_assertion_value_type") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local function f() return {4, "b", 3} :: {string|number} @@ -5167,8 +719,6 @@ end TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local a: (number, number) -> number = function(a, b) return a - b end @@ -5182,10 +732,8 @@ b, c = {2, "s"}, {"b", 4} LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types_mutable_lval") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_assignment_value_types_mutable_lval") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local a = {} a.x = 2 @@ -5195,65 +743,11 @@ a = setmetatable(a, { __call = function(x) end }) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "refine_and_or") +TEST_CASE_FIXTURE(Fixture, "infer_through_group_expr") { - ScopedFastFlag sff{"LuauSlightlyMoreFlexibleBinaryPredicates", true}; - CheckResult result = check(R"( - local t: {x: number?}? = {x = nil} - local u = t and t.x or 5 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("number", toString(requireType("u"))); -} - -TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") -{ - ScopedFastFlag sffs[] = { - {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - {"LuauExtraNilRecovery", true}, - }; - - CheckResult result = check(R"( - local t: {x: number?}? = {x = nil} - local u = t.x and t or 5 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); - CHECK_EQ("number | {| x: number? |}", toString(requireType("u"))); -} - -TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") -{ - ScopedFastFlag sffs[] = { - {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - {"LuauExtraNilRecovery", true}, - }; - - CheckResult result = check(R"( - local t: {x: number?}? = {x = nil} - local u = t and t.x == 5 or t.x == 31337 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); - CHECK_EQ("boolean", toString(requireType("u"))); -} - -TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") -{ - ScopedFastFlag luauFollowInTypeFunApply("LuauFollowInTypeFunApply", true); - - CheckResult result = check(R"( -type A = { x: number } -local a: A = { x = 1 } -local b = a -type B = typeof(b) -type X = T -local c: X +local function f(a: (number, number) -> number) return a(1, 3) end +f(((function(a, b) return a + b end))) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -5261,46 +755,421 @@ local c: X TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - { - CheckResult result = check(R"(local a = if true then "true" else "false")"); - LUAU_REQUIRE_NO_ERRORS(result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); - } + CheckResult result = check(R"(local a = if true then "true" else "false")"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - { - // Test expression containing elseif - CheckResult result = check(R"( + // Test expression containing elseif + CheckResult result = check(R"( local a = if false then "a" elseif false then "b" else "c" - )"); + )"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union") +{ + CheckResult result = check(R"(local a: number? = if true then 42 else nil)"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a"), {true}), "number?"); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_1") +{ + CheckResult result = check(R"( +type X = {number | string} +local a: X = if true then {"1", 2, 3} else {4, 5, 6} +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a"), {true}), "{number | string}"); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_2") +{ + CheckResult result = check(R"( +local a: number? = if true then 1 else nil +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tc_if_else_expressions_expected_type_3") +{ + CheckResult result = check(R"( +local function times(n: any, f: () -> T) + local result: {T} = {} + local res = f() + table.insert(result, if true then res else n) + return result +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_interpolated_string_basic") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CheckResult result = check(R"( + local foo: string = `hello {"world"}` + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_interpolated_string_with_invalid_expression") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CheckResult result = check(R"( + local function f(x: number) end + + local foo: string = `hello {f("uh oh")}` + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_interpolated_string_constant_type") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CheckResult result = check(R"( + local foo: "hello" = `hello` + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * If it wasn't instantly obvious, we have the fuzzer to thank for this gem of a test. + * + * We had an issue here where the scope for the `if` block here would + * have an elevated TypeLevel even though there is no function nesting going on. + * This would result in a free typevar for the type of _ that was much higher than + * it should be. This type would be erroneously quantified in the definition of `aaa`. + * This in turn caused an ice when evaluating `_()` in the while loop. + */ +TEST_CASE_FIXTURE(Fixture, "free_typevars_introduced_within_control_flow_constructs_do_not_get_an_elevated_TypeLevel") +{ + check(R"( + --!strict + if _ then + _[_], _ = nil + _() + end + + local aaa = function():typeof(_) return 1 end + + if aaa then + while _() do + end + end + )"); + + // No ice()? No problem. +} + +/* + * This is a bit elaborate. Bear with me. + * + * The type of _ becomes free with the first statement. With the second, we unify it with a function. + * + * At this point, it is important that the newly created fresh types of this new function type are promoted + * to the same level as the original free type. If we do not, they are incorrectly ascribed the level of the + * containing function. + * + * If this is allowed to happen, the final lambda erroneously quantifies the type of _ to something ridiculous + * just before we typecheck the invocation to _. + */ +TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") +{ + check(R"( + l0, _ = nil + + local function p() + _() + end + + a = _( + function():(typeof(p),typeof(_)) + end + )[nil] + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_crash") +{ + CheckResult result = check(R"( +local function getIt() + local y + y = setmetatable({}, y) + return y +end +local a = getIt() +local b = getIt() +local c = a or b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "bound_typepack_promote") +{ + // No assertions should trigger + check(R"( +local function p() + local this = {} + this.pf = foo() + function this:IsActive() end + function this:Start(o) end + return this +end + +local function h(tp, o) + ep = tp + tp:Start(o) + tp.pf.Connect(function() + ep:IsActive() + end) +end + +function on() + local t = p() + h(t) +end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") +{ + CheckResult result = check(R"( + --!strict + --!nolint + + type FieldSpecifier = { + fieldName: string, + } + + type ReadFieldOptions = FieldSpecifier & { from: number? } + + type Policies = { + getStoreFieldName: (self: Policies, fieldSpec: FieldSpecifier) -> string, + } + + local Policies = {} + + local function foo(p: Policies) + end + + function Policies:getStoreFieldName(specifier: FieldSpecifier): string + return "" + end + + function Policies:readField(options: ReadFieldOptions) + local _ = self:getStoreFieldName(options) + --[[ + Type error: + TypeError { "MainModule", Location { { line = 25, col = 16 }, { line = 25, col = 20 } }, TypeMismatch { Policies, {- getStoreFieldName: (tp1) -> (a, b...) -} } } + ]] + foo(self) + end + )"); + + if (FFlag::LuauInstantiateInSubtyping) + { + // though this didn't error before the flag, it seems as though it should error since fields of a table are invariant. + // the user's intent would likely be that these "method" fields would be read-only, but without an annotation, accepting this should be + // unsound. + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ( + R"(Type 't1 where t1 = {+ getStoreFieldName: (t1, {| fieldName: string |} & {| from: number? |}) -> (a, b...) +}' could not be converted into 'Policies' +caused by: + Property 'getStoreFieldName' is not compatible. Type 't1 where t1 = ({+ getStoreFieldName: t1 +}, {| fieldName: string |} & {| from: number? |}) -> (a, b...)' could not be converted into '(Policies, FieldSpecifier) -> string' +caused by: + Argument #2 type is not compatible. Type 'FieldSpecifier' could not be converted into 'FieldSpecifier & {| from: number? |}' +caused by: + Not all intersection parts are compatible. Table type 'FieldSpecifier' not compatible with type '{| from: number? |}' because the former has extra field 'fieldName')", + toString(result.errors[0])); + } + else + { LUAU_REQUIRE_NO_ERRORS(result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); } } -TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") +TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + ScopedFastInt sfi("LuauTypeInferRecursionLimit", 2); - { - CheckResult result = check(R"(local a = if true then "true" else 42)"); - // We currently require both true/false expressions to unify to the same type. However, we do intend to lift - // this restriction in the future. - LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); - } + CheckResult result = check(R"( + function complex() + function _(l0:t0): (any, ()->()) + return 0,_ + end + type t0 = t0 | {} + _(nil) + end + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") +{ + ScopedFastInt sfi("LuauTypeInferRecursionLimit", 10); + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : a&b&c&d&e&f&g&h&(i?) + local y : (a&b&c&d&e&f&g&h&i)? = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Internal error: Code is too complex to typecheck! Consider adding type annotations around this area", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "type_infer_cache_limit_normalizer") +{ + ScopedFastInt sfi("LuauNormalizeCacheLimit", 10); + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number) -> number) & ((string) -> string) & ((nil) -> nil) & (({}) -> {}) + local y : (number | string | nil | {}) -> (number | string | nil | {}) = x + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Internal error: Code is too complex to typecheck! Consider adding type annotations around this area", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") +{ + CheckResult result = check(R"( + local obj = {} + + function obj:Method() + self.fieldA = function(object) + if object.a then + self.arr[object] = true + elseif object.b then + self.fieldB[object] = object:Connect(function(arg) + self.arr[arg] = nil + end) + end + end + end + + return obj + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "types_stored_in_astResolvedTypes") +{ + CheckResult result = check(R"( +type alias = typeof("hello") +local function foo(param: alias) +end + )"); + + auto node = findNodeAtPosition(*getMainSourceModule(), {2, 16}); + auto ty = lookupType("alias"); + REQUIRE(node); + REQUIRE(node->is()); + REQUIRE(ty); + + auto func = node->as(); + REQUIRE(func->args.size == 1); + + auto arg = *func->args.begin(); + auto annotation = arg->annotation; + + CHECK_EQ(*getMainModule()->astResolvedTypes.find(annotation), *ty); +} + +TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_of_higher_order_function") +{ + CheckResult result = check(R"( + function higher(cb: (number) -> ()) end + + higher(function(n) -- no error here. n : number + local e: string = n -- error here. n /: string + end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + Location location = result.errors[0].location; + CHECK(location.begin.line == 4); + CHECK(location.end.line == 4); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + function validate(stats, hits, misses) + local checked = {} + + for _,l in ipairs(hits) do + if not (stats[l] and stats[l] > 0) then + return false, string.format("expected line %d to be hit", l) + end + checked[l] = true + end + + for _,l in ipairs(misses) do + if not (stats[l] and stats[l] == 0) then + return false, string.format("expected line %d to be missed", l) + end + checked[l] = true + end + + for k,v in pairs(stats) do + if type(k) == "number" and not checked[k] then + return false, string.format("expected line %d to be absent", k) + end + end + + return true + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_free_table_type_change_during_index_check") +{ + ScopedFastFlag luauFollowInLvalueIndexCheck{"LuauFollowInLvalueIndexCheck", true}; + + CheckResult result = check(R"( +local _ = nil +while _["" >= _] do +end + )"); + + LUAU_REQUIRE_ERRORS(result); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 91ac9f06..b1abdf7c 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -14,7 +14,9 @@ struct TryUnifyFixture : Fixture TypeArena arena; ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; - Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, &iceHandler}; + UnifierSharedState unifierState{&iceHandler}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&unifierState}}; + Unifier state{NotNull{&normalizer}, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant}; }; TEST_SUITE_BEGIN("TryUnifyTests"); @@ -24,7 +26,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TypeVar numberOne{TypeVariant{PrimitiveTypeVar{PrimitiveTypeVar::Number}}}; TypeVar numberTwo = numberOne; - state.tryUnify(&numberOne, &numberTwo); + state.tryUnify(&numberTwo, &numberOne); CHECK(state.errors.empty()); } @@ -37,9 +39,11 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TypeVar functionTwo{TypeVariant{ FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; - state.tryUnify(&functionOne, &functionTwo); + state.tryUnify(&functionTwo, &functionOne); CHECK(state.errors.empty()); + state.log.commit(); + CHECK_EQ(functionOne, functionTwo); } @@ -57,7 +61,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") TypeVar functionTwoSaved = functionTwo; - state.tryUnify(&functionOne, &functionTwo); + state.tryUnify(&functionTwo, &functionOne); CHECK(!state.errors.empty()); CHECK_EQ(functionOne, functionOneSaved); @@ -76,10 +80,12 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); - state.tryUnify(&tableOne, &tableTwo); + state.tryUnify(&tableTwo, &tableOne); CHECK(state.errors.empty()); + state.log.commit(); + CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } @@ -97,15 +103,75 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); - state.tryUnify(&tableOne, &tableTwo); + state.tryUnify(&tableTwo, &tableOne); CHECK_EQ(1, state.errors.size()); - state.log.rollback(); - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } +TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_never") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f(arg : string & number) : never + return arg + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_anything") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f(arg : string & number) : boolean + return arg + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_never") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + {"LuauUninhabitedSubAnything2", true}, + }; + + CheckResult result = check(R"( + function f(arg : { prop : string & number }) : never + return arg + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_table_sub_anything") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + {"LuauUninhabitedSubAnything2", true}, + }; + + CheckResult result = check(R"( + function f(arg : { prop : string & number }) : boolean + return arg + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_unified_with_errorType") { CheckResult result = check(R"( @@ -117,9 +183,24 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId bType = requireType("b"); + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*error-type*", toString(requireType("b"))); +} - CHECK_MESSAGE(get(bType), "Should be an error: " << toString(bType)); +TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") +{ + CheckResult result = check(R"( + function f(arg: number) return arg end + local a + local b + local c = f(a, b) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*error-type*", toString(requireType("b"))); + CHECK_EQ("number", toString(requireType("c"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails") @@ -138,7 +219,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails" )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); + CHECK_EQ("(number) -> boolean", toString(requireType("f"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") @@ -146,7 +227,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") TypePackVar testPack{TypePack{{typeChecker.numberType, typeChecker.stringType}, std::nullopt}}; TypePackVar variadicPack{VariadicTypePack{typeChecker.numberType}}; - state.tryUnify(&variadicPack, &testPack); + state.tryUnify(&testPack, &variadicPack); CHECK(!state.errors.empty()); } @@ -156,24 +237,12 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") TypePackVar a{TypePack{{typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType, typeChecker.booleanType}}}; TypePackVar b{TypePack{{typeChecker.numberType, typeChecker.stringType}, &variadicPack}}; - state.tryUnify(&a, &b); + state.tryUnify(&b, &a); CHECK(state.errors.empty()); } -TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_work") -{ - TypePackId variadicPack = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}}); - TypePackId errorPack = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{Unifiable::Error{}})}); - - state.tryUnify(variadicPack, errorPack); - REQUIRE_EQ(0, state.errors.size()); -} - TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") { - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( --!strict local function f(...: T): ...T @@ -190,10 +259,8 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") CHECK_EQ(toString(tm->wantedType), "string"); } -TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unification") +TEST_CASE_FIXTURE(BuiltinsFixture, "cli_41095_concat_log_in_sealed_table_unification") { - ScopedFastFlag sffs2("LuauGenericFunctions", true); - CheckResult result = check(R"( --!strict table.insert() @@ -204,4 +271,109 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unifica CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); } +TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") +{ + TypePackId threeNumbers = arena.addTypePack(TypePack{{typeChecker.numberType, typeChecker.numberType, typeChecker.numberType}, std::nullopt}); + TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); + + ErrorVec unifyErrors = state.canUnify(numberAndFreeTail, threeNumbers); + CHECK(unifyErrors.size() == 0); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") +{ + TypeVar redirect{FreeTypeVar{TypeLevel{}}}; + TypeVar table{TableTypeVar{}}; + TypeVar metatable{MetatableTypeVar{&redirect, &table}}; + redirect = BoundTypeVar{&metatable}; // Now we have a metatable that is recursive on the table type + TypeVar variant{UnionTypeVar{{&metatable, typeChecker.numberType}}}; + + state.tryUnify(&metatable, &variant); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") +{ + TypePackVar free{FreeTypePack{TypeLevel{}}}; + TypePackVar target{TypePack{}}; + + TypeVar func{FunctionTypeVar{&free, &free}}; + + state.tryUnify(&free, &target); + // Shouldn't assert or error. + state.tryUnify(&func, typeChecker.anyType); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") +{ + TypeId a = arena.addType(TypeVar{FreeTypeVar{TypeLevel{}}}); + TypeId b = typeChecker.numberType; + + state.tryUnify(a, b); + state.log.commit(); + + CHECK_EQ(a->owningArena, &arena); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") +{ + TypePackId a = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); + TypePackId b = typeChecker.anyTypePack; + + state.tryUnify(a, b); + state.log.commit(); + + CHECK_EQ(a->owningArena, &arena); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table") +{ + ScopedFastFlag sff("DebugLuauDeferredConstraintResolution", true); + + TableTypeVar::Props freeProps{ + {"foo", {typeChecker.numberType}}, + }; + + TypeId free = arena.addType(TableTypeVar{freeProps, std::nullopt, TypeLevel{}, TableState::Free}); + + TableTypeVar::Props indexProps{ + {"foo", {typeChecker.stringType}}, + }; + + TypeId index = arena.addType(TableTypeVar{indexProps, std::nullopt, TypeLevel{}, TableState::Sealed}); + + TableTypeVar::Props mtProps{ + {"__index", {index}}, + }; + + TypeId mt = arena.addType(TableTypeVar{mtProps, std::nullopt, TypeLevel{}, TableState::Sealed}); + + TypeId target = arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); + TypeId metatable = arena.addType(MetatableTypeVar{target, mt}); + + state.tryUnify(metatable, free); + state.log.commit(); + + REQUIRE_EQ(state.errors.size(), 1); + + std::string expected = "Type '{ @metatable {| __index: {| foo: string |} |}, { } }' could not be converted into '{- foo: number -}'\n" + "caused by:\n" + " Type 'number' could not be converted into 'string'"; + CHECK_EQ(toString(state.errors[0]), expected); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") +{ + ScopedFastFlag luauTxnLogTypePackIterator{"LuauTxnLogTypePackIterator", true}; + + TypePackVar variadicAny{VariadicTypePack{typeChecker.anyType}}; + TypePackVar packTmp{TypePack{{typeChecker.anyType}, &variadicAny}}; + TypePackVar packSub{TypePack{{typeChecker.anyType, typeChecker.anyType}, &packTmp}}; + + TypeVar freeTy{FreeTypeVar{TypeLevel{}}}; + TypePackVar freeTp{FreeTypePack{TypeLevel{}}}; + TypePackVar packSuper{TypePack{{&freeTy}, &freeTp}}; + + state.tryUnify(&packSub, &packSuper); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 5f7f2847..0e4074f7 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -25,11 +24,11 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") const FunctionTypeVar* takeTwoType = get(requireType("take_two")); REQUIRE(takeTwoType != nullptr); - const auto& [returns, tail] = flatten(takeTwoType->retType); + const auto& [returns, tail] = flatten(takeTwoType->retTypes); CHECK_EQ(2, returns.size()); - CHECK_EQ(typeChecker.numberType, returns[0]); - CHECK_EQ(typeChecker.numberType, returns[1]); + CHECK_EQ(typeChecker.numberType, follow(returns[0])); + CHECK_EQ(typeChecker.numberType, follow(returns[1])); CHECK(!tail); } @@ -72,12 +71,12 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac const FunctionTypeVar* takeOneMoreType = get(requireType("take_three")); REQUIRE(takeOneMoreType != nullptr); - const auto& [rets, tail] = flatten(takeOneMoreType->retType); + const auto& [rets, tail] = flatten(takeOneMoreType->retTypes); REQUIRE_EQ(3, rets.size()); - CHECK_EQ(typeChecker.numberType, rets[0]); - CHECK_EQ(typeChecker.numberType, rets[1]); - CHECK_EQ(typeChecker.numberType, rets[2]); + CHECK_EQ(typeChecker.numberType, follow(rets[0])); + CHECK_EQ(typeChecker.numberType, follow(rets[1])); + CHECK_EQ(typeChecker.numberType, follow(rets[2])); CHECK(!tail); } @@ -92,26 +91,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* applyType = get(requireType("apply")); - REQUIRE(applyType != nullptr); - - std::vector applyArgs = flatten(applyType->argTypes).first; - REQUIRE_EQ(3, applyArgs.size()); - - const FunctionTypeVar* fType = get(applyArgs[0]); - REQUIRE(fType != nullptr); - - const FunctionTypeVar* gType = get(applyArgs[1]); - REQUIRE(gType != nullptr); - - std::vector gArgs = flatten(gType->argTypes).first; - REQUIRE_EQ(1, gArgs.size()); - - // function(function(t1, T2...): (t3, T4...), function(t5): (t1, T2...), t5): (t3, T4...) - - REQUIRE_EQ(*gArgs[0], *applyArgs[2]); - REQUIRE_EQ(toString(fType->argTypes), toString(gType->retType)); - REQUIRE_EQ(toString(fType->retType), toString(applyType->retType)); + CHECK_EQ("((b...) -> (c...), (a) -> (b...), a) -> (c...)", toString(requireType("apply"))); } TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") @@ -123,10 +103,10 @@ TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") LUAU_REQUIRE_NO_ERRORS(result); const FunctionTypeVar* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); - CHECK_EQ(0, size(fTy->retType)); + CHECK_EQ(0, size(fTy->retTypes)); const FunctionTypeVar* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); - CHECK_EQ(0, size(gTy->retType)); + CHECK_EQ(0, size(gTy->retTypes)); } TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") @@ -143,15 +123,15 @@ TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") const FunctionTypeVar* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); - CHECK_EQ(1, size(follow(fTy->retType))); + CHECK_EQ(1, size(follow(fTy->retTypes))); const FunctionTypeVar* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); - CHECK_EQ(0, size(gTy->retType)); + CHECK_EQ(0, size(gTy->retTypes)); const FunctionTypeVar* hTy = get(requireType("h")); REQUIRE(hTy != nullptr); - CHECK_EQ(0, size(hTy->retType)); + CHECK_EQ(0, size(hTy->retTypes)); } TEST_CASE_FIXTURE(Fixture, "varargs_inference_through_multiple_scopes") @@ -212,7 +192,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_packs") TypePackId listOfStrings = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.stringType}}); // clang-format off - addGlobalBinding(typeChecker, "foo", + addGlobalBinding(frontend, "foo", arena.addType( FunctionTypeVar{ listOfNumbers, @@ -221,7 +201,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_packs") ), "@test" ); - addGlobalBinding(typeChecker, "bar", + addGlobalBinding(frontend, "bar", arena.addType( FunctionTypeVar{ arena.addTypePack({{typeChecker.numberType}, listOfStrings}), @@ -263,8 +243,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_pack_syntax") CHECK_EQ(toString(requireType("foo")), "(...number) -> ()"); } -// CLI-45791 -TEST_CASE_FIXTURE(UnfrozenFixture, "type_pack_hidden_free_tail_infinite_growth") +TEST_CASE_FIXTURE(Fixture, "type_pack_hidden_free_tail_infinite_growth") { CheckResult result = check(R"( --!nonstrict @@ -285,7 +264,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") { CheckResult result = check(R"( local _ = function():((...any)->(...any),()->()) - return function() end, function() end + return function() end, function() end end for y in _() do end @@ -294,4 +273,785 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs") +{ + CheckResult result = check(R"( +type Packed = (T...) -> T... +local a: Packed<> +local b: Packed +local c: Packed + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tf = lookupType("Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "(T...) -> (T...)"); + CHECK_EQ(toString(requireType("a")), "() -> ()"); + CHECK_EQ(toString(requireType("b")), "(number) -> number"); + CHECK_EQ(toString(requireType("c")), "(string, number) -> (string, number)"); + + result = check(R"( +-- (U..., T) cannot be parsed right now +type Packed = { f: (a: T, U...) -> (T, U...) } +local a: Packed +local b: Packed +local c: Packed + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + tf = lookupType("Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Packed"); + CHECK_EQ(toString(*tf, {true}), "{| f: (T, U...) -> (T, U...) |}"); + + auto ttvA = get(requireType("a")); + REQUIRE(ttvA); + CHECK_EQ(toString(requireType("a")), "Packed"); + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}"); + REQUIRE(ttvA->instantiatedTypeParams.size() == 1); + REQUIRE(ttvA->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); + CHECK_EQ(toString(ttvA->instantiatedTypePackParams[0], {true}), ""); + + auto ttvB = get(requireType("b")); + REQUIRE(ttvB); + CHECK_EQ(toString(requireType("b")), "Packed"); + CHECK_EQ(toString(requireType("b"), {true}), "{| f: (string, number) -> (string, number) |}"); + REQUIRE(ttvB->instantiatedTypeParams.size() == 1); + REQUIRE(ttvB->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvB->instantiatedTypeParams[0], {true}), "string"); + CHECK_EQ(toString(ttvB->instantiatedTypePackParams[0], {true}), "number"); + + auto ttvC = get(requireType("c")); + REQUIRE(ttvC); + CHECK_EQ(toString(requireType("c")), "Packed"); + CHECK_EQ(toString(requireType("c"), {true}), "{| f: (string, number, boolean) -> (string, number, boolean) |}"); + REQUIRE(ttvC->instantiatedTypeParams.size() == 1); + REQUIRE(ttvC->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvC->instantiatedTypeParams[0], {true}), "string"); + CHECK_EQ(toString(ttvC->instantiatedTypePackParams[0], {true}), "number, boolean"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_type_packs_import") +{ + fileResolver.source["game/A"] = R"( +export type Packed = { a: T, b: (U...) -> () } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +local a: Import.Packed +local b: Import.Packed +local c: Import.Packed +local d: { a: typeof(c) } + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + auto tf = lookupImportedType("Import", "Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Packed"); + CHECK_EQ(toString(*tf, {true}), "{| a: T, b: (U...) -> () |}"); + + CHECK_EQ(toString(requireType("a"), {true}), "{| a: number, b: () -> () |}"); + CHECK_EQ(toString(requireType("b"), {true}), "{| a: string, b: (number) -> () |}"); + CHECK_EQ(toString(requireType("c"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + CHECK_EQ(toString(requireType("d")), "{| a: Packed |}"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_pack_type_parameters") +{ + fileResolver.source["game/A"] = R"( +export type Packed = { a: T, b: (U...) -> () } +return {} + )"; + + CheckResult cResult = check(R"( +local Import = require(game.A) +type Alias = Import.Packed +local a: Alias + +type B = Import.Packed +type C = Import.Packed + )"); + LUAU_REQUIRE_NO_ERRORS(cResult); + + auto tf = lookupType("Alias"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Alias"); + CHECK_EQ(toString(*tf, {true}), "{| a: S, b: (T, R...) -> () |}"); + + CHECK_EQ(toString(requireType("a"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + + tf = lookupType("B"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "B"); + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (X...) -> () |}"); + + tf = lookupType("C"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "C"); + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (number, X...) -> () |}"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") +{ + CheckResult result = check(R"( +type Packed1 = (T...) -> (T...) +type Packed2 = (Packed1, T...) -> (Packed1, T...) +type Packed3 = (Packed2, T...) -> (Packed2, T...) +type Packed4 = (Packed3, T...) -> (Packed3, T...) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tf = lookupType("Packed4"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), + "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...) -> " + "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") +{ + CheckResult result = check(R"( +type X = (T...) -> (string, T...) + +type D = X<...number> +type E = X<(number, ...string)> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("D")), "(...number) -> (string, ...number)"); + CHECK_EQ(toString(*lookupType("E")), "(number, ...string) -> (string, number, ...string)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") +{ + CheckResult result = check(R"( +type Y = (T...) -> (U...) +type A = Y +type B = Y<(number, ...string), S...> + +type Z = (T) -> (U...) +type E = Z +type F = Z + +type W = (T, U...) -> (T, V...) +type H = W +type I = W + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)"); + CHECK_EQ(toString(*lookupType("B")), "(number, ...string) -> (S...)"); + + CHECK_EQ(toString(*lookupType("E")), "(number) -> (S...)"); + CHECK_EQ(toString(*lookupType("F")), "(number) -> (string, S...)"); + + CHECK_EQ(toString(*lookupType("H")), "(number, S...) -> (number, R...)"); + CHECK_EQ(toString(*lookupType("I")), "(number, string, S...) -> (number, R...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") +{ + CheckResult result = check(R"( +type X = (T...) -> (T...) + +type A = X<(S...)> +type B = X<()> +type C = X<(number)> +type D = X<(number, string)> +type E = X<(...number)> +type F = X<(string, ...number)> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)"); + CHECK_EQ(toString(*lookupType("B")), "() -> ()"); + CHECK_EQ(toString(*lookupType("C")), "(number) -> number"); + CHECK_EQ(toString(*lookupType("D")), "(number, string) -> (number, string)"); + CHECK_EQ(toString(*lookupType("E")), "(...number) -> (...number)"); + CHECK_EQ(toString(*lookupType("F")), "(string, ...number) -> (string, ...number)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") +{ + CheckResult result = check(R"( +type Y = (T...) -> (U...) + +type A = Y<(number, string), (boolean)> +type B = Y<(), ()> +type C = Y<...string, (number, S...)> +type D = Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(number, string) -> boolean"); + CHECK_EQ(toString(*lookupType("B")), "() -> ()"); + CHECK_EQ(toString(*lookupType("C")), "(...string) -> (number, S...)"); + CHECK_EQ(toString(*lookupType("D")), "(X...) -> (number, string, X...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring") +{ + CheckResult result = check(R"( +type Y = { f: (T...) -> (U...) } + +local a: Y<(number, string), (boolean)> +local b: Y<(), ()> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (boolean)>"); + CHECK_EQ(toString(requireType("b")), "Y<(), ()>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") +{ + CheckResult result = check(R"( +type X = () -> T +type Y = (T) -> U + +type A = X<(number)> +type B = Y<(number), (boolean)> +type C = Y<(number), boolean> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "() -> number"); + CHECK_EQ(toString(*lookupType("B")), "(number) -> boolean"); + CHECK_EQ(toString(*lookupType("C")), "(number) -> boolean"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors") +{ + CheckResult result = check(R"( +type Packed = (T, U) -> (V...) +local b: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects at least 2 type arguments, but only 1 is specified"); + + result = check(R"( +type Packed = (T, U) -> () +type B = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 0 type pack arguments, but 1 is specified"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameters must come before type pack parameters"); + + result = check(R"( +type Packed = (T) -> U +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type arguments, but only 1 is specified"); + + result = check(R"( +type Packed = (T...) -> T... +local a: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameter list is required"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed<> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but none are specified"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but only 1 is specified"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_explicit") +{ + CheckResult result = check(R"( +type Y = { a: T, b: U } + +local a: Y = { a = 2, b = 3 } +local b: Y = { a = 2, b = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + + result = check(R"( +type Y = { a: T } + +local a: Y = { a = 2 } +local b: Y<> = { a = "s" } +local c: Y = { a = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + CHECK_EQ(toString(requireType("c")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_self") +{ + CheckResult result = check(R"( +type Y = { a: T, b: U } + +local a: Y = { a = 2, b = 3 } +local b: Y = { a = "h", b = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + + result = check(R"( +type Y string> = { a: T, b: U } + +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y string>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_chained") +{ + CheckResult result = check(R"( +type Y = { a: T, b: U, c: V } + +local a: Y +local b: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_explicit") +{ + CheckResult result = check(R"( +type Y = { a: (T...) -> () } +local a: Y<> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_ty") +{ + CheckResult result = check(R"( +type Y = { a: T, b: (U...) -> T } + +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_tp") +{ + CheckResult result = check(R"( +type Y = { a: (T...) -> U... } +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string)>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_chained_tp") +{ + CheckResult result = check(R"( +type Y = { a: (T...) -> U..., b: (T...) -> V... } +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string), (number, string)>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_mixed_self") +{ + CheckResult result = check(R"( +type Y = { a: (T, U, V...) -> W... } +local a: Y +local b: Y +local c: Y +local d: Y ()> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + CHECK_EQ(toString(requireType("c")), "Y"); + CHECK_EQ(toString(requireType("d")), "Y ()>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors") +{ + CheckResult result = check(R"( +type Y = { a: T } +local a: Y = { a = 2 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); + + result = check(R"( +type Y = { a: (T...) -> () } +local a: Y<> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); + + result = check(R"( +type Y = { a: (T) -> U... } +local a: Y<...number> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Y' expects at least 1 type argument, but none are specified"); + + result = check(R"( +type Packed = (T) -> T +local a: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameter list is required"); + + result = check(R"( +type Y = { a: T } +local a: Y + )"); + + LUAU_REQUIRE_ERRORS(result); + + result = check(R"( +type Y = { a: T } +local a: Y<...number> + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_default_export") +{ + fileResolver.source["Module/Types"] = R"( +export type A = { a: T, b: U } +export type B = { a: T, b: U } +export type C string> = { a: T, b: U } +export type D = { a: T, b: U, c: V } +export type E = { a: (T...) -> () } +export type F = { a: T, b: (U...) -> T } +export type G = { b: (U...) -> T... } +export type H = { b: (T...) -> T... } +return {} + )"; + + CheckResult resultTypes = frontend.check("Module/Types"); + LUAU_REQUIRE_NO_ERRORS(resultTypes); + + fileResolver.source["Module/Users"] = R"( +local Types = require(script.Parent.Types) + +local a: Types.A +local b: Types.B +local c: Types.C +local d: Types.D +local e: Types.E<> +local eVoid: Types.E<()> +local f: Types.F +local g: Types.G<...number> +local h: Types.H<> + )"; + + CheckResult resultUsers = frontend.check("Module/Users"); + LUAU_REQUIRE_NO_ERRORS(resultUsers); + + CHECK_EQ(toString(requireType("Module/Users", "a")), "A"); + CHECK_EQ(toString(requireType("Module/Users", "b")), "B"); + CHECK_EQ(toString(requireType("Module/Users", "c")), "C string>"); + CHECK_EQ(toString(requireType("Module/Users", "d")), "D"); + CHECK_EQ(toString(requireType("Module/Users", "e")), "E"); + CHECK_EQ(toString(requireType("Module/Users", "eVoid")), "E<>"); + CHECK_EQ(toString(requireType("Module/Users", "f")), "F"); + CHECK_EQ(toString(requireType("Module/Users", "g")), "G<...number, ()>"); + CHECK_EQ(toString(requireType("Module/Users", "h")), "H<>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_skip_brackets") +{ + CheckResult result = check(R"( +type Y = (T...) -> number +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "(...string) -> number"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_confusing_types") +{ + CheckResult result = check(R"( +type A = (T, V...) -> (U, W...) +type B = A +type C = A + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("B"), {true}), "(string, ...any) -> (number, ...any)"); + CHECK_EQ(toString(*lookupType("C"), {true}), "(string, boolean) -> (number, boolean)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_recursive_type") +{ + CheckResult result = check(R"( +type F ()> = (K) -> V +type R = { m: F } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("R"), {true}), "t1 where t1 = {| m: (t1) -> (t1) -> () |}"); +} + +TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check") +{ + CheckResult result = check(R"( +local a: () -> (number, ...string) +local b: () -> (number, ...boolean) +a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> (number, ...boolean)' could not be converted into '() -> (number, ...string)' +caused by: + Type 'boolean' could not be converted into 'string')"); +} + +// TODO: File a Jira about this +/* +TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack") +{ + CheckResult result = check(R"( + function a(x) return 1 end + a(...) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE(bool(getMainModule()->getModuleScope()->varargPack)); + + TypePackId varargPack = *getMainModule()->getModuleScope()->varargPack; + + auto iter = begin(varargPack); + auto endIter = end(varargPack); + + CHECK(iter != endIter); + ++iter; + CHECK(iter == endIter); + + CHECK(!iter.tail()); +} +*/ + +TEST_CASE_FIXTURE(Fixture, "dont_ice_if_a_TypePack_is_an_error") +{ + CheckResult result = check(R"( + --!strict + function f(s) + print(s) + return f + end + + f("foo")("bar") + )"); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs") +{ + // this has a risk of creating cyclic type packs, causing infinite loops / OOMs + check(R"( +--!nonstrict +_ += _(_,...) +repeat +_ += _(...) +until ... + _ +)"); + + check(R"( +--!nonstrict +_ += _(_(...,...),_(...)) +repeat +until _ +)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "detect_cyclic_typepacks") +{ + CheckResult result = check(R"( + type ( ... ) ( ) ; + ( ... ) ( - - ... ) ( - ... ) + type = ( ... ) ; + ( ... ) ( ) ( ... ) ; + ( ... ) "" + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "detect_cyclic_typepacks2") +{ + CheckResult result = check(R"( + function _(l0:((typeof((pcall)))|((((t0)->())|(typeof(-67108864)))|(any)))|(any),...):(((typeof(0))|(any))|(any),typeof(-67108864),any) + xpcall(_,_,_) + _(_,_,_) + end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments") +{ + CheckResult result = check(R"( + function foo(...: string): number + return 1 + end + + function bar(...: number): number + return foo(...) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type 'number' could not be converted into 'string'"); +} + +TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments_free") +{ + CheckResult result = check(R"( + function foo(...: T...): T... + return ... + end + + function bar(...: number): boolean + return foo(...) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type 'number' could not be converted into 'boolean'"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_packs_with_tails_in_vararg_adjustment") +{ + std::optional sff; + if (FFlag::DebugLuauDeferredConstraintResolution) + sff = {"LuauInstantiateInSubtyping", true}; + + CheckResult result = check(R"( + local function wrapReject(fn: (self: any, ...TArg) -> ...TResult): (self: any, ...TArg) -> ...TResult + return function(self, ...) + local arguments = { ... } + local ok, result = pcall(function() + return fn(self, table.unpack(arguments)) + end) + return result + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "generalize_expectedTypes_with_proper_scope") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + {"LuauInstantiateInSubtyping", true}, + }; + + CheckResult result = check(R"( + local function f(fn: () -> ...TResult): () -> ...TResult + return function() + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_typepack_iter_follow") +{ + ScopedFastFlag luauTxnLogTypePackIterator{"LuauTxnLogTypePackIterator", true}; + + CheckResult result = check(R"( +local _ +local _ = _,_(),_(_) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_typepack_iter_follow_2") +{ + ScopedFastFlag luauTxnLogTypePackIterator{"LuauTxnLogTypePackIterator", true}; + + CheckResult result = check(R"( +function test(name, searchTerm) + local found = string.find(name:lower(), searchTerm:lower()) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index ae4d836b..adfc61b6 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -7,8 +6,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauEqConstraint) - using namespace Luau; TEST_SUITE_BEGIN("UnionTypes"); @@ -52,7 +49,7 @@ TEST_CASE_FIXTURE(Fixture, "allow_more_specific_assign") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign") +TEST_CASE_FIXTURE(Fixture, "disallow_less_specific_assign") { CheckResult result = check(R"( local a:number = 10 @@ -63,7 +60,7 @@ TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign2") +TEST_CASE_FIXTURE(Fixture, "disallow_less_specific_assign2") { CheckResult result = check(R"( local a:number? = 10 @@ -104,7 +101,7 @@ TEST_CASE_FIXTURE(Fixture, "optional_arguments_table2") REQUIRE(!result.errors.empty()); } -TEST_CASE_FIXTURE(Fixture, "error_takes_optional_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "error_takes_optional_arguments") { CheckResult result = check(R"( error("message") @@ -181,8 +178,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property") TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") { - ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); - CheckResult result = check(R"( type A = {x: number} type B = {} @@ -202,8 +197,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - TypeId r = requireType("r"); - CHECK_MESSAGE(get(r), "Expected error, got " << toString(r)); + CHECK_EQ("*error-type*", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") @@ -237,27 +231,11 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") local z = a == c )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.booleanType, *requireType("x")); - CHECK_EQ(*typeChecker.booleanType, *requireType("y")); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("(number | string)?", toString(*tm->wantedType)); - CHECK_EQ("boolean | number", toString(*tm->givenType)); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "optional_union_members") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = { a = { x = 1, y = 2 }, b = 3 } type A = typeof(a) @@ -273,14 +251,12 @@ local c = bf.a.y TEST_CASE_FIXTURE(Fixture, "optional_union_functions") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( -local a = {} -function a.foo(x:number, y:number) return x + y end -type A = typeof(a) -local b: A? = a -local c = b.foo(1, 2) + local a = {} + function a.foo(x:number, y:number) return x + y end + type A = typeof(a) + local b: A? = a + local c = b.foo(1, 2) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -290,8 +266,6 @@ local c = b.foo(1, 2) TEST_CASE_FIXTURE(Fixture, "optional_union_methods") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = {} function a:foo(x:number, y:number) return x + y end @@ -320,12 +294,11 @@ return f() REQUIRE(acm); CHECK_EQ(1, acm->expected); CHECK_EQ(0, acm->actual); + CHECK_FALSE(acm->isVariadic); } TEST_CASE_FIXTURE(Fixture, "optional_field_access_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = { x: number } local b: A? = { x = 2 } @@ -341,8 +314,6 @@ local d = b.y TEST_CASE_FIXTURE(Fixture, "optional_index_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = {number} local a: A? = {1, 2, 3} @@ -355,8 +326,6 @@ local b = a[1] TEST_CASE_FIXTURE(Fixture, "optional_call_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = (number) -> number local a: A? = function(a) return -a end @@ -369,8 +338,6 @@ local b = a(4) TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = { x: number } local a: A? = { x = 2 } @@ -392,8 +359,6 @@ a.x = 2 TEST_CASE_FIXTURE(Fixture, "optional_length_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = {number} local a: A? = {1, 2, 3} @@ -406,9 +371,6 @@ local b = #a TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); - CheckResult result = check(R"( type A = { x: number, y: number } type B = { x: number, y: number } @@ -433,12 +395,27 @@ local e = a.z CHECK_EQ("Type 'A | B | C | D' does not have key 'z'", toString(result.errors[3])); } -TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") +TEST_CASE_FIXTURE(Fixture, "optional_iteration") { - ScopedFastFlag luauSealedTableUnifyOptionalFix("LuauSealedTableUnifyOptionalFix", true); + ScopedFastFlag luauNilIterator{"LuauNilIterator", true}; CheckResult result = check(R"( -local x: { x: number } = { x = 3 } +function foo(values: {number}?) + local s = 0 + for _, value in values do + s += value + end +end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '{number}?' could be nil", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "unify_unsealed_table_union_check") +{ + CheckResult result = check(R"( +local x = { x = 3 } type A = number? type B = string? local y: { x: number, y: A | B } @@ -448,7 +425,7 @@ y = x LUAU_REQUIRE_NO_ERRORS(result); result = check(R"( -local x: { x: number } = { x = 3 } +local x = { x = 3 } local a: number? = 2 local y = {} @@ -461,4 +438,328 @@ y = x LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") +{ + CheckResult result = check(R"( + -- the difference between this and unify_unsealed_table_union_check is the type annotation on x +local t = { x = 3, y = true } +local x: { x: number } = t +type A = number? +type B = string? +local y: { x: number, y: A | B } +-- Shouldn't typecheck! +y = x +-- If it does, we can convert any type to any other type +y.y = 5 +local oh : boolean = t.y + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_union_part") +{ + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X | Y | Z + +local a: XYZ +local b: { w: number } = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'X | Y | Z' could not be converted into '{| w: number |}' +caused by: + Not all union options are compatible. Table type 'X' not compatible with type '{| w: number |}' because the former is missing field 'w')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_union_all") +{ + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X | Y | Z + +local a: XYZ = { w = 4 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X | Y | Z'; none of the union options are compatible)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_optional") +{ + CheckResult result = check(R"( +type X = { x: number } + +local a: X? = { w = 4 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X?' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'X' because the former is missing field 'x')"); +} + +// We had a bug where a cyclic union caused a stack overflow. +// ex type U = number | U +TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") +{ + CheckResult result = check(R"( + --!strict + + function f(a, b) + a:g(b or {}) + a:g(b) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") +{ + CheckResult result = check(R"( + type A = { x: number, y: (number) -> string } | { z: number, y: (number) -> string } + + local a:A = nil + + function a.y(x) + return tostring(x * 2) + end + + function a.y(x: string): number + return tonumber(x) or 0 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + // NOTE: union normalization will improve this message + CHECK_EQ(toString(result.errors[0]), + R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); +} + +TEST_CASE_FIXTURE(Fixture, "union_true_and_false") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : boolean + local y1 : (true | false) = x -- OK + local y2 : (true | false | (string & number)) = x -- OK + local y3 : (true | (string & number) | false) = x -- OK + local y4 : (true | (boolean & true) | false) = x -- OK + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (number) -> number? + local y : ((number?) -> number?) | ((number) -> number) = x -- OK + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_generic_functions") +{ + CheckResult result = check(R"( + local x : (a) -> a? + local y : ((a?) -> a?) | ((b) -> b) = x -- Not OK + )"); + + // TODO: should this example typecheck? + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_generic_typepack_functions") +{ + CheckResult result = check(R"( + local x : (number, a...) -> (number?, a...) + local y : ((number?, a...) -> (number?, a...)) | ((number, b...) -> (number, b...)) = x -- Not OK + )"); + + // TODO: should this example typecheck? + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generics") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : (a) -> a? + local y : ((a?) -> nil) | ((a) -> a) = x -- OK + local z : ((b?) -> nil) | ((b) -> b) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + "Type '(a) -> a?' could not be converted into '((b) -> b) | ((b?) -> nil)'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generic_typepacks") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : (number, a...) -> (number?, a...) + local y : ((number | string, a...) -> (number, a...)) | ((number?, a...) -> (nil, a...)) = x -- OK + local z : ((number) -> number) | ((number?, a...) -> (number?, a...)) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(number, a...) -> (number?, a...)' could not be converted into '((number) -> number) | ((number?, " + "a...) -> (number?, a...))'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_arities") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (number) -> number? + local y : ((number?) -> number) | ((number | string) -> nil) = x -- OK + local z : ((number, string?) -> number) | ((number) -> nil) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(number) -> number?' could not be converted into '((number) -> nil) | ((number, string?) -> " + "number)'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_arities") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : () -> (number | string) + local y : (() -> number) | (() -> string) = x -- OK + local z : (() -> number) | (() -> (string, string)) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '() -> number | string' could not be converted into '(() -> (string, string)) | (() -> number)'; none " + "of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_variadics") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (...nil) -> (...number?) + local y : ((...string?) -> (...number)) | ((...number?) -> nil) = x -- OK + local z : ((...string?) -> (...number)) | ((...string?) -> nil) = x -- OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(...nil) -> (...number?)' could not be converted into '((...string?) -> (...number)) | ((...string?) " + "-> nil)'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_variadics") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (number) -> () + local y : ((number?) -> ()) | ((...number) -> ()) = x -- OK + local z : ((number?) -> ()) | ((...number?) -> ()) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + "Type '(number) -> ()' could not be converted into '((...number?) -> ()) | ((number?) -> ())'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_variadics") +{ + ScopedFastFlag sffs[]{ + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : () -> (number?, ...number) + local y : (() -> (...number)) | (() -> nil) = x -- OK + local z : (() -> (...number)) | (() -> number) = x -- OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '() -> (number?, ...number)' could not be converted into '(() -> (...number)) | (() -> number)'; none " + "of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function f(t): { x: number } | { x: string } + local x = t.x + return t + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({| x: number |} | {| x: string |}) -> {| x: number |} | {| x: string |}", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types_2") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function f(t: { x: number } | { x: string }) + return t.x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({| x: number |} | {| x: string |}) -> number | string", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp new file mode 100644 index 00000000..9d1e46c5 --- /dev/null +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -0,0 +1,318 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferUnknownNever"); + +TEST_CASE_FIXTURE(Fixture, "string_subtype_and_unknown_supertype") +{ + CheckResult result = check(R"( + local function f(x: string) + local foo: unknown = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_subtype_and_string_supertype") +{ + CheckResult result = check(R"( + local function f(x: unknown) + local foo: string = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_is_reflexive") +{ + CheckResult result = check(R"( + local function f(x: unknown) + local foo: unknown = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_subtype_and_never_supertype") +{ + CheckResult result = check(R"( + local function f(x: string) + local foo: never = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "never_subtype_and_string_supertype") +{ + CheckResult result = check(R"( + local function f(x: never) + local foo: string = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "never_is_reflexive") +{ + CheckResult result = check(R"( + local function f(x: never) + local foo: never = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_is_optional_because_it_too_encompasses_nil") +{ + CheckResult result = check(R"( + local t: {x: unknown} = {} + )"); +} + +TEST_CASE_FIXTURE(Fixture, "table_with_prop_of_type_never_is_uninhabitable") +{ + CheckResult result = check(R"( + local t: {x: never} = {} + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "table_with_prop_of_type_never_is_also_reflexive") +{ + CheckResult result = check(R"( + local t: {x: never} = {x = 5 :: never} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "array_like_table_of_never_is_inhabitable") +{ + CheckResult result = check(R"( + local t: {never} = {} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable") +{ + CheckResult result = check(R"( + local function f() return "foo", 5 :: never end + + local x, y, z = f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); + CHECK_EQ("never", toString(requireType("y"))); + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable2") +{ + CheckResult result = check(R"( + local function f(): (string, never) return "", 5 :: never end + local function g(): (never, string) return 5 :: never, "" end + + local x1, x2 = f() + local y1, y2 = g() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x1"))); + CHECK_EQ("never", toString(requireType("x2"))); + CHECK_EQ("never", toString(requireType("y1"))); + CHECK_EQ("never", toString(requireType("y2"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_never") +{ + CheckResult result = check(R"( + local x: never = 5 :: never + local z = x.y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "call_never") +{ + CheckResult result = check(R"( + local f: never = 5 :: never + local x, y, z = f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); + CHECK_EQ("never", toString(requireType("y"))); + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_local_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t = 3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_global_which_is_never") +{ + CheckResult result = check(R"( + --!nonstrict + t = 5 :: never + t = "" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_prop_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t.x = 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t[5] = 7 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never") +{ + CheckResult result = check(R"( + for i, v in (5 :: never) do + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "pick_never_from_variadic_type_pack") +{ + CheckResult result = check(R"( + local function f(...: never) + local x, y = (...) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_union_of_tables_for_properties_that_is_never") +{ + CheckResult result = check(R"( + type Disjoint = {foo: never, bar: unknown, tag: "ok"} | {foo: never, baz: unknown, tag: "err"} + local disjoint: Disjoint = {foo = 5 :: never, bar = true, tag = "ok"} + local foo = disjoint.foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_union_of_tables_for_properties_that_is_sorta_never") +{ + CheckResult result = check(R"( + type Disjoint = {foo: string, bar: unknown, tag: "ok"} | {foo: never, baz: unknown, tag: "err"} + local disjoint: Disjoint = {foo = 5 :: never, bar = true, tag = "ok"} + local foo = disjoint.foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "unary_minus_of_never") +{ + CheckResult result = check(R"( + local x = -(5 :: never) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "length_of_never") +{ + ScopedFastFlag sff{"LuauNeverTypesAndOperatorsInference", true}; + + CheckResult result = check(R"( + local x = #({} :: never) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators") +{ + ScopedFastFlag sff[]{ + {"LuauUnknownAndNeverType", true}, + {"LuauNeverTypesAndOperatorsInference", true}, + {"LuauTryhardAnd", true}, + }; + + CheckResult result = check(R"( + local function ord(x: nil, y) + return x ~= nil and x > y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // Widening doesn't normalize yet, so the result is a bit strange + CHECK_EQ("(nil, a) -> boolean | boolean", toString(requireType("ord"))); +} + +TEST_CASE_FIXTURE(Fixture, "math_operators_and_never") +{ + ScopedFastFlag sff[]{ + {"LuauUnknownAndNeverType", true}, + {"LuauNeverTypesAndOperatorsInference", true}, + }; + + CheckResult result = check(R"( + local function mul(x: nil, y) + return x ~= nil and x * y -- infers boolean | never, which is normalized into boolean + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(nil, a) -> boolean", toString(requireType("mul"))); +} + +TEST_SUITE_END(); diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 8b056544..1087a24c 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -26,7 +25,7 @@ struct TypePackFixture TypePackId freshTypePack() { - typePacks.emplace_back(new TypePackVar{Unifiable::Free{{}}}); + typePacks.emplace_back(new TypePackVar{Unifiable::Free{TypeLevel{}}}); return typePacks.back().get(); } @@ -198,4 +197,18 @@ TEST_CASE_FIXTURE(TypePackFixture, "std_distance") CHECK_EQ(4, std::distance(b, e)); } +TEST_CASE("content_reassignment") +{ + TypePackVar myError{Unifiable::Error{}, /*presistent*/ true}; + + TypeArena arena; + + TypePackId futureError = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); + asMutable(futureError)->reassign(myError); + + CHECK(get(futureError) != nullptr); + CHECK(!futureError->persistent); + CHECK(futureError->owningArena == &arena); +} + TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 98ce9f93..5dd1b1bc 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -1,7 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -10,8 +11,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauGenericFunctions); - TEST_SUITE_BEGIN("TypeVarTests"); TEST_CASE_FIXTURE(Fixture, "primitives_are_equal") @@ -183,6 +182,22 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_empty_union") CHECK(actual.empty()); } +TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_only_cyclic_union") +{ + TypeVar tv{UnionTypeVar{}}; + auto utv = getMutable(&tv); + utv->options.push_back(&tv); + utv->options.push_back(&tv); + + std::vector actual(begin(utv), end(utv)); + CHECK(actual.empty()); +} + + +/* FIXME: This test is pretty weird. It would be much nicer if we could + * perform this operation without a TypeChecker so that we don't have to jam + * all this state into it to make stuff work. + */ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { TypeVar ftv11{FreeTypeVar{TypeLevel{}}}; @@ -258,10 +273,180 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypeId root = &ttvTweenResult; typeChecker.currentModule = std::make_shared(); + typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(singletonTypes->anyTypePack)); TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); - CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result)); + CHECK_EQ("{| f: t1 |} where t1 = () -> {| f: () -> {| f: ({| f: t1 |}) -> (), signal: {| f: (any) -> () |} |} |}", toString(result)); +} + +TEST_CASE("tagging_tables") +{ + TypeVar ttv{TableTypeVar{}}; + CHECK(!Luau::hasTag(&ttv, "foo")); + Luau::attachTag(&ttv, "foo"); + CHECK(Luau::hasTag(&ttv, "foo")); +} + +TEST_CASE("tagging_classes") +{ + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; + CHECK(!Luau::hasTag(&base, "foo")); + Luau::attachTag(&base, "foo"); + CHECK(Luau::hasTag(&base, "foo")); +} + +TEST_CASE("tagging_subclasses") +{ + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; + TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr, "Test"}}; + + CHECK(!Luau::hasTag(&base, "foo")); + CHECK(!Luau::hasTag(&derived, "foo")); + + Luau::attachTag(&base, "foo"); + CHECK(Luau::hasTag(&base, "foo")); + CHECK(Luau::hasTag(&derived, "foo")); + + Luau::attachTag(&derived, "bar"); + CHECK(!Luau::hasTag(&base, "bar")); + CHECK(Luau::hasTag(&derived, "bar")); +} + +TEST_CASE("tagging_functions") +{ + TypePackVar empty{TypePack{}}; + TypeVar ftv{FunctionTypeVar{&empty, &empty}}; + CHECK(!Luau::hasTag(&ftv, "foo")); + Luau::attachTag(&ftv, "foo"); + CHECK(Luau::hasTag(&ftv, "foo")); +} + +TEST_CASE("tagging_props") +{ + Property prop{}; + CHECK(!Luau::hasTag(prop, "foo")); + Luau::attachTag(prop, "foo"); + CHECK(Luau::hasTag(prop, "foo")); +} + +struct VisitCountTracker final : TypeVarOnceVisitor +{ + std::unordered_map tyVisits; + std::unordered_map tpVisits; + + void cycle(TypeId) override {} + void cycle(TypePackId) override {} + + template + bool operator()(TypeId ty, const T& t) + { + return visit(ty); + } + + template + bool operator()(TypePackId tp, const T&) + { + return visit(tp); + } + + bool visit(TypeId ty) override + { + tyVisits[ty]++; + return true; + } + + bool visit(TypePackId tp) override + { + tpVisits[tp]++; + return true; + } +}; + +TEST_CASE_FIXTURE(Fixture, "visit_once") +{ + CheckResult result = check(R"( +type T = { a: number, b: () -> () } +local b: (T, T, T) -> T +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId bType = requireType("b"); + + VisitCountTracker tester; + tester.traverse(bType); + + for (auto [_, count] : tester.tyVisits) + CHECK_EQ(count, 1); + + for (auto [_, count] : tester.tpVisits) + CHECK_EQ(count, 1); +} + +TEST_CASE("isString_on_string_singletons") +{ + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + CHECK(isString(&helloString)); +} + +TEST_CASE("isString_on_unions_of_various_string_singletons") +{ + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; + TypeVar union_{UnionTypeVar{{&helloString, &byeString}}}; + + CHECK(isString(&union_)); +} + +TEST_CASE("proof_that_isString_uses_all_of") +{ + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; + TypeVar booleanType{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}}; + TypeVar union_{UnionTypeVar{{&helloString, &byeString, &booleanType}}}; + + CHECK(!isString(&union_)); +} + +TEST_CASE("isBoolean_on_boolean_singletons") +{ + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + CHECK(isBoolean(&trueBool)); +} + +TEST_CASE("isBoolean_on_unions_of_true_or_false_singletons") +{ + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; + TypeVar union_{UnionTypeVar{{&trueBool, &falseBool}}}; + + CHECK(isBoolean(&union_)); +} + +TEST_CASE("proof_that_isBoolean_uses_all_of") +{ + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; + TypeVar stringType{PrimitiveTypeVar{PrimitiveTypeVar::String}}; + TypeVar union_{UnionTypeVar{{&trueBool, &falseBool, &stringType}}}; + + CHECK(!isBoolean(&union_)); +} + +TEST_CASE("content_reassignment") +{ + TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; + myAny.documentationSymbol = "@global/any"; + + TypeArena arena; + + TypeId futureAny = arena.addType(FreeTypeVar{TypeLevel{}}); + asMutable(futureAny)->reassign(myAny); + + CHECK(get(futureAny) != nullptr); + CHECK(!futureAny->persistent); + CHECK(futureAny->documentationSymbol == "@global/any"); + CHECK(futureAny->owningArena == &arena); } TEST_SUITE_END(); diff --git a/tests/Variant.test.cpp b/tests/Variant.test.cpp index fcf37875..83eec519 100644 --- a/tests/Variant.test.cpp +++ b/tests/Variant.test.cpp @@ -13,6 +13,25 @@ struct Foo int x = 42; }; +struct Bar +{ + explicit Bar(int x) + : prop(x * 2) + { + ++count; + } + + ~Bar() + { + --count; + } + + int prop; + static int count; +}; + +int Bar::count = 0; + TEST_SUITE_BEGIN("Variant"); TEST_CASE("DefaultCtor") @@ -46,6 +65,29 @@ TEST_CASE("Create") CHECK(get_if(&v3)->x == 3); } +TEST_CASE("Emplace") +{ + { + Variant v1; + + CHECK(0 == Bar::count); + int& i = v1.emplace(5); + CHECK(5 == i); + + CHECK(0 == Bar::count); + + CHECK(get_if(&v1) == &i); + + Bar& bar = v1.emplace(11); + CHECK(22 == bar.prop); + CHECK(1 == Bar::count); + + CHECK(get_if(&v1) == &bar); + } + + CHECK(0 == Bar::count); +} + TEST_CASE("NonPOD") { // initialize (copy) @@ -175,4 +217,35 @@ TEST_CASE("Visit") CHECK(r3 == "1231147"); } +struct MoveOnly +{ + MoveOnly() = default; + + MoveOnly(const MoveOnly&) = delete; + MoveOnly& operator=(const MoveOnly&) = delete; + + MoveOnly(MoveOnly&&) = default; + MoveOnly& operator=(MoveOnly&&) = default; +}; + +TEST_CASE("Move") +{ + Variant v1 = MoveOnly{}; + Variant v2 = std::move(v1); +} + +TEST_CASE("MoveWithCopyableAlternative") +{ + Variant v1 = std::string{"Hello, world! I am longer than a normal hello world string to avoid SSO."}; + Variant v2 = std::move(v1); + + std::string* s1 = get_if(&v1); + REQUIRE(s1); + CHECK(*s1 == ""); + + std::string* s2 = get_if(&v2); + REQUIRE(s2); + CHECK(*s2 == "Hello, world! I am longer than a normal hello world string to avoid SSO."); +} + TEST_SUITE_END(); diff --git a/tests/VisitTypeVar.test.cpp b/tests/VisitTypeVar.test.cpp new file mode 100644 index 00000000..4fba694a --- /dev/null +++ b/tests/VisitTypeVar.test.cpp @@ -0,0 +1,41 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "Luau/RecursionCounter.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTINT(LuauVisitRecursionLimit) + +TEST_SUITE_BEGIN("VisitTypeVar"); + +TEST_CASE_FIXTURE(Fixture, "throw_when_limit_is_exceeded") +{ + ScopedFastInt sfi{"LuauVisitRecursionLimit", 3}; + + CheckResult result = check(R"( + local t : {a: {b: {c: {d: {e: boolean}}}}} + )"); + + TypeId tType = requireType("t"); + + CHECK_THROWS_AS(toString(tType), RecursionLimitException); +} + +TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") +{ + ScopedFastInt sfi{"LuauVisitRecursionLimit", 8}; + + CheckResult result = check(R"( + local t : {a: {b: {c: {d: {e: boolean}}}}} + )"); + + TypeId tType = requireType("t"); + + (void)toString(tType); +} + +TEST_SUITE_END(); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 5e03b055..27416623 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -2,7 +2,24 @@ print('testing function calls through API') function add(a, b) - return a + b + return a + b +end + +local m = { __eq = function(a, b) return a.a == b.a end } + +function create_with_tm(x) + return setmetatable({ a = x }, m) +end + +local gen = 0 +function incuv() + gen += 1 + return gen +end + +pi = 3.1415926 +function getpi() + return pi end return('OK') diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 687fff1e..7a05f8e9 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -49,6 +49,12 @@ assert((function() _G.foo = 1 return _G['foo'] end)() == 1) assert((function() _G['bar'] = 1 return _G.bar end)() == 1) assert((function() local a = 1 (function () a = 2 end)() return a end)() == 2) +-- assignments with local conflicts +assert((function() local a, b = 1, {} a, b[a] = 43, -1 return a + b[1] end)() == 42) +assert((function() local a = {} local b = a a[1], a = 43, -1 return a + b[1] end)() == 42) +assert((function() local a, b = 1, {} a, b[a] = (function() return 43, -1 end)() return a + b[1] end)() == 42) +assert((function() local a = {} local b = a a[1], a = (function() return 43, -1 end)() return a + b[1] end)() == 42) + -- upvalues assert((function() local a = 1 function foo() return a end return foo() end)() == 1) @@ -87,6 +93,10 @@ assert((function() local a = 1 a = a * 2 return a end)() == 2) assert((function() local a = 1 a = a / 2 return a end)() == 0.5) assert((function() local a = 5 a = a % 2 return a end)() == 1) assert((function() local a = 3 a = a ^ 2 return a end)() == 9) +assert((function() local a = 3 a = a ^ 3 return a end)() == 27) +assert((function() local a = 9 a = a ^ 0.5 return a end)() == 3) +assert((function() local a = -2 a = a ^ 2 return a end)() == 4) +assert((function() local a = -2 a = a ^ 0.5 return tostring(a) end)() == "nan") assert((function() local a = '1' a = a .. '2' return a end)() == "12") assert((function() local a = '1' a = a .. '2' .. '3' return a end)() == "123") @@ -105,6 +115,11 @@ assert((function() local a = nil a = a and 2 return a end)() == nil) assert((function() local a = 1 a = a or 2 return a end)() == 1) assert((function() local a = nil a = a or 2 return a end)() == 2) +assert((function() local a a = 1 local b = 2 b = a and b return b end)() == 2) +assert((function() local a a = nil local b = 2 b = a and b return b end)() == nil) +assert((function() local a a = 1 local b = 2 b = a or b return b end)() == 1) +assert((function() local a a = nil local b = 2 b = a or b return b end)() == 2) + -- binary arithmetics coerces strings to numbers (sadly) assert(1 + "2" == 3) assert(2 * "0xa" == 20) @@ -120,6 +135,10 @@ assert((function() return #'g' end)() == 1) assert((function() local a = 1 a = -a return a end)() == -1) +-- __len metamethod +assert((function() local ud = newproxy(true) getmetatable(ud).__len = function() return 42 end return #ud end)() == 42) +assert((function() local t = {} setmetatable(t, { __len = function() return 42 end }) return #t end)() == 42) + -- while/repeat assert((function() local a = 10 local b = 1 while a > 1 do b = b * 2 a = a - 1 end return b end)() == 512) assert((function() local a = 10 local b = 1 repeat b = b * 2 a = a - 1 until a == 1 return b end)() == 512) @@ -321,6 +340,10 @@ assert((function() local t = {6, 9, 7} t[4.5] = 10 return t[4.5] end)() == 10) assert((function() local t = {6, 9, 7} t['a'] = 11 return t['a'] end)() == 11) assert((function() local t = {6, 9, 7} setmetatable(t, { __newindex = function(t,i,v) rawset(t, i * 10, v) end }) t[1] = 17 t[5] = 1 return concat(t[1],t[5],t[50]) end)() == "17,nil,1") +-- userdata access +assert((function() local ud = newproxy(true) getmetatable(ud).__index = function(ud,i) return i * 10 end return ud[2] end)() == 20) +assert((function() local ud = newproxy(true) getmetatable(ud).__index = function() return function(self, i) return i * 10 end end return ud:meow(2) end)() == 20) + -- and/or -- rhs is a constant assert((function() local a = 1 a = a and 2 return a end)() == 2) @@ -441,7 +464,8 @@ assert((function() a = {} b = {} mt = { __eq = function(l, r) return #l == #r en assert((function() a = {} b = {} function eq(l, r) return #l == #r end setmetatable(a, {__eq = eq}) setmetatable(b, {__eq = eq}) return concat(a == b, a ~= b) end)() == "true,false") assert((function() a = {} b = {} setmetatable(a, {__eq = function(l, r) return #l == #r end}) setmetatable(b, {__eq = function(l, r) return #l == #r end}) return concat(a == b, a ~= b) end)() == "false,true") --- userdata, reference equality (no mt) +-- userdata, reference equality (no mt or mt.__eq) +assert((function() a = newproxy() return concat(a == newproxy(),a ~= newproxy()) end)() == "false,true") assert((function() a = newproxy(true) return concat(a == newproxy(true),a ~= newproxy(true)) end)() == "false,true") -- rawequal @@ -455,9 +479,15 @@ assert(rawequal("a", "a") == true) assert(rawequal("a", "b") == false) assert((function() a = {} b = {} mt = { __eq = function(l, r) return #l == #r end } setmetatable(a, mt) setmetatable(b, mt) return concat(a == b, rawequal(a, b)) end)() == "true,false") +-- rawequal fallback +assert(concat(pcall(rawequal, "a", "a")) == "true,true") +assert(concat(pcall(rawequal, "a", "b")) == "true,false") +assert(concat(pcall(rawequal, "a", nil)) == "true,false") +assert(pcall(rawequal, "a") == false) + -- metatable ops local function vec3t(x, y, z) - return setmetatable({ x=x, y=y, z=z}, { + return setmetatable({x=x, y=y, z=z}, { __add = function(l, r) return vec3t(l.x + r.x, l.y + r.y, l.z + r.z) end, __sub = function(l, r) return vec3t(l.x - r.x, l.y - r.y, l.z - r.z) end, __mul = function(l, r) return type(r) == "number" and vec3t(l.x * r, l.y * r, l.z * r) or vec3t(l.x * r.x, l.y * r.y, l.z * r.z) end, @@ -488,6 +518,9 @@ assert((function() function cmp(a,b) return ab,a>=b end return concat( assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('abc', 'abd')) end)() == "true,true,false,false") assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('ab\\0c', 'ab\\0d')) end)() == "true,true,false,false") assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('ab\\0c', 'ab\\0')) end)() == "false,false,true,true") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('\\0a', '\\0b')) end)() == "true,true,false,false") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('a', 'a\\0')) end)() == "true,true,false,false") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('a', '\200')) end)() == "true,true,false,false") -- array access assert((function() local a = {4,5,6} return a[3] end)() == 6) @@ -676,7 +709,11 @@ end assert(chainTest(100) == "v0,v100") -- this validates import fallbacks +assert(idontexist == nil) +assert(math.idontexist == nil) assert(pcall(function() return idontexist.a end) == false) +assert(pcall(function() return math.pow.a end) == false) +assert(pcall(function() return math.a.b end) == false) -- make sure that NaN is preserved by the bytecode compiler local realnan = tostring(math.abs(0)/math.abs(0)) @@ -718,16 +755,20 @@ assert((function() local abs = math.abs function foo(...) return abs(...) end re -- NOTE: getfenv breaks fastcalls for the remainder of the source! hence why this is delayed until the end function testgetfenv() getfenv() + + -- declare constant so that at O2 this test doesn't interfere with constant folding which we can't deoptimize + local negfive negfive = -5 + -- getfenv breaks fastcalls (we assume we can't rely on knowing the semantics), but behavior shouldn't change - assert((function() return math.abs(-5) end)() == 5) - assert((function() local abs = math.abs return abs(-5) end)() == 5) - assert((function() local abs = math.abs function foo() return abs(-5) end return foo() end)() == 5) + assert((function() return math.abs(negfive) end)() == 5) + assert((function() local abs = math.abs return abs(negfive) end)() == 5) + assert((function() local abs = math.abs function foo() return abs(negfive) end return foo() end)() == 5) -- ... unless you actually reassign the function :D getfenv().math = { abs = function(n) return n*n end } - assert((function() return math.abs(-5) end)() == 25) - assert((function() local abs = math.abs return abs(-5) end)() == 25) - assert((function() local abs = math.abs function foo() return abs(-5) end return foo() end)() == 25) + assert((function() return math.abs(negfive) end)() == 25) + assert((function() local abs = math.abs return abs(negfive) end)() == 25) + assert((function() local abs = math.abs function foo() return abs(negfive) end return foo() end)() == 25) end -- you need to have enough arguments and arguments of the right type; if you don't, we'll fallback to the regular code. This checks coercions @@ -826,6 +867,17 @@ assert((function() return sum end)() == 105) +-- shrinking array part +assert((function() + local t = table.create(100, 42) + for i=1,90 do t[i] = nil end + t[101] = 42 + local sum = 0 + for _,v in ipairs(t) do sum += v end + for _,v in pairs(t) do sum += v end + return sum +end)() == 462) + -- upvalues: recursive capture assert((function() local function fact(n) return n < 1 and 1 or n * fact(n-1) end return fact(5) end)() == 120) @@ -871,9 +923,21 @@ assert((function() return table.concat(res, ',') end)() == "6,8,10") +-- typeof and type require an argument +assert(pcall(typeof) == false) +assert(pcall(type) == false) + -- typeof == type in absence of custom userdata assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number,nil,table,userdata") +-- type/typeof/newproxy interaction with metatables: __type doesn't work intentionally to avoid spoofing +assert((function() + local ud = newproxy(true) + getmetatable(ud).__type = "number" + + return concat(type(ud),typeof(ud)) +end)() == "userdata,userdata") + testgetfenv() -- DONT MOVE THIS LINE -return'OK' +return 'OK' diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.lua index 6efa5960..3b117892 100644 --- a/tests/conformance/bitwise.lua +++ b/tests/conformance/bitwise.lua @@ -71,6 +71,7 @@ for _, b in pairs(c) do assert(bit32.bxor(b) == b) assert(bit32.bxor(b, b) == 0) assert(bit32.bxor(b, 0) == b) + assert(bit32.bxor(b, b, b) == b) assert(bit32.bnot(b) ~= b) assert(bit32.bnot(bit32.bnot(b)) == b) assert(bit32.bnot(b) == 2^32 - 1 - b) @@ -100,6 +101,12 @@ assert(bit32.extract(0xa0001111, 28, 4) == 0xa) assert(bit32.extract(0xa0001111, 31, 1) == 1) assert(bit32.extract(0x50000111, 31, 1) == 0) assert(bit32.extract(0xf2345679, 0, 32) == 0xf2345679) +assert(bit32.extract(0xa0001111, 16) == 0) +assert(bit32.extract(0xa0001111, 31) == 1) +assert(bit32.extract(42, 1, 3) == 5) + +local pos pos = 1 +assert(bit32.extract(42, pos, 3) == 5) -- test bit32.extract builtin instead of bit32.extractk assert(not pcall(bit32.extract, 0, -1)) assert(not pcall(bit32.extract, 0, 32)) @@ -113,6 +120,20 @@ assert(bit32.replace(0, -1, 4) == 2^4) assert(bit32.replace(-1, 0, 31) == 2^31 - 1) assert(bit32.replace(-1, 0, 1, 2) == 2^32 - 7) +-- testing countlz/countrc +assert(bit32.countlz(0) == 32) +assert(bit32.countlz(42) == 26) +assert(bit32.countlz(0xffffffff) == 0) +assert(bit32.countlz(0x80000000) == 0) +assert(bit32.countlz(0x7fffffff) == 1) + +assert(bit32.countrz(0) == 32) +assert(bit32.countrz(1) == 0) +assert(bit32.countrz(42) == 1) +assert(bit32.countrz(0x80000000) == 31) +assert(bit32.countrz(0x40000000) == 30) +assert(bit32.countrz(0x7fffffff) == 0) + --[[ This test verifies a fix in luauF_replace() where if the 4th parameter was not a number, but the first three are numbers, it will @@ -127,14 +148,21 @@ assert(bit32.lrotate("0x12345678", 4) == 0x23456781) assert(bit32.rrotate("0x12345678", -4) == 0x23456781) assert(bit32.arshift("0x12345678", 1) == 0x12345678 / 2) assert(bit32.arshift("-1", 32) == 0xffffffff) +assert(bit32.arshift("-1", 1) == 0xffffffff) assert(bit32.bnot("1") == 0xfffffffe) assert(bit32.band("1", 3) == 1) assert(bit32.band(1, "3") == 1) +assert(bit32.band(1, 3, "5") == 1) assert(bit32.bor("1", 2) == 3) assert(bit32.bor(1, "2") == 3) +assert(bit32.bor(1, 3, "5") == 7) assert(bit32.bxor("1", 3) == 2) assert(bit32.bxor(1, "3") == 2) +assert(bit32.bxor(1, 3, "5") == 7) assert(bit32.btest(1, "3") == true) assert(bit32.btest("1", 3) == true) +assert(bit32.countlz("42") == 26) +assert(bit32.countrz("42") == 1) +assert(bit32.extract("42", 1, 3) == 5) return('OK') diff --git a/tests/conformance/calls.lua b/tests/conformance/calls.lua index 7f9610a3..621a921a 100644 --- a/tests/conformance/calls.lua +++ b/tests/conformance/calls.lua @@ -226,4 +226,14 @@ assert((function () return nil end)(4) == nil) assert((function () local a; return a end)(4) == nil) assert((function (a) return a end)() == nil) +-- C-stack overflow while handling C-stack overflow +if not limitedstack then + local function loop () + assert(pcall(loop)) + end + + local err, msg = xpcall(loop, loop) + assert(not err and string.find(msg, "error")) +end + return('OK') diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.lua index 79f8d9c2..7b057354 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.lua @@ -319,7 +319,7 @@ end assert(a == 5^4) --- access to locals of collected corroutines +-- access to locals of collected coroutines local C = {}; setmetatable(C, {__mode = "kv"}) local x = coroutine.wrap (function () local a = 10 @@ -419,11 +419,20 @@ co = coroutine.create(function () return loadstring("return a")() end) -a = {a = 15} --- debug.setfenv(co, a) --- assert(debug.getfenv(co) == a) --- assert(select(2, coroutine.resume(co)) == a) --- assert(select(2, coroutine.resume(co)) == a.a) +-- large closure size +do + local a1, a2, a3, a4, a5, a6, a7, a8, a9, a0 + local b1, b2, b3, b4, b5, b6, b7, b8, b9, b0 + local c1, c2, c3, c4, c5, c6, c7, c8, c9, c0 + local d1, d2, d3, d4, d5, d6, d7, d8, d9, d0 + local f = function() + return + a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a0 + + b1 + b2 + b3 + b4 + b5 + b6 + b7 + b8 + b9 + b0 + + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c0 + + d1 + d2 + d3 + d4 + d5 + d6 + d7 + d8 + d9 + d0 + end +end -return'OK' +return 'OK' diff --git a/tests/conformance/constructs.lua b/tests/conformance/constructs.lua index 16c63b00..f133501f 100644 --- a/tests/conformance/constructs.lua +++ b/tests/conformance/constructs.lua @@ -237,4 +237,4 @@ repeat i = i+1 until i==c -return'OK' +return 'OK' diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 73c3833d..b4f81bba 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -185,7 +185,7 @@ end assert(a == 5^4) --- access to locals of collected corroutines +-- access to locals of collected coroutines local C = {}; setmetatable(C, {__mode = "kv"}) local x = coroutine.wrap (function () local a = 10 @@ -319,4 +319,67 @@ for i=0,30 do assert(#T2 == 1 or T2[#T2] == 42) end -return'OK' +-- test coroutine.close +do + -- ok to close a dead coroutine + local co = coroutine.create(type) + assert(coroutine.resume(co, "testing 'coroutine.close'")) + assert(coroutine.status(co) == "dead") + local st, msg = coroutine.close(co) + assert(st and msg == nil) + -- also ok to close it again + st, msg = coroutine.close(co) + assert(st and msg == nil) + + + -- cannot close the running coroutine + coroutine.wrap(function() + local st, msg = pcall(coroutine.close, coroutine.running()) + assert(not st and string.find(msg, "running")) + end)() + + -- cannot close a "normal" coroutine + coroutine.wrap(function() + local co = coroutine.running() + coroutine.wrap(function () + local st, msg = pcall(coroutine.close, co) + assert(not st and string.find(msg, "normal")) + end)() + end)() + + -- closing a coroutine after an error + local co = coroutine.create(error) + local obj = {42} + local st, msg = coroutine.resume(co, obj) + assert(not st and msg == obj) + st, msg = coroutine.close(co) + assert(not st and msg == obj) + -- after closing, no more errors + st, msg = coroutine.close(co) + assert(st and msg == nil) + + -- closing a coroutine that has outstanding upvalues + local f + local co = coroutine.create(function() + local a = 42 + f = function() return a end + coroutine.yield() + a = 20 + end) + coroutine.resume(co) + assert(f() == 42) + st, msg = coroutine.close(co) + assert(st and msg == nil) + assert(f() == 42) + + -- closing a coroutine with a large stack + co = coroutine.create(function() + local function f(depth) return if depth > 0 then f(depth - 1) + depth else 0 end + coroutine.yield(f(100)) + end) + assert(coroutine.resume(co)) + st, msg = coroutine.close(co) + assert(st and msg == nil) +end + +return 'OK' diff --git a/tests/conformance/coverage.lua b/tests/conformance/coverage.lua new file mode 100644 index 00000000..14d843a4 --- /dev/null +++ b/tests/conformance/coverage.lua @@ -0,0 +1,72 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing coverage") + +function foo() + local x = 1 + local y = 2 + assert(x + y) +end + +function bar() + local function one(x) + return x + end + + local two = function(x) + return x + end + + one(1) +end + +function validate(stats, hits, misses) + local checked = {} + + for _,l in ipairs(hits) do + if not (stats[l] and stats[l] > 0) then + return false, string.format("expected line %d to be hit", l) + end + checked[l] = true + end + + for _,l in ipairs(misses) do + if not (stats[l] and stats[l] == 0) then + return false, string.format("expected line %d to be missed", l) + end + checked[l] = true + end + + for k,v in pairs(stats) do + if type(k) == "number" and not checked[k] then + return false, string.format("expected line %d to be absent", k) + end + end + + return true +end + +foo() +c = getcoverage(foo) +assert(#c == 1) +assert(c[1].name == "foo") +assert(c[1].linedefined == 4) +assert(c[1].depth == 0) +assert(validate(c[1], {5, 6, 7}, {})) + +bar() +c = getcoverage(bar) +assert(#c == 3) +assert(c[1].name == "bar") +assert(c[1].linedefined == 10) +assert(c[1].depth == 0) +assert(validate(c[1], {11, 15, 19}, {})) +assert(c[2].name == "one") +assert(c[2].linedefined == 11) +assert(c[2].depth == 1) +assert(validate(c[2], {12}, {})) +assert(c[3].name == nil) +assert(c[3].linedefined == 15) +assert(c[3].depth == 1) +assert(validate(c[3], {}, {16})) + +return 'OK' diff --git a/tests/conformance/datetime.lua b/tests/conformance/datetime.lua index 21ef60d7..dc73948b 100644 --- a/tests/conformance/datetime.lua +++ b/tests/conformance/datetime.lua @@ -16,6 +16,7 @@ D = os.date("*t", t) assert(os.date(string.rep("%d", 1000), t) == string.rep(os.date("%d", t), 1000)) assert(os.date(string.rep("%", 200)) == string.rep("%", 100)) +assert(os.date("", -1) == nil) local function checkDateTable (t) local D = os.date("!*t", t) @@ -74,4 +75,4 @@ assert(os.difftime(t1,t2) == 60*2-19) assert(os.time({ year = 1970, day = 1, month = 1, hour = 0}) == 0) -return'OK' +return 'OK' diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index ee79a14f..0c8cc2d8 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -37,6 +37,7 @@ coroutine.resume(co2, 0 / 0, 42) assert(debug.traceback(co2) == "debug.lua:31 function halp\n") assert(debug.info(co2, 0, "l") == 31) +assert(debug.info(co2, 0, "f") == halp) -- info errors function qux(...) @@ -75,6 +76,9 @@ assert(baz(co, 2, "n") == nil) assert(baz(math.sqrt, "n") == "sqrt") assert(baz(math.sqrt, "f") == math.sqrt) -- yes this is pointless +local t = { foo = function() return 1 end } +assert(baz(t.foo, "n") == "foo") + -- info multi-arg returns function quux(...) return {debug.info(...)} @@ -98,4 +102,13 @@ assert(quuz(function(...) end) == "0 true") assert(quuz(function(a, b) end) == "2 false") assert(quuz(function(a, b, ...) end) == "2 true") -return'OK' +-- info linedefined & line +function testlinedefined() + local line = debug.info(1, "l") + local linedefined = debug.info(testlinedefined, "l") + assert(linedefined + 1 == line) +end + +testlinedefined() + +return 'OK' diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index 5e69fc6b..c773013b 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -3,14 +3,14 @@ print "testing debugger" -- note, this file can't run in isolation from C tests local a = 5 -function foo(b) +function foo(b, ...) print("in foo", b) a = 6 end breakpoint(8) -foo(50) +foo(50, 42) breakpoint(16) -- next line print("here") @@ -45,4 +45,28 @@ breakpoint(38) -- break inside corobad() local co = coroutine.create(corobad) assert(coroutine.resume(co) == false) -- this breaks, resumes and dies! +function bar() + print("in bar") +end + +breakpoint(49) +breakpoint(49, false) -- validate that disabling breakpoints works + +bar() + +local function breakpointSetFromMetamethod() + local a = setmetatable({}, { + __index = function() + breakpoint(67) + return 2 + end + }) + + local b = a.x + + assert(b == 2) +end + +breakpointSetFromMetamethod() + return 'OK' diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index eded14e9..57d2b693 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -34,15 +34,15 @@ assert(doit("error('hi', 0)") == 'hi') assert(doit("unpack({}, 1, n=2^30)")) assert(doit("a=math.sin()")) assert(not doit("tostring(1)") and doit("tostring()")) -assert(doit"tonumber()") -assert(doit"repeat until 1; a") +assert(doit("tonumber()")) +assert(doit("repeat until 1; a")) checksyntax("break label", "", "label", 1) -assert(doit";") -assert(doit"a=1;;") -assert(doit"return;;") -assert(doit"assert(false)") -assert(doit"assert(nil)") -assert(doit"a=math.sin\n(3)") +assert(doit(";")) +assert(doit("a=1;;")) +assert(doit("return;;")) +assert(doit("assert(false)")) +assert(doit("assert(nil)")) +assert(doit("a=math.sin\n(3)")) assert(doit("function a (... , ...) end")) assert(doit("function a (, ...) end")) @@ -59,7 +59,7 @@ checkmessage("a=1; local a,bbbb=2,3; a = math.sin(1) and bbbb(3)", "local 'bbbb'") checkmessage("a={}; do local a=1 end a:bbbb(3)", "method 'bbbb'") checkmessage("local a={}; a.bbbb(3)", "field 'bbbb'") -assert(not string.find(doit"a={13}; local bbbb=1; a[bbbb](3)", "'bbbb'")) +assert(not string.find(doit("a={13}; local bbbb=1; a[bbbb](3)"), "'bbbb'")) checkmessage("a={13}; local bbbb=1; a[bbbb](3)", "number") aaa = nil @@ -67,14 +67,14 @@ checkmessage("aaa.bbb:ddd(9)", "global 'aaa'") checkmessage("local aaa={bbb=1}; aaa.bbb:ddd(9)", "field 'bbb'") checkmessage("local aaa={bbb={}}; aaa.bbb:ddd(9)", "method 'ddd'") checkmessage("local a,b,c; (function () a = b+1 end)()", "upvalue 'b'") -assert(not doit"local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)") +assert(not doit("local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)")) checkmessage("b=1; local aaa='a'; x=aaa+b", "local 'aaa'") checkmessage("aaa={}; x=3/aaa", "global 'aaa'") checkmessage("aaa='2'; b=nil;x=aaa*b", "global 'b'") checkmessage("aaa={}; x=-aaa", "global 'aaa'") -assert(not string.find(doit"aaa={}; x=(aaa or aaa)+(aaa and aaa)", "'aaa'")) -assert(not string.find(doit"aaa={}; (aaa or aaa)()", "'aaa'")) +assert(not string.find(doit("aaa={}; x=(aaa or aaa)+(aaa and aaa)"), "'aaa'")) +assert(not string.find(doit("aaa={}; (aaa or aaa)()"), "'aaa'")) checkmessage([[aaa=9 repeat until 3==3 @@ -122,10 +122,10 @@ function lineerror (s) return line and line+0 end -assert(lineerror"local a\n for i=1,'a' do \n print(i) \n end" == 2) --- assert(lineerror"\n local a \n for k,v in 3 \n do \n print(k) \n end" == 3) --- assert(lineerror"\n\n for k,v in \n 3 \n do \n print(k) \n end" == 4) -assert(lineerror"function a.x.y ()\na=a+1\nend" == 1) +assert(lineerror("local a\n for i=1,'a' do \n print(i) \n end") == 2) +-- assert(lineerror("\n local a \n for k,v in 3 \n do \n print(k) \n end") == 3) +-- assert(lineerror("\n\n for k,v in \n 3 \n do \n print(k) \n end") == 4) +assert(lineerror("function a.x.y ()\na=a+1\nend") == 1) local p = [[ function g() f() end @@ -167,6 +167,81 @@ if not limitedstack then end end +-- C stack overflow +if not limitedstack then + local count = 1 + local cso = setmetatable({}, { + __index = function(self, i) + count = count + 1 + return self[i] + end, + __newindex = function(self, i, v) + count = count + 1 + self[i] = v + end, + __tostring = function(self) + count = count + 1 + return tostring(self) + end + }) + + local ehline + local function ehassert(cond) + if not cond then + ehline = debug.info(2, "l") + error() + end + end + + local userdata = newproxy(true) + getmetatable(userdata).__index = print + assert(debug.info(print, "s") == "[C]") + + local s, e = xpcall(tostring, function(e) + ehassert(string.find(e, "C stack overflow")) + print("after __tostring C stack overflow", count) -- 198: 1 resume + 1 xpcall + 198 luaB_tostring calls (which runs our __tostring successfully 197 times, erroring on the last attempt) + ehassert(count > 1) + + local ps, pe + + -- __tostring overflow (lua_call) + count = 1 + ps, pe = pcall(tostring, cso) + print("after __tostring overflow in handler", count) -- 23: xpcall error handler + pcall + 23 luaB_tostring calls + ehassert(not ps and string.find(pe, "error in error handling")) + ehassert(count > 1) + + -- __index overflow (callTMres) + count = 1 + ps, pe = pcall(function() return cso[cso] end) + print("after __index overflow in handler", count) -- 23: xpcall error handler + pcall + 23 __index calls + ehassert(not ps and string.find(pe, "error in error handling")) + ehassert(count > 1) + + -- __newindex overflow (callTM) + count = 1 + ps, pe = pcall(function() cso[cso] = "kohuke" end) + print("after __newindex overflow in handler", count) -- 23: xpcall error handler + pcall + 23 __newindex calls + ehassert(not ps and string.find(pe, "error in error handling")) + ehassert(count > 1) + + -- test various C __index invocations on userdata + ehassert(pcall(function() return userdata[userdata] end)) -- LOP_GETTABLE + ehassert(pcall(function() return userdata[1] end)) -- LOP_GETTABLEN + ehassert(pcall(function() return userdata.StringConstant end)) -- LOP_GETTABLEKS (luau_callTM) + + -- lua_resume test + local coro = coroutine.create(function() end) + ps, pe = coroutine.resume(coro) + ehassert(not ps and string.find(pe, "C stack overflow")) + + return true + end, cso) + + assert(not s) + assert(e == true, "error in xpcall eh, line " .. tostring(ehline)) +end + --[[ local i=1 while stack[i] ~= l1 do @@ -220,8 +295,9 @@ end -- testing syntax limits +local syntaxdepth = if limitedstack then 200 else 500 local function testrep (init, rep) - local s = "local a; "..init .. string.rep(rep, 300) + local s = "local a; "..init .. string.rep(rep, syntaxdepth) local a,b = loadstring(s) assert(not a) -- and string.find(b, "syntax levels")) end @@ -260,8 +336,7 @@ local a,b = loadstring(s) assert(not a) --assert(string.find(b, "line 2")) --- Test for CLI-28786 --- The xpcall is intentially going to cause an exception +-- The xpcall is intentionally going to cause an exception -- followed by a forced exception in the error handler. -- If the secondary handler isn't trapped, it will cause -- the unit test to fail. If the xpcall captures the @@ -281,16 +356,56 @@ coroutine.wrap(function() assert(not pcall(debug.getinfo, coroutine.running(), 0, ">")) end)() --- arith errors +-- loadstring chunk truncation +local a,b = loadstring("nope", "@short") +assert(not a and b:match('[^ ]+') == "short:1:") + +local a,b = loadstring("nope", "@" .. string.rep("thisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilities", 10)) +assert(not a and b:match('[^ ]+') == "...wontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilities:1:") + +local a,b = loadstring("nope", "=short") +assert(not a and b:match('[^ ]+') == "short:1:") + +local a,b = loadstring("nope", "=" .. string.rep("thisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilities", 10)) +assert(not a and b:match('[^ ]+') == "thisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbuffe:1:") + function ecall(fn, ...) local ok, err = pcall(fn, ...) assert(not ok) - return err:sub(err:find(": ") + 2, #err) + return err:sub((err:find(": ") or -1) + 2, #err) end +-- arith errors assert(ecall(function() return nil + 5 end) == "attempt to perform arithmetic (add) on nil and number") assert(ecall(function() return "a" + "b" end) == "attempt to perform arithmetic (add) on string") assert(ecall(function() return 1 > nil end) == "attempt to compare nil < number") -- note reversed order (by design) assert(ecall(function() return "a" <= 5 end) == "attempt to compare string <= number") +-- table errors +assert(ecall(function() local t = {} t[nil] = 2 end) == "table index is nil") +assert(ecall(function() local t = {} t[0/0] = 2 end) == "table index is NaN") + +assert(ecall(function() local t = {} rawset(t, nil, 2) end) == "table index is nil") +assert(ecall(function() local t = {} rawset(t, 0/0, 2) end) == "table index is NaN") + +assert(ecall(function() local t = {} t[nil] = nil end) == "table index is nil") +assert(ecall(function() local t = {} t[0/0] = nil end) == "table index is NaN") + +assert(ecall(function() local t = {} rawset(t, nil, nil) end) == "table index is nil") +assert(ecall(function() local t = {} rawset(t, 0/0, nil) end) == "table index is NaN") + +-- for loop type errors +assert(ecall(function() for i='a',2 do end end) == "invalid 'for' initial value (number expected, got string)") +assert(ecall(function() for i=1,'a' do end end) == "invalid 'for' limit (number expected, got string)") +assert(ecall(function() for i=1,2,'a' do end end) == "invalid 'for' step (number expected, got string)") + +-- method call errors +assert(ecall(function() ({}):foo() end) == "attempt to call missing method 'foo' of table") +assert(ecall(function() (""):foo() end) == "attempt to call missing method 'foo' of string") +assert(ecall(function() (42):foo() end) == "attempt to index number with 'foo'") +assert(ecall(function() ({foo=42}):foo() end) == "attempt to call a number value") +assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = {} ud:foo() end) == "attempt to call missing method 'foo' of userdata") +assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = function() end ud:foo() end) == "attempt to call missing method 'foo' of userdata") +assert(ecall(function() local ud = newproxy(true) getmetatable(ud).__index = function() error("nope") end ud:foo() end) == "nope") + return('OK') diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua index 32e090ac..94314c3f 100644 --- a/tests/conformance/events.lua +++ b/tests/conformance/events.lua @@ -4,26 +4,20 @@ print('testing metatables') local unpack = table.unpack -X = 20; B = 30 - -local _G = getfenv() -setfenv(1, setmetatable({}, {__index=_G})) - -collectgarbage() - -X = X+10 -assert(X == 30 and _G.X == 20) -B = false -assert(B == false) -B = nil -assert(B == 30) - assert(getmetatable{} == nil) assert(getmetatable(4) == nil) assert(getmetatable(nil) == nil) a={}; setmetatable(a, {__metatable = "xuxu", __tostring=function(x) return x.name end}) assert(getmetatable(a) == "xuxu") +ud=newproxy(true); getmetatable(ud).__metatable = "xuxu" +assert(getmetatable(ud) == "xuxu") + +assert(pcall(getmetatable) == false) +assert(pcall(function() return getmetatable() end) == false) +assert(select(2, pcall(getmetatable, {})) == nil) +assert(select(2, pcall(getmetatable, ud)) == "xuxu") + local res,err = pcall(tostring, a) assert(not res and err == "'__tostring' must return a string") -- cannot change a protected metatable @@ -296,14 +290,8 @@ x = c(3,4,5) assert(i == 3 and x[1] == 3 and x[3] == 5) -assert(_G.X == 20) -assert(_G == getfenv(0)) - print'+' -local _g = _G -setfenv(1, setmetatable({}, {__index=function (_,k) return _g[k] end})) - -- testing proxies assert(getmetatable(newproxy()) == nil) assert(getmetatable(newproxy(false)) == nil) @@ -386,4 +374,117 @@ do assert(t.X) -- fails if table flags are set incorrectly end +do + -- verify __len behavior & error handling + local t = {1} + + setmetatable(t, {}) + assert(#t == 1) + + setmetatable(t, { __len = rawlen }) + assert(#t == 1) + + setmetatable(t, { __len = function() return 42 end }) + assert(#t == 42) + + setmetatable(t, { __len = 42 }) + local ok, err = pcall(function() return #t end) + assert(not ok and err:match("attempt to call a number value")) + + setmetatable(t, { __len = function() end }) + local ok, err = pcall(function() return #t end) + assert(not ok and err:match("'__len' must return a number")) + + setmetatable(t, { __len = error }) + local ok, err = pcall(function() return #t end) + assert(not ok and err == t) +end + +-- verify rawlen behavior +do + local t = {1} + setmetatable(t, { __len = 42 }) + + assert(rawlen(t) == 1) + assert(rawlen("foo") == 3) + + local ok, err = pcall(function() return rawlen(42) end) + assert(not ok and err:match("table or string expected")) +end + +-- verify that NaN/nil keys are passed to __newindex even though table assignment with them anywhere in the chain fails +do + assert(pcall(function() local t = {} t[nil] = 5 end) == false) + assert(pcall(function() local t = {} setmetatable(t, { __newindex = {} }) t[nil] = 5 end) == false) + assert(pcall(function() local t = {} setmetatable(t, { __newindex = function() end }) t[nil] = 5 end) == true) + + assert(pcall(function() local t = {} t[0/0] = 5 end) == false) + assert(pcall(function() local t = {} setmetatable(t, { __newindex = {} }) t[0/0] = 5 end) == false) + assert(pcall(function() local t = {} setmetatable(t, { __newindex = function() end }) t[0/0] = 5 end) == true) +end + +-- verify that __newindex gets called for frozen tables but only if the assignment is to a key absent from the table +do + local ni = {} + local t = table.create(2) + + t[1] = 42 + -- t[2] is semantically absent with storage allocated for it + + t.a = 1 + t.b = 2 + t.b = nil -- this sets 'b' value to nil but leaves key as is to exercise more internal paths -- no observable behavior change expected between b and other absent keys + + setmetatable(t, { __newindex = function(_, k, v) + assert(v == 42) + table.insert(ni, k) + end }) + table.freeze(t) + + -- "redundant" combinations are there to test all three of SETTABLEN/SETTABLEKS/SETTABLE + assert(pcall(function() t.a = 42 end) == false) + assert(pcall(function() t[1] = 42 end) == false) + assert(pcall(function() local key key = "a" t[key] = 42 end) == false) + assert(pcall(function() local key key = 1 t[key] = 42 end) == false) + + -- now repeat the same for keys absent from the table: b (semantically absent), c (physically absent), 2 (semantically absent), 3 (physically absent) + assert(pcall(function() t.b = 42 end) == true) + assert(pcall(function() t.c = 42 end) == true) + assert(pcall(function() local key key = "b" t[key] = 42 end) == true) + assert(pcall(function() local key key = "c" t[key] = 42 end) == true) + assert(pcall(function() t[2] = 42 end) == true) + assert(pcall(function() t[3] = 42 end) == true) + assert(pcall(function() local key key = 2 t[key] = 42 end) == true) + assert(pcall(function() local key key = 3 t[key] = 42 end) == true) + + -- validate the assignment sequence + local ei = { "b", "c", "b", "c", 2, 3, 2, 3 } + assert(#ni == #ei) + for k,v in ni do + assert(ei[k] == v) + end +end + +function testfenv() + X = 20; B = 30 + + local _G = getfenv() + setfenv(1, setmetatable({}, {__index=_G})) + + X = X+10 + assert(X == 30 and _G.X == 20) + B = false + assert(B == false) + B = nil + assert(B == 30) + + assert(_G.X == 20) + assert(_G == getfenv(0)) + + assert(pcall(getfenv, 10) == false) + assert(pcall(setfenv, setfenv, {}) == false) +end + +testfenv() -- DONT MOVE THIS LINE + return 'OK' diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index fd4b4de1..c49dbe77 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -77,7 +77,7 @@ end local function dosteps (siz) collectgarbage() - collectgarbage"stop" + collectgarbage("stop") local a = {} for i=1,100 do a[i] = {{}}; local b = {} end local x = gcinfo() @@ -99,11 +99,11 @@ assert(dosteps(10000) == 1) do local x = gcinfo() collectgarbage() - collectgarbage"stop" + collectgarbage("stop") repeat local a = {} until gcinfo() > 1000 - collectgarbage"restart" + collectgarbage("restart") repeat local a = {} until gcinfo() < 1000 @@ -123,7 +123,7 @@ for n in pairs(b) do end b = nil collectgarbage() -for n in pairs(a) do error'cannot be here' end +for n in pairs(a) do error("cannot be here") end for i=1,lim do a[i] = i end for i=1,lim do assert(a[i] == i) end @@ -180,6 +180,11 @@ x,y,z=nil collectgarbage() assert(next(a) == string.rep('$', 11)) +-- shrinking tables reduce their capacity; confirming the shrinking is difficult but we can at least test the surface level behavior +a = {}; setmetatable(a, {__mode = 'ks'}) +for i=1,lim do a[{}] = i end +collectgarbage() +assert(next(a) == nil) -- testing userdata collectgarbage("stop") -- stop collection @@ -247,25 +252,30 @@ if not rawget(_G, "_soft") then end -- create many threads with self-references and open upvalues -local thread_id = 0 -local threads = {} +do + local thread_id = 0 + local threads = {} -function fn(thread) - local x = {} - threads[thread_id] = function() - thread = x - end - coroutine.yield() + function fn(thread) + local x = {} + threads[thread_id] = function() + thread = x + end + coroutine.yield() + end + + while thread_id < 1000 do + local thread = coroutine.create(fn) + coroutine.resume(thread, thread) + thread_id = thread_id + 1 + end + + collectgarbage() + + -- ensure that we no longer have a lot of reachable threads for subsequent tests + threads = {} end -while thread_id < 1000 do - local thread = coroutine.create(fn) - coroutine.resume(thread, thread) - thread_id = thread_id + 1 -end - - - -- create a userdata to be collected when state is closed do local newproxy,assert,type,print,getmetatable = @@ -277,7 +287,7 @@ do assert(getmetatable(o) == tt) -- create new objects during GC local a = 'xuxu'..(10+3)..'joao', {} - ___Glob = o -- ressurect object! + ___Glob = o -- resurrect object! newproxy(o) -- creates a new one with same metatable print(">>> closing state " .. "<<<\n") end @@ -291,4 +301,53 @@ do for i = 1,10 do table.insert(___Glob, newproxy(true)) end end +-- create threads that die together with their unmarked upvalues +do + local t = {} + + for i = 1,100 do + local c = coroutine.wrap(function() + local uv = {i + 1} + local function f() + return uv[1] * 10 + end + coroutine.yield(uv[1]) + uv = {i + 2} + coroutine.yield(f()) + end) + + assert(c() == i + 1) + table.insert(t, c) + end + + for i = 1,100 do + t[i] = nil + end + + collectgarbage() +end + +-- create a lot of threads with upvalues to force a case where full gc happens after we've marked some upvalues +do + local t = {} + for i = 1,100 do + local c = coroutine.wrap(function() + local uv = {i + 1} + local function f() + return uv[1] * 10 + end + coroutine.yield(uv[1]) + uv = {i + 2} + coroutine.yield(f()) + end) + + assert(c() == i + 1) + table.insert(t, c) + end + + t = {} + + collectgarbage() +end + return('OK') diff --git a/tests/conformance/interrupt.lua b/tests/conformance/interrupt.lua new file mode 100644 index 00000000..d4b7c80a --- /dev/null +++ b/tests/conformance/interrupt.lua @@ -0,0 +1,20 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing interrupts") + +function foo() + for i=1,10 do end + return +end + +foo() + +function bar() + local i = 0 + while i < 10 do + i += i + 1 + end +end + +bar() + +return "OK" diff --git a/tests/conformance/iter.lua b/tests/conformance/iter.lua new file mode 100644 index 00000000..5f8f1a89 --- /dev/null +++ b/tests/conformance/iter.lua @@ -0,0 +1,216 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print('testing iteration') + +-- basic for loop tests +do + local a + for a,b in pairs{} do error("not here") end + for i=1,0 do error("not here") end + for i=0,1,-1 do error("not here") end + a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) + a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) + a = 0; for i=0, 1, 0.1 do a=a+1 end; assert(a==11) +end + +-- precision tests for for loops +do + local a + --a = 0; for i=1, 0, -0.01 do a=a+1 end; assert(a==101) + a = 0; for i=0, 0.999999999, 0.1 do a=a+1 end; assert(a==10) + a = 0; for i=1, 1, 1 do a=a+1 end; assert(a==1) + a = 0; for i=1e10, 1e10, -1 do a=a+1 end; assert(a==1) + a = 0; for i=1, 0.99999, 1 do a=a+1 end; assert(a==0) + a = 0; for i=99999, 1e5, -1 do a=a+1 end; assert(a==0) + a = 0; for i=1, 0.99999, -1 do a=a+1 end; assert(a==1) +end + +-- for loops do string->number coercion +do + local a = 0; for i="10","1","-2" do a=a+1 end; assert(a==5) +end + +-- generic for with function iterators +do + local function f (n, p) + local t = {}; for i=1,p do t[i] = i*10 end + return function (_,n) + if n > 0 then + n = n-1 + return n, unpack(t) + end + end, nil, n + end + + local x = 0 + for n,a,b,c,d in f(5,3) do + x = x+1 + assert(a == 10 and b == 20 and c == 30 and d == nil) + end + assert(x == 5) +end + +-- generic for with __call (tables) +do + local f = {} + setmetatable(f, { __call = function(_, _, n) if n > 0 then return n - 1 end end }) + + local x = 0 + for n in f, nil, 5 do + x += n + end + assert(x == 10) +end + +-- generic for with __call (userdata) +do + local f = newproxy(true) + getmetatable(f).__call = function(_, _, n) if n > 0 then return n - 1 end end + + local x = 0 + for n in f, nil, 5 do + x += n + end + assert(x == 10) +end + +-- generic for with pairs +do + local x = 0 + for k, v in pairs({a = 1, b = 2, c = 3}) do + x += v + end + assert(x == 6) +end + +-- generic for with pairs with holes +do + local x = 0 + for k, v in pairs({1, 2, 3, nil, 5}) do + x += v + end + assert(x == 11) +end + +-- generic for with ipairs +do + local x = 0 + for k, v in ipairs({1, 2, 3, nil, 5}) do + x += v + end + assert(x == 6) +end + +-- generic for with __iter (tables) +do + local f = {} + setmetatable(f, { __iter = function(x) + assert(f == x) + return next, {1, 2, 3, 4} + end }) + + local x = 0 + for n in f do + x += n + end + assert(x == 10) +end + +-- generic for with __iter (userdata) +do + local f = newproxy(true) + getmetatable(f).__iter = function(x) + assert(f == x) + return next, {1, 2, 3, 4} + end + + local x = 0 + for n in f do + x += n + end + assert(x == 10) +end + +-- generic for with tables (dictionary) +do + local x = 0 + for k, v in {a = 1, b = 2, c = 3} do + print(k, v) + x += v + end + assert(x == 6) +end + +-- generic for with tables (arrays) +do + local x = '' + for k, v in {1, 2, 3, nil, 5} do + x ..= tostring(v) + end + assert(x == "1235") +end + +-- generic for with tables (mixed) +do + local x = 0 + for k, v in {1, 2, 3, nil, 5, a = 1, b = 2, c = 3} do + x += v + end + assert(x == 17) +end + +-- generic for over a non-iterable object +do + local ok, err = pcall(function() for x in 42 do end end) + assert(not ok and err:match("attempt to iterate")) +end + +-- generic for over an iterable object that doesn't return a function +do + local obj = {} + setmetatable(obj, { __iter = function() end }) + + local ok, err = pcall(function() for x in obj do end end) + assert(not ok and err:match("attempt to call a nil value")) +end + +-- it's okay to iterate through a table with a single variable +do + local x = 0 + for k in {1, 2, 3, 4, 5} do + x += k + end + assert(x == 15) +end + +-- all extra variables should be set to nil during builtin traversal +do + local x = 0 + for k,v,a,b,c,d,e in {1, 2, 3, 4, 5} do + x += k + assert(a == nil and b == nil and c == nil and d == nil and e == nil) + end + assert(x == 15) +end + +-- pairs/ipairs/next may be substituted through getfenv +-- however, they *must* be substituted with functions - we don't support them falling back to generalized iteration +function testgetfenv() + local env = getfenv(1) + env.pairs = function() return "nope" end + env.ipairs = function() return "nope" end + env.next = {1, 2, 3} + + local ok, err = pcall(function() for k, v in pairs({}) do end end) + assert(not ok and err:match("attempt to iterate over a string value")) + + local ok, err = pcall(function() for k, v in ipairs({}) do end end) + assert(not ok and err:match("attempt to iterate over a string value")) + + local ok, err = pcall(function() for k, v in next, {} do end end) + assert(not ok and err:match("attempt to iterate over a table value")) +end + +testgetfenv() -- DONT MOVE THIS LINE + +return"OK" diff --git a/tests/conformance/locals.lua b/tests/conformance/locals.lua index cbe5f92d..2d8d004b 100644 --- a/tests/conformance/locals.lua +++ b/tests/conformance/locals.lua @@ -117,7 +117,7 @@ if rawget(_G, "querytab") then local t = querytab(a) for k,_ in pairs(a) do a[k] = nil end - collectgarbage() -- restore GC and collect dead fiels in `a' + collectgarbage() -- restore GC and collect dead fields in `a' for i=0,t-1 do local k = querytab(a, i) assert(k == nil or type(k) == 'number' or k == 'alo') diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 5e8b9398..972c399b 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -26,6 +26,7 @@ function f(...) end end +assert(pcall(tonumber) == false) assert(tonumber{} == nil) assert(tonumber'+0.01' == 1/100 and tonumber'+.01' == 0.01 and tonumber'.01' == 0.01 and tonumber'-1.' == -1 and @@ -151,14 +152,34 @@ assert(eq(a[1000][3], 1000/3, 0.001)) print('+') do -- testing NaN - local NaN = 10e500 - 10e400 + local NaN -- to avoid constant folding + NaN = 10e500 - 10e400 + assert(NaN ~= NaN) + assert(not (NaN == NaN)) + assert(not (NaN < NaN)) assert(not (NaN <= NaN)) assert(not (NaN > NaN)) assert(not (NaN >= NaN)) + + assert(not (0 == NaN)) assert(not (0 < NaN)) + assert(not (0 <= NaN)) + assert(not (0 > NaN)) + assert(not (0 >= NaN)) + + assert(not (NaN == 0)) assert(not (NaN < 0)) + assert(not (NaN <= 0)) + assert(not (NaN > 0)) + assert(not (NaN >= 0)) + + assert(if NaN < 0 then false else true) + assert(if NaN <= 0 then false else true) + assert(if NaN > 0 then false else true) + assert(if NaN >= 0 then false else true) + local a = {} assert(not pcall(function () a[NaN] = 1 end)) assert(a[NaN] == nil) @@ -172,7 +193,7 @@ end a = nil --- testing implicit convertions +-- testing implicit conversions local a,b = '10', '20' assert(a*b == 200 and a+b == 30 and a-b == -10 and a/b == 0.5 and -b == -20) @@ -214,6 +235,16 @@ assert(flag); assert(select(2, pcall(math.random, 1, 2, 3)):match("wrong number of arguments")) +-- min/max +assert(math.min(1) == 1) +assert(math.min(1, 2) == 1) +assert(math.min(1, 2, -1) == -1) +assert(math.min(1, -1, 2) == -1) +assert(math.max(1) == 1) +assert(math.max(1, 2) == 2) +assert(math.max(1, 2, -1) == 2) +assert(math.max(1, -1, 2) == 2) + -- noise assert(math.noise(0.5) == 0) assert(math.noise(0.5, 0.5) == -0.25) @@ -252,6 +283,13 @@ assert(math.fmod(-3, 2) == -1) assert(math.fmod(3, -2) == 1) assert(math.fmod(-3, -2) == -1) +-- pow +assert(math.pow(2, 0) == 1) +assert(math.pow(2, 2) == 4) +assert(math.pow(4, 0.5) == 2) +assert(math.pow(-2, 2) == 4) +assert(tostring(math.pow(-2, 0.5)) == "nan") + -- most of the tests above go through fastcall path -- to make sure the basic implementations are also correct we test these functions with string->number coercions assert(math.abs("-4") == 4) @@ -276,8 +314,10 @@ assert(math.log("10", 10) == 1) assert(math.log("9", 3) == 2) assert(math.max("1", 2) == 2) assert(math.max(2, "1") == 2) +assert(math.max(1, 2, "3") == 3) assert(math.min("1", 2) == 1) assert(math.min(2, "1") == 1) +assert(math.min(1, 2, "3") == 1) local v,f = math.modf("1.5") assert(v == 1 and f == 0.5) assert(math.pow("2", 2) == 4) @@ -288,9 +328,15 @@ assert(math.sqrt("4") == 2) assert(math.tanh("0") == 0) assert(math.tan("0") == 0) assert(math.clamp("0", 2, 3) == 2) +assert(math.clamp("4", 2, 3) == 3) assert(math.sign("2") == 1) assert(math.sign("-2") == -1) assert(math.sign("0") == 0) assert(math.round("1.8") == 2) +-- test that fastcalls return correct number of results +assert(select('#', math.floor(1.4)) == 1) +assert(select('#', math.ceil(1.6)) == 1) +assert(select('#', math.sqrt(9)) == 1) + return('OK') diff --git a/tests/conformance/move.lua b/tests/conformance/move.lua index 3f28b4b3..27a96ffc 100644 --- a/tests/conformance/move.lua +++ b/tests/conformance/move.lua @@ -74,4 +74,6 @@ checkerror("wrap around", table.move, {}, 1, maxI, 2) checkerror("wrap around", table.move, {}, 1, 2, maxI) checkerror("wrap around", table.move, {}, minI, -2, 2) +checkerror("readonly", table.move, table.freeze({}), 1, 1, 1) + return"OK" diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index a2072d2c..969209fc 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -144,4 +144,21 @@ coroutine.resume(co) resumeerror(co, "fail") checkresults({ true, false, "fail" }, coroutine.resume(co)) -return'OK' +-- stack overflow needs to happen at the call limit +local calllimit = 20000 +function recurse(n) return n <= 1 and 1 or recurse(n-1) + 1 end + +-- we use one frame for top-level function and one frame is the service frame for coroutines +assert(recurse(calllimit - 2) == calllimit - 2) + +-- note that when calling through pcall, pcall eats one more frame +checkresults({ true, calllimit - 3 }, pcall(recurse, calllimit - 3)) +checkerror(pcall(recurse, calllimit - 2)) + +-- xpcall handler runs in context of the stack frame, but this works just fine since we allow extra stack consumption past stack overflow +checkresults({ false, "ok" }, xpcall(recurse, function() return string.reverse("ko") end, calllimit - 2)) + +-- however, if xpcall handler itself runs out of extra stack space, we get "error in error handling" +checkresults({ false, "error in error handling" }, xpcall(recurse, function() return recurse(calllimit) end, calllimit - 2)) + +return 'OK' diff --git a/tests/conformance/pm.lua b/tests/conformance/pm.lua index 9a113964..263759ac 100644 --- a/tests/conformance/pm.lua +++ b/tests/conformance/pm.lua @@ -21,9 +21,9 @@ a,b = string.find('alo', '') assert(a == 1 and b == 0) a,b = string.find('a\0o a\0o a\0o', 'a', 1) -- first position assert(a == 1 and b == 1) -a,b = string.find('a\0o a\0o a\0o', 'a\0o', 2) -- starts in the midle +a,b = string.find('a\0o a\0o a\0o', 'a\0o', 2) -- starts in the middle assert(a == 5 and b == 7) -a,b = string.find('a\0o a\0o a\0o', 'a\0o', 9) -- starts in the midle +a,b = string.find('a\0o a\0o a\0o', 'a\0o', 9) -- starts in the middle assert(a == 9 and b == 11) a,b = string.find('a\0a\0a\0a\0\0ab', '\0ab', 2); -- finds at the end assert(a == 9 and b == 11); diff --git a/tests/conformance/safeenv.lua b/tests/conformance/safeenv.lua new file mode 100644 index 00000000..3a430b5f --- /dev/null +++ b/tests/conformance/safeenv.lua @@ -0,0 +1,23 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("safeenv reset") + +local function envChangeInMetamethod() + -- declare constant so that at O2 this test doesn't interfere with constant folding which we can't deoptimize + local ten + ten = 10 + + local a = setmetatable({}, { + __index = function() + getfenv().math = { abs = function(n) return n*n end } + return 2 + end + }) + + local b = a.x + + assert(math.abs(ten) == 100) +end + +envChangeInMetamethod() + +return"OK" diff --git a/tests/conformance/strconv.lua b/tests/conformance/strconv.lua new file mode 100644 index 00000000..85ad0295 --- /dev/null +++ b/tests/conformance/strconv.lua @@ -0,0 +1,51 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print("testing string-number conversion") + +-- zero +assert(tostring(0) == "0") +assert(tostring(0/-1) == "-0") + +-- specials +assert(tostring(1/0) == "inf") +assert(tostring(-1/0) == "-inf") +assert(tostring(0/0) == "nan") + +-- integers +assert(tostring(1) == "1") +assert(tostring(42) == "42") +assert(tostring(-4294967296) == "-4294967296") +assert(tostring(9007199254740991) == "9007199254740991") + +-- decimals +assert(tostring(0.5) == "0.5") +assert(tostring(0.1) == "0.1") +assert(tostring(-0.17) == "-0.17") +assert(tostring(math.pi) == "3.141592653589793") + +-- fuzzing corpus +assert(tostring(5.4536123983019448e-311) == "5.453612398302e-311") +assert(tostring(5.4834368411298348e-311) == "5.48343684113e-311") +assert(tostring(4.4154895841930002e-305) == "4.415489584193e-305") +assert(tostring(1125968630513728) == "1125968630513728") +assert(tostring(3.3951932655938423e-313) == "3.3951932656e-313") +assert(tostring(1.625) == "1.625") +assert(tostring(4.9406564584124654e-324) == "5.e-324") +assert(tostring(2.0049288280105384) == "2.0049288280105384") +assert(tostring(3.0517578125e-05) == "0.000030517578125") +assert(tostring(1.383544921875) == "1.383544921875") +assert(tostring(3.0053350932691001) == "3.0053350932691") +assert(tostring(0.0001373291015625) == "0.0001373291015625") +assert(tostring(-1.9490628022799998e+289) == "-1.94906280228e+289") +assert(tostring(-0.00610404721867928) == "-0.00610404721867928") +assert(tostring(0.00014495849609375) == "0.00014495849609375") +assert(tostring(0.453125) == "0.453125") +assert(tostring(-4.2375343999999997e+73) == "-4.2375344e+73") +assert(tostring(1.3202313930270133e-192) == "1.3202313930270133e-192") +assert(tostring(3.6984408976312836e+19) == "36984408976312840000") +assert(tostring(2.0563000527063302) == "2.05630005270633") +assert(tostring(4.8970527433648997e-260) == "4.8970527433649e-260") +assert(tostring(1.62890625) == "1.62890625") +assert(tostring(1.1295093211933533e+65) == "1.1295093211933533e+65") + +return "OK" diff --git a/tests/conformance/stringinterp.lua b/tests/conformance/stringinterp.lua new file mode 100644 index 00000000..efb25bae --- /dev/null +++ b/tests/conformance/stringinterp.lua @@ -0,0 +1,59 @@ +local function assertEq(left, right) + assert(typeof(left) == "string", "left is a " .. typeof(left)) + assert(typeof(right) == "string", "right is a " .. typeof(right)) + + if left ~= right then + error(string.format("%q ~= %q", left, right)) + end +end + +assertEq(`hello {"world"}`, "hello world") +assertEq(`Welcome {"to"} {"Luau"}!`, "Welcome to Luau!") + +assertEq(`2 + 2 = {2 + 2}`, "2 + 2 = 4") + +assertEq(`{1} {2} {3} {4} {5} {6} {7}`, "1 2 3 4 5 6 7") + +local combo = {5, 2, 8, 9} +assertEq(`The lock combinations are: {table.concat(combo, ", ")}`, "The lock combinations are: 5, 2, 8, 9") + +assertEq(`true = {true}`, "true = true") + +local name = "Luau" +assertEq(`Welcome to { + name +}!`, "Welcome to Luau!") + +local nameNotConstantEvaluated = (function() return "Luau" end)() +assertEq(`Welcome to {nameNotConstantEvaluated}!`, "Welcome to Luau!") + +assertEq(`This {localName} does not exist`, "This nil does not exist") + +assertEq(`Welcome to \ +{name}!`, "Welcome to \nLuau!") + +assertEq(`empty`, "empty") + +assertEq(`Escaped brace: \{}`, "Escaped brace: {}") +assertEq(`Escaped brace \{} with {"expression"}`, "Escaped brace {} with expression") +assertEq(`Backslash \ that escapes the space is not a part of the string...`, "Backslash that escapes the space is not a part of the string...") +assertEq(`Escaped backslash \\`, "Escaped backslash \\") +assertEq(`Escaped backtick: \``, "Escaped backtick: `") + +assertEq(`Hello {`from inside {"a nested string"}`}`, "Hello from inside a nested string") + +assertEq(`1 {`2 {`3 {4}`}`}`, "1 2 3 4") + +local health = 50 +assert(`You have {health}% health` == "You have 50% health") + +local function shadowsString(string) + return `Value is {string}` +end + +assertEq(shadowsString("hello"), "Value is hello") +assertEq(shadowsString(1), "Value is 1") + +assertEq(`\u{0041}\t`, "A\t") + +return "OK" diff --git a/tests/conformance/strings.lua b/tests/conformance/strings.lua index 98a5721a..61bac726 100644 --- a/tests/conformance/strings.lua +++ b/tests/conformance/strings.lua @@ -48,6 +48,7 @@ assert(string.find("", "") == 1) assert(string.find('', 'aaa', 1) == nil) assert(('alo(.)alo'):find('(.)', 1, 1) == 4) assert(string.find('', '1', 2) == nil) +assert(string.find('123', '2', 0) == 2) print('+') assert(string.len("") == 0) @@ -88,6 +89,8 @@ assert(string.lower("\0ABCc%$") == "\0abcc%$") assert(string.rep('teste', 0) == '') assert(string.rep('tés\00tê', 2) == 'tés\0têtés\000tê') assert(string.rep('', 10) == '') +assert(string.rep('', 1e9) == '') +assert(pcall(string.rep, 'x', 2e9) == false) assert(string.reverse"" == "") assert(string.reverse"\0\1\2\3" == "\3\2\1\0") @@ -126,10 +129,45 @@ assert(string.format("-%.20s.20s", string.rep("%", 2000)) == "-"..string.rep("%" assert(string.format('"-%20s.20s"', string.rep("%", 2000)) == string.format("%q", "-"..string.rep("%", 2000)..".20s")) +assert(string.format("%o %u %x %X", -1, -1, -1, -1) == "1777777777777777777777 18446744073709551615 ffffffffffffffff FFFFFFFFFFFFFFFF") + +assert(string.format("%e %E", 1.5, -1.5) == "1.500000e+00 -1.500000E+00") + +assert(pcall(string.format, "%##################d", 1) == false) +assert(pcall(string.format, "%.123d", 1) == false) +assert(pcall(string.format, "%?", 1) == false) -- longest number that can be formated assert(string.len(string.format('%99.99f', -1e308)) >= 100) +local function return_one_thing() + return "hi" +end +local function return_two_nils() + return nil, nil +end + +assert(string.format("%*", return_one_thing()) == "hi") +assert(string.format("%* %*", return_two_nils()) == "nil nil") +assert(pcall(function() + string.format("%* %* %*", return_two_nils()) +end) == false) + +assert(string.format("%*", "a\0b\0c") == "a\0b\0c") +assert(string.format("%*", string.rep("doge", 3000)) == string.rep("doge", 3000)) +assert(string.format("%*", 42) == "42") +assert(string.format("%*", true) == "true") + +assert(string.format("%*", setmetatable({}, { __tostring = function() return "ok" end })) == "ok") + +local ud = newproxy(true) +getmetatable(ud).__tostring = function() return "good" end +assert(string.format("%*", ud) == "good") + +assert(pcall(function() + string.format("%#*", "bad form") +end) == false) + assert(loadstring("return 1\n--comentário sem EOL no final")() == 1) @@ -151,6 +189,26 @@ assert(table.concat(a, ",", 2) == "b,c") assert(table.concat(a, ",", 3) == "c") assert(table.concat(a, ",", 4) == "") +-- string.split +do + local function eq(a, b) + if #a ~= #b then + return false + end + for i=1,#a do + if a[i] ~= b[i] then + return false + end + end + return true + end + + assert(eq(string.split("abc", ""), {'a', 'b', 'c'})) + assert(eq(string.split("abc", "b"), {'a', 'c'})) + assert(eq(string.split("abc", "d"), {'abc'})) + assert(eq(string.split("abc", "c"), {'ab', ''})) +end + --[[ local locales = { "ptb", "ISO-8859-1", "pt_BR" } local function trylocale (w) diff --git a/tests/conformance/nextvar.lua b/tests/conformance/tables.lua similarity index 64% rename from tests/conformance/nextvar.lua rename to tests/conformance/tables.lua index 7f9b7596..7ae80cc4 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/tables.lua @@ -87,35 +87,59 @@ print'+' -- testing tables dynamically built local lim = 130 -local a = {}; a[2] = 1; check(a, 0, 1) -a = {}; a[0] = 1; check(a, 0, 1); a[2] = 1; check(a, 0, 2) -a = {}; a[0] = 1; a[1] = 1; check(a, 1, 1) -a = {} -for i = 1,lim do - a[i] = 1 - assert(#a == i) - check(a, mp2(i), 0) + +do + local a = {}; a[2] = 1; check(a, 0, 1) + a = {}; a[0] = 1; check(a, 0, 1); a[2] = 1; check(a, 0, 2) + a = {}; a[0] = 1; a[1] = 1; check(a, 1, 1) + a = {} + for i = 1,lim do + a[i] = 1 + assert(#a == i) + check(a, mp2(i), 0) + end end -a = {} -for i = 1,lim do - a['a'..i] = 1 - assert(#a == 0) - check(a, 0, mp2(i)) +do + local a = {} + for i = 1,lim do + a['a'..i] = 1 + assert(#a == 0) + check(a, 0, mp2(i)) + end end -a = {} -for i=1,16 do a[i] = i end -check(a, 16, 0) -for i=1,11 do a[i] = nil end -for i=30,40 do a[i] = nil end -- force a rehash (?) -check(a, 0, 8) -a[10] = 1 -for i=30,40 do a[i] = nil end -- force a rehash (?) -check(a, 0, 8) -for i=1,14 do a[i] = nil end -for i=30,50 do a[i] = nil end -- force a rehash (?) -check(a, 0, 4) +do + local a = {} + for i=1,16 do a[i] = i end + check(a, 16, 0) + for i=1,11 do a[i] = nil end + for i=30,40 do a[i] = nil end -- force a rehash (?) + check(a, 0, 8) + a[10] = 1 + for i=30,40 do a[i] = nil end -- force a rehash (?) + check(a, 0, 8) + for i=1,14 do a[i] = nil end + for i=30,50 do a[i] = nil end -- force a rehash (?) + check(a, 0, 4) +end + +do -- rehash moving elements from array to hash + local a = {} + for i = 1, 100 do a[i] = i end + check(a, 128, 0) + + for i = 5, 95 do a[i] = nil end + check(a, 128, 0) + + a.x = 1 -- force a re-hash + check(a, 4, 8) + + for i = 1, 4 do assert(a[i] == i) end + for i = 5, 95 do assert(a[i] == nil) end + for i = 96, 100 do assert(a[i] == i) end + assert(a.x == 1) +end -- reverse filling for i=1,lim do @@ -368,48 +392,6 @@ assert(next(a,nil) == 1000 and next(a,1000) == nil) assert(next({}) == nil) assert(next({}, nil) == nil) -for a,b in pairs{} do error"not here" end -for i=1,0 do error'not here' end -for i=0,1,-1 do error'not here' end -a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) -a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) - -a = 0; for i=0, 1, 0.1 do a=a+1 end; assert(a==11) --- precision problems ---a = 0; for i=1, 0, -0.01 do a=a+1 end; assert(a==101) -a = 0; for i=0, 0.999999999, 0.1 do a=a+1 end; assert(a==10) -a = 0; for i=1, 1, 1 do a=a+1 end; assert(a==1) -a = 0; for i=1e10, 1e10, -1 do a=a+1 end; assert(a==1) -a = 0; for i=1, 0.99999, 1 do a=a+1 end; assert(a==0) -a = 0; for i=99999, 1e5, -1 do a=a+1 end; assert(a==0) -a = 0; for i=1, 0.99999, -1 do a=a+1 end; assert(a==1) - --- conversion -a = 0; for i="10","1","-2" do a=a+1 end; assert(a==5) - - -collectgarbage() - - --- testing generic 'for' - -local function f (n, p) - local t = {}; for i=1,p do t[i] = i*10 end - return function (_,n) - if n > 0 then - n = n-1 - return n, unpack(t) - end - end, nil, n -end - -local x = 0 -for n,a,b,c,d in f(5,3) do - x = x+1 - assert(a == 10 and b == 20 and c == 30 and d == nil) -end -assert(x == 5) - -- testing table.create and table.find do local t = table.create(5) @@ -512,4 +494,196 @@ do assert(#t == 7) end +-- test clone +do + local t = {a = 1, b = 2, 3, 4, 5} + local tt = table.clone(t) + + assert(#tt == 3) + assert(tt.a == 1 and tt.b == 2) + + t.c = 3 + assert(tt.c == nil) + + t = table.freeze({"test"}) + tt = table.clone(t) + assert(table.isfrozen(t) and not table.isfrozen(tt)) + + t = setmetatable({}, {}) + tt = table.clone(t) + assert(getmetatable(t) == getmetatable(tt)) + + t = setmetatable({}, {__metatable = "protected"}) + assert(not pcall(table.clone, t)) + + function order(t) + local r = '' + for k,v in pairs(t) do + r ..= tostring(v) + end + return v + end + + t = {a = 1, b = 2, c = 3, d = 4, e = 5, f = 6} + tt = table.clone(t) + assert(order(t) == order(tt)) + + assert(not pcall(table.clone)) + assert(not pcall(table.clone, 42)) +end + +-- test boundary invariant maintenance during rehash +do + local arr = table.create(5, 42) + + arr[1] = nil + arr.a = 'a' -- trigger rehash + + assert(#arr == 5) -- technically 0 is also valid, but it happens to be 5 because array capacity is 5 +end + +-- test boundary invariant maintenance when replacing hash keys +do + local arr = {} + arr.a = 'a' + arr.a = nil + arr[1] = 1 -- should rehash and resize array part, otherwise # won't find the boundary in array part + + assert(#arr == 1) +end + +-- test boundary invariant maintenance when table is filled from the end +do + local arr = {} + for i=5,2,-1 do + arr[i] = i + assert(#arr == 0) + end + arr[1] = 1 + assert(#arr == 5) +end + +-- test boundary invariant maintenance when table is filled using SETLIST opcode +do + local arr = {[2]=2,1} + assert(#arr == 2) +end + +-- test boundary invariant maintenance when table is filled using table.move +do + local t1 = {1, 2, 3, 4, 5} + local t2 = {[6] = 6} + + table.move(t1, 1, 5, 1, t2) + assert(#t2 == 6) +end + +-- test table.unpack fastcall for rejecting large unpacks +do + local ok, res = pcall(function() + local a = table.create(7999, 0) + local b = table.create(8000, 0) + + local at = { table.unpack(a) } + local bt = { table.unpack(b) } + end) + + assert(not ok) +end + +-- test iteration with lightuserdata keys +do + function countud() + local t = {} + t[makelud(1)] = 1 + t[makelud(2)] = 2 + + local count = 0 + for k,v in pairs(t) do + count += v + end + + return count + end + + assert(countud() == 3) +end + +-- test iteration with lightuserdata keys with a substituted environment +do + local env = { makelud = makelud, pairs = pairs } + setfenv(countud, env) + assert(countud() == 3) +end + +-- test __newindex-as-a-table indirection: this had memory safety bugs in Lua 5.1.0 +do + local hit = false + + local grandparent = {} + grandparent.__newindex = function(s,k,v) + assert(k == "foo" and v == 10) + hit = true + end + + local parent = {} + parent.__newindex = parent + setmetatable(parent, grandparent) + + local child = setmetatable({}, parent) + child.foo = 10 + + assert(hit and child.foo == nil and parent.foo == nil) +end + +-- testing next x GC of deleted keys +do + local co = coroutine.wrap(function (t) + for k, v in pairs(t) do + local k1 = next(t) -- all previous keys were deleted + assert(k == k1) -- current key is the first in the table + t[k] = nil + local expected = (type(k) == "table" and k[1] or + type(k) == "function" and k() or + string.sub(k, 1, 1)) + assert(expected == v) + coroutine.yield(v) + end + end) + local t = {} + t[{1}] = 1 -- add several unanchored, collectable keys + t[{2}] = 2 + t[string.rep("a", 50)] = "a" -- long string + t[string.rep("b", 50)] = "b" + t[{3}] = 3 + t[string.rep("c", 10)] = "c" -- short string + t[function () return 10 end] = 10 + local count = 7 + while co(t) do + collectgarbage("collect") -- collect dead keys + count = count - 1 + end + assert(count == 0 and next(t) == nil) -- traversed the whole table +end + +-- test error cases for table functions +do + assert(pcall(table.insert, {}) == false) + assert(pcall(table.insert, {}, 1, 2, 3) == false) + assert(pcall(table.insert, table.freeze({1, 2, 3}), 4) == false) + assert(pcall(table.insert, table.freeze({1, 2, 3}), 1, 4) == false) + + assert(pcall(table.remove, table.freeze({1})) == false) + + assert(pcall(table.concat, {true}) == false) + + assert(pcall(table.create) == false) + assert(pcall(table.create, -1) == false) + assert(pcall(table.create, 1e9) == false) + + assert(pcall(table.find, {}, 42, 0) == false) + + assert(pcall(table.clear, table.freeze({})) == false) +end + return"OK" diff --git a/tests/conformance/tmerror.lua b/tests/conformance/tmerror.lua new file mode 100644 index 00000000..1ad4dd16 --- /dev/null +++ b/tests/conformance/tmerror.lua @@ -0,0 +1,15 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes + +-- Generate an error (i.e. throw an exception) inside a tag method which is indirectly +-- called via pcall. +-- This test is meant to detect a regression in handling errors inside a tag method + +local testtable = {} +setmetatable(testtable, { __index = function() error("Error") end }) + +pcall(function() + testtable.missingmethod() +end) + +return('OK') diff --git a/tests/conformance/tpack.lua b/tests/conformance/tpack.lua index 835bf564..b240f482 100644 --- a/tests/conformance/tpack.lua +++ b/tests/conformance/tpack.lua @@ -306,6 +306,8 @@ do -- testing initial position assert(i == 4 and p == 17) local i, p = unpack("!4 i4", x, -#x) assert(i == 1 and p == 5) + local i, p = unpack("!4 i4", x, 0) + assert(i == 1 and p == 5) -- limits for i = 1, #x + 1 do diff --git a/tests/conformance/types.lua b/tests/conformance/types.lua index cdddceef..3539b34d 100644 --- a/tests/conformance/types.lua +++ b/tests/conformance/types.lua @@ -10,9 +10,6 @@ local ignore = -- what follows is a set of mismatches that hopefully eventually will go down to 0 "_G.require", -- need to move to Roblox type defs - "_G.utf8.nfcnormalize", -- need to move to Roblox type defs - "_G.utf8.nfdnormalize", -- need to move to Roblox type defs - "_G.utf8.graphemes", -- need to move to Roblox type defs } function verify(real, rtti, path) diff --git a/tests/conformance/userdata.lua b/tests/conformance/userdata.lua new file mode 100644 index 00000000..98392e25 --- /dev/null +++ b/tests/conformance/userdata.lua @@ -0,0 +1,45 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print('testing userdata') + +-- int64 is a userdata type defined in C++ + +-- equality +assert(int64(1) == int64(1)) +assert(int64(1) ~= int64(2)) + +-- relational +assert(not (int64(1) < int64(1))) +assert(int64(1) < int64(2)) +assert(int64(1) <= int64(1)) +assert(int64(1) <= int64(2)) + +-- arithmetics +assert(-int64(2) == int64(-2)) + +assert(int64(1) + int64(2) == int64(3)) +assert(int64(1) - int64(2) == int64(-1)) +assert(int64(2) * int64(3) == int64(6)) +assert(int64(4) / int64(2) == int64(2)) +assert(int64(4) % int64(3) == int64(1)) +assert(int64(2) ^ int64(3) == int64(8)) + +assert(int64(1) + 2 == int64(3)) +assert(int64(1) - 2 == int64(-1)) +assert(int64(2) * 3 == int64(6)) +assert(int64(4) / 2 == int64(2)) +assert(int64(4) % 3 == int64(1)) +assert(int64(2) ^ 3 == int64(8)) + +-- tostring +assert(tostring(int64(2)) == "2") + +-- index/newindex; note, mutable userdatas aren't very idiomatic but we need to test this +do + local v = int64(42) + assert(v.value == 42) + v.value = 4 + assert(v.value == 4) + assert(v == int64(4)) +end + +return 'OK' diff --git a/tests/conformance/utf8.lua b/tests/conformance/utf8.lua index 024cb16d..bfd7a1ac 100644 --- a/tests/conformance/utf8.lua +++ b/tests/conformance/utf8.lua @@ -205,4 +205,4 @@ for p, c in string.gmatch(x, "()(" .. utf8.charpattern .. ")") do end end -return'OK' +return 'OK' diff --git a/tests/conformance/vararg.lua b/tests/conformance/vararg.lua index 5aaa422b..178c56b8 100644 --- a/tests/conformance/vararg.lua +++ b/tests/conformance/vararg.lua @@ -105,6 +105,51 @@ assert(a==5 and b==4 and c==3 and d==2 and e==1) a,b,c,d,e = f(4) assert(a==nil and b==nil and c==nil and d==nil and e==nil) +-- select tests +a = {select(3, unpack{10,20,30,40})} +assert(table.getn(a) == 2 and a[1] == 30 and a[2] == 40) +a = {select(1)} +assert(next(a) == nil) +a = {select(-1, 3, 5, 7)} +assert(a[1] == 7 and a[2] == nil) +a = {select(-2, 3, 5, 7)} +assert(a[1] == 5 and a[2] == 7 and a[3] == nil) +pcall(select, 10000) +pcall(select, -10000) + +-- select(_, ...) has special optimizations so it needs extra testing +function selectone(n, ...) + local e = select(n, ...) + return e +end + +function selectmany(n, ...) + return table.concat({select(n, ...)}, ',') +end + +assert(selectone('#') == 0) +assert(selectmany('#') == "0") + +assert(selectone('#', 10, 20, 30) == 3) +assert(selectmany('#', 10, 20, 30) == "3") + +assert(selectone(1, 10, 20, 30) == 10) +assert(selectmany(1, 10, 20, 30) == "10,20,30") + +assert(selectone(2, 10, 20, 30) == 20) +assert(selectmany(2, 10, 20, 30) == "20,30") + +assert(selectone(3, 10, 20, 30) == 30) +assert(selectmany(3, 10, 20, 30) == "30") + +assert(selectone(4, 10, 20, 30) == nil) +assert(selectmany(4, 10, 20, 30) == "") + +assert(selectone(-2, 10, 20, 30) == 20) +assert(selectmany(-2, 10, 20, 30) == "20,30") + +assert(selectone('3', 10, 20, 30) == 30) +assert(selectmany('3', 10, 20, 30) == "30") -- varargs for main chunks f = loadstring[[ return {...} ]] @@ -122,16 +167,5 @@ f = loadstring[[ assert(f("a", "b", nil, {}, assert)) assert(f()) -a = {select(3, unpack{10,20,30,40})} -assert(table.getn(a) == 2 and a[1] == 30 and a[2] == 40) -a = {select(1)} -assert(next(a) == nil) -a = {select(-1, 3, 5, 7)} -assert(a[1] == 7 and a[2] == nil) -a = {select(-2, 3, 5, 7)} -assert(a[1] == 5 and a[2] == 7 and a[3] == nil) -pcall(select, 10000) -pcall(select, -10000) - return('OK') diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 620f646a..6164e929 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -1,6 +1,9 @@ -- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details print('testing vectors') +-- detect vector size +local vector_size = if pcall(function() return vector(0, 0, 0).w end) then 4 else 3 + -- equality assert(vector(1, 2, 3) == vector(1, 2, 3)) assert(vector(0, 1, 2) == vector(-0, 1, 2)) @@ -13,8 +16,14 @@ assert(not rawequal(vector(1, 2, 3), vector(1, 2, 4))) -- type & tostring assert(type(vector(1, 2, 3)) == "vector") -assert(tostring(vector(1, 2, 3)) == "1, 2, 3") -assert(tostring(vector(-1, 2, 0.5)) == "-1, 2, 0.5") + +if vector_size == 4 then + assert(tostring(vector(1, 2, 3, 4)) == "1, 2, 3, 4") + assert(tostring(vector(-1, 2, 0.5, 0)) == "-1, 2, 0.5, 0") +else + assert(tostring(vector(1, 2, 3)) == "1, 2, 3") + assert(tostring(vector(-1, 2, 0.5)) == "-1, 2, 0.5") +end local t = {} @@ -42,12 +51,19 @@ assert(8 * vector(8, 16, 24) == vector(64, 128, 192)); assert(vector(1, 2, 4) * '8' == vector(8, 16, 32)); assert('8' * vector(8, 16, 24) == vector(64, 128, 192)); -assert(vector(1, 2, 4) / vector(8, 16, 24) == vector(1/8, 2/16, 4/24)); +if vector_size == 4 then + assert(vector(1, 2, 4, 8) / vector(8, 16, 24, 32) == vector(1/8, 2/16, 4/24, 8/32)); + assert(8 / vector(8, 16, 24, 32) == vector(1, 1/2, 1/3, 1/4)); + assert('8' / vector(8, 16, 24, 32) == vector(1, 1/2, 1/3, 1/4)); +else + assert(vector(1, 2, 4) / vector(8, 16, 24, 1) == vector(1/8, 2/16, 4/24)); + assert(8 / vector(8, 16, 24) == vector(1, 1/2, 1/3)); + assert('8' / vector(8, 16, 24) == vector(1, 1/2, 1/3)); +end + assert(vector(1, 2, 4) / 8 == vector(1/8, 1/4, 1/2)); assert(vector(1, 2, 4) / (1 / val) == vector(1/8, 2/8, 4/8)); -assert(8 / vector(8, 16, 24) == vector(1, 1/2, 1/3)); assert(vector(1, 2, 4) / '8' == vector(1/8, 1/4, 1/2)); -assert('8' / vector(8, 16, 24) == vector(1, 1/2, 1/3)); assert(-vector(1, 2, 4) == vector(-1, -2, -4)); @@ -71,4 +87,34 @@ assert(pcall(function() local t = {} rawset(t, vector(0/0, 2, 3), 1) end) == fal -- make sure we cover both builtin and C impl assert(vector(1, 2, 4) == vector("1", "2", "4")) +-- validate component access (both cases) +assert(vector(1, 2, 3).x == 1) +assert(vector(1, 2, 3).X == 1) +assert(vector(1, 2, 3).y == 2) +assert(vector(1, 2, 3).Y == 2) +assert(vector(1, 2, 3).z == 3) +assert(vector(1, 2, 3).Z == 3) + +-- additional checks for 4-component vectors +if vector_size == 4 then + assert(vector(1, 2, 3, 4).w == 4) + assert(vector(1, 2, 3, 4).W == 4) +end + +-- negative zero should hash the same as zero +-- note: our earlier test only really checks the low hash bit, so in absence of perfect avalanche it's insufficient +do + local larget = {} + for i = 1, 2^14 do + larget[vector(0, 0, i)] = true + end + + larget[vector(0, 0, 0)] = 42 + + assert(larget[vector(0, 0, 0)] == 42) + assert(larget[vector(0, 0, -0)] == 42) + assert(larget[vector(0, -0, 0)] == 42) + assert(larget[vector(-0, 0, 0)] == 42) +end + return 'OK' diff --git a/tests/main.cpp b/tests/main.cpp index ed17070c..82ce4e16 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -2,12 +2,17 @@ #include "Luau/Common.h" #define DOCTEST_CONFIG_IMPLEMENT +// Our calls to parseOption/parseFlag don't provide a prefix so set the prefix to the empty string. +#define DOCTEST_CONFIG_OPTIONS_PREFIX "" #include "doctest.h" #ifdef _WIN32 #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN #endif +#ifndef NOMINMAX +#define NOMINMAX +#endif #include // IsDebuggerPresent #endif @@ -18,6 +23,19 @@ #include +// Indicates if verbose output is enabled; can be overridden via --verbose +// Currently, this enables output from 'print', but other verbose output could be enabled eventually. +bool verbose = false; + +// Default optimization level for conformance test; can be overridden via -On +int optimizationLevel = 1; + +// Run conformance tests with native code generation +bool codegen = false; + +// Something to seed a pseudorandom number generator with +std::optional randomSeed; + static bool skipFastFlag(const char* flagName) { if (strncmp(flagName, "Test", 4) == 0) @@ -46,12 +64,12 @@ static bool debuggerPresent() #endif } -static int assertionHandler(const char* expr, const char* file, int line) +static int testAssertionHandler(const char* expr, const char* file, int line, const char* function) { if (debuggerPresent()) LUAU_DEBUGBREAK(); - ADD_FAIL_AT(file, line, "Assertion failed: ", expr); + ADD_FAIL_AT(file, line, "Assertion failed: ", std::string(expr)); return 1; } @@ -212,7 +230,7 @@ static void setFastFlags(const std::vector& flags) int main(int argc, char** argv) { - Luau::assertHandler() = assertionHandler; + Luau::assertHandler() = testAssertionHandler; doctest::registerReporter("boost", 0, true); @@ -235,6 +253,35 @@ int main(int argc, char** argv) return 0; } + if (doctest::parseFlag(argc, argv, "--verbose")) + { + verbose = true; + } + + if (doctest::parseFlag(argc, argv, "--codegen")) + { + codegen = true; + } + + int level = -1; + if (doctest::parseIntOption(argc, argv, "-O", doctest::option_int, level)) + { + if (level < 0 || level > 2) + std::cerr << "Optimization level must be between 0 and 2 inclusive." << std::endl; + else + optimizationLevel = level; + } + + int rseed = -1; + if (doctest::parseIntOption(argc, argv, "--random-seed=", doctest::option_int, rseed)) + randomSeed = unsigned(rseed); + + if (doctest::parseOption(argc, argv, "--randomize") && !randomSeed) + { + randomSeed = unsigned(time(nullptr)); + printf("Using RNG seed %u\n", *randomSeed); + } + if (std::vector flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags)) setFastFlags(flags); @@ -261,7 +308,16 @@ int main(int argc, char** argv) } } - return context.run(); + int result = context.run(); + if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h")) + { + printf("Additional command line options:\n"); + printf(" -O[n] Changes default optimization level (1) for conformance runs\n"); + printf(" --verbose Enables verbose output (e.g. lua 'print' statements)\n"); + printf(" --fflags= Sets specified fast flags\n"); + printf(" --list-fflags List all fast flags\n"); + printf(" --randomize Use a random RNG seed\n"); + printf(" --random-seed=n Use a particular RNG seed\n"); + } + return result; } - - diff --git a/tools/faillist.txt b/tools/faillist.txt new file mode 100644 index 00000000..5d6779f4 --- /dev/null +++ b/tools/faillist.txt @@ -0,0 +1,495 @@ +AnnotationTests.corecursive_types_error_on_tight_loop +AnnotationTests.duplicate_type_param_name +AnnotationTests.for_loop_counter_annotation_is_checked +AnnotationTests.generic_aliases_are_cloned_properly +AnnotationTests.instantiation_clone_has_to_follow +AnnotationTests.luau_print_is_not_special_without_the_flag +AnnotationTests.occurs_check_on_cyclic_intersection_typevar +AnnotationTests.occurs_check_on_cyclic_union_typevar +AnnotationTests.too_many_type_params +AnnotationTests.two_type_params +AnnotationTests.unknown_type_reference_generates_error +AstQuery.last_argument_function_call_type +AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method +AstQuery::getDocumentationSymbolAtPosition.overloaded_fn +AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop +AutocompleteTest.autocomplete_first_function_arg_expected_type +AutocompleteTest.autocomplete_interpolated_string_as_singleton +AutocompleteTest.autocomplete_interpolated_string_constant +AutocompleteTest.autocomplete_interpolated_string_expression +AutocompleteTest.autocomplete_interpolated_string_expression_with_comments +AutocompleteTest.autocomplete_oop_implicit_self +AutocompleteTest.autocomplete_string_singleton_escape +AutocompleteTest.autocomplete_string_singletons +AutocompleteTest.autocompleteProp_index_function_metamethod_is_variadic +AutocompleteTest.do_compatible_self_calls +AutocompleteTest.do_wrong_compatible_self_calls +AutocompleteTest.keyword_methods +AutocompleteTest.no_incompatible_self_calls +AutocompleteTest.no_wrong_compatible_self_calls_with_generics +AutocompleteTest.string_singleton_as_table_key +AutocompleteTest.suggest_external_module_type +AutocompleteTest.suggest_table_keys +AutocompleteTest.type_correct_argument_type_suggestion +AutocompleteTest.type_correct_expected_argument_type_pack_suggestion +AutocompleteTest.type_correct_expected_argument_type_suggestion +AutocompleteTest.type_correct_expected_argument_type_suggestion_optional +AutocompleteTest.type_correct_expected_argument_type_suggestion_self +AutocompleteTest.type_correct_expected_return_type_pack_suggestion +AutocompleteTest.type_correct_expected_return_type_suggestion +AutocompleteTest.type_correct_function_no_parenthesis +AutocompleteTest.type_correct_function_return_types +AutocompleteTest.type_correct_keywords +AutocompleteTest.type_correct_suggestion_for_overloads +AutocompleteTest.type_correct_suggestion_in_argument +AutocompleteTest.type_correct_suggestion_in_table +BuiltinTests.aliased_string_format +BuiltinTests.assert_removes_falsy_types +BuiltinTests.assert_removes_falsy_types2 +BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type +BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy +BuiltinTests.bad_select_should_not_crash +BuiltinTests.coroutine_wrap_anything_goes +BuiltinTests.debug_info_is_crazy +BuiltinTests.debug_traceback_is_crazy +BuiltinTests.dont_add_definitions_to_persistent_types +BuiltinTests.find_capture_types +BuiltinTests.find_capture_types2 +BuiltinTests.find_capture_types3 +BuiltinTests.gmatch_capture_types_balanced_escaped_parens +BuiltinTests.gmatch_capture_types_default_capture +BuiltinTests.gmatch_capture_types_parens_in_sets_are_ignored +BuiltinTests.gmatch_definition +BuiltinTests.ipairs_iterator_should_infer_types_and_type_check +BuiltinTests.match_capture_types +BuiltinTests.match_capture_types2 +BuiltinTests.math_max_checks_for_numbers +BuiltinTests.next_iterator_should_infer_types_and_type_check +BuiltinTests.pairs_iterator_should_infer_types_and_type_check +BuiltinTests.see_thru_select +BuiltinTests.select_slightly_out_of_range +BuiltinTests.select_way_out_of_range +BuiltinTests.select_with_decimal_argument_is_rounded_down +BuiltinTests.set_metatable_needs_arguments +BuiltinTests.setmetatable_should_not_mutate_persisted_types +BuiltinTests.sort_with_bad_predicate +BuiltinTests.string_format_as_method +BuiltinTests.string_format_correctly_ordered_types +BuiltinTests.string_format_report_all_type_errors_at_correct_positions +BuiltinTests.string_format_use_correct_argument2 +BuiltinTests.table_freeze_is_generic +BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload +BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload +BuiltinTests.table_pack +BuiltinTests.table_pack_reduce +BuiltinTests.table_pack_variadic +BuiltinTests.tonumber_returns_optional_number_type +DefinitionTests.class_definition_overload_metamethods +DefinitionTests.class_definition_string_props +DefinitionTests.declaring_generic_functions +DefinitionTests.definition_file_classes +FrontendTest.environments +FrontendTest.it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded +FrontendTest.nocheck_cycle_used_by_checked +FrontendTest.reexport_cyclic_type +FrontendTest.trace_requires_in_nonstrict_mode +GenericsTests.apply_type_function_nested_generics1 +GenericsTests.apply_type_function_nested_generics2 +GenericsTests.better_mismatch_error_messages +GenericsTests.check_generic_typepack_function +GenericsTests.check_mutual_generic_functions +GenericsTests.correctly_instantiate_polymorphic_member_functions +GenericsTests.do_not_infer_generic_functions +GenericsTests.duplicate_generic_type_packs +GenericsTests.duplicate_generic_types +GenericsTests.generic_argument_count_too_few +GenericsTests.generic_argument_count_too_many +GenericsTests.generic_factories +GenericsTests.generic_functions_should_be_memory_safe +GenericsTests.generic_table_method +GenericsTests.generic_type_pack_parentheses +GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments +GenericsTests.infer_generic_function_function_argument +GenericsTests.infer_generic_function_function_argument_overloaded +GenericsTests.infer_generic_lib_function_function_argument +GenericsTests.infer_generic_property +GenericsTests.instantiated_function_argument_names +GenericsTests.instantiation_sharing_types +GenericsTests.no_stack_overflow_from_quantifying +GenericsTests.reject_clashing_generic_and_pack_names +GenericsTests.self_recursive_instantiated_param +IntersectionTypes.no_stack_overflow_from_flattenintersection +IntersectionTypes.select_correct_union_fn +IntersectionTypes.should_still_pick_an_overload_whose_arguments_are_unions +IntersectionTypes.table_intersection_write_sealed +IntersectionTypes.table_intersection_write_sealed_indirect +IntersectionTypes.table_write_sealed_indirect +ModuleTests.any_persistance_does_not_leak +ModuleTests.clone_self_property +ModuleTests.deepClone_cyclic_table +NonstrictModeTests.for_in_iterator_variables_are_any +NonstrictModeTests.function_parameters_are_any +NonstrictModeTests.inconsistent_module_return_types_are_ok +NonstrictModeTests.inconsistent_return_types_are_ok +NonstrictModeTests.infer_nullary_function +NonstrictModeTests.infer_the_maximum_number_of_values_the_function_could_return +NonstrictModeTests.inline_table_props_are_also_any +NonstrictModeTests.local_tables_are_not_any +NonstrictModeTests.locals_are_any_by_default +NonstrictModeTests.offer_a_hint_if_you_use_a_dot_instead_of_a_colon +NonstrictModeTests.parameters_having_type_any_are_optional +NonstrictModeTests.table_dot_insert_and_recursive_calls +NonstrictModeTests.table_props_are_any +Normalize.cyclic_table_normalizes_sensibly +ParseErrorRecovery.generic_type_list_recovery +ParseErrorRecovery.recovery_of_parenthesized_expressions +ParserTests.parse_nesting_based_end_detection_failsafe_earlier +ParserTests.parse_nesting_based_end_detection_local_function +ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal +ProvisionalTests.bail_early_if_unification_is_too_complicated +ProvisionalTests.discriminate_from_x_not_equal_to_nil +ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack +ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean +ProvisionalTests.free_options_cannot_be_unified_together +ProvisionalTests.generic_type_leak_to_module_interface_variadic +ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns +ProvisionalTests.lvalue_equals_another_lvalue_with_no_overlap +ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing +ProvisionalTests.setmetatable_constrains_free_type_into_free_table +ProvisionalTests.specialization_binds_with_prototypes_too_early +ProvisionalTests.table_insert_with_a_singleton_argument +ProvisionalTests.typeguard_inference_incomplete +ProvisionalTests.weirditer_should_not_loop_forever +ProvisionalTests.while_body_are_also_refined +RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string +RefinementTest.call_an_incompatible_function_after_using_typeguard +RefinementTest.correctly_lookup_property_whose_base_was_previously_refined2 +RefinementTest.discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false +RefinementTest.discriminate_tag +RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union +RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil +RefinementTest.narrow_property_of_a_bounded_variable +RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true +RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table +RefinementTest.refine_unknowns +RefinementTest.type_guard_can_filter_for_intersection_of_tables +RefinementTest.type_guard_narrowed_into_nothingness +RefinementTest.type_narrow_for_all_the_userdata +RefinementTest.type_narrow_to_vector +RefinementTest.typeguard_cast_free_table_to_vector +RefinementTest.typeguard_in_assert_position +RefinementTest.typeguard_narrows_for_table +RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table +RefinementTest.x_is_not_instance_or_else_not_part +RuntimeLimits.typescript_port_of_Result_type +TableTests.a_free_shape_can_turn_into_a_scalar_directly +TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible +TableTests.access_index_metamethod_that_returns_variadic +TableTests.accidentally_checked_prop_in_opposite_branch +TableTests.builtin_table_names +TableTests.call_method +TableTests.call_method_with_explicit_self_argument +TableTests.cannot_augment_sealed_table +TableTests.casting_sealed_tables_with_props_into_table_with_indexer +TableTests.casting_tables_with_props_into_table_with_indexer3 +TableTests.casting_tables_with_props_into_table_with_indexer4 +TableTests.checked_prop_too_early +TableTests.defining_a_method_for_a_local_unsealed_table_is_ok +TableTests.defining_a_self_method_for_a_local_unsealed_table_is_ok +TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar +TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index +TableTests.dont_quantify_table_that_belongs_to_outer_scope +TableTests.dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table +TableTests.dont_suggest_exact_match_keys +TableTests.error_detailed_metatable_prop +TableTests.expected_indexer_from_table_union +TableTests.expected_indexer_value_type_extra +TableTests.expected_indexer_value_type_extra_2 +TableTests.explicitly_typed_table +TableTests.explicitly_typed_table_error +TableTests.explicitly_typed_table_with_indexer +TableTests.found_like_key_in_table_function_call +TableTests.found_like_key_in_table_property_access +TableTests.found_multiple_like_keys +TableTests.function_calls_produces_sealed_table_given_unsealed_table +TableTests.generic_table_instantiation_potential_regression +TableTests.getmetatable_returns_pointer_to_metatable +TableTests.give_up_after_one_metatable_index_look_up +TableTests.hide_table_error_properties +TableTests.indexer_on_sealed_table_must_unify_with_free_table +TableTests.indexing_from_a_table_should_prefer_properties_when_possible +TableTests.inequality_operators_imply_exactly_matching_types +TableTests.infer_array_2 +TableTests.infer_indexer_from_value_property_in_literal +TableTests.inferred_return_type_of_free_table +TableTests.inferring_crazy_table_should_also_be_quick +TableTests.instantiate_table_cloning_3 +TableTests.instantiate_tables_at_scope_level +TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound +TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound +TableTests.leaking_bad_metatable_errors +TableTests.less_exponential_blowup_please +TableTests.meta_add_inferred +TableTests.metatable_mismatch_should_fail +TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred +TableTests.mixed_tables_with_implicit_numbered_keys +TableTests.nil_assign_doesnt_hit_indexer +TableTests.nil_assign_doesnt_hit_no_indexer +TableTests.okay_to_add_property_to_unsealed_tables_by_function_call +TableTests.only_ascribe_synthetic_names_at_module_scope +TableTests.oop_indexer_works +TableTests.oop_polymorphic +TableTests.open_table_unification_2 +TableTests.persistent_sealed_table_is_immutable +TableTests.property_lookup_through_tabletypevar_metatable +TableTests.quantify_even_that_table_was_never_exported_at_all +TableTests.quantify_metatables_of_metatables_of_table +TableTests.quantifying_a_bound_var_works +TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table +TableTests.result_is_always_any_if_lhs_is_any +TableTests.result_is_bool_for_equality_operators_if_lhs_is_any +TableTests.right_table_missing_key2 +TableTests.scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type +TableTests.shared_selfs +TableTests.shared_selfs_from_free_param +TableTests.shared_selfs_through_metatables +TableTests.table_call_metamethod_basic +TableTests.table_function_check_use_after_free +TableTests.table_indexing_error_location +TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict +TableTests.table_insert_should_cope_with_optional_properties_in_strict +TableTests.table_param_row_polymorphism_2 +TableTests.table_param_row_polymorphism_3 +TableTests.table_simple_call +TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors +TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors +TableTests.tables_get_names_from_their_locals +TableTests.tc_member_function +TableTests.tc_member_function_2 +TableTests.unification_of_unions_in_a_self_referential_type +TableTests.unifying_tables_shouldnt_uaf1 +TableTests.unifying_tables_shouldnt_uaf2 +TableTests.used_colon_correctly +TableTests.used_colon_instead_of_dot +TableTests.used_dot_instead_of_colon +TableTests.used_dot_instead_of_colon_but_correctly +ToDot.bound_table +ToDot.function +ToDot.table +ToString.exhaustive_toString_of_cyclic_table +ToString.function_type_with_argument_names_and_self +ToString.function_type_with_argument_names_generic +ToString.toStringDetailed2 +ToString.toStringErrorPack +ToString.toStringNamedFunction_generic_pack +ToString.toStringNamedFunction_hide_self_param +ToString.toStringNamedFunction_hide_type_params +ToString.toStringNamedFunction_id +ToString.toStringNamedFunction_include_self_param +ToString.toStringNamedFunction_map +ToString.toStringNamedFunction_variadics +TryUnifyTests.cli_41095_concat_log_in_sealed_table_unification +TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType +TryUnifyTests.result_of_failed_typepack_unification_is_constrained +TryUnifyTests.typepack_unification_should_trim_free_tails +TryUnifyTests.variadics_should_use_reversed_properly +TypeAliases.cannot_create_cyclic_type_with_unknown_module +TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any +TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2 +TypeAliases.generic_param_remap +TypeAliases.mismatched_generic_type_param +TypeAliases.mutually_recursive_types_restriction_not_ok_1 +TypeAliases.mutually_recursive_types_restriction_not_ok_2 +TypeAliases.mutually_recursive_types_swapsies_not_ok +TypeAliases.recursive_types_restriction_not_ok +TypeAliases.report_shadowed_aliases +TypeAliases.stringify_optional_parameterized_alias +TypeAliases.stringify_type_alias_of_recursive_template_table_type +TypeAliases.stringify_type_alias_of_recursive_template_table_type2 +TypeAliases.type_alias_fwd_declaration_is_precise +TypeAliases.type_alias_local_mutation +TypeAliases.type_alias_local_rename +TypeAliases.type_alias_of_an_imported_recursive_generic_type +TypeInfer.check_type_infer_recursion_count +TypeInfer.checking_should_not_ice +TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error +TypeInfer.dont_report_type_errors_within_an_AstExprError +TypeInfer.dont_report_type_errors_within_an_AstStatError +TypeInfer.follow_on_new_types_in_substitution +TypeInfer.fuzz_free_table_type_change_during_index_check +TypeInfer.globals +TypeInfer.globals2 +TypeInfer.infer_assignment_value_types_mutable_lval +TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict +TypeInfer.no_stack_overflow_from_isoptional +TypeInfer.tc_after_error_recovery_no_replacement_name_in_error +TypeInfer.tc_if_else_expressions_expected_type_3 +TypeInfer.tc_interpolated_string_basic +TypeInfer.tc_interpolated_string_with_invalid_expression +TypeInfer.type_infer_recursion_limit_no_ice +TypeInfer.type_infer_recursion_limit_normalizer +TypeInferAnyError.for_in_loop_iterator_is_any2 +TypeInferAnyError.metatable_of_any_can_be_a_table +TypeInferClasses.can_read_prop_of_base_class_using_string +TypeInferClasses.class_type_mismatch_with_name_conflict +TypeInferClasses.classes_without_overloaded_operators_cannot_be_added +TypeInferClasses.detailed_class_unification_error +TypeInferClasses.higher_order_function_arguments_are_contravariant +TypeInferClasses.index_instance_property +TypeInferClasses.optional_class_field_access_error +TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties +TypeInferClasses.warn_when_prop_almost_matches +TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class +TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types +TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument +TypeInferFunctions.cannot_hoist_interior_defns_into_signature +TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists +TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site +TypeInferFunctions.dont_mutate_the_underlying_head_of_typepack_when_calling_with_self +TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict +TypeInferFunctions.first_argument_can_be_optional +TypeInferFunctions.function_cast_error_uses_correct_language +TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 +TypeInferFunctions.function_decl_non_self_unsealed_overwrite +TypeInferFunctions.function_does_not_return_enough_values +TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer +TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict +TypeInferFunctions.improved_function_arg_mismatch_errors +TypeInferFunctions.infer_anonymous_function_arguments +TypeInferFunctions.infer_return_type_from_selected_overload +TypeInferFunctions.infer_that_function_does_not_return_a_table +TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count +TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count +TypeInferFunctions.luau_subtyping_is_np_hard +TypeInferFunctions.no_lossy_function_type +TypeInferFunctions.occurs_check_failure_in_function_return_type +TypeInferFunctions.record_matching_overload +TypeInferFunctions.report_exiting_without_return_nonstrict +TypeInferFunctions.report_exiting_without_return_strict +TypeInferFunctions.return_type_by_overload +TypeInferFunctions.too_few_arguments_variadic +TypeInferFunctions.too_few_arguments_variadic_generic +TypeInferFunctions.too_few_arguments_variadic_generic2 +TypeInferFunctions.too_many_arguments +TypeInferFunctions.too_many_arguments_error_location +TypeInferFunctions.too_many_return_values +TypeInferFunctions.too_many_return_values_in_parentheses +TypeInferFunctions.too_many_return_values_no_function +TypeInferFunctions.vararg_function_is_quantified +TypeInferLoops.for_in_loop +TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values +TypeInferLoops.for_in_loop_with_next +TypeInferLoops.for_in_with_generic_next +TypeInferLoops.loop_iter_metamethod_ok_with_inference +TypeInferLoops.loop_iter_no_indexer_nonstrict +TypeInferLoops.loop_iter_trailing_nil +TypeInferLoops.properly_infer_iteratee_is_a_free_table +TypeInferLoops.unreachable_code_after_infinite_loop +TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free +TypeInferModules.bound_free_table_export_is_ok +TypeInferModules.custom_require_global +TypeInferModules.do_not_modify_imported_types_4 +TypeInferModules.do_not_modify_imported_types_5 +TypeInferModules.module_type_conflict +TypeInferModules.module_type_conflict_instantiated +TypeInferModules.require_a_variadic_function +TypeInferModules.type_error_of_unknown_qualified_type +TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works +TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2 +TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon +TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory +TypeInferOOP.method_depends_on_table +TypeInferOOP.methods_are_topologically_sorted +TypeInferOOP.nonstrict_self_mismatch_tail +TypeInferOOP.object_constructor_can_refer_to_method_of_self +TypeInferOOP.table_oop +TypeInferOperators.CallAndOrOfFunctions +TypeInferOperators.CallOrOfFunctions +TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable +TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable +TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators +TypeInferOperators.cli_38355_recursive_union +TypeInferOperators.compound_assign_metatable +TypeInferOperators.compound_assign_mismatch_metatable +TypeInferOperators.compound_assign_mismatch_op +TypeInferOperators.compound_assign_mismatch_result +TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops +TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators +TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown +TypeInferOperators.operator_eq_completely_incompatible +TypeInferOperators.or_joins_types_with_no_superfluous_union +TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not +TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection +TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs +TypeInferOperators.UnknownGlobalCompoundAssign +TypeInferOperators.unrelated_classes_cannot_be_compared +TypeInferOperators.unrelated_primitives_cannot_be_compared +TypeInferPrimitives.CheckMethodsOfNumber +TypeInferPrimitives.string_index +TypeInferUnknownNever.assign_to_global_which_is_never +TypeInferUnknownNever.assign_to_local_which_is_never +TypeInferUnknownNever.assign_to_prop_which_is_never +TypeInferUnknownNever.assign_to_subscript_which_is_never +TypeInferUnknownNever.call_never +TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators +TypeInferUnknownNever.index_on_union_of_tables_for_properties_that_is_sorta_never +TypeInferUnknownNever.math_operators_and_never +TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable +TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 +TypeInferUnknownNever.unary_minus_of_never +TypePackTests.detect_cyclic_typepacks2 +TypePackTests.pack_tail_unification_check +TypePackTests.self_and_varargs_should_work +TypePackTests.type_alias_backwards_compatible +TypePackTests.type_alias_default_export +TypePackTests.type_alias_default_mixed_self +TypePackTests.type_alias_default_type_chained +TypePackTests.type_alias_default_type_errors +TypePackTests.type_alias_default_type_pack_self_chained_tp +TypePackTests.type_alias_default_type_pack_self_tp +TypePackTests.type_alias_default_type_self +TypePackTests.type_alias_defaults_confusing_types +TypePackTests.type_alias_defaults_recursive_type +TypePackTests.type_alias_type_pack_multi +TypePackTests.type_alias_type_pack_variadic +TypePackTests.type_alias_type_packs +TypePackTests.type_alias_type_packs_errors +TypePackTests.type_alias_type_packs_import +TypePackTests.type_alias_type_packs_nested +TypePackTests.type_pack_type_parameters +TypePackTests.unify_variadic_tails_in_arguments +TypePackTests.unify_variadic_tails_in_arguments_free +TypePackTests.variadic_packs +TypeSingletons.error_detailed_tagged_union_mismatch_bool +TypeSingletons.error_detailed_tagged_union_mismatch_string +TypeSingletons.function_call_with_singletons +TypeSingletons.function_call_with_singletons_mismatch +TypeSingletons.indexing_on_string_singletons +TypeSingletons.indexing_on_union_of_string_singletons +TypeSingletons.overloaded_function_call_with_singletons +TypeSingletons.overloaded_function_call_with_singletons_mismatch +TypeSingletons.return_type_of_f_is_not_widened +TypeSingletons.table_properties_singleton_strings_mismatch +TypeSingletons.table_properties_type_error_escapes +TypeSingletons.taking_the_length_of_string_singleton +TypeSingletons.taking_the_length_of_union_of_string_singleton +TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton +TypeSingletons.widening_happens_almost_everywhere +TypeSingletons.widening_happens_almost_everywhere_except_for_tables +UnionTypes.error_detailed_optional +UnionTypes.error_detailed_union_all +UnionTypes.index_on_a_union_type_with_missing_property +UnionTypes.index_on_a_union_type_with_one_optional_property +UnionTypes.index_on_a_union_type_with_one_property_of_type_any +UnionTypes.optional_assignment_errors +UnionTypes.optional_call_error +UnionTypes.optional_field_access_error +UnionTypes.optional_index_error +UnionTypes.optional_iteration +UnionTypes.optional_length_error +UnionTypes.optional_missing_key_error_details +UnionTypes.optional_union_follow +UnionTypes.optional_union_functions +UnionTypes.optional_union_members +UnionTypes.optional_union_methods +UnionTypes.table_union_write_indirect diff --git a/tools/gdb-printers.py b/tools/gdb_printers.py similarity index 77% rename from tools/gdb-printers.py rename to tools/gdb_printers.py index c711c5e2..017b9f95 100644 --- a/tools/gdb-printers.py +++ b/tools/gdb_printers.py @@ -11,9 +11,9 @@ class VariantPrinter: return type.name + " [" + str(value) + "]" def match_printer(val): - type = val.type.strip_typedefs() - if type.name and type.name.startswith('Luau::Variant<'): - return VariantPrinter(val) - return None + type = val.type.strip_typedefs() + if type.name and type.name.startswith('Luau::Variant<'): + return VariantPrinter(val) + return None gdb.pretty_printers.append(match_printer) diff --git a/tools/heapgraph.py b/tools/heapgraph.py index 106db549..b8dc207f 100644 --- a/tools/heapgraph.py +++ b/tools/heapgraph.py @@ -7,10 +7,20 @@ # The result of analysis is a .svg file which can be viewed in a browser # To generate these dumps, use luaC_dump, ideally preceded by luaC_fullgc +import argparse import json import sys import svg +argumentParser = argparse.ArgumentParser(description='Luau heap snapshot analyzer') + +argumentParser.add_argument('--split', dest = 'split', type = str, default = 'none', help = 'Perform additional root split using memory categories', choices = ['none', 'custom', 'all']) + +argumentParser.add_argument('snapshot') +argumentParser.add_argument('snapshotnew', nargs='?') + +arguments = argumentParser.parse_args() + class Node(svg.Node): def __init__(self): svg.Node.__init__(self) @@ -29,17 +39,29 @@ class Node(svg.Node): def details(self, root): return "{} ({:,} bytes, {:.1%}); self: {:,} bytes in {:,} objects".format(self.name, self.width, self.width / root.width, self.size, self.count) +def getkey(heap, obj, key): + pairs = obj.get("pairs", []) + for i in range(0, len(pairs), 2): + if pairs[i] and heap[pairs[i]]["type"] == "string" and heap[pairs[i]]["data"] == key: + if pairs[i + 1] and heap[pairs[i + 1]]["type"] == "string": + return heap[pairs[i + 1]]["data"] + else: + return None + return None + # load files -if len(sys.argv) == 2: +if arguments.snapshotnew == None: dumpold = None - with open(sys.argv[1]) as f: + with open(arguments.snapshot) as f: dump = json.load(f) else: - with open(sys.argv[1]) as f: + with open(arguments.snapshot) as f: dumpold = json.load(f) - with open(sys.argv[2]) as f: + with open(arguments.snapshotnew) as f: dump = json.load(f) +heap = dump["objects"] + # reachability analysis: how much of the heap is reachable from roots? visited = set() queue = [] @@ -56,7 +78,7 @@ while offset < len(queue): continue visited.add(addr) - obj = dump["objects"][addr] + obj = heap[addr] if not dumpold or not addr in dumpold["objects"]: node.count += 1 @@ -65,17 +87,27 @@ while offset < len(queue): if obj["type"] == "table": pairs = obj.get("pairs", []) + weakkey = False + weakval = False + + if "metatable" in obj: + modemt = getkey(heap, heap[obj["metatable"]], "__mode") + if modemt: + weakkey = "k" in modemt + weakval = "v" in modemt for i in range(0, len(pairs), 2): key = pairs[i+0] val = pairs[i+1] - if key and val and dump["objects"][key]["type"] == "string": + if key and heap[key]["type"] == "string": + # string keys are always strong queue.append((key, node)) - queue.append((val, node.child(dump["objects"][key]["data"]))) + if val and not weakval: + queue.append((val, node.child(heap[key]["data"]))) else: - if key: + if key and not weakkey: queue.append((key, node)) - if val: + if val and not weakval: queue.append((val, node)) for a in obj.get("array", []): @@ -87,7 +119,7 @@ while offset < len(queue): source = "" if "proto" in obj: - proto = dump["objects"][obj["proto"]] + proto = heap[obj["proto"]] if "source" in proto: source = proto["source"] @@ -100,8 +132,16 @@ while offset < len(queue): queue.append((obj["metatable"], node.child("__meta"))) elif obj["type"] == "thread": queue.append((obj["env"], node.child("__env"))) - for a in obj.get("stack", []): - queue.append((a, node.child("__stack"))) + stack = obj.get("stack") + stacknames = obj.get("stacknames", []) + stacknode = node.child("__stack") + framenode = None + for i in range(len(stack)): + name = stacknames[i] if stacknames else None + if name and name.startswith("frame:"): + framenode = stacknode.child(name[6:]) + name = None + queue.append((stack[i], framenode.child(name) if framenode and name else framenode or stacknode)) elif obj["type"] == "proto": for a in obj.get("constants", []): queue.append((a, node)) @@ -111,12 +151,15 @@ while offset < len(queue): if "object" in obj: queue.append((obj["object"], node)) -def annotateContainedCategories(node): +def annotateContainedCategories(node, start): for obj in node.objects: + if obj["cat"] < start: + obj["cat"] = 0 + node.categories.add(obj["cat"]) for child in node.children.values(): - annotateContainedCategories(child) + annotateContainedCategories(child, start) for cat in child.categories: node.categories.add(cat) @@ -172,9 +215,11 @@ def splitIntoCategories(root): return result -# temporarily disabled because it makes FG harder to read, maybe this should be a separate command line option? -if dump["stats"].get("categories") and False: - annotateContainedCategories(root) +if dump["stats"].get("categories") and arguments.split != 'none': + if arguments.split == 'custom': + annotateContainedCategories(root, 128) + else: + annotateContainedCategories(root, 0) root = splitIntoCategories(root) diff --git a/tools/heapstat.py b/tools/heapstat.py index 838521b6..7337aa44 100644 --- a/tools/heapstat.py +++ b/tools/heapstat.py @@ -15,12 +15,20 @@ def updatesize(d, k, s): def sortedsize(p): return sorted(p, key = lambda s: s[1][1], reverse = True) +def getkey(heap, obj, key): + pairs = obj.get("pairs", []) + for i in range(0, len(pairs), 2): + if pairs[i] and heap[pairs[i]]["type"] == "string" and heap[pairs[i]]["data"] == key: + if pairs[i + 1] and heap[pairs[i + 1]]["type"] == "string": + return heap[pairs[i + 1]]["data"] + else: + return None + return None + with open(sys.argv[1]) as f: dump = json.load(f) heap = dump["objects"] -type_addr = next((addr for addr,obj in heap.items() if obj["type"] == "string" and obj["data"] == "__type"), None) - size_type = {} size_udata = {} size_category = {} @@ -33,11 +41,7 @@ for addr, obj in heap.items(): if obj["type"] == "userdata" and "metatable" in obj: metatable = heap[obj["metatable"]] - pairs = metatable.get("pairs", []) - typemt = "unknown" - for i in range(0, len(pairs), 2): - if type_addr and pairs[i] == type_addr and pairs[i + 1] and heap[pairs[i + 1]]["type"] == "string": - typemt = heap[pairs[i + 1]]["data"] + typemt = getkey(heap, metatable, "__type") or "unknown" updatesize(size_udata, typemt, obj["size"]) print("objects by type:") @@ -55,5 +59,6 @@ if len(size_category) != 0: print("objects by category:") for type, (count, size) in sortedsize(size_category.items()): - name = dump["stats"]["categories"][type]["name"] + cat = dump["stats"]["categories"][type] + name = cat["name"] if "name" in cat else str(type) print(name.ljust(30), str(size).rjust(8), "bytes", str(count).rjust(5), "objects") diff --git a/tools/lldb_formatters.lldb b/tools/lldb_formatters.lldb new file mode 100644 index 00000000..f10faa94 --- /dev/null +++ b/tools/lldb_formatters.lldb @@ -0,0 +1,10 @@ +type synthetic add -x "^Luau::detail::DenseHashTable<.*>$" -l lldb_formatters.DenseHashTableSyntheticChildrenProvider +type summary add "Luau::Symbol" -F lldb_formatters.luau_symbol_summary + +type synthetic add -x "^Luau::Variant<.+>$" -l lldb_formatters.LuauVariantSyntheticChildrenProvider +type summary add -x "^Luau::Variant<.+>$" -F lldb_formatters.luau_variant_summary + +type synthetic add -x "^Luau::AstArray<.+>$" -l lldb_formatters.AstArraySyntheticChildrenProvider + +type summary add --summary-string "${var.line}:${var.column}" Luau::Position +type summary add --summary-string "${var.begin}-${var.end}" Luau::Location diff --git a/tools/lldb_formatters.py b/tools/lldb_formatters.py new file mode 100644 index 00000000..19fc0f54 --- /dev/null +++ b/tools/lldb_formatters.py @@ -0,0 +1,215 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# HACK: LLDB's python API doesn't afford anything helpful for getting at variadic template parameters. +# We're forced to resort to parsing names as strings. +def templateParams(s): + depth = 0 + start = s.find("<") + 1 + result = [] + for i, c in enumerate(s[start:], start): + if c == "<": + depth += 1 + elif c == ">": + if depth == 0: + result.append(s[start:i].strip()) + break + depth -= 1 + elif c == "," and depth == 0: + result.append(s[start:i].strip()) + start = i + 1 + return result + + +def getType(target, typeName): + stars = 0 + + typeName = typeName.strip() + while typeName.endswith("*"): + stars += 1 + typeName = typeName[:-1] + + if typeName.startswith("const "): + typeName = typeName[6:] + + ty = target.FindFirstType(typeName.strip()) + for _ in range(stars): + ty = ty.GetPointerType() + + return ty + + +def luau_variant_summary(valobj, internal_dict, options): + return valobj.GetChildMemberWithName("type").GetSummary()[1:-1] + + +class LuauVariantSyntheticChildrenProvider: + node_names = ["type", "value"] + + def __init__(self, valobj, internal_dict): + self.valobj = valobj + self.type_index = None + self.current_type = None + self.type_params = [] + self.stored_value = None + + def num_children(self): + return len(self.node_names) + + def has_children(self): + return True + + def get_child_index(self, name): + try: + return self.node_names.index(name) + except ValueError: + return -1 + + def get_child_at_index(self, index): + try: + node = self.node_names[index] + except IndexError: + return None + + if node == "type": + if self.current_type: + return self.valobj.CreateValueFromExpression( + node, f'(const char*)"{self.current_type.GetDisplayTypeName()}"' + ) + else: + return self.valobj.CreateValueFromExpression( + node, '(const char*)""' + ) + elif node == "value": + if self.stored_value is not None: + if self.current_type is not None: + return self.valobj.CreateValueFromData( + node, self.stored_value.GetData(), self.current_type + ) + else: + return self.valobj.CreateValueExpression( + node, '(const char*)""' + ) + else: + return self.valobj.CreateValueFromExpression( + node, '(const char*)""' + ) + else: + return None + + def update(self): + self.type_index = self.valobj.GetChildMemberWithName( + "typeId" + ).GetValueAsSigned() + self.type_params = templateParams( + self.valobj.GetType().GetCanonicalType().GetName() + ) + + if len(self.type_params) > self.type_index: + self.current_type = getType( + self.valobj.GetTarget(), self.type_params[self.type_index] + ) + + if self.current_type: + storage = self.valobj.GetChildMemberWithName("storage") + self.stored_value = storage.Cast(self.current_type) + else: + self.stored_value = None + else: + self.current_type = None + self.stored_value = None + + return False + + +class DenseHashTableSyntheticChildrenProvider: + def __init__(self, valobj, internal_dict): + """this call should initialize the Python object using valobj as the variable to provide synthetic children for""" + self.valobj = valobj + self.update() + + def num_children(self): + """this call should return the number of children that you want your object to have""" + return self.capacity + + def get_child_index(self, name): + """this call should return the index of the synthetic child whose name is given as argument""" + try: + if name.startswith("[") and name.endswith("]"): + return int(name[1:-1]) + else: + return -1 + except Exception as e: + print("get_child_index exception", e) + return -1 + + def get_child_at_index(self, index): + """this call should return a new LLDB SBValue object representing the child at the index given as argument""" + try: + dataMember = self.valobj.GetChildMemberWithName("data") + + data = dataMember.GetPointeeData(index) + + return self.valobj.CreateValueFromData( + f"[{index}]", + data, + dataMember.Dereference().GetType(), + ) + + except Exception as e: + print("get_child_at_index error", e) + + def update(self): + """this call should be used to update the internal state of this Python object whenever the state of the variables in LLDB changes.[1] + Also, this method is invoked before any other method in the interface.""" + self.capacity = self.valobj.GetChildMemberWithName( + "capacity" + ).GetValueAsUnsigned() + + def has_children(self): + """this call should return True if this object might have children, and False if this object can be guaranteed not to have children.[2]""" + return True + + +def luau_symbol_summary(valobj, internal_dict, options): + local = valobj.GetChildMemberWithName("local") + global_ = valobj.GetChildMemberWithName("global").GetChildMemberWithName("value") + + if local.GetValueAsUnsigned() != 0: + return f'local {local.GetChildMemberWithName("name").GetChildMemberWithName("value").GetSummary()}' + elif global_.GetValueAsUnsigned() != 0: + return f"global {global_.GetSummary()}" + else: + return "???" + + +class AstArraySyntheticChildrenProvider: + def __init__(self, valobj, internal_dict): + self.valobj = valobj + + def num_children(self): + return self.size + + def get_child_index(self, name): + try: + if name.startswith("[") and name.endswith("]"): + return int(name[1:-1]) + else: + return -1 + except Exception as e: + print("get_child_index error:", e) + + def get_child_at_index(self, index): + try: + dataMember = self.valobj.GetChildMemberWithName("data") + data = dataMember.GetPointeeData(index) + return self.valobj.CreateValueFromData( + f"[{index}]", data, dataMember.Dereference().GetType() + ) + except Exception as e: + print("get_child_index error:", e) + + def update(self): + self.size = self.valobj.GetChildMemberWithName("size").GetValueAsUnsigned() + + def has_children(self): + return True diff --git a/tools/lvmexecute_split.py b/tools/lvmexecute_split.py new file mode 100644 index 00000000..f4a78960 --- /dev/null +++ b/tools/lvmexecute_split.py @@ -0,0 +1,112 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# This code can be used to split lvmexecute.cpp VM switch into separate functions for use as native code generation fallbacks +import sys +import re + +input = sys.stdin.readlines() + +inst = "" + +header = """// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand +#pragma once + +#include + +struct lua_State; +struct Closure; +typedef uint32_t Instruction; +typedef struct lua_TValue TValue; +typedef TValue* StkId; + +""" + +source = """// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand +#include "Fallbacks.h" +#include "FallbacksProlog.h" + +""" + +function = "" +signature = "" + +includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_COVERAGE", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS"] + +state = 0 + +# parse with the state machine +for line in input: + # find the start of an instruction + if state == 0: + match = re.match("\s+VM_CASE\((LOP_[A-Z_0-9]+)\)", line) + + if match: + inst = match[1] + signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)" + function = signature + "\n" + function += "{\n" + function += " [[maybe_unused]] Closure* cl = clvalue(L->ci->func);\n" + state = 1 + + # first line of the instruction which is "{" + elif state == 1: + assert(line == " {\n") + state = 2 + + # find the end of an instruction + elif state == 2: + # remove jumps back into the native code + if line == "#if LUA_CUSTOM_EXECUTION\n": + state = 3 + continue + + if line[0] == ' ': + finalline = line[12:-1] + "\n" + else: + finalline = line + + finalline = finalline.replace("VM_NEXT();", "return pc;"); + finalline = finalline.replace("goto exit;", "return NULL;"); + finalline = finalline.replace("return;", "return NULL;"); + + function += finalline + match = re.match(" }", line) + + if match: + # break is not supported + if inst == "LOP_BREAK": + function = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)\n" + function += "{\n LUAU_ASSERT(!\"Unsupported deprecated opcode\");\n LUAU_UNREACHABLE();\n}\n" + # handle fallthrough + elif inst == "LOP_NAMECALL": + function = function[:-len(finalline)] + function += " return pc;\n}\n" + + if inst in includeInsts: + header += signature + ";\n" + source += function + "\n" + + state = 0 + + # skip LUA_CUSTOM_EXECUTION code blocks + elif state == 3: + if line == "#endif\n": + state = 4 + continue + + # skip extra line + elif state == 4: + state = 2 + +# make sure we found the ending +assert(state == 0) + +with open("Fallbacks.h", "w") as fp: + fp.writelines(header) + +with open("Fallbacks.cpp", "w") as fp: + fp.writelines(source) diff --git a/tools/natvis/Analysis.natvis b/tools/natvis/Analysis.natvis new file mode 100644 index 00000000..7d03dd3f --- /dev/null +++ b/tools/natvis/Analysis.natvis @@ -0,0 +1,142 @@ + + + + + AnyTypeVar + + + + {{ typeId=0, value={*($T1*)storage} }} + {{ typeId=1, value={*($T2*)storage} }} + {{ typeId=2, value={*($T3*)storage} }} + {{ typeId=3, value={*($T4*)storage} }} + {{ typeId=4, value={*($T5*)storage} }} + {{ typeId=5, value={*($T6*)storage} }} + {{ typeId=6, value={*($T7*)storage} }} + {{ typeId=7, value={*($T8*)storage} }} + {{ typeId=8, value={*($T9*)storage} }} + {{ typeId=9, value={*($T10*)storage} }} + {{ typeId=10, value={*($T11*)storage} }} + {{ typeId=11, value={*($T12*)storage} }} + {{ typeId=12, value={*($T13*)storage} }} + {{ typeId=13, value={*($T14*)storage} }} + {{ typeId=14, value={*($T15*)storage} }} + {{ typeId=15, value={*($T16*)storage} }} + {{ typeId=16, value={*($T17*)storage} }} + {{ typeId=17, value={*($T18*)storage} }} + {{ typeId=18, value={*($T19*)storage} }} + {{ typeId=19, value={*($T20*)storage} }} + {{ typeId=20, value={*($T21*)storage} }} + {{ typeId=21, value={*($T22*)storage} }} + {{ typeId=22, value={*($T23*)storage} }} + {{ typeId=23, value={*($T24*)storage} }} + {{ typeId=24, value={*($T25*)storage} }} + {{ typeId=25, value={*($T26*)storage} }} + {{ typeId=26, value={*($T27*)storage} }} + {{ typeId=27, value={*($T28*)storage} }} + {{ typeId=28, value={*($T29*)storage} }} + {{ typeId=29, value={*($T30*)storage} }} + {{ typeId=30, value={*($T31*)storage} }} + {{ typeId=31, value={*($T32*)storage} }} + {{ typeId=32, value={*($T33*)storage} }} + {{ typeId=33, value={*($T34*)storage} }} + {{ typeId=34, value={*($T35*)storage} }} + {{ typeId=35, value={*($T36*)storage} }} + {{ typeId=36, value={*($T37*)storage} }} + {{ typeId=37, value={*($T38*)storage} }} + {{ typeId=38, value={*($T39*)storage} }} + {{ typeId=39, value={*($T40*)storage} }} + {{ typeId=40, value={*($T41*)storage} }} + {{ typeId=41, value={*($T42*)storage} }} + {{ typeId=42, value={*($T43*)storage} }} + {{ typeId=43, value={*($T44*)storage} }} + {{ typeId=44, value={*($T45*)storage} }} + {{ typeId=45, value={*($T46*)storage} }} + {{ typeId=46, value={*($T47*)storage} }} + {{ typeId=47, value={*($T48*)storage} }} + {{ typeId=48, value={*($T49*)storage} }} + {{ typeId=49, value={*($T50*)storage} }} + {{ typeId=50, value={*($T51*)storage} }} + {{ typeId=51, value={*($T52*)storage} }} + {{ typeId=52, value={*($T53*)storage} }} + {{ typeId=53, value={*($T54*)storage} }} + {{ typeId=54, value={*($T55*)storage} }} + {{ typeId=55, value={*($T56*)storage} }} + {{ typeId=56, value={*($T57*)storage} }} + {{ typeId=57, value={*($T58*)storage} }} + {{ typeId=58, value={*($T59*)storage} }} + {{ typeId=59, value={*($T60*)storage} }} + {{ typeId=60, value={*($T61*)storage} }} + {{ typeId=61, value={*($T62*)storage} }} + {{ typeId=62, value={*($T63*)storage} }} + {{ typeId=63, value={*($T64*)storage} }} + + typeId + *($T1*)storage + *($T2*)storage + *($T3*)storage + *($T4*)storage + *($T5*)storage + *($T6*)storage + *($T7*)storage + *($T8*)storage + *($T9*)storage + *($T10*)storage + *($T11*)storage + *($T12*)storage + *($T13*)storage + *($T14*)storage + *($T15*)storage + *($T16*)storage + *($T17*)storage + *($T18*)storage + *($T19*)storage + *($T20*)storage + *($T21*)storage + *($T22*)storage + *($T23*)storage + *($T24*)storage + *($T25*)storage + *($T26*)storage + *($T27*)storage + *($T28*)storage + *($T29*)storage + *($T30*)storage + *($T31*)storage + *($T32*)storage + *($T33*)storage + *($T34*)storage + *($T35*)storage + *($T36*)storage + *($T37*)storage + *($T38*)storage + *($T39*)storage + *($T40*)storage + *($T41*)storage + *($T42*)storage + *($T43*)storage + *($T44*)storage + *($T45*)storage + *($T46*)storage + *($T47*)storage + *($T48*)storage + *($T49*)storage + *($T50*)storage + *($T51*)storage + *($T52*)storage + *($T53*)storage + *($T54*)storage + *($T55*)storage + *($T56*)storage + *($T57*)storage + *($T58*)storage + *($T59*)storage + *($T60*)storage + *($T61*)storage + *($T62*)storage + *($T63*)storage + *($T64*)storage + + + + diff --git a/tools/natvis/Ast.natvis b/tools/natvis/Ast.natvis new file mode 100644 index 00000000..18e7b762 --- /dev/null +++ b/tools/natvis/Ast.natvis @@ -0,0 +1,46 @@ + + + + + AstArray size={size} + + size + + size + data + + + + + + + size_ + + size_ + storage._Mypair._Myval2._Myfirst + offset + + + + + + {value,na} + + + + {name.value,na} + + + + local {local->name.value,na} + global {global.value,na} + + + + {line}:{column} + + + + {begin}-{end} + + + diff --git a/tools/natvis/CodeGen.natvis b/tools/natvis/CodeGen.natvis new file mode 100644 index 00000000..5ff6e143 --- /dev/null +++ b/tools/natvis/CodeGen.natvis @@ -0,0 +1,56 @@ + + + + + noreg + rip + + al + cl + dl + bl + + eax + ecx + edx + ebx + esp + ebp + esi + edi + e{(int)index,d}d + + rax + rcx + rdx + rbx + rsp + rbp + rsi + rdi + r{(int)index,d} + + xmm{(int)index,d} + + ymm{(int)index,d} + + + + {base} + {memSize,en} ptr[{base} + {index}*{(int)scale,d} + {imm}] + {memSize,en} ptr[{index}*{(int)scale,d} + {imm}] + {memSize,en} ptr[{base} + {imm}] + {memSize,en} ptr[{imm}] + {imm} + + base + imm + memSize + base + index + scale + imm + + + + diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis new file mode 100644 index 00000000..e45e4e28 --- /dev/null +++ b/tools/natvis/VM.natvis @@ -0,0 +1,280 @@ + + + + + nil + {(bool)value.b} + lightuserdata {value.p} + number = {value.n} + vector = {value.v[0]}, {value.v[1]}, {*(float*)&extra} + {value.gc->ts} + {value.gc->h} + function {value.gc->cl,view(short)} + userdata {value.gc->u} + thread {value.gc->th} + proto {value.gc->p} + upvalue {value.gc->uv} + deadkey + empty + + value.p + value.gc->ts + value.gc->h + value.gc->cl + value.gc->cl + value.gc->u + value.gc->th + value.gc->p + value.gc->uv + + fixed ({(int)value.gc->gch.marked}) + black ({(int)value.gc->gch.marked}) + white ({(int)value.gc->gch.marked}) + white ({(int)value.gc->gch.marked}) + gray ({(int)value.gc->gch.marked}) + + + + + + nil + {(bool)value.b} + lightuserdata {value.p} + number = {value.n} + vector = {value.v[0]}, {value.v[1]}, {*(float*)&extra} + {value.gc->ts} + {value.gc->h} + function {value.gc->cl,view(short)} + userdata {value.gc->u} + thread {value.gc->th} + proto {value.gc->p} + upvalue {value.gc->uv} + deadkey + empty + + (void**)value.p + value.gc->ts + value.gc->h + value.gc->cl + value.gc->cl + value.gc->u + value.gc->th + value.gc->p + value.gc->uv + + next + + + + + {key,na} = {val} + --- + + + + table + + metatable + + + [size] {1<<lsizenode} + + + 1<<lsizenode + node[$i] + + + + + [size] {sizearray} + + + sizearray + array[$i] + + + + + + + + + + + + + 1 + + + + + metatable->node[i].val + + + + i = i + 1 + + + "unknown",sb + + + tag + len + metatable + data + + + + + {c.f,na} + {l.p,na} + {c} + {l} + invalid + + + + {data,s} + + + + + + + + + + + empty + none + + + {proto()->source->data,sb}:{line()} function {proto()->debugname->data,sb}() + {proto()->source->data,sb}:{line()} function() + + + =[C] function {cl().c.debugname,sb}() {cl().c.f,na} + =[C] {cl().c.f,na} + + + + {ci,na} + thread + + + {ci-base_ci} frames + + + ci-base_ci + + + base_ci[ci-base_ci - $i] + + + + + + {top-base} values + + + top-base + base + + + + + {top-stack} values + + + top-stack + stack + + + + + + + openupval + u.open.threadnext + this + + + + gt + userdata + + + + + {source->data,sb}:{linedefined} function {debugname->data,sb} [{(int)numparams} arg, {(int)nups} upval] + {source->data,sb}:{linedefined} [{(int)numparams} arg, {(int)nups} upval] + + debugname + + constants + + + sizek + k[$i] + + + + + locals + + + sizelocvars + locvars[$i] + + + + + bytecode + + + sizecode + code[$i] + + + + + functions + + + sizep + p[$i] + + + + + upvals + + + sizeupvalues + upvalues[$i] + + + + + source + + + + + + + {(lua_Type)tt} + + + fixed ({(int)marked}) + black ({(int)marked}) + white ({(int)marked}) + white ({(int)marked}) + gray ({(int)marked}) + unknown + + memcat + + + + diff --git a/tools/numprint.py b/tools/numprint.py new file mode 100644 index 00000000..47ad36d9 --- /dev/null +++ b/tools/numprint.py @@ -0,0 +1,82 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# This code can be used to generate power tables for Schubfach algorithm (see lnumprint.cpp) + +import math +import sys + +(_, pow10min, pow10max, compact) = sys.argv +pow10min = int(pow10min) +pow10max = int(pow10max) +compact = compact == "True" + +# extract high 128 bits of the value +def high128(n, roundup): + L = math.ceil(math.log2(n)) + + r = 0 + for i in range(L - 128, L): + if i >= 0 and (n & (1 << i)) != 0: + r |= (1 << (i - L + 128)) + + return r + (1 if roundup else 0) + +def pow10approx(n): + if n == 0: + return 1 << 127 + elif n > 0: + return high128(10**n, 5**n >= 2**128) + else: + # 10^-n is a binary fraction that can't be represented in floating point + # we need to extract top 128 bits of the fraction starting from the first 1 + # to get there, we need to divide 2^k by 10^n for a sufficiently large k and repeat the extraction process + p = 10**-n + k = 2**128 * 16**-n # this guarantees that the fraction has more than 128 extra bits + return high128(k // p, True) + +def pow5_64(n): + assert(n >= 0) + if n == 0: + return 1 << 63 + else: + return high128(5**n, False) >> 64 + +if not compact: + print("// kPow10Table", pow10min, "..", pow10max) + print("{") + for p in range(pow10min, pow10max + 1): + h = hex(pow10approx(p))[2:] + assert(len(h) == 32) + print(" {0x%s, 0x%s}," % (h[0:16].upper(), h[16:32].upper())) + print("}") +else: + print("// kPow5Table") + print("{") + for i in range(16): + print(" " + hex(pow5_64(i)) + ",") + print("}") + print("// kPow10Table", pow10min, "..", pow10max) + print("{") + for p in range(pow10min, pow10max + 1, 16): + base = pow10approx(p) + errw = 0 + for i in range(16): + real = pow10approx(p + i) + appr = (base * pow5_64(i)) >> 64 + scale = 1 if appr < (1 << 127) else 0 # 1-bit scale + + offset = (appr << scale) - real + assert(offset >= -4 and offset <= 3) # 3-bit offset + assert((appr << scale) >> 64 == real >> 64) # offset only affects low half + assert((appr << scale) - offset == real) # validate full reconstruction + + err = (scale << 3) | (offset + 4) + errw |= err << (i * 4) + + hbase = hex(base)[2:] + assert(len(hbase) == 32) + assert(errw < 1 << 64) + + print(" {0x%s, 0x%s, 0x%16x}," % (hbase[0:16], hbase[16:32], errw)) + print("}") diff --git a/tools/patchtests.py b/tools/patchtests.py new file mode 100644 index 00000000..56970c9f --- /dev/null +++ b/tools/patchtests.py @@ -0,0 +1,84 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# This code can be used to patch Compiler.test.cpp following bytecode changes, based on error output +import sys +import re + +(_, filename) = sys.argv +input = sys.stdin.readlines() + +errors = [] + +# 0: looking for error, 1: looking for replacement, 2: collecting replacement, 3: collecting initial text +state = 0 + +# parse input into errors[] with the state machine; this is using doctest output and expects multi-line match failures +for line in input: + if state == 0: + if sys.platform == "win32": + match = re.match("[^(]+\((\d+)\): ERROR: CHECK_EQ", line) + else: + match = re.match("tests/[^:]+:(\d+): ERROR: CHECK_EQ", line) + + if match: + error_line = int(match[1]) + state = 1 + elif state == 1: + if re.match("\s*values: CHECK_EQ\(\s*$", line): + error_repl = [] + state = 2 + elif re.match("\s*values: CHECK_EQ", line): + state = 0 # skipping single-line checks since we can't patch them + elif state == 2: + if line.strip() == ",": + error_orig = [] + state = 3 + else: + error_repl.append(line) + elif state == 3: + if line.strip() == ")": + errors.append((error_line, error_orig, error_repl)) + state = 0 + else: + error_orig.append(line) + +# make sure we fully process each individual check +assert(state == 0) + +errors.sort(key = lambda e: e[0]) + +with open(filename, "r") as fp: + source = fp.readlines() + +# patch source text into result[] using errors[]; we expect every match to appear at or after the line error was reported at +result = [] + +current = 0 +index = 0 +target = 0 + +while index < len(source): + line = source[index] + error = errors[current] if current < len(errors) else None + + if error: + target = error[0] if sys.platform != "win32" else error[0] - len(error[1]) - 1 + + if not error or index < target or line != error[1][0]: + result.append(line) + index += 1 + else: + # validate that the patch has a complete match in source text + for v in range(len(error[1])): + assert(source[index + v] == error[1][v]) + + result += error[2] + index += len(error[1]) + current += 1 + +# make sure we patch all errors +assert(current == len(errors)) + +with open(filename, "w") as fp: + fp.writelines(result) diff --git a/tools/perfgraph.py b/tools/perfgraph.py index 95baef9c..eb6b68ce 100644 --- a/tools/perfgraph.py +++ b/tools/perfgraph.py @@ -4,8 +4,13 @@ # Given a profile dump, this tool generates a flame graph based on the stacks listed in the profile # The result of analysis is a .svg file which can be viewed in a browser -import sys import svg +import argparse +import json + +argumentParser = argparse.ArgumentParser(description='Generate flamegraph SVG from Luau sampling profiler dumps') +argumentParser.add_argument('source_file', type=open) +argumentParser.add_argument('--json', dest='useJson',action='store_const',const=1,default=0,help='Parse source_file as JSON') class Node(svg.Node): def __init__(self): @@ -27,26 +32,139 @@ class Node(svg.Node): def details(self, root): return "Function: {} [{}:{}] ({:,} usec, {:.1%}); self: {:,} usec".format(self.function, self.source, self.line, self.width, self.width / root.width, self.ticks) -with open(sys.argv[1]) as f: - dump = f.readlines() -root = Node() +def nodeFromCallstackListFile(source_file): + dump = source_file.readlines() + root = Node() -for l in dump: - ticks, stack = l.strip().split(" ", 1) - node = root + for l in dump: + ticks, stack = l.strip().split(" ", 1) + node = root - for f in reversed(stack.split(";")): - source, function, line = f.split(",") + for f in reversed(stack.split(";")): + source, function, line = f.split(",") - child = node.child(f) - child.function = function - child.source = source - child.line = int(line) if len(line) > 0 else 0 + child = node.child(f) + child.function = function + child.source = source + child.line = int(line) if len(line) > 0 else 0 + + node = child + + node.ticks += int(ticks) + + return root + + +def getDuration(nodes, nid): + node = nodes[nid - 1] + total = node['TotalDuration'] + + for cid in node['NodeIds']: + total -= nodes[cid - 1]['TotalDuration'] + + return total + +def getFunctionKey(fn): + return fn['Source'] + "," + fn['Name'] + "," + str(fn['Line']) + +def recursivelyBuildNodeTree(nodes, functions, parent, fid, nid): + ninfo = nodes[nid - 1] + finfo = functions[fid - 1] + + child = parent.child(getFunctionKey(finfo)) + child.source = finfo['Source'] + child.function = finfo['Name'] + child.line = int(finfo['Line']) if finfo['Line'] > 0 else 0 + + child.ticks = getDuration(nodes, nid) + + assert(len(ninfo['FunctionIds']) == len(ninfo['NodeIds'])) + + for i in range(0, len(ninfo['FunctionIds'])): + recursivelyBuildNodeTree(nodes, functions, child, ninfo['FunctionIds'][i], ninfo['NodeIds'][i]) + + return + +def nodeFromJSONV2(dump): + assert(dump['Version'] == 2) + + nodes = dump['Nodes'] + functions = dump['Functions'] + categories = dump['Categories'] + + root = Node() + + for category in categories: + nid = category['NodeId'] + node = nodes[nid - 1] + name = category['Name'] + + child = root.child(name) + child.function = name + child.ticks = getDuration(nodes, nid) + + assert(len(node['FunctionIds']) == len(node['NodeIds'])) + + for i in range(0, len(node['FunctionIds'])): + recursivelyBuildNodeTree(nodes, functions, child, node['FunctionIds'][i], node['NodeIds'][i]) + + return root + +def getDurationV1(obj): + total = obj['TotalDuration'] + + if 'Children' in obj: + for key, obj in obj['Children'].items(): + total -= obj['TotalDuration'] + + return total + + +def nodeFromJSONObject(node, key, obj): + source, function, line = key.split(",") + + node.function = function + node.source = source + node.line = int(line) if len(line) > 0 else 0 + + node.ticks = getDurationV1(obj) + + if 'Children' in obj: + for key, obj in obj['Children'].items(): + nodeFromJSONObject(node.child(key), key, obj) + + return node + +def nodeFromJSONV1(dump): + assert(dump['Version'] == 1) + root = Node() + + if 'Children' in dump: + for key, obj in dump['Children'].items(): + nodeFromJSONObject(root.child(key), key, obj) + + return root + +def nodeFromJSONFile(source_file): + dump = json.load(source_file) + + if dump['Version'] == 2: + return nodeFromJSONV2(dump) + elif dump['Version'] == 1: + return nodeFromJSONV1(dump) + + return Node() + + +arguments = argumentParser.parse_args() + +if arguments.useJson: + root = nodeFromJSONFile(arguments.source_file) +else: + root = nodeFromCallstackListFile(arguments.source_file) - node = child - node.ticks += int(ticks) svg.layout(root, lambda n: n.ticks) svg.display(root, "Flame Graph", "hot", flip = True) diff --git a/tools/perfstat.py b/tools/perfstat.py new file mode 100644 index 00000000..e5cfd117 --- /dev/null +++ b/tools/perfstat.py @@ -0,0 +1,65 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a profile dump, this tool displays top functions based on the stacks listed in the profile + +import argparse + +class Node: + def __init__(self): + self.function = "" + self.source = "" + self.line = 0 + self.hier_ticks = 0 + self.self_ticks = 0 + + def title(self): + if self.line > 0: + return "{} ({}:{})".format(self.function, self.source, self.line) + else: + return self.function + +argumentParser = argparse.ArgumentParser(description='Display summary statistics from Luau sampling profiler dumps') +argumentParser.add_argument('source_file', type=open) +argumentParser.add_argument('--limit', dest='limit', type=int, default=10, help='Display top N functions') + +arguments = argumentParser.parse_args() + +dump = arguments.source_file.readlines() + +stats = {} +total = 0 +total_gc = 0 + +for l in dump: + ticks, stack = l.strip().split(" ", 1) + hier = {} + + for f in reversed(stack.split(";")): + source, function, line = f.split(",") + node = stats.setdefault(f, Node()) + + node.function = function + node.source = source + node.line = int(line) if len(line) > 0 else 0 + + if not node in hier: + node.hier_ticks += int(ticks) + hier[node] = True + + total += int(ticks) + node.self_ticks += int(ticks) + + if node.source == "GC": + total_gc += int(ticks) + +if total > 0: + print(f"Runtime: {total:,} usec ({100.0 * total_gc / total:.2f}% GC)") + print() + print("Top functions (self time):") + for n in sorted(stats.values(), key=lambda node: node.self_ticks, reverse=True)[:arguments.limit]: + print(f"{n.self_ticks:12,} usec ({100.0 * n.self_ticks / total:.2f}%): {n.title()}") + print() + print("Top functions (total time):") + for n in sorted(stats.values(), key=lambda node: node.hier_ticks, reverse=True)[:arguments.limit]: + print(f"{n.hier_ticks:12,} usec ({100.0 * n.hier_ticks / total:.2f}%): {n.title()}") diff --git a/tools/stack-usage-reporter.py b/tools/stack-usage-reporter.py new file mode 100644 index 00000000..91e74887 --- /dev/null +++ b/tools/stack-usage-reporter.py @@ -0,0 +1,173 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# The purpose of this script is to analyze disassembly generated by objdump or +# dumpbin to print (or to compare) the stack usage of functions/methods. +# This is a quickly written script, so it is quite possible it may not handle +# all code properly. +# +# The script expects the user to create a text assembly dump to be passed to +# the script. +# +# objdump Example +# objdump --demangle --disassemble objfile.o > objfile.s +# +# dumpbin Example +# dumpbin /disasm objfile.obj > objfile.s +# +# If the script is passed a single file, then all stack size information that +# is found it printed. If two files are passed, then the script compares the +# stack usage of the two files (useful for A/B comparisons). +# Currently more than two input files are not supported. (But adding support shouldn't +# be very difficult.) +# +# Note: The script only handles x64 disassembly. Supporting x86 is likely +# trivial, but ARM support could be difficult. +# Thus far the script has been tested with MSVC on Win64 and clang on OSX. + +import argparse +import re + +blank_re = re.compile('\s*') + +class LineReader: + def __init__(self, lines): + self.lines = list(reversed(lines)) + def get_line(self): + return self.lines.pop(-1) + def peek_line(self): + return self.lines[-1] + def consume_blank_lines(self): + while blank_re.fullmatch(self.peek_line()): + self.get_line() + def is_empty(self): + return len(self.lines) == 0 + +def parse_objdump_assembly(in_file): + results = {} + text_section_re = re.compile('Disassembly of section __TEXT,__text:\s*') + symbol_re = re.compile('[^<]*<(.*)>:\s*') + stack_alloc = re.compile('.*subq\s*\$(\d*), %rsp\s*') + + lr = LineReader(in_file.readlines()) + + def find_stack_alloc_size(): + while True: + if lr.is_empty(): + return None + if blank_re.fullmatch(lr.peek_line()): + return None + + line = lr.get_line() + mo = stack_alloc.fullmatch(line) + if mo: + lr.consume_blank_lines() + return int(mo.group(1)) + + # Find beginning of disassembly + while not text_section_re.fullmatch(lr.get_line()): + pass + + # Scan for symbols + while not lr.is_empty(): + lr.consume_blank_lines() + if lr.is_empty(): + break + line = lr.get_line() + mo = symbol_re.fullmatch(line) + # Found a symbol + if mo: + symbol = mo.group(1) + stack_size = find_stack_alloc_size() + if stack_size != None: + results[symbol] = stack_size + + return results + +def parse_dumpbin_assembly(in_file): + results = {} + + file_type_re = re.compile('File Type: COFF OBJECT\s*') + symbol_re = re.compile('[^(]*\((.*)\):\s*') + summary_re = re.compile('\s*Summary\s*') + stack_alloc = re.compile('.*sub\s*rsp,([A-Z0-9]*)h\s*') + + lr = LineReader(in_file.readlines()) + + def find_stack_alloc_size(): + while True: + if lr.is_empty(): + return None + if blank_re.fullmatch(lr.peek_line()): + return None + + line = lr.get_line() + mo = stack_alloc.fullmatch(line) + if mo: + lr.consume_blank_lines() + return int(mo.group(1), 16) # return value in decimal + + # Find beginning of disassembly + while not file_type_re.fullmatch(lr.get_line()): + pass + + # Scan for symbols + while not lr.is_empty(): + lr.consume_blank_lines() + if lr.is_empty(): + break + line = lr.get_line() + if summary_re.fullmatch(line): + break + mo = symbol_re.fullmatch(line) + # Found a symbol + if mo: + symbol = mo.group(1) + stack_size = find_stack_alloc_size() + if stack_size != None: + results[symbol] = stack_size + return results + +def main(): + parser = argparse.ArgumentParser(description='Tool used for reporting or comparing the stack usage of functions/methods') + parser.add_argument('--format', choices=['dumpbin', 'objdump'], required=True, help='Specifies the program used to generate the input files') + parser.add_argument('--input', action='append', required=True, help='Input assembly file. This option may be specified multiple times.') + parser.add_argument('--md-output', action='store_true', help='Show table output in markdown format') + parser.add_argument('--only-diffs', action='store_true', help='Only show stack info when it differs between the input files') + args = parser.parse_args() + + parsers = {'dumpbin': parse_dumpbin_assembly, 'objdump' : parse_objdump_assembly} + parse_func = parsers[args.format] + + input_results = [] + for input_name in args.input: + with open(input_name) as in_file: + results = parse_func(in_file) + input_results.append(results) + + if len(input_results) == 1: + # Print out the results sorted by size + size_sorted = sorted([(size, symbol) for symbol, size in results.items()], reverse=True) + print(input_name) + for size, symbol in size_sorted: + print(f'{size:10}\t{symbol}') + print() + elif len(input_results) == 2: + common_symbols = set(input_results[0].keys()).intersection(set(input_results[1].keys())) + print(f'Found {len(common_symbols)} common symbols') + stack_sizes = sorted([(input_results[0][sym], input_results[1][sym], sym) for sym in common_symbols], reverse=True) + if args.md_output: + print('Before | After | Symbol') + print('-- | -- | --') + for size0, size1, symbol in stack_sizes: + if args.only_diffs and size0 == size1: + continue + if args.md_output: + print(f'{size0} | {size1} | {symbol}') + else: + print(f'{size0:10}\t{size1:10}\t{symbol}') + else: + print("TODO support more than 2 inputs") + +if __name__ == '__main__': + main() diff --git a/tools/svg.py b/tools/svg.py index 3b3bb28c..21200eeb 100644 --- a/tools/svg.py +++ b/tools/svg.py @@ -452,19 +452,22 @@ def display(root, title, colors, flip = False): .replace("$gradient-start", gradient_start) .replace("$gradient-end", gradient_end) .replace("$height", str(svgheight)) - .replace("$status", str(svgheight - 16 + 3)) + .replace("$status", str((svgheight - 16 + 3 if flip else 3 * 16 - 3))) .replace("$flip", str(int(flip))) ) framewidth = 1200 - 20 + def pixels(x): + return float(x) / root.width * framewidth if root.width > 0 else 0 + for n in root.subtree(): - if n.width / root.width * framewidth < 0.1: + if pixels(n.width) < 0.1: continue - x = 10 + n.offset / root.width * framewidth + x = 10 + pixels(n.offset) y = (maxdepth - 1 - n.depth if flip else n.depth) * 16 + 3 * 16 - width = n.width / root.width * framewidth + width = pixels(n.width) height = 15 if colors == "cold": diff --git a/tools/test_dcr.py b/tools/test_dcr.py new file mode 100644 index 00000000..6d553b64 --- /dev/null +++ b/tools/test_dcr.py @@ -0,0 +1,227 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +import argparse +import os.path +import subprocess as sp +import sys +import xml.sax as x +import colorama as c + +c.init() + +SCRIPT_PATH = os.path.split(sys.argv[0])[0] +FAIL_LIST_PATH = os.path.join(SCRIPT_PATH, "faillist.txt") + + +def loadFailList(): + with open(FAIL_LIST_PATH) as f: + return set(map(str.strip, f.readlines())) + + +def safeParseInt(i, default=0): + try: + return int(i) + except ValueError: + return default + + +def makeDottedName(path): + return ".".join(path) + + +class Handler(x.ContentHandler): + def __init__(self, failList): + self.currentTest = [] + self.failList = failList # Set of dotted test names that are expected to fail + + self.results = {} # {DottedName: TrueIfTheTestPassed} + + self.numSkippedTests = 0 + + self.pass_count = 0 + self.fail_count = 0 + self.test_count = 0 + + self.crashed_tests = [] + + def startElement(self, name, attrs): + if name == "TestSuite": + self.currentTest.append(attrs["name"]) + elif name == "TestCase": + self.currentTest.append(attrs["name"]) + + elif name == "OverallResultsAsserts": + if self.currentTest: + passed = attrs["test_case_success"] == "true" + + dottedName = makeDottedName(self.currentTest) + + # Sometimes we get multiple XML trees for the same test. All of + # them must report a pass in order for us to consider the test + # to have passed. + r = self.results.get(dottedName, True) + self.results[dottedName] = r and passed + + self.test_count += 1 + if passed: + self.pass_count += 1 + else: + self.fail_count += 1 + + elif name == "OverallResultsTestCases": + self.numSkippedTests = safeParseInt(attrs.get("skipped", 0)) + + elif name == "Exception": + if attrs.get("crash") == "true": + self.crashed_tests.append(makeDottedName(self.currentTest)) + + def endElement(self, name): + if name == "TestCase": + self.currentTest.pop() + + elif name == "TestSuite": + self.currentTest.pop() + + +def print_stderr(*args, **kw): + print(*args, **kw, file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser( + description="Run Luau.UnitTest with deferred constraint resolution enabled" + ) + parser.add_argument( + "path", action="store", help="Path to the Luau.UnitTest executable" + ) + parser.add_argument( + "--dump", + dest="dump", + action="store_true", + help="Instead of doing any processing, dump the raw output of the test run. Useful for debugging this tool.", + ) + parser.add_argument( + "--write", + dest="write", + action="store_true", + help="Write a new faillist.txt after running tests.", + ) + + parser.add_argument("--randomize", action="store_true", help="Pick a random seed") + + parser.add_argument( + "--random-seed", + action="store", + dest="random_seed", + type=int, + help="Accept a specific RNG seed", + ) + + args = parser.parse_args() + + failList = loadFailList() + + commandLine = [ + args.path, + "--reporters=xml", + "--fflags=true,DebugLuauDeferredConstraintResolution=true", + ] + + if args.random_seed: + commandLine.append("--random-seed=" + str(args.random_seed)) + elif args.randomize: + commandLine.append("--randomize") + + print_stderr(">", " ".join(commandLine)) + + p = sp.Popen( + commandLine, + stdout=sp.PIPE, + ) + + handler = Handler(failList) + + if args.dump: + for line in p.stdout: + sys.stdout.buffer.write(line) + return + else: + try: + x.parse(p.stdout, handler) + except x.SAXParseException as e: + print_stderr( + f"XML parsing failed during test {makeDottedName(handler.currentTest)}. That probably means that the test crashed" + ) + sys.exit(1) + + p.wait() + + unexpected_fails = 0 + unexpected_passes = 0 + + for testName, passed in handler.results.items(): + if passed and testName in failList: + unexpected_passes += 1 + print_stderr( + f"UNEXPECTED: {c.Fore.RED}{testName}{c.Fore.RESET} should have failed" + ) + elif not passed and testName not in failList: + unexpected_fails += 1 + print_stderr( + f"UNEXPECTED: {c.Fore.GREEN}{testName}{c.Fore.RESET} should have passed" + ) + + if unexpected_fails or unexpected_passes: + print_stderr("") + print_stderr(f"Unexpected fails: {unexpected_fails}") + print_stderr(f"Unexpected passes: {unexpected_passes}") + + pass_percent = int(handler.pass_count / handler.test_count * 100) + + print_stderr("") + print_stderr( + f"{handler.pass_count} of {handler.test_count} tests passed. ({pass_percent}%)" + ) + print_stderr(f"{handler.fail_count} tests failed.") + + if args.write: + newFailList = sorted( + ( + dottedName + for dottedName, passed in handler.results.items() + if not passed + ), + key=str.lower, + ) + with open(FAIL_LIST_PATH, "w", newline="\n") as f: + for name in newFailList: + print(name, file=f) + print_stderr("Updated faillist.txt") + + if handler.crashed_tests: + print_stderr() + for test in handler.crashed_tests: + print_stderr( + f"{c.Fore.RED}{test}{c.Fore.RESET} threw an exception and crashed the test process!" + ) + + if handler.numSkippedTests > 0: + print_stderr(f"{handler.numSkippedTests} test(s) were skipped!") + + ok = ( + not handler.crashed_tests + and handler.numSkippedTests == 0 + and all( + not passed == (dottedName in failList) + for dottedName, passed in handler.results.items() + ) + ) + + if ok: + print_stderr("Everything in order!") + + sys.exit(0 if ok else 1) + + +if __name__ == "__main__": + main() diff --git a/tools/tracegraph.py b/tools/tracegraph.py new file mode 100644 index 00000000..a46423e7 --- /dev/null +++ b/tools/tracegraph.py @@ -0,0 +1,95 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a trace event file, this tool generates a flame graph based on the event scopes present in the file +# The result of analysis is a .svg file which can be viewed in a browser + +import sys +import svg +import json + +class Node(svg.Node): + def __init__(self): + svg.Node.__init__(self) + self.caption = "" + self.description = "" + self.ticks = 0 + + def text(self): + return self.caption + + def title(self): + return self.caption + + def details(self, root): + return "{} ({:,} usec, {:.1%}); self: {:,} usec".format(self.description, self.width, self.width / root.width, self.ticks) + +with open(sys.argv[1]) as f: + dump = f.read() + +root = Node() + +# Finish the file +if not dump.endswith("]"): + dump += "{}]" + +data = json.loads(dump) + +stacks = {} + +for l in data: + if len(l) == 0: + continue + + # Track stack of each thread, but aggregate values together + tid = l["tid"] + + if not tid in stacks: + stacks[tid] = [] + stack = stacks[tid] + + if l["ph"] == 'B': + stack.append(l) + elif l["ph"] == 'E': + node = root + + for e in stack: + caption = e["name"] + description = '' + + if "args" in e: + for arg in e["args"]: + if len(description) != 0: + description += ", " + + description += "{}: {}".format(arg, e["args"][arg]) + + child = node.child(caption + description) + child.caption = caption + child.description = description + + node = child + + begin = stack[-1] + + ticks = l["ts"] - begin["ts"] + rawticks = ticks + + # Flame graph requires ticks without children duration + if "childts" in begin: + ticks -= begin["childts"] + + node.ticks += int(ticks) + + stack.pop() + + if len(stack): + parent = stack[-1] + + if "childts" in parent: + parent["childts"] += rawticks + else: + parent["childts"] = rawticks + +svg.layout(root, lambda n: n.ticks) +svg.display(root, "Flame Graph", "hot", flip = True)