This commit is contained in:
RigidOfficial 2022-12-22 22:55:18 +04:00
commit bb2d6420e3
507 changed files with 105284 additions and 27610 deletions

View file

@ -1,5 +1,5 @@
blank_issues_enabled: false blank_issues_enabled: false
contact_links: contact_links:
- name: Help and support - name: Questions
url: https://github.com/Roblox/luau/discussions url: https://github.com/Roblox/luau/discussions
about: Please use GitHub Discussions if you have questions or need help. about: Please use GitHub Discussions if you have questions or need help.

1
.github/codecov.yml vendored Normal file
View file

@ -0,0 +1 @@
comment: false

185
.github/workflows/benchmark-dev.yml vendored Normal file
View file

@ -0,0 +1,185 @@
name: benchmark-dev
on:
push:
branches:
- master
paths-ignore:
- "docs/**"
- "papers/**"
- "rfcs/**"
- "*.md"
jobs:
windows:
name: windows-${{matrix.arch}}
strategy:
fail-fast: false
matrix:
os: [windows-latest]
arch: [Win32, x64]
bench:
- {
script: "run-benchmarks",
timeout: 12,
title: "Luau Benchmarks",
}
benchResultsRepo:
- { name: "luau-lang/benchmark-data", branch: "main" }
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Luau repository
uses: actions/checkout@v3
- name: Build Luau
shell: bash # necessary for fail-fast
run: |
mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
cmake --build . --target Luau.Repl.CLI --config Release
cmake --build . --target Luau.Analyze.CLI --config Release
- name: Move build files to root
run: |
move build/Release/* .
- uses: actions/setup-python@v3
with:
python-version: "3.9"
architecture: "x64"
- name: Install python dependencies
run: |
python -m pip install requests
python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose
- name: Run benchmark
run: |
python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt
- name: Push benchmark results
id: pushBenchmarkAttempt1
continue-on-error: true
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})"
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 2)
id: pushBenchmarkAttempt2
continue-on-error: true
if: steps.pushBenchmarkAttempt1.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})"
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 3)
id: pushBenchmarkAttempt3
continue-on-error: true
if: steps.pushBenchmarkAttempt2.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})"
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
unix:
name: ${{matrix.os}}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
bench:
- {
script: "run-benchmarks",
timeout: 12,
title: "Luau Benchmarks",
}
benchResultsRepo:
- { name: "luau-lang/benchmark-data", branch: "main" }
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Luau repository
uses: actions/checkout@v3
- name: Build Luau
run: make config=release luau luau-analyze
- uses: actions/setup-python@v3
with:
python-version: "3.9"
architecture: "x64"
- name: Install python dependencies
run: |
python -m pip install requests
python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose
- name: Run benchmark
run: |
python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt
- name: Push benchmark results
id: pushBenchmarkAttempt1
continue-on-error: true
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: ${{ matrix.bench.title }}
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 2)
id: pushBenchmarkAttempt2
continue-on-error: true
if: steps.pushBenchmarkAttempt1.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: ${{ matrix.bench.title }}
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 3)
id: pushBenchmarkAttempt3
continue-on-error: true
if: steps.pushBenchmarkAttempt2.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: ${{ matrix.bench.title }}
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"

140
.github/workflows/benchmark.yml vendored Normal file
View file

@ -0,0 +1,140 @@
name: benchmark
on:
push:
branches:
- master
paths-ignore:
- "docs/**"
- "papers/**"
- "rfcs/**"
- "*.md"
jobs:
callgrind:
strategy:
matrix:
os: [ubuntu-22.04]
benchResultsRepo:
- { name: "luau-lang/benchmark-data", branch: "main" }
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Luau repository
uses: actions/checkout@v3
- name: Install valgrind
run: |
sudo apt-get install valgrind
- name: Build Luau (gcc)
run: |
CXX=g++ make config=profile luau
cp luau luau-gcc
- name: Build Luau (codegen)
run: |
make config=profile clean
CXX=clang++ make config=profile native=1 luau
cp luau luau-codegen
- name: Build Luau (clang)
run: |
make config=profile clean
CXX=clang++ make config=profile luau luau-analyze
- name: Run benchmark (bench-gcc)
run: |
python bench/bench.py --callgrind --vm "./luau-gcc -O2" | tee -a bench-gcc-output.txt
- name: Run benchmark (bench)
run: |
python bench/bench.py --callgrind --vm "./luau -O2" | tee -a bench-output.txt
- name: Run benchmark (bench-codegen)
run: |
python bench/bench.py --callgrind --vm "./luau-codegen --codegen -O2" | tee -a bench-codegen-output.txt
- name: Run benchmark (analyze)
run: |
filter() {
awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau-analyze"}'
}
valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-nonstrict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-strict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/regex.lua 2>&1 | filter regex-nonstrict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/regex.lua 2>&1 | filter regex-strict | tee -a analyze-output.txt
- name: Run benchmark (compile)
run: |
filter() {
awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau --compile"}'
}
valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt
- name: Checkout benchmark results
uses: actions/checkout@v3
with:
repository: ${{ matrix.benchResultsRepo.name }}
ref: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
- name: Store results (bench)
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: callgrind clang
tool: "benchmarkluau"
output-file-path: ./bench-output.txt
external-data-json-path: ./gh-pages/bench.json
- name: Store results (bench-codegen)
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: callgrind codegen
tool: "benchmarkluau"
output-file-path: ./bench-codegen-output.txt
external-data-json-path: ./gh-pages/bench-codegen.json
- name: Store results (bench-gcc)
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: callgrind gcc
tool: "benchmarkluau"
output-file-path: ./bench-gcc-output.txt
external-data-json-path: ./gh-pages/bench-gcc.json
- name: Store results (analyze)
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: luau-analyze
tool: "benchmarkluau"
output-file-path: ./analyze-output.txt
external-data-json-path: ./gh-pages/analyze.json
- name: Store results (compile)
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: luau --compile
tool: "benchmarkluau"
output-file-path: ./compile-output.txt
external-data-json-path: ./gh-pages/compile.json
- name: Push benchmark results
if: github.event_name == 'push'
run: |
echo "Pushing benchmark results..."
cd gh-pages
git config user.name github-actions
git config user.email github@users.noreply.github.com
git add *.json
git commit -m "Add benchmarks results for ${{ github.sha }}"
git push
cd ..

View file

@ -4,6 +4,11 @@ on:
push: push:
branches: branches:
- 'master' - 'master'
paths-ignore:
- 'docs/**'
- 'papers/**'
- 'rfcs/**'
- '*.md'
pull_request: pull_request:
paths-ignore: paths-ignore:
- 'docs/**' - 'docs/**'
@ -20,15 +25,24 @@ jobs:
runs-on: ${{matrix.os}}-latest runs-on: ${{matrix.os}}-latest
steps: steps:
- uses: actions/checkout@v1 - uses: actions/checkout@v1
- name: make test - name: make tests
run: | run: |
make -j2 config=sanitize test make -j2 config=sanitize werror=1 native=1 luau-tests
- name: make test w/flags - name: run tests
run: | run: |
make -j2 config=sanitize flags=true test ./luau-tests
./luau-tests --fflags=true
- name: run extra conformance tests
run: |
./luau-tests -ts=Conformance -O2
./luau-tests -ts=Conformance -O2 --fflags=true
./luau-tests -ts=Conformance --codegen
./luau-tests -ts=Conformance --codegen --fflags=true
./luau-tests -ts=Conformance --codegen -O2
./luau-tests -ts=Conformance --codegen -O2 --fflags=true
- name: make cli - name: make cli
run: | run: |
make -j2 config=sanitize luau luau-analyze # match config with tests to improve build time make -j2 config=sanitize werror=1 luau luau-analyze # match config with tests to improve build time
./luau tests/conformance/assert.lua ./luau tests/conformance/assert.lua
./luau-analyze tests/conformance/assert.lua ./luau-analyze tests/conformance/assert.lua
@ -40,18 +54,25 @@ jobs:
steps: steps:
- uses: actions/checkout@v1 - uses: actions/checkout@v1
- name: cmake configure - name: cmake configure
run: cmake . -A ${{matrix.arch}} run: cmake . -A ${{matrix.arch}} -DLUAU_WERROR=ON -DLUAU_NATIVE=ON
- name: cmake test - name: cmake build
run: cmake --build . --target Luau.UnitTest Luau.Conformance --config Debug
- name: run tests
shell: bash # necessary for fail-fast shell: bash # necessary for fail-fast
run: | run: |
cmake --build . --target Luau.UnitTest Luau.Conformance --config Debug
Debug/Luau.UnitTest.exe Debug/Luau.UnitTest.exe
Debug/Luau.Conformance.exe Debug/Luau.Conformance.exe
- name: cmake test w/flags
shell: bash # necessary for fail-fast
run: |
Debug/Luau.UnitTest.exe --fflags=true Debug/Luau.UnitTest.exe --fflags=true
Debug/Luau.Conformance.exe --fflags=true Debug/Luau.Conformance.exe --fflags=true
- name: run extra conformance tests
shell: bash # necessary for fail-fast
run: |
Debug/Luau.Conformance.exe -O2
Debug/Luau.Conformance.exe -O2 --fflags=true
Debug/Luau.Conformance.exe --codegen
Debug/Luau.Conformance.exe --codegen --fflags=true
Debug/Luau.Conformance.exe --codegen -O2
Debug/Luau.Conformance.exe --codegen -O2 --fflags=true
- name: cmake cli - name: cmake cli
shell: bash # necessary for fail-fast shell: bash # necessary for fail-fast
run: | run: |
@ -62,19 +83,34 @@ jobs:
coverage: coverage:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v1 - uses: actions/checkout@v2
- name: install - name: install
run: | run: |
sudo apt install llvm sudo apt install llvm
- name: make coverage - name: make coverage
run: | run: |
CXX=clang++-10 make -j2 config=coverage coverage CXX=clang++ make -j2 config=coverage native=1 coverage
- name: upload coverage - name: upload coverage
uses: coverallsapp/github-action@master uses: codecov/codecov-action@v3
with: with:
path-to-lcov: ./coverage.info files: ./coverage.info
github-token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
- uses: actions/upload-artifact@v2
web:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v2
with: with:
name: coverage repository: emscripten-core/emsdk
path: coverage path: emsdk
- name: emsdk install
run: |
cd emsdk
./emsdk install latest
./emsdk activate latest
- name: make
run: |
source emsdk/emsdk_env.sh
emcmake cmake . -DLUAU_BUILD_WEB=ON -DCMAKE_BUILD_TYPE=Release
make -j2 Luau.Web

83
.github/workflows/new-release.yml vendored Normal file
View file

@ -0,0 +1,83 @@
name: new-release
on:
workflow_dispatch:
inputs:
version:
required: true
description: Release version including 0.
permissions:
contents: write
jobs:
create-release:
runs-on: ubuntu-latest
outputs:
upload_url: ${{ steps.create_release.outputs.upload_url }}
steps:
- name: create release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ github.event.inputs.version }}
release_name: ${{ github.event.inputs.version }}
draft: true
build:
needs: ["create-release"]
strategy:
matrix:
os: [ubuntu, macos, windows]
name: ${{matrix.os}}
runs-on: ${{matrix.os}}-latest
steps:
- uses: actions/checkout@v1
- name: configure
run: cmake . -DCMAKE_BUILD_TYPE=Release
- name: build
run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Release -j 2
- name: pack
if: matrix.os != 'windows'
run: zip luau-${{matrix.os}}.zip luau*
- name: pack
if: matrix.os == 'windows'
run: 7z a luau-${{matrix.os}}.zip .\Release\luau*.exe
- uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ needs.create-release.outputs.upload_url }}
asset_path: luau-${{matrix.os}}.zip
asset_name: luau-${{matrix.os}}.zip
asset_content_type: application/octet-stream
web:
needs: ["create-release"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v2
with:
repository: emscripten-core/emsdk
path: emsdk
- name: emsdk install
run: |
cd emsdk
./emsdk install latest
./emsdk activate latest
- name: make
run: |
source emsdk/emsdk_env.sh
emcmake cmake . -DLUAU_BUILD_WEB=ON -DCMAKE_BUILD_TYPE=Release
make -j2 Luau.Web
- uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ needs.create-release.outputs.upload_url }}
asset_path: Luau.Web.js
asset_name: Luau.Web.js
asset_content_type: application/octet-stream

View file

@ -0,0 +1,63 @@
name: Checkout & push results
description: Checkout a given repo and push results to GitHub
inputs:
repository:
required: true
type: string
description: The benchmark results repository to check out
branch:
required: true
type: string
description: The benchmark results repository's branch to check out
token:
required: true
type: string
description: The GitHub token to use for pushing results
path:
required: true
type: string
description: The path to check out the results repository to
bench_name:
required: true
type: string
bench_tool:
required: true
type: string
bench_output_file_path:
required: true
type: string
bench_external_data_json_path:
required: true
type: string
runs:
using: "composite"
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
repository: ${{ inputs.repository }}
ref: ${{ inputs.branch }}
token: ${{ inputs.token }}
path: ${{ inputs.path }}
- name: Store results
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: ${{ inputs.bench_name }}
tool: ${{ inputs.bench_tool }}
gh-pages-branch: ${{ inputs.branch }}
output-file-path: ${{ inputs.bench_output_file_path }}
external-data-json-path: ${{ inputs.bench_external_data_json_path }}
- name: Push benchmark results
shell: bash
run: |
echo "Pushing benchmark results..."
cd gh-pages
git config user.name github-actions
git config user.email github@users.noreply.github.com
git add *.json
git commit -m "Add benchmarks results for ${{ github.sha }}"
git push
cd ..

View file

@ -8,6 +8,7 @@ on:
- 'docs/**' - 'docs/**'
- 'papers/**' - 'papers/**'
- 'rfcs/**' - 'rfcs/**'
- '*.md'
jobs: jobs:
build: build:
@ -21,7 +22,7 @@ jobs:
- name: configure - name: configure
run: cmake . -DCMAKE_BUILD_TYPE=Release run: cmake . -DCMAKE_BUILD_TYPE=Release
- name: build - name: build
run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Release run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Release -j 2
- uses: actions/upload-artifact@v2 - uses: actions/upload-artifact@v2
if: matrix.os != 'windows' if: matrix.os != 'windows'
with: with:
@ -32,3 +33,26 @@ jobs:
with: with:
name: luau-${{matrix.os}} name: luau-${{matrix.os}}
path: Release\luau*.exe path: Release\luau*.exe
web:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v2
with:
repository: emscripten-core/emsdk
path: emsdk
- name: emsdk install
run: |
cd emsdk
./emsdk install latest
./emsdk activate latest
- name: make
run: |
source emsdk/emsdk_env.sh
emcmake cmake . -DLUAU_BUILD_WEB=ON -DCMAKE_BUILD_TYPE=Release
make -j2 Luau.Web
- uses: actions/upload-artifact@v2
with:
name: Luau.Web.js
path: Luau.Web.js

20
.gitignore vendored
View file

@ -1,7 +1,13 @@
build/ /build/
^coverage/ /build[.-]*/
^fuzz/luau.pb.* /coverage/
^crash-* /.vs/
^default.prof* /.vscode/
^fuzz-* /fuzz/luau.pb.*
^luau$ /crash-*
/default.prof*
/fuzz-*
/luau
/luau-tests
/luau-analyze
__pycache__

View file

@ -0,0 +1,42 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/NotNull.h"
#include "Luau/Substitution.h"
#include "Luau/TypeVar.h"
#include <memory>
namespace Luau
{
struct TypeArena;
struct Scope;
struct InternalErrorReporter;
using ScopePtr = std::shared_ptr<Scope>;
// A substitution which replaces free types by any
struct Anyification : Substitution
{
Anyification(TypeArena* arena, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter* iceHandler, TypeId anyType,
TypePackId anyTypePack);
Anyification(TypeArena* arena, const ScopePtr& scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter* iceHandler, TypeId anyType,
TypePackId anyTypePack);
NotNull<Scope> scope;
NotNull<SingletonTypes> singletonTypes;
InternalErrorReporter* iceHandler;
TypeId anyType;
TypePackId anyTypePack;
bool normalizationTooComplex = false;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
bool ignoreChildren(TypeId ty) override;
bool ignoreChildren(TypePackId ty) override;
};
} // namespace Luau

View file

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Substitution.h"
#include "Luau/TxnLog.h"
#include "Luau/TypeVar.h"
namespace Luau
{
// A substitution which replaces the type parameters of a type function by arguments
struct ApplyTypeFunction : Substitution
{
ApplyTypeFunction(TypeArena* arena)
: Substitution(TxnLog::empty(), arena)
, encounteredForwardedType(false)
{
}
// Never set under deferred constraint resolution.
bool encounteredForwardedType;
std::unordered_map<TypeId, TypeId> typeArguments;
std::unordered_map<TypePackId, TypePackId> typePackArguments;
bool ignoreChildren(TypeId ty) override;
bool ignoreChildren(TypePackId tp) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
};
} // namespace Luau

View file

@ -2,12 +2,15 @@
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
namespace Luau namespace Luau
{ {
class AstNode; class AstNode;
struct Comment;
std::string toJson(AstNode* node); std::string toJson(AstNode* node);
std::string toJson(AstNode* node, const std::vector<Comment>& commentLocations);
} // namespace Luau } // namespace Luau

View file

@ -42,12 +42,28 @@ struct ExprOrLocal
{ {
return expr ? expr->location : (local ? local->location : std::optional<Location>{}); return expr ? expr->location : (local ? local->location : std::optional<Location>{});
} }
std::optional<AstName> getName()
{
if (expr)
{
if (AstName name = getIdentifier(expr); name.value)
{
return name;
}
}
else if (local)
{
return local->name;
}
return std::nullopt;
}
private: private:
AstExpr* expr = nullptr; AstExpr* expr = nullptr;
AstLocal* local = nullptr; AstLocal* local = nullptr;
}; };
std::vector<AstNode*> findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos);
std::vector<AstNode*> findAstAncestryOfPosition(const SourceModule& source, Position pos); std::vector<AstNode*> findAstAncestryOfPosition(const SourceModule& source, Position pos);
AstNode* findNodeAtPosition(const SourceModule& source, Position pos); AstNode* findNodeAtPosition(const SourceModule& source, Position pos);
AstExpr* findExprAtPosition(const SourceModule& source, Position pos); AstExpr* findExprAtPosition(const SourceModule& source, Position pos);

View file

@ -19,6 +19,17 @@ struct TypeChecker;
using ModulePtr = std::shared_ptr<Module>; using ModulePtr = std::shared_ptr<Module>;
enum class AutocompleteContext
{
Unknown,
Expression,
Statement,
Property,
Type,
Keyword,
String,
};
enum class AutocompleteEntryKind enum class AutocompleteEntryKind
{ {
Property, Property,
@ -66,11 +77,13 @@ struct AutocompleteResult
{ {
AutocompleteEntryMap entryMap; AutocompleteEntryMap entryMap;
std::vector<AstNode*> ancestry; std::vector<AstNode*> ancestry;
AutocompleteContext context = AutocompleteContext::Unknown;
AutocompleteResult() = default; AutocompleteResult() = default;
AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry) AutocompleteResult(AutocompleteEntryMap entryMap, std::vector<AstNode*> ancestry, AutocompleteContext context)
: entryMap(std::move(entryMap)) : entryMap(std::move(entryMap))
, ancestry(std::move(ancestry)) , ancestry(std::move(ancestry))
, context(context)
{ {
} }
}; };
@ -78,14 +91,6 @@ struct AutocompleteResult
using ModuleName = std::string; using ModuleName = std::string;
using StringCompletionCallback = std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassTypeVar*> ctx)>; using StringCompletionCallback = std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassTypeVar*> ctx)>;
struct OwningAutocompleteResult
{
AutocompleteResult result;
ModulePtr module;
std::unique_ptr<SourceModule> sourceModule;
};
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback);
OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback);
} // namespace Luau } // namespace Luau

View file

@ -1,12 +1,22 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "TypeInfer.h" #include "Luau/Scope.h"
#include "Luau/TypeVar.h"
#include <optional>
namespace Luau namespace Luau
{ {
void registerBuiltinTypes(TypeChecker& typeChecker); struct Frontend;
struct TypeChecker;
struct TypeArena;
void registerBuiltinTypes(Frontend& frontend);
void registerBuiltinGlobals(TypeChecker& typeChecker);
void registerBuiltinGlobals(Frontend& frontend);
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types); TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types);
TypeId makeIntersection(TypeArena& arena, std::vector<TypeId>&& types); TypeId makeIntersection(TypeArena& arena, std::vector<TypeId>&& types);
@ -14,6 +24,7 @@ TypeId makeIntersection(TypeArena& arena, std::vector<TypeId>&& types);
/** Build an optional 't' /** Build an optional 't'
*/ */
TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t); TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t);
TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t);
/** Small utility function for building up type definitions from C++. /** Small utility function for building up type definitions from C++.
*/ */
@ -33,19 +44,25 @@ TypeId makeFunction( // Polymorphic
std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes); std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes);
void attachMagicFunction(TypeId ty, MagicFunction fn); void attachMagicFunction(TypeId ty, MagicFunction fn);
void attachFunctionTag(TypeId ty, std::string constraint); void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn);
void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn);
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt); Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt);
void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName); void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName);
std::string getBuiltinDefinitionSource(); std::string getBuiltinDefinitionSource();
void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName);
void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding); void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding);
void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName);
void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName);
void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding); void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding);
std::optional<Binding> tryGetGlobalBinding(TypeChecker& typeChecker, const std::string& name); void addGlobalBinding(Frontend& frontend, const std::string& name, TypeId ty, const std::string& packageName);
void addGlobalBinding(Frontend& frontend, const std::string& name, Binding binding);
void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName);
void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, Binding binding);
std::optional<Binding> tryGetGlobalBinding(Frontend& frontend, const std::string& name);
Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name); Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name);
TypeId getGlobalBinding(Frontend& frontend, const std::string& name);
TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name); TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name);
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <Luau/NotNull.h>
#include "Luau/TypeArena.h"
#include "Luau/TypeVar.h"
#include <unordered_map>
namespace Luau
{
// Only exposed so they can be unit tested.
using SeenTypes = std::unordered_map<TypeId, TypeId>;
using SeenTypePacks = std::unordered_map<TypePackId, TypePackId>;
struct CloneState
{
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
int recursionCount = 0;
};
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState);
TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);
TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false);
TypeId shallowClone(TypeId ty, NotNull<TypeArena> dest);
} // namespace Luau

View file

@ -17,12 +17,9 @@ constexpr const char* kConfigName = ".luaurc";
struct Config struct Config
{ {
Config() Config();
{
enabledLint.setDefaults();
}
Mode mode = Mode::NoCheck; Mode mode;
ParseOptions parseOptions; ParseOptions parseOptions;

View file

@ -0,0 +1,70 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Def.h"
#include "Luau/TypedAllocator.h"
#include "Luau/Variant.h"
#include <memory>
namespace Luau
{
struct TypeVar;
using TypeId = const TypeVar*;
struct Negation;
struct Conjunction;
struct Disjunction;
struct Equivalence;
struct Proposition;
using Connective = Variant<Negation, Conjunction, Disjunction, Equivalence, Proposition>;
using ConnectiveId = Connective*; // Can and most likely is nullptr.
struct Negation
{
ConnectiveId connective;
};
struct Conjunction
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Disjunction
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Equivalence
{
ConnectiveId lhs;
ConnectiveId rhs;
};
struct Proposition
{
DefId def;
TypeId discriminantTy;
};
template<typename T>
const T* get(ConnectiveId connective)
{
return get_if<T>(connective);
}
struct ConnectiveArena
{
TypedAllocator<Connective> allocator;
ConnectiveId negation(ConnectiveId connective);
ConnectiveId conjunction(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId disjunction(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId equivalence(ConnectiveId lhs, ConnectiveId rhs);
ConnectiveId proposition(DefId def, TypeId discriminantTy);
};
} // namespace Luau

View file

@ -0,0 +1,207 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h" // Used for some of the enumerations
#include "Luau/Def.h"
#include "Luau/DenseHash.h"
#include "Luau/NotNull.h"
#include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include <string>
#include <memory>
#include <vector>
namespace Luau
{
struct Scope;
struct TypeVar;
using TypeId = const TypeVar*;
struct TypePackVar;
using TypePackId = const TypePackVar*;
// subType <: superType
struct SubtypeConstraint
{
TypeId subType;
TypeId superType;
};
// subPack <: superPack
struct PackSubtypeConstraint
{
TypePackId subPack;
TypePackId superPack;
};
// generalizedType ~ gen sourceType
struct GeneralizationConstraint
{
TypeId generalizedType;
TypeId sourceType;
};
// subType ~ inst superType
struct InstantiationConstraint
{
TypeId subType;
TypeId superType;
};
struct UnaryConstraint
{
AstExprUnary::Op op;
TypeId operandType;
TypeId resultType;
};
// let L : leftType
// let R : rightType
// in
// L op R : resultType
struct BinaryConstraint
{
AstExprBinary::Op op;
TypeId leftType;
TypeId rightType;
TypeId resultType;
// When we dispatch this constraint, we update the key at this map to record
// the overload that we selected.
AstExpr* expr;
DenseHashMap<const AstExpr*, TypeId>* astOriginalCallTypes;
DenseHashMap<const AstExpr*, TypeId>* astOverloadResolvedTypes;
};
// iteratee is iterable
// iterators is the iteration types.
struct IterableConstraint
{
TypePackId iterator;
TypePackId variables;
};
// name(namedType) = name
struct NameConstraint
{
TypeId namedType;
std::string name;
};
// target ~ inst target
struct TypeAliasExpansionConstraint
{
// Must be a PendingExpansionTypeVar.
TypeId target;
};
struct FunctionCallConstraint
{
std::vector<NotNull<const struct Constraint>> innerConstraints;
TypeId fn;
TypePackId argsPack;
TypePackId result;
class AstExprCall* callSite;
};
// result ~ prim ExpectedType SomeSingletonType MultitonType
//
// If ExpectedType is potentially a singleton (an actual singleton or a union
// that contains a singleton), then result ~ SomeSingletonType
//
// else result ~ MultitonType
struct PrimitiveTypeConstraint
{
TypeId resultType;
TypeId expectedType;
TypeId singletonType;
TypeId multitonType;
};
// result ~ hasProp type "prop_name"
//
// If the subject is a table, bind the result to the named prop. If the table
// has an indexer, bind it to the index result type. If the subject is a union,
// bind the result to the union of its constituents' properties.
//
// It would be nice to get rid of this constraint and someday replace it with
//
// T <: {p: X}
//
// Where {} describes an inexact shape type.
struct HasPropConstraint
{
TypeId resultType;
TypeId subjectType;
std::string prop;
};
// result ~ setProp subjectType ["prop", "prop2", ...] propType
//
// If the subject is a table or table-like thing that already has the named
// property chain, we unify propType with that existing property type.
//
// If the subject is a free table, we augment it in place.
//
// If the subject is an unsealed table, result is an augmented table that
// includes that new prop.
struct SetPropConstraint
{
TypeId resultType;
TypeId subjectType;
std::vector<std::string> path;
TypeId propType;
};
// if negation:
// result ~ if isSingleton D then ~D else unknown where D = discriminantType
// if not negation:
// result ~ if isSingleton D then D else unknown where D = discriminantType
struct SingletonOrTopTypeConstraint
{
TypeId resultType;
TypeId discriminantType;
bool negated;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint,
BinaryConstraint, IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint,
HasPropConstraint, SetPropConstraint, SingletonOrTopTypeConstraint>;
struct Constraint
{
Constraint(NotNull<Scope> scope, const Location& location, ConstraintV&& c);
Constraint(const Constraint&) = delete;
Constraint& operator=(const Constraint&) = delete;
NotNull<Scope> scope;
Location location; // TODO: Extract this out into only the constraints that needs a location. Not all constraints needs locations.
ConstraintV c;
std::vector<NotNull<Constraint>> dependencies;
};
using ConstraintPtr = std::unique_ptr<Constraint>;
inline Constraint& asMutable(const Constraint& c)
{
return const_cast<Constraint&>(c);
}
template<typename T>
T* getMutable(Constraint& c)
{
return ::Luau::get_if<T>(&c.c);
}
template<typename T>
const T* get(const Constraint& c)
{
return getMutable<T>(asMutable(c));
}
} // namespace Luau

View file

@ -0,0 +1,279 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Connective.h"
#include "Luau/Constraint.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/Module.h"
#include "Luau/ModuleResolver.h"
#include "Luau/NotNull.h"
#include "Luau/Symbol.h"
#include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include <memory>
#include <vector>
#include <unordered_map>
namespace Luau
{
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct DcrLogger;
struct Inference
{
TypeId ty = nullptr;
ConnectiveId connective = nullptr;
Inference() = default;
explicit Inference(TypeId ty, ConnectiveId connective = nullptr)
: ty(ty)
, connective(connective)
{
}
};
struct InferencePack
{
TypePackId tp = nullptr;
std::vector<ConnectiveId> connectives;
InferencePack() = default;
explicit InferencePack(TypePackId tp, const std::vector<ConnectiveId>& connectives = {})
: tp(tp)
, connectives(connectives)
{
}
};
struct ConstraintGraphBuilder
{
// A list of all the scopes in the module. This vector holds ownership of the
// scope pointers; the scopes themselves borrow pointers to other scopes to
// define the scope hierarchy.
std::vector<std::pair<Location, ScopePtr>> scopes;
ModuleName moduleName;
ModulePtr module;
NotNull<SingletonTypes> singletonTypes;
const NotNull<TypeArena> arena;
// The root scope of the module we're generating constraints for.
// This is null when the CGB is initially constructed.
Scope* rootScope;
// Constraints that go straight to the solver.
std::vector<ConstraintPtr> constraints;
// Constraints that do not go to the solver right away. Other constraints
// will enqueue them during solving.
std::vector<ConstraintPtr> unqueuedConstraints;
// A mapping of AST node to TypeId.
DenseHashMap<const AstExpr*, TypeId> astTypes{nullptr};
// A mapping of AST node to TypePackId.
DenseHashMap<const AstExpr*, TypePackId> astTypePacks{nullptr};
// If the node was applied as a function, this is the unspecialized type of
// that expression.
DenseHashMap<const AstExpr*, TypeId> astOriginalCallTypes{nullptr};
// If overload resolution was performed on this element, this is the
// overload that was selected.
DenseHashMap<const AstExpr*, TypeId> astOverloadResolvedTypes{nullptr};
// Types resolved from type annotations. Analogous to astTypes.
DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr};
// Type packs resolved from type annotations. Analogous to astTypePacks.
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
// Defining scopes for AST nodes.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
NotNull<const DataFlowGraph> dfg;
ConnectiveArena connectiveArena;
int recursionCount = 0;
// It is pretty uncommon for constraint generation to itself produce errors, but it can happen.
std::vector<TypeError> errors;
// Needed to resolve modules to make 'require' import types properly.
NotNull<ModuleResolver> moduleResolver;
// Occasionally constraint generation needs to produce an ICE.
const NotNull<InternalErrorReporter> ice;
ScopePtr globalScope;
DcrLogger* logger;
ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull<ModuleResolver> moduleResolver,
NotNull<SingletonTypes> singletonTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, DcrLogger* logger,
NotNull<DataFlowGraph> dfg);
/**
* Fabricates a new free type belonging to a given scope.
* @param scope the scope the free type belongs to.
*/
TypeId freshType(const ScopePtr& scope);
/**
* Fabricates a new free type pack belonging to a given scope.
* @param scope the scope the free type pack belongs to.
*/
TypePackId freshTypePack(const ScopePtr& scope);
/**
* Fabricates a scope that is a child of another scope.
* @param node the lexical node that the scope belongs to.
* @param parent the parent scope of the new scope. Must not be null.
*/
ScopePtr childScope(AstNode* node, const ScopePtr& parent);
/**
* Adds a new constraint with no dependencies to a given scope.
* @param scope the scope to add the constraint to.
* @param cv the constraint variant to add.
* @return the pointer to the inserted constraint
*/
NotNull<Constraint> addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv);
/**
* Adds a constraint to a given scope.
* @param scope the scope to add the constraint to. Must not be null.
* @param c the constraint to add.
* @return the pointer to the inserted constraint
*/
NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
void applyRefinements(const ScopePtr& scope, Location location, ConnectiveId connective);
/**
* The entry point to the ConstraintGraphBuilder. This will construct a set
* of scopes, constraints, and free types that can be solved later.
* @param block the root block to generate constraints for.
*/
void visit(AstStatBlock* block);
void visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block);
void visit(const ScopePtr& scope, AstStat* stat);
void visit(const ScopePtr& scope, AstStatBlock* block);
void visit(const ScopePtr& scope, AstStatLocal* local);
void visit(const ScopePtr& scope, AstStatFor* for_);
void visit(const ScopePtr& scope, AstStatForIn* forIn);
void visit(const ScopePtr& scope, AstStatWhile* while_);
void visit(const ScopePtr& scope, AstStatRepeat* repeat);
void visit(const ScopePtr& scope, AstStatLocalFunction* function);
void visit(const ScopePtr& scope, AstStatFunction* function);
void visit(const ScopePtr& scope, AstStatReturn* ret);
void visit(const ScopePtr& scope, AstStatAssign* assign);
void visit(const ScopePtr& scope, AstStatCompoundAssign* assign);
void visit(const ScopePtr& scope, AstStatIf* ifStatement);
void visit(const ScopePtr& scope, AstStatTypeAlias* alias);
void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal);
void visit(const ScopePtr& scope, AstStatDeclareClass* declareClass);
void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction);
void visit(const ScopePtr& scope, AstStatError* error);
InferencePack checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<TypeId>& expectedTypes = {});
InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector<TypeId>& expectedTypes = {});
InferencePack checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector<TypeId>& expectedTypes);
/**
* Checks an expression that is expected to evaluate to one type.
* @param scope the scope the expression is contained within.
* @param expr the expression to check.
* @param expectedType the type of the expression that is expected from its
* surrounding context. Used to implement bidirectional type checking.
* @return the type of the expression.
*/
Inference check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}, bool forceSingleton = false);
Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprLocal* local);
Inference check(const ScopePtr& scope, AstExprGlobal* global);
Inference check(const ScopePtr& scope, AstExprIndexName* indexName);
Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
Inference check(const ScopePtr& scope, AstExprUnary* unary);
Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, ConnectiveId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
TypePackId checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);
TypeId checkLValue(const ScopePtr& scope, AstExpr* expr);
struct FunctionSignature
{
// The type of the function.
TypeId signature;
// The scope that encompasses the function's signature. May be nullptr
// if there was no need for a signature scope (the function has no
// generics).
ScopePtr signatureScope;
// The scope that encompasses the function's body. Is a child scope of
// signatureScope, if present.
ScopePtr bodyScope;
};
FunctionSignature checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn, std::optional<TypeId> expectedType = {});
/**
* Checks the body of a function expression.
* @param scope the interior scope of the body of the function.
* @param fn the function expression to check.
*/
void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn);
/**
* Resolves a type from its AST annotation.
* @param scope the scope that the type annotation appears within.
* @param ty the AST annotation to resolve.
* @param topLevel whether the annotation is a "top-level" annotation.
* @return the type of the AST annotation.
**/
TypeId resolveType(const ScopePtr& scope, AstType* ty, bool topLevel = false);
/**
* Resolves a type pack from its AST annotation.
* @param scope the scope that the type annotation appears within.
* @param tp the AST annotation to resolve.
* @return the type pack of the AST annotation.
**/
TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list);
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(const ScopePtr& scope, AstArray<AstGenericType> generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(const ScopePtr& scope, AstArray<AstGenericTypePack> packs);
Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack);
void reportError(Location location, TypeErrorData err);
void reportCodeTooComplex(Location location);
/** Scan the program for global definitions.
*
* ConstraintGraphBuilder needs to differentiate between globals and accesses to undefined symbols. Doing this "for
* real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an
* initial scan of the AST and note what globals are defined.
*/
void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program);
};
/** Borrow a vector of pointers from a vector of owning pointers to constraints.
*/
std::vector<NotNull<Constraint>> borrowConstraints(const std::vector<ConstraintPtr>& constraints);
} // namespace Luau

View file

@ -0,0 +1,230 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Constraint.h"
#include "Luau/Error.h"
#include "Luau/Module.h"
#include "Luau/Normalize.h"
#include "Luau/ToString.h"
#include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include <vector>
namespace Luau
{
struct DcrLogger;
// TypeId, TypePackId, or Constraint*. It is impossible to know which, but we
// never dereference this pointer.
using BlockedConstraintId = const void*;
struct ModuleResolver;
struct InstantiationSignature
{
TypeFun fn;
std::vector<TypeId> arguments;
std::vector<TypePackId> packArguments;
bool operator==(const InstantiationSignature& rhs) const;
bool operator!=(const InstantiationSignature& rhs) const
{
return !((*this) == rhs);
}
};
struct HashInstantiationSignature
{
size_t operator()(const InstantiationSignature& signature) const;
};
struct ConstraintSolver
{
TypeArena* arena;
NotNull<SingletonTypes> singletonTypes;
InternalErrorReporter iceReporter;
NotNull<Normalizer> normalizer;
// The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints;
NotNull<Scope> rootScope;
ModuleName currentModuleName;
// Constraints that the solver has generated, rather than sourcing from the
// scope tree.
std::vector<std::unique_ptr<Constraint>> solverConstraints;
// This includes every constraint that has not been fully solved.
// A constraint can be both blocked and unsolved, for instance.
std::vector<NotNull<const Constraint>> unsolvedConstraints;
// A mapping of constraint pointer to how many things the constraint is
// blocked on. Can be empty or 0 for constraints that are not blocked on
// anything.
std::unordered_map<NotNull<const Constraint>, size_t> blockedConstraints;
// A mapping of type/pack pointers to the constraints they block.
std::unordered_map<BlockedConstraintId, std::vector<NotNull<const Constraint>>> blocked;
// Memoized instantiations of type aliases.
DenseHashMap<InstantiationSignature, TypeId, HashInstantiationSignature> instantiatedAliases{{}};
// Recorded errors that take place within the solver.
ErrorVec errors;
NotNull<ModuleResolver> moduleResolver;
std::vector<RequireCycle> requireCycles;
DcrLogger* logger;
explicit ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger);
// Randomize the order in which to dispatch constraints
void randomize(unsigned seed);
/**
* Attempts to dispatch all pending constraints and reach a type solution
* that satisfies all of the constraints.
**/
void run();
bool isDone();
void finalizeModule();
/** Attempt to dispatch a constraint. Returns true if it was successful. If
* tryDispatch() returns false, the constraint remains in the unsolved set
* and will be retried later.
*/
bool tryDispatch(NotNull<const Constraint> c, bool force);
bool tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const InstantiationConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const UnaryConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const BinaryConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const FunctionCallConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const SetPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint);
// for a, ... in some_table do
// also handles __iter metamethod
bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
// for a, ... in next_function, t, ... do
bool tryDispatchIterableFunction(
TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
std::optional<TypeId> lookupTableProp(TypeId subjectType, const std::string& propName);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
/**
* Block a constraint on the resolution of a TypeVar.
* @returns false always. This is just to allow tryDispatch to return the result of block()
*/
bool block(TypeId target, NotNull<const Constraint> constraint);
bool block(TypePackId target, NotNull<const Constraint> constraint);
// Traverse the type. If any blocked or pending typevars are found, block
// the constraint on them.
//
// Returns false if a type blocks the constraint.
//
// FIXME: This use of a boolean for the return result is an appalling
// interface.
bool recursiveBlock(TypeId target, NotNull<const Constraint> constraint);
bool recursiveBlock(TypePackId target, NotNull<const Constraint> constraint);
void unblock(NotNull<const Constraint> progressed);
void unblock(TypeId progressed);
void unblock(TypePackId progressed);
void unblock(const std::vector<TypeId>& types);
void unblock(const std::vector<TypePackId>& packs);
/**
* @returns true if the TypeId is in a blocked state.
*/
bool isBlocked(TypeId ty);
/**
* @returns true if the TypePackId is in a blocked state.
*/
bool isBlocked(TypePackId tp);
/**
* Returns whether the constraint is blocked on anything.
* @param constraint the constraint to check.
*/
bool isBlocked(NotNull<const Constraint> constraint);
/**
* Creates a new Unifier and performs a single unification operation. Commits
* the result.
* @param subType the sub-type to unify.
* @param superType the super-type to unify.
*/
void unify(TypeId subType, TypeId superType, NotNull<Scope> scope);
/**
* Creates a new Unifier and performs a single unification operation. Commits
* the result.
* @param subPack the sub-type pack to unify.
* @param superPack the super-type pack to unify.
*/
void unify(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope);
/** Pushes a new solver constraint to the solver.
* @param cv the body of the constraint.
**/
void pushConstraint(NotNull<Scope> scope, const Location& location, ConstraintV cv);
/**
* Attempts to resolve a module from its module information. Returns the
* module-level return type of the module, or the error type if one cannot
* be found. Reports errors to the solver if the module cannot be found or
* the require is illegal.
* @param module the module information to look up.
* @param location the location where the require is taking place; used for
* error locations.
**/
TypeId resolveModule(const ModuleInfo& module, const Location& location);
void reportError(TypeErrorData&& data, const Location& location);
void reportError(TypeError e);
private:
/**
* Marks a constraint as being blocked on a type or type pack. The constraint
* solver will not attempt to dispatch blocked constraints until their
* dependencies have made progress.
* @param target the type or type pack pointer that the constraint is blocked on.
* @param constraint the constraint to block.
**/
void block_(BlockedConstraintId target, NotNull<const Constraint> constraint);
/**
* Informs the solver that progress has been made on a type or type pack. The
* solver will wake up all constraints that are blocked on the type or type pack,
* and will resume attempting to dispatch them.
* @param progressed the type or type pack pointer that has progressed.
**/
void unblock_(BlockedConstraintId progressed);
TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const;
TypeId unionOfTypes(TypeId a, TypeId b, NotNull<Scope> scope, bool unifyFreeTypes);
ToStringOptions opts;
};
void dump(NotNull<Scope> rootScope, struct ToStringOptions& opts);
} // namespace Luau

View file

@ -0,0 +1,120 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
// Do not include LValue. It should never be used here.
#include "Luau/Ast.h"
#include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/Symbol.h"
#include <unordered_map>
namespace Luau
{
struct DataFlowGraph
{
DataFlowGraph(DataFlowGraph&&) = default;
DataFlowGraph& operator=(DataFlowGraph&&) = default;
// TODO: AstExprLocal, AstExprGlobal, and AstLocal* are guaranteed never to return nullopt.
// We leave them to return an optional as we build it out, but the end state is for them to return a non-optional DefId.
std::optional<DefId> getDef(const AstExpr* expr) const;
std::optional<DefId> getDef(const AstLocal* local) const;
/// Retrieve the Def that corresponds to the given Symbol.
///
/// We do not perform dataflow analysis on globals, so this function always
/// yields nullopt when passed a global Symbol.
std::optional<DefId> getDef(const Symbol& symbol) const;
private:
DataFlowGraph() = default;
DataFlowGraph(const DataFlowGraph&) = delete;
DataFlowGraph& operator=(const DataFlowGraph&) = delete;
DefArena arena;
DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr};
DenseHashMap<const AstLocal*, const Def*> localDefs{nullptr};
friend struct DataFlowGraphBuilder;
};
struct DfgScope
{
DfgScope* parent;
DenseHashMap<Symbol, const Def*> bindings{Symbol{}};
};
struct ExpressionFlowGraph
{
std::optional<DefId> def;
};
// Currently unsound. We do not presently track the control flow of the program.
// Additionally, we do not presently track assignments.
struct DataFlowGraphBuilder
{
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle);
private:
DataFlowGraphBuilder() = default;
DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete;
DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete;
DataFlowGraph graph;
NotNull<DefArena> arena{&graph.arena};
struct InternalErrorReporter* handle;
std::vector<std::unique_ptr<DfgScope>> scopes;
// Does not belong in DataFlowGraphBuilder, but the old solver allows properties to escape the scope they were defined in,
// so we will need to be able to emulate this same behavior here too. We can kill this once we have better flow sensitivity.
DenseHashMap<const Def*, std::unordered_map<std::string, const Def*>> props{nullptr};
DfgScope* childScope(DfgScope* scope);
std::optional<DefId> use(DfgScope* scope, Symbol symbol, AstExpr* e);
DefId use(DefId def, AstExprIndexName* e);
void visit(DfgScope* scope, AstStatBlock* b);
void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b);
// TODO: visit type aliases
void visit(DfgScope* scope, AstStat* s);
void visit(DfgScope* scope, AstStatIf* i);
void visit(DfgScope* scope, AstStatWhile* w);
void visit(DfgScope* scope, AstStatRepeat* r);
void visit(DfgScope* scope, AstStatBreak* b);
void visit(DfgScope* scope, AstStatContinue* c);
void visit(DfgScope* scope, AstStatReturn* r);
void visit(DfgScope* scope, AstStatExpr* e);
void visit(DfgScope* scope, AstStatLocal* l);
void visit(DfgScope* scope, AstStatFor* f);
void visit(DfgScope* scope, AstStatForIn* f);
void visit(DfgScope* scope, AstStatAssign* a);
void visit(DfgScope* scope, AstStatCompoundAssign* c);
void visit(DfgScope* scope, AstStatFunction* f);
void visit(DfgScope* scope, AstStatLocalFunction* l);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExpr* e);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprLocal* l);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprGlobal* g);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprCall* c);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexName* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIndexExpr* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprFunction* f);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTable* t);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprUnary* u);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprBinary* b);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprTypeAssertion* t);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprIfElse* i);
ExpressionFlowGraph visitExpr(DfgScope* scope, AstExprInterpString* i);
// TODO: visitLValue
// TODO: visitTypes (because of typeof which has access to values namespace, needs unreachable scope)
// TODO: visitTypePacks (because of typeof which has access to values namespace, needs unreachable scope)
};
} // namespace Luau

View file

@ -0,0 +1,133 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Constraint.h"
#include "Luau/NotNull.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/Error.h"
#include "Luau/Variant.h"
#include <optional>
#include <string>
#include <vector>
namespace Luau
{
struct ErrorSnapshot
{
std::string message;
Location location;
};
struct BindingSnapshot
{
std::string typeId;
std::string typeString;
Location location;
};
struct TypeBindingSnapshot
{
std::string typeId;
std::string typeString;
};
struct ConstraintGenerationLog
{
std::string source;
std::unordered_map<std::string, Location> constraintLocations;
std::vector<ErrorSnapshot> errors;
};
struct ScopeSnapshot
{
std::unordered_map<Name, BindingSnapshot> bindings;
std::unordered_map<Name, TypeBindingSnapshot> typeBindings;
std::unordered_map<Name, TypeBindingSnapshot> typePackBindings;
std::vector<ScopeSnapshot> children;
};
enum class ConstraintBlockKind
{
TypeId,
TypePackId,
ConstraintId,
};
struct ConstraintBlock
{
ConstraintBlockKind kind;
std::string stringification;
};
struct ConstraintSnapshot
{
std::string stringification;
std::vector<ConstraintBlock> blocks;
};
struct BoundarySnapshot
{
std::unordered_map<std::string, ConstraintSnapshot> constraints;
ScopeSnapshot rootScope;
};
struct StepSnapshot
{
std::string currentConstraint;
bool forced;
std::unordered_map<std::string, ConstraintSnapshot> unsolvedConstraints;
ScopeSnapshot rootScope;
};
struct TypeSolveLog
{
BoundarySnapshot initialState;
std::vector<StepSnapshot> stepStates;
BoundarySnapshot finalState;
};
struct TypeCheckLog
{
std::vector<ErrorSnapshot> errors;
};
using ConstraintBlockTarget = Variant<TypeId, TypePackId, NotNull<const Constraint>>;
struct DcrLogger
{
std::string compileOutput();
void captureSource(std::string source);
void captureGenerationError(const TypeError& error);
void captureConstraintLocation(NotNull<const Constraint> constraint, Location location);
void pushBlock(NotNull<const Constraint> constraint, TypeId block);
void pushBlock(NotNull<const Constraint> constraint, TypePackId block);
void pushBlock(NotNull<const Constraint> constraint, NotNull<const Constraint> block);
void popBlock(TypeId block);
void popBlock(TypePackId block);
void popBlock(NotNull<const Constraint> block);
void captureInitialSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints);
StepSnapshot prepareStepSnapshot(
const Scope* rootScope, NotNull<const Constraint> current, bool force, const std::vector<NotNull<const Constraint>>& unsolvedConstraints);
void commitStepSnapshot(StepSnapshot snapshot);
void captureFinalSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints);
void captureTypeCheckError(const TypeError& error);
private:
ConstraintGenerationLog generationLog;
std::unordered_map<NotNull<const Constraint>, std::vector<ConstraintBlockTarget>> constraintBlocks;
TypeSolveLog solveLog;
TypeCheckLog checkLog;
ToStringOptions opts;
std::vector<ConstraintBlock> snapshotBlocks(NotNull<const Constraint> constraint);
};
} // namespace Luau

View file

@ -0,0 +1,91 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/NotNull.h"
#include "Luau/TypedAllocator.h"
#include "Luau/Variant.h"
#include <string>
#include <optional>
namespace Luau
{
struct Def;
using DefId = NotNull<const Def>;
struct FieldMetadata
{
DefId parent;
std::string propName;
};
/**
* A cell is a "single-object" value.
*
* Leaky implementation note: sometimes "multiple-object" values, but none of which were interesting enough to warrant creating a phi node instead.
* That can happen because there's no point in creating a phi node that points to either resultant in `if math.random() > 0.5 then 5 else "hello"`.
* This might become of utmost importance if we wanted to do some backward reasoning, e.g. if `5` is taken, then `cond` must be `truthy`.
*/
struct Cell
{
std::optional<struct FieldMetadata> field;
};
/**
* A phi node is a union of cells.
*
* We need this because we're statically evaluating a program, and sometimes a place may be assigned with
* different cells, and when that happens, we need a special data type that merges in all the cells
* that will flow into that specific place. For example, consider this simple program:
*
* ```
* x-1
* if cond() then
* x-2 = 5
* else
* x-3 = "hello"
* end
* x-4 : {x-2, x-3}
* ```
*
* At x-4, we know for a fact statically that either `5` or `"hello"` can flow into the variable `x` after the branch, but
* we cannot make any definitive decisions about which one, so we just take in both.
*/
struct Phi
{
std::vector<DefId> operands;
};
/**
* We statically approximate a value at runtime using a symbolic value, which we call a Def.
*
* DataFlowGraphBuilder will allocate these defs as a stand-in for some Luau values, and bind them to places that
* can hold a Luau value, and then observes how those defs will commute as it statically evaluate the program.
*
* It must also be noted that defs are a cyclic graph, so it is not safe to recursively traverse into it expecting it to terminate.
*/
struct Def
{
using V = Variant<struct Cell, struct Phi>;
V v;
};
template<typename T>
const T* get(DefId def)
{
return get_if<T>(&def->v);
}
struct DefArena
{
TypedAllocator<Def> allocator;
DefId freshCell();
DefId freshCell(DefId parent, const std::string& prop);
// TODO: implement once we have cases where we need to merge in definitions
// DefId phi(const std::vector<DefId>& defs);
};
} // namespace Luau

View file

@ -1,3 +1,4 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
@ -12,10 +13,18 @@ namespace Luau
struct FunctionDocumentation; struct FunctionDocumentation;
struct TableDocumentation; struct TableDocumentation;
struct OverloadedFunctionDocumentation; struct OverloadedFunctionDocumentation;
struct BasicDocumentation;
using Documentation = Luau::Variant<std::string, FunctionDocumentation, TableDocumentation, OverloadedFunctionDocumentation>; using Documentation = Luau::Variant<BasicDocumentation, FunctionDocumentation, TableDocumentation, OverloadedFunctionDocumentation>;
using DocumentationSymbol = std::string; using DocumentationSymbol = std::string;
struct BasicDocumentation
{
std::string documentation;
std::string learnMoreLink;
std::string codeSample;
};
struct FunctionParameterDocumentation struct FunctionParameterDocumentation
{ {
std::string name; std::string name;
@ -29,6 +38,8 @@ struct FunctionDocumentation
std::string documentation; std::string documentation;
std::vector<FunctionParameterDocumentation> parameters; std::vector<FunctionParameterDocumentation> parameters;
std::vector<DocumentationSymbol> returns; std::vector<DocumentationSymbol> returns;
std::string learnMoreLink;
std::string codeSample;
}; };
struct OverloadedFunctionDocumentation struct OverloadedFunctionDocumentation
@ -43,6 +54,8 @@ struct TableDocumentation
{ {
std::string documentation; std::string documentation;
Luau::DenseHashMap<std::string, DocumentationSymbol> keys; Luau::DenseHashMap<std::string, DocumentationSymbol> keys;
std::string learnMoreLink;
std::string codeSample;
}; };
using DocumentationDatabase = Luau::DenseHashMap<DocumentationSymbol, Documentation>; using DocumentationDatabase = Luau::DenseHashMap<DocumentationSymbol, Documentation>;

View file

@ -1,7 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/FileResolver.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
@ -9,10 +8,33 @@
namespace Luau namespace Luau
{ {
struct FileResolver;
struct TypeArena;
struct TypeError;
struct TypeMismatch struct TypeMismatch
{ {
TypeId wantedType; enum Context
TypeId givenType; {
CovariantContext,
InvariantContext
};
TypeMismatch() = default;
TypeMismatch(TypeId wantedType, TypeId givenType);
TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason);
TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional<TypeError> error);
TypeMismatch(TypeId wantedType, TypeId givenType, Context context);
TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, Context context);
TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional<TypeError> error, Context context);
TypeId wantedType = nullptr;
TypeId givenType = nullptr;
Context context = CovariantContext;
std::string reason;
std::shared_ptr<TypeError> error;
bool operator==(const TypeMismatch& rhs) const; bool operator==(const TypeMismatch& rhs) const;
}; };
@ -23,7 +45,6 @@ struct UnknownSymbol
{ {
Binding, Binding,
Type, Type,
Generic
}; };
Name name; Name name;
Context context; Context context;
@ -71,7 +92,7 @@ struct OnlyTablesCanHaveMethods
struct DuplicateTypeDefinition struct DuplicateTypeDefinition
{ {
Name name; Name name;
Location previousLocation; std::optional<Location> previousLocation;
bool operator==(const DuplicateTypeDefinition& rhs) const; bool operator==(const DuplicateTypeDefinition& rhs) const;
}; };
@ -81,12 +102,16 @@ struct CountMismatch
enum Context enum Context
{ {
Arg, Arg,
Result, FunctionResult,
ExprListResult,
Return, Return,
}; };
size_t expected; size_t expected;
std::optional<size_t> maximum;
size_t actual; size_t actual;
Context context = Arg; Context context = Arg;
bool isVariadic = false;
std::string function;
bool operator==(const CountMismatch& rhs) const; bool operator==(const CountMismatch& rhs) const;
}; };
@ -98,8 +123,6 @@ struct FunctionDoesNotTakeSelf
struct FunctionRequiresSelf struct FunctionRequiresSelf
{ {
int requiredExtraNils = 0;
bool operator==(const FunctionRequiresSelf& rhs) const; bool operator==(const FunctionRequiresSelf& rhs) const;
}; };
@ -120,6 +143,7 @@ struct IncorrectGenericParameterCount
Name name; Name name;
TypeFun typeFun; TypeFun typeFun;
size_t actualParameters; size_t actualParameters;
size_t actualPackParameters;
bool operator==(const IncorrectGenericParameterCount& rhs) const; bool operator==(const IncorrectGenericParameterCount& rhs) const;
}; };
@ -159,6 +183,13 @@ struct GenericError
bool operator==(const GenericError& rhs) const; bool operator==(const GenericError& rhs) const;
}; };
struct InternalError
{
std::string message;
bool operator==(const InternalError& rhs) const;
};
struct CannotCallNonFunction struct CannotCallNonFunction
{ {
TypeId ty; TypeId ty;
@ -267,11 +298,57 @@ struct MissingUnionProperty
bool operator==(const MissingUnionProperty& rhs) const; bool operator==(const MissingUnionProperty& rhs) const;
}; };
struct TypesAreUnrelated
{
TypeId left;
TypeId right;
bool operator==(const TypesAreUnrelated& rhs) const;
};
struct NormalizationTooComplex
{
bool operator==(const NormalizationTooComplex&) const
{
return true;
}
};
struct TypePackMismatch
{
TypePackId wantedTp;
TypePackId givenTp;
bool operator==(const TypePackMismatch& rhs) const;
};
struct DynamicPropertyLookupOnClassesUnsafe
{
TypeId ty;
bool operator==(const DynamicPropertyLookupOnClassesUnsafe& rhs) const;
};
using TypeErrorData = Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods, using TypeErrorData = Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods,
DuplicateTypeDefinition, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire, DuplicateTypeDefinition, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire,
IncorrectGenericParameterCount, SyntaxError, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, IncorrectGenericParameterCount, SyntaxError, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError,
CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning, CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning,
DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty>; DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty,
TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch, DynamicPropertyLookupOnClassesUnsafe>;
struct TypeErrorSummary
{
Location location;
ModuleName moduleName;
int code;
TypeErrorSummary(const Location& location, const ModuleName& moduleName, int code)
: location(location)
, moduleName(moduleName)
, code(code)
{
}
};
struct TypeError struct TypeError
{ {
@ -279,6 +356,7 @@ struct TypeError
ModuleName moduleName; ModuleName moduleName;
TypeErrorData data; TypeErrorData data;
static int minCode();
int code() const; int code() const;
TypeError() = default; TypeError() = default;
@ -296,6 +374,8 @@ struct TypeError
} }
bool operator==(const TypeError& rhs) const; bool operator==(const TypeError& rhs) const;
TypeErrorSummary summary() const;
}; };
template<typename T> template<typename T>
@ -312,7 +392,13 @@ T* get(TypeError& e)
using ErrorVec = std::vector<TypeError>; using ErrorVec = std::vector<TypeError>;
struct TypeErrorToStringOptions
{
FileResolver* fileResolver = nullptr;
};
std::string toString(const TypeError& error); std::string toString(const TypeError& error);
std::string toString(const TypeError& error, TypeErrorToStringOptions options);
bool containsParseErrorName(const TypeError& error); bool containsParseErrorName(const TypeError& error);
@ -329,4 +415,29 @@ struct InternalErrorReporter
[[noreturn]] void ice(const std::string& message); [[noreturn]] void ice(const std::string& message);
}; };
class InternalCompilerError : public std::exception
{
public:
explicit InternalCompilerError(const std::string& message)
: message(message)
{
}
explicit InternalCompilerError(const std::string& message, const std::string& moduleName)
: message(message)
, moduleName(moduleName)
{
}
explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location)
: message(message)
, moduleName(moduleName)
, location(location)
{
}
virtual const char* what() const throw();
const std::string message;
const std::optional<std::string> moduleName;
const std::optional<Location> location;
};
} // namespace Luau } // namespace Luau

View file

@ -25,51 +25,32 @@ struct SourceCode
Type type; Type type;
}; };
struct ModuleInfo
{
ModuleName name;
bool optional = false;
};
struct FileResolver struct FileResolver
{ {
virtual ~FileResolver() {} virtual ~FileResolver() {}
/** Fetch the source code associated with the provided ModuleName.
*
* FIXME: This requires a string copy!
*
* @returns The actual Lua code on success.
* @returns std::nullopt if no such file exists. When this occurs, type inference will report an UnknownRequire error.
*/
virtual std::optional<SourceCode> readSource(const ModuleName& name) = 0; virtual std::optional<SourceCode> readSource(const ModuleName& name) = 0;
/** Does the module exist? virtual std::optional<ModuleInfo> resolveModule(const ModuleInfo* context, AstExpr* expr)
* {
* Saves a string copy over reading the source and throwing it away. return std::nullopt;
*/ }
virtual bool moduleExists(const ModuleName& name) const = 0;
virtual std::optional<ModuleName> fromAstFragment(AstExpr* expr) const = 0; virtual std::string getHumanReadableModuleName(const ModuleName& name) const
/** Given a valid module name and a string of arbitrary data, figure out the concatenation.
*/
virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0;
/** Goes "up" a level in the hierarchy that the ModuleName represents.
*
* For instances, this is analogous to someInstance.Parent; for paths, this is equivalent to removing the last
* element of the path. Other ModuleName representations may have other ways of doing this.
*
* @returns The parent ModuleName, if one exists.
* @returns std::nullopt if there is no parent for this module name.
*/
virtual std::optional<ModuleName> getParentModuleName(const ModuleName& name) const = 0;
virtual std::optional<std::string> getHumanReadableModuleName_(const ModuleName& name) const
{ {
return name; return name;
} }
virtual std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const = 0; virtual std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const
{
/** LanguageService only: return std::nullopt;
* std::optional<ModuleName> fromInstance(Instance* inst) }
*/
}; };
struct NullFileResolver : FileResolver struct NullFileResolver : FileResolver
@ -78,26 +59,6 @@ struct NullFileResolver : FileResolver
{ {
return std::nullopt; return std::nullopt;
} }
bool moduleExists(const ModuleName& name) const override
{
return false;
}
std::optional<ModuleName> fromAstFragment(AstExpr* expr) const override
{
return std::nullopt;
}
ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override
{
return lhs;
}
std::optional<ModuleName> getParentModuleName(const ModuleName& name) const override
{
return std::nullopt;
}
std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const override
{
return std::nullopt;
}
}; };
} // namespace Luau } // namespace Luau

View file

@ -5,6 +5,7 @@
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/RequireTracer.h" #include "Luau/RequireTracer.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
@ -24,6 +25,7 @@ struct TypeChecker;
struct FileResolver; struct FileResolver;
struct ModuleResolver; struct ModuleResolver;
struct ParseResult; struct ParseResult;
struct HotComment;
struct LoadDefinitionFileResult struct LoadDefinitionFileResult
{ {
@ -35,7 +37,7 @@ struct LoadDefinitionFileResult
LoadDefinitionFileResult loadDefinitionFile( LoadDefinitionFileResult loadDefinitionFile(
TypeChecker& typeChecker, ScopePtr targetScope, std::string_view definition, const std::string& packageName); TypeChecker& typeChecker, ScopePtr targetScope, std::string_view definition, const std::string& packageName);
std::optional<Mode> parseMode(const std::vector<std::string>& hotcomments); std::optional<Mode> parseMode(const std::vector<HotComment>& hotcomments);
std::vector<std::string_view> parsePathExpr(const AstExpr& pathExpr); std::vector<std::string_view> parsePathExpr(const AstExpr& pathExpr);
@ -54,10 +56,23 @@ std::optional<ModuleName> pathExprToModuleName(const ModuleName& currentModuleNa
struct SourceNode struct SourceNode
{ {
bool hasDirtySourceModule() const
{
return dirtySourceModule;
}
bool hasDirtyModule(bool forAutocomplete) const
{
return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule;
}
ModuleName name; ModuleName name;
std::unordered_set<ModuleName> requires; std::unordered_set<ModuleName> requireSet;
std::vector<std::pair<ModuleName, Location>> requireLocations; std::vector<std::pair<ModuleName, Location>> requireLocations;
bool dirty = true; bool dirtySourceModule = true;
bool dirtyModule = true;
bool dirtyModuleForAutocomplete = true;
double autocompleteLimitsMult = 1.0;
}; };
struct FrontendOptions struct FrontendOptions
@ -68,14 +83,19 @@ struct FrontendOptions
// is complete. // is complete.
bool retainFullTypeGraphs = false; bool retainFullTypeGraphs = false;
// When true, we run typechecking twice, one in the regular mode, ond once in strict mode // Run typechecking only in mode required for autocomplete (strict mode in
// in order to get more precise type information (e.g. for autocomplete). // order to get more precise type information)
bool typecheckTwice = false; bool forAutocomplete = false;
// If not empty, randomly shuffle the constraint set before attempting to
// solve. Use this value to seed the random number generator.
std::optional<unsigned> randomizeConstraintResolutionSeed;
}; };
struct CheckResult struct CheckResult
{ {
std::vector<TypeError> errors; std::vector<TypeError> errors;
std::vector<ModuleName> timeoutHits;
}; };
struct FrontendModuleResolver : ModuleResolver struct FrontendModuleResolver : ModuleResolver
@ -109,20 +129,12 @@ struct Frontend
Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {}); Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {});
CheckResult check(const ModuleName& name); // new shininess CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess
LintResult lint(const ModuleName& name, std::optional<Luau::LintOptions> enabledLintWarnings = {}); LintResult lint(const ModuleName& name, std::optional<LintOptions> enabledLintWarnings = {});
/** Lint some code that has no associated DataModel object LintResult lint(const SourceModule& module, std::optional<LintOptions> enabledLintWarnings = {});
*
* Since this source fragment has no name, we cannot cache its AST. Instead,
* we return it to the caller to use as they wish.
*/
std::pair<SourceModule, LintResult> lintFragment(std::string_view source, std::optional<Luau::LintOptions> enabledLintWarnings = {});
CheckResult check(const SourceModule& module); // OLD. TODO KILL bool isDirty(const ModuleName& name, bool forAutocomplete = false) const;
LintResult lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings = {});
bool isDirty(const ModuleName& name) const;
void markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty = nullptr); void markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty = nullptr);
/** Borrow a pointer into the SourceModule cache. /** Borrow a pointer into the SourceModule cache.
@ -145,20 +157,31 @@ struct Frontend
void registerBuiltinDefinition(const std::string& name, std::function<void(TypeChecker&, ScopePtr)>); void registerBuiltinDefinition(const std::string& name, std::function<void(TypeChecker&, ScopePtr)>);
void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName);
LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName);
ScopePtr getGlobalScope();
private: private:
ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope, std::vector<RequireCycle> requireCycles,
bool forAutocomplete = false);
std::pair<SourceNode*, SourceModule*> getSourceNode(CheckResult& checkResult, const ModuleName& name); std::pair<SourceNode*, SourceModule*> getSourceNode(CheckResult& checkResult, const ModuleName& name);
SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions);
bool parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& checkResult, const ModuleName& root); bool parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete);
static LintResult classifyLints(const std::vector<LintWarning>& warnings, const Config& config); static LintResult classifyLints(const std::vector<LintWarning>& warnings, const Config& config);
ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config); ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete = false);
std::unordered_map<std::string, ScopePtr> environments; std::unordered_map<std::string, ScopePtr> environments;
std::unordered_map<std::string, std::function<void(TypeChecker&, ScopePtr)>> builtinDefinitions; std::unordered_map<std::string, std::function<void(TypeChecker&, ScopePtr)>> builtinDefinitions;
SingletonTypes singletonTypes_;
public: public:
const NotNull<SingletonTypes> singletonTypes;
FileResolver* fileResolver; FileResolver* fileResolver;
FrontendModuleResolver moduleResolver; FrontendModuleResolver moduleResolver;
FrontendModuleResolver moduleResolverForAutocomplete; FrontendModuleResolver moduleResolverForAutocomplete;
@ -167,13 +190,16 @@ public:
ConfigResolver* configResolver; ConfigResolver* configResolver;
FrontendOptions options; FrontendOptions options;
InternalErrorReporter iceHandler; InternalErrorReporter iceHandler;
TypeArena arenaForAutocomplete; TypeArena globalTypes;
std::unordered_map<ModuleName, SourceNode> sourceNodes; std::unordered_map<ModuleName, SourceNode> sourceNodes;
std::unordered_map<ModuleName, SourceModule> sourceModules; std::unordered_map<ModuleName, SourceModule> sourceModules;
std::unordered_map<ModuleName, RequireTraceResult> requires; std::unordered_map<ModuleName, RequireTraceResult> requireTrace;
Stats stats = {}; Stats stats = {};
private:
ScopePtr globalScope;
}; };
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,57 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Substitution.h"
#include "Luau/TypeVar.h"
#include "Luau/Unifiable.h"
namespace Luau
{
struct TypeArena;
struct TxnLog;
// A substitution which replaces generic types in a given set by free types.
struct ReplaceGenerics : Substitution
{
ReplaceGenerics(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope, const std::vector<TypeId>& generics,
const std::vector<TypePackId>& genericPacks)
: Substitution(log, arena)
, level(level)
, scope(scope)
, generics(generics)
, genericPacks(genericPacks)
{
}
TypeLevel level;
Scope* scope;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
};
// A substitution which replaces generic functions by monomorphic functions
struct Instantiation : Substitution
{
Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope)
: Substitution(log, arena)
, level(level)
, scope(scope)
{
}
TypeLevel level;
Scope* scope;
bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
};
} // namespace Luau

View file

@ -30,12 +30,17 @@ std::ostream& operator<<(std::ostream& lhs, const OccursCheckFailed& error);
std::ostream& operator<<(std::ostream& lhs, const UnknownRequire& error); std::ostream& operator<<(std::ostream& lhs, const UnknownRequire& error);
std::ostream& operator<<(std::ostream& lhs, const UnknownPropButFoundLikeProp& e); std::ostream& operator<<(std::ostream& lhs, const UnknownPropButFoundLikeProp& e);
std::ostream& operator<<(std::ostream& lhs, const GenericError& error); std::ostream& operator<<(std::ostream& lhs, const GenericError& error);
std::ostream& operator<<(std::ostream& lhs, const InternalError& error);
std::ostream& operator<<(std::ostream& lhs, const FunctionExitsWithoutReturning& error); std::ostream& operator<<(std::ostream& lhs, const FunctionExitsWithoutReturning& error);
std::ostream& operator<<(std::ostream& lhs, const MissingProperties& error); std::ostream& operator<<(std::ostream& lhs, const MissingProperties& error);
std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error); std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error);
std::ostream& operator<<(std::ostream& lhs, const ModuleHasCyclicDependency& error); std::ostream& operator<<(std::ostream& lhs, const ModuleHasCyclicDependency& error);
std::ostream& operator<<(std::ostream& lhs, const DuplicateGenericParameter& error); std::ostream& operator<<(std::ostream& lhs, const DuplicateGenericParameter& error);
std::ostream& operator<<(std::ostream& lhs, const CannotInferBinaryOperation& error); std::ostream& operator<<(std::ostream& lhs, const CannotInferBinaryOperation& error);
std::ostream& operator<<(std::ostream& lhs, const SwappedGenericTypeParameter& error);
std::ostream& operator<<(std::ostream& lhs, const OptionalValueAccess& error);
std::ostream& operator<<(std::ostream& lhs, const MissingUnionProperty& error);
std::ostream& operator<<(std::ostream& lhs, const TypesAreUnrelated& error);
std::ostream& operator<<(std::ostream& lhs, const TableState& tv); std::ostream& operator<<(std::ostream& lhs, const TableState& tv);
std::ostream& operator<<(std::ostream& lhs, const TypeVar& tv); std::ostream& operator<<(std::ostream& lhs, const TypeVar& tv);

View file

@ -0,0 +1,247 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <type_traits>
#include <string>
#include <optional>
#include <unordered_map>
#include <vector>
#include "Luau/NotNull.h"
namespace Luau::Json
{
struct JsonEmitter;
/// Writes a value to the JsonEmitter. Note that this can produce invalid JSON
/// if you do not insert commas or appropriate object / array syntax.
template<typename T>
void write(JsonEmitter&, T) = delete;
/// Writes a boolean to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param b the boolean to write.
void write(JsonEmitter& emitter, bool b);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, int i);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, long i);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, long long i);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, unsigned int i);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, unsigned long i);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, unsigned long long i);
/// Writes a double to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param d the double to write.
void write(JsonEmitter& emitter, double d);
/// Writes a string to a JsonEmitter. The string will be escaped.
/// @param emitter the emitter to write to.
/// @param sv the string to write.
void write(JsonEmitter& emitter, std::string_view sv);
/// Writes a character to a JsonEmitter as a single-character string. The
/// character will be escaped.
/// @param emitter the emitter to write to.
/// @param c the string to write.
void write(JsonEmitter& emitter, char c);
/// Writes a string to a JsonEmitter. The string will be escaped.
/// @param emitter the emitter to write to.
/// @param str the string to write.
void write(JsonEmitter& emitter, const char* str);
/// Writes a string to a JsonEmitter. The string will be escaped.
/// @param emitter the emitter to write to.
/// @param str the string to write.
void write(JsonEmitter& emitter, const std::string& str);
/// Writes null to a JsonEmitter.
/// @param emitter the emitter to write to.
void write(JsonEmitter& emitter, std::nullptr_t);
/// Writes null to a JsonEmitter.
/// @param emitter the emitter to write to.
void write(JsonEmitter& emitter, std::nullopt_t);
struct ObjectEmitter;
struct ArrayEmitter;
struct JsonEmitter
{
JsonEmitter();
/// Converts the current contents of the JsonEmitter to a string value. This
/// does not invalidate the emitter, but it does not clear it either.
std::string str();
/// Returns the current comma state and resets it to false. Use popComma to
/// restore the old state.
/// @returns the previous comma state.
bool pushComma();
/// Restores a previous comma state.
/// @param c the comma state to restore.
void popComma(bool c);
/// Writes a raw sequence of characters to the buffer, without escaping or
/// other processing.
/// @param sv the character sequence to write.
void writeRaw(std::string_view sv);
/// Writes a character to the buffer, without escaping or other processing.
/// @param c the character to write.
void writeRaw(char c);
/// Writes a comma if this wasn't the first time writeComma has been
/// invoked. Otherwise, sets the comma state to true.
/// @see pushComma
/// @see popComma
void writeComma();
/// Begins writing an object to the emitter.
/// @returns an ObjectEmitter that can be used to write key-value pairs.
ObjectEmitter writeObject();
/// Begins writing an array to the emitter.
/// @returns an ArrayEmitter that can be used to write values.
ArrayEmitter writeArray();
private:
bool comma = false;
std::vector<std::string> chunks;
void newChunk();
};
/// An interface for writing an object into a JsonEmitter instance.
/// @see JsonEmitter::writeObject
struct ObjectEmitter
{
ObjectEmitter(NotNull<JsonEmitter> emitter);
~ObjectEmitter();
NotNull<JsonEmitter> emitter;
bool comma;
bool finished;
/// Writes a key-value pair to the associated JsonEmitter. Keys will be escaped.
/// @param name the name of the key-value pair.
/// @param value the value to write.
template<typename T>
void writePair(std::string_view name, T value)
{
if (finished)
{
return;
}
emitter->writeComma();
write(*emitter, name);
emitter->writeRaw(':');
write(*emitter, value);
}
/// Finishes writing the object, appending a closing `}` character and
/// resetting the comma state of the associated emitter. This can only be
/// called once, and once called will render the emitter unusable. This
/// method is also called when the ObjectEmitter is destructed.
void finish();
};
/// An interface for writing an array into a JsonEmitter instance. Array values
/// do not need to be the same type.
/// @see JsonEmitter::writeArray
struct ArrayEmitter
{
ArrayEmitter(NotNull<JsonEmitter> emitter);
~ArrayEmitter();
NotNull<JsonEmitter> emitter;
bool comma;
bool finished;
/// Writes a value to the array.
/// @param value the value to write.
template<typename T>
void writeValue(T value)
{
if (finished)
{
return;
}
emitter->writeComma();
write(*emitter, value);
}
/// Finishes writing the object, appending a closing `]` character and
/// resetting the comma state of the associated emitter. This can only be
/// called once, and once called will render the emitter unusable. This
/// method is also called when the ArrayEmitter is destructed.
void finish();
};
/// Writes a vector as an array to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param vec the vector to write.
template<typename T>
void write(JsonEmitter& emitter, const std::vector<T>& vec)
{
ArrayEmitter a = emitter.writeArray();
for (const T& value : vec)
a.writeValue(value);
a.finish();
}
/// Writes an optional to a JsonEmitter. Will write the contained value, if
/// present, or null, if no value is present.
/// @param emitter the emitter to write to.
/// @param v the value to write.
template<typename T>
void write(JsonEmitter& emitter, const std::optional<T>& v)
{
if (v.has_value())
write(emitter, *v);
else
emitter.writeRaw("null");
}
template<typename T>
void write(JsonEmitter& emitter, const std::unordered_map<std::string, T>& map)
{
ObjectEmitter o = emitter.writeObject();
for (const auto& [k, v] : map)
o.writePair(k, v);
o.finish();
}
} // namespace Luau::Json

View file

@ -0,0 +1,53 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Variant.h"
#include "Luau/Symbol.h"
#include <memory>
#include <unordered_map>
namespace Luau
{
struct TypeVar;
using TypeId = const TypeVar*;
struct Field;
// Deprecated. Do not use in new work.
using LValue = Variant<Symbol, Field>;
struct Field
{
std::shared_ptr<LValue> parent;
std::string key;
bool operator==(const Field& rhs) const;
bool operator!=(const Field& rhs) const;
};
struct LValueHasher
{
size_t operator()(const LValue& lvalue) const;
};
const LValue* baseof(const LValue& lvalue);
std::optional<LValue> tryGetLValue(const class AstExpr& expr);
// Utility function: breaks down an LValue to get at the Symbol
Symbol getBaseSymbol(const LValue& lvalue);
template<typename T>
const T* get(const LValue& lvalue)
{
return get_if<T>(&lvalue);
}
using RefinementMap = std::unordered_map<LValue, TypeId, LValueHasher>;
void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f);
void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty);
} // namespace Luau

View file

@ -4,6 +4,7 @@
#include "Luau/Location.h" #include "Luau/Location.h"
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
namespace Luau namespace Luau
@ -14,6 +15,7 @@ class AstStat;
class AstNameTable; class AstNameTable;
struct TypeChecker; struct TypeChecker;
struct Module; struct Module;
struct HotComment;
using ScopePtr = std::shared_ptr<struct Scope>; using ScopePtr = std::shared_ptr<struct Scope>;
@ -49,6 +51,10 @@ struct LintWarning
Code_DeprecatedApi = 22, Code_DeprecatedApi = 22,
Code_TableOperations = 23, Code_TableOperations = 23,
Code_DuplicateCondition = 24, Code_DuplicateCondition = 24,
Code_MisleadingAndOr = 25,
Code_CommentDirective = 26,
Code_IntegerParsing = 27,
Code_ComparisonPrecedence = 28,
Code__Count Code__Count
}; };
@ -59,7 +65,7 @@ struct LintWarning
static const char* getName(Code code); static const char* getName(Code code);
static Code parseName(const char* name); static Code parseName(const char* name);
static uint64_t parseMask(const std::vector<std::string>& hotcomments); static uint64_t parseMask(const std::vector<HotComment>& hotcomments);
}; };
struct LintResult struct LintResult
@ -89,7 +95,8 @@ struct LintOptions
void setDefaults(); void setDefaults();
}; };
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const LintOptions& options); std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module,
const std::vector<HotComment>& hotcomments, const LintOptions& options);
std::vector<AstName> getDeprecatedGlobals(const AstNameTable& names); std::vector<AstName> getDeprecatedGlobals(const AstNameTable& names);

View file

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include <unordered_map>
namespace Luau
{
static const std::unordered_map<AstExprBinary::Op, const char*> kBinaryOpMetamethods{
{AstExprBinary::Op::CompareEq, "__eq"},
{AstExprBinary::Op::CompareNe, "__eq"},
{AstExprBinary::Op::CompareGe, "__lt"},
{AstExprBinary::Op::CompareGt, "__le"},
{AstExprBinary::Op::CompareLe, "__le"},
{AstExprBinary::Op::CompareLt, "__lt"},
{AstExprBinary::Op::Add, "__add"},
{AstExprBinary::Op::Sub, "__sub"},
{AstExprBinary::Op::Mul, "__mul"},
{AstExprBinary::Op::Div, "__div"},
{AstExprBinary::Op::Pow, "__pow"},
{AstExprBinary::Op::Mod, "__mod"},
{AstExprBinary::Op::Concat, "__concat"},
};
static const std::unordered_map<AstExprUnary::Op, const char*> kUnaryOpMetamethods{
{AstExprUnary::Op::Minus, "__unm"},
{AstExprUnary::Op::Len, "__len"},
};
} // namespace Luau

View file

@ -1,12 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/FileResolver.h"
#include "Luau/TypePack.h"
#include "Luau/TypedAllocator.h"
#include "Luau/ParseOptions.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Parser.h" #include "Luau/FileResolver.h"
#include "Luau/ParseOptions.h"
#include "Luau/ParseResult.h"
#include "Luau/Scope.h"
#include "Luau/TypeArena.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -21,6 +21,9 @@ struct Module;
using ScopePtr = std::shared_ptr<struct Scope>; using ScopePtr = std::shared_ptr<struct Scope>;
using ModulePtr = std::shared_ptr<Module>; using ModulePtr = std::shared_ptr<Module>;
class AstType;
class AstTypePack;
/// Root of the AST of a parsed source file /// Root of the AST of a parsed source file
struct SourceModule struct SourceModule
{ {
@ -29,14 +32,14 @@ struct SourceModule
std::optional<std::string> environmentName; std::optional<std::string> environmentName;
bool cyclic = false; bool cyclic = false;
std::unique_ptr<Allocator> allocator; std::shared_ptr<Allocator> allocator;
std::unique_ptr<AstNameTable> names; std::shared_ptr<AstNameTable> names;
std::vector<ParseError> parseErrors; std::vector<ParseError> parseErrors;
AstStatBlock* root = nullptr; AstStatBlock* root = nullptr;
std::optional<Mode> mode; std::optional<Mode> mode;
uint64_t ignoreLints = 0;
std::vector<HotComment> hotcomments;
std::vector<Comment> commentLocations; std::vector<Comment> commentLocations;
SourceModule() SourceModule()
@ -48,40 +51,12 @@ struct SourceModule
bool isWithinComment(const SourceModule& sourceModule, Position pos); bool isWithinComment(const SourceModule& sourceModule, Position pos);
struct TypeArena struct RequireCycle
{ {
TypedAllocator<TypeVar> typeVars; Location location;
TypedAllocator<TypePackVar> typePacks; std::vector<ModuleName> path; // one of the paths for a require() to go all the way back to the originating module
void clear();
template<typename T>
TypeId addType(T tv)
{
return addTV(TypeVar(std::move(tv)));
}
TypeId addTV(TypeVar&& tv);
TypeId freshType(TypeLevel level);
TypePackId addTypePack(std::initializer_list<TypeId> types);
TypePackId addTypePack(std::vector<TypeId> types);
TypePackId addTypePack(TypePack pack);
TypePackId addTypePack(TypePackVar pack);
}; };
void freeze(TypeArena& arena);
void unfreeze(TypeArena& arena);
// Only exposed so they can be unit tested.
using SeenTypes = std::unordered_map<TypeId, TypeId>;
using SeenTypePacks = std::unordered_map<TypePackId, TypePackId>;
TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr);
TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr);
struct Module struct Module
{ {
~Module(); ~Module();
@ -89,23 +64,33 @@ struct Module
TypeArena interfaceTypes; TypeArena interfaceTypes;
TypeArena internalTypes; TypeArena internalTypes;
// Scopes and AST types refer to parse data, so we need to keep that alive
std::shared_ptr<Allocator> allocator;
std::shared_ptr<AstNameTable> names;
std::vector<std::pair<Location, ScopePtr>> scopes; // never empty std::vector<std::pair<Location, ScopePtr>> scopes; // never empty
std::unordered_map<const AstExpr*, TypeId> astTypes;
std::unordered_map<const AstExpr*, TypeId> astExpectedTypes; DenseHashMap<const AstExpr*, TypeId> astTypes{nullptr};
std::unordered_map<const AstExpr*, TypeId> astOriginalCallTypes; DenseHashMap<const AstExpr*, TypePackId> astTypePacks{nullptr};
std::unordered_map<const AstExpr*, TypeId> astOverloadResolvedTypes; DenseHashMap<const AstExpr*, TypeId> astExpectedTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astOriginalCallTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astOverloadResolvedTypes{nullptr};
DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr};
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
// Map AST nodes to the scope they create. Cannot be NotNull<Scope> because we need a sentinel value for the map.
DenseHashMap<const AstNode*, Scope*> astScopes{nullptr};
std::unordered_map<Name, TypeId> declaredGlobals; std::unordered_map<Name, TypeId> declaredGlobals;
ErrorVec errors; ErrorVec errors;
Mode mode; Mode mode;
SourceCode::Type type; SourceCode::Type type;
bool timeout = false;
ScopePtr getModuleScope() const; ScopePtr getModuleScope() const;
// Once a module has been typechecked, we clone its public interface into a separate arena. // Once a module has been typechecked, we clone its public interface into a separate arena.
// This helps us to force TypeVar ownership into a DAG rather than a DCG. // This helps us to force TypeVar ownership into a DAG rather than a DCG.
// Returns true if there were any free types encountered in the public interface. This void clonePublicInterface(NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
// indicates a bug in the type checker that we want to surface.
bool clonePublicInterface();
}; };
} // namespace Luau } // namespace Luau

View file

@ -15,12 +15,6 @@ struct Module;
using ModulePtr = std::shared_ptr<Module>; using ModulePtr = std::shared_ptr<Module>;
struct ModuleInfo
{
ModuleName name;
bool optional = false;
};
struct ModuleResolver struct ModuleResolver
{ {
virtual ~ModuleResolver() {} virtual ~ModuleResolver() {}

View file

@ -0,0 +1,329 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/NotNull.h"
#include "Luau/TypeVar.h"
#include "Luau/UnifierSharedState.h"
#include <memory>
namespace Luau
{
struct InternalErrorReporter;
struct Module;
struct Scope;
struct SingletonTypes;
using ModulePtr = std::shared_ptr<Module>;
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice);
class TypeIds
{
private:
std::unordered_set<TypeId> types;
std::vector<TypeId> order;
std::size_t hash = 0;
public:
using iterator = std::vector<TypeId>::iterator;
using const_iterator = std::vector<TypeId>::const_iterator;
TypeIds(const TypeIds&) = delete;
TypeIds(TypeIds&&) = default;
TypeIds() = default;
~TypeIds() = default;
TypeIds& operator=(TypeIds&&) = default;
void insert(TypeId ty);
/// Erase every element that does not also occur in tys
void retain(const TypeIds& tys);
void clear();
TypeId front() const;
iterator begin();
iterator end();
const_iterator begin() const;
const_iterator end() const;
iterator erase(const_iterator it);
size_t size() const;
bool empty() const;
size_t count(TypeId ty) const;
template<class Iterator>
void insert(Iterator begin, Iterator end)
{
for (Iterator it = begin; it != end; ++it)
insert(*it);
}
bool operator==(const TypeIds& there) const;
size_t getHash() const;
};
} // namespace Luau
template<>
struct std::hash<Luau::TypeIds>
{
std::size_t operator()(const Luau::TypeIds& tys) const
{
return tys.getHash();
}
};
template<>
struct std::hash<const Luau::TypeIds*>
{
std::size_t operator()(const Luau::TypeIds* tys) const
{
return tys->getHash();
}
};
template<>
struct std::equal_to<Luau::TypeIds>
{
bool operator()(const Luau::TypeIds& here, const Luau::TypeIds& there) const
{
return here == there;
}
};
template<>
struct std::equal_to<const Luau::TypeIds*>
{
bool operator()(const Luau::TypeIds* here, const Luau::TypeIds* there) const
{
return *here == *there;
}
};
namespace Luau
{
/** A normalized string type is either `string` (represented by `nullopt`) or a
* union of string singletons.
*
* The representation is as follows:
*
* * A union of string singletons is finite and includes the singletons named by
* the `singletons` field.
* * An intersection of negated string singletons is cofinite and includes the
* singletons excluded by the `singletons` field. It is implied that cofinite
* values are exclusions from `string` itself.
* * The `string` data type is a cofinite set minus zero elements.
* * The `never` data type is a finite set plus zero elements.
*/
struct NormalizedStringType
{
// When false, this type represents a union of singleton string types.
// eg "a" | "b" | "c"
//
// When true, this type represents string intersected with negated string
// singleton types.
// eg string & ~"a" & ~"b" & ...
bool isCofinite = false;
std::map<std::string, TypeId> singletons;
void resetToString();
void resetToNever();
bool isNever() const;
bool isString() const;
/// Returns true if the string has finite domain.
///
/// Important subtlety: This method returns true for `never`. The empty set
/// is indeed an empty set.
bool isUnion() const;
/// Returns true if the string has infinite domain.
bool isIntersection() const;
bool includes(const std::string& str) const;
static const NormalizedStringType never;
NormalizedStringType();
NormalizedStringType(bool isCofinite, std::map<std::string, TypeId> singletons);
};
bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr);
// A normalized function type can be `never`, the top function type `function`,
// or an intersection of function types.
//
// NOTE: type normalization can fail on function types with generics (e.g.
// because we do not support unions and intersections of generic type packs), so
// this type may contain `error`.
struct NormalizedFunctionType
{
NormalizedFunctionType();
bool isTop = false;
// TODO: Remove this wrapping optional when clipping
// FFlagLuauNegatedFunctionTypes.
std::optional<TypeIds> parts;
void resetToNever();
void resetToTop();
bool isNever() const;
};
// A normalized generic/free type is a union, where each option is of the form (X & T) where
// * X is either a free type or a generic
// * T is a normalized type.
struct NormalizedType;
using NormalizedTyvars = std::unordered_map<TypeId, std::unique_ptr<NormalizedType>>;
bool isInhabited_DEPRECATED(const NormalizedType& norm);
// A normalized type is either any, unknown, or one of the form P | T | F | G where
// * P is a union of primitive types (including singletons, classes and the error type)
// * T is a union of table types
// * F is a union of an intersection of function types
// * G is a union of generic/free normalized types, intersected with a normalized type
struct NormalizedType
{
// The top part of the type.
// This type is either never, unknown, or any.
// If this type is not never, all the other fields are null.
TypeId tops;
// The boolean part of the type.
// This type is either never, boolean type, or a boolean singleton.
TypeId booleans;
// The class part of the type.
// Each element of this set is a class, and none of the classes are subclasses of each other.
TypeIds classes;
// The error part of the type.
// This type is either never or the error type.
TypeId errors;
// The nil part of the type.
// This type is either never or nil.
TypeId nils;
// The number part of the type.
// This type is either never or number.
TypeId numbers;
// The string part of the type.
// This may be the `string` type, or a union of singletons.
NormalizedStringType strings;
// The thread part of the type.
// This type is either never or thread.
TypeId threads;
// The (meta)table part of the type.
// Each element of this set is a (meta)table type.
TypeIds tables;
// The function part of the type.
NormalizedFunctionType functions;
// The generic/free part of the type.
NormalizedTyvars tyvars;
NormalizedType(NotNull<SingletonTypes> singletonTypes);
NormalizedType() = delete;
~NormalizedType() = default;
NormalizedType(const NormalizedType&) = delete;
NormalizedType& operator=(const NormalizedType&) = delete;
NormalizedType(NormalizedType&&) = default;
NormalizedType& operator=(NormalizedType&&) = default;
};
class Normalizer
{
std::unordered_map<TypeId, std::unique_ptr<NormalizedType>> cachedNormals;
std::unordered_map<const TypeIds*, TypeId> cachedIntersections;
std::unordered_map<const TypeIds*, TypeId> cachedUnions;
std::unordered_map<const TypeIds*, std::unique_ptr<TypeIds>> cachedTypeIds;
bool withinResourceLimits();
public:
TypeArena* arena;
NotNull<SingletonTypes> singletonTypes;
NotNull<UnifierSharedState> sharedState;
Normalizer(TypeArena* arena, NotNull<SingletonTypes> singletonTypes, NotNull<UnifierSharedState> sharedState);
Normalizer(const Normalizer&) = delete;
Normalizer(Normalizer&&) = delete;
Normalizer() = delete;
~Normalizer() = default;
Normalizer& operator=(Normalizer&&) = delete;
Normalizer& operator=(Normalizer&) = delete;
// If this returns null, the typechecker should emit a "too complex" error
const NormalizedType* normalize(TypeId ty);
void clearNormal(NormalizedType& norm);
// ------- Cached TypeIds
TypeId unionType(TypeId here, TypeId there);
TypeId intersectionType(TypeId here, TypeId there);
const TypeIds* cacheTypeIds(TypeIds tys);
void clearCaches();
// ------- Normalizing unions
void unionTysWithTy(TypeIds& here, TypeId there);
TypeId unionOfTops(TypeId here, TypeId there);
TypeId unionOfBools(TypeId here, TypeId there);
void unionClassesWithClass(TypeIds& heres, TypeId there);
void unionClasses(TypeIds& heres, const TypeIds& theres);
void unionStrings(NormalizedStringType& here, const NormalizedStringType& there);
std::optional<TypePackId> unionOfTypePacks(TypePackId here, TypePackId there);
std::optional<TypeId> unionOfFunctions(TypeId here, TypeId there);
std::optional<TypeId> unionSaturatedFunctions(TypeId here, TypeId there);
void unionFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there);
void unionFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress);
void unionTablesWithTable(TypeIds& heres, TypeId there);
void unionTables(TypeIds& heres, const TypeIds& theres);
bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1);
// ------- Negations
std::optional<NormalizedType> negateNormal(const NormalizedType& here);
TypeIds negateAll(const TypeIds& theres);
TypeId negate(TypeId there);
void subtractPrimitive(NormalizedType& here, TypeId ty);
void subtractSingleton(NormalizedType& here, TypeId ty);
// ------- Normalizing intersections
TypeId intersectionOfTops(TypeId here, TypeId there);
TypeId intersectionOfBools(TypeId here, TypeId there);
void intersectClasses(TypeIds& heres, const TypeIds& theres);
void intersectClassesWithClass(TypeIds& heres, TypeId there);
void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there);
std::optional<TypePackId> intersectionOfTypePacks(TypePackId here, TypePackId there);
std::optional<TypeId> intersectionOfTables(TypeId here, TypeId there);
void intersectTablesWithTable(TypeIds& heres, TypeId there);
void intersectTables(TypeIds& heres, const TypeIds& theres);
std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there);
void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there);
void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress);
bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there);
bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
bool intersectNormalWithTy(NormalizedType& here, TypeId there);
// Check for inhabitance
bool isInhabited(TypeId ty, std::unordered_set<TypeId> seen = {});
bool isInhabited(const NormalizedType* norm, std::unordered_set<TypeId> seen = {});
// -------- Convert back from a normalized type to a type
TypeId typeFromNormal(const NormalizedType& norm);
};
} // namespace Luau

View file

@ -0,0 +1,104 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include <functional>
namespace Luau
{
/** A non-owning, non-null pointer to a T.
*
* A NotNull<T> is notionally identical to a T* with the added restriction that
* it can never store nullptr.
*
* The sole conversion rule from T* to NotNull<T> is the single-argument
* constructor, which is intentionally marked explicit. This constructor
* performs a runtime test to verify that the passed pointer is never nullptr.
*
* Pointer arithmetic, increment, decrement, and array indexing are all
* forbidden.
*
* An implicit coersion from NotNull<T> to T* is afforded, as are the pointer
* indirection and member access operators. (*p and p->prop)
*
* The explicit delete statement is permitted (but not recommended) on a
* NotNull<T> through this implicit conversion.
*/
template<typename T>
struct NotNull
{
explicit NotNull(T* t)
: ptr(t)
{
LUAU_ASSERT(t);
}
explicit NotNull(std::nullptr_t) = delete;
void operator=(std::nullptr_t) = delete;
template<typename U>
NotNull(NotNull<U> other)
: ptr(other.get())
{
}
operator T*() const noexcept
{
return ptr;
}
T& operator*() const noexcept
{
return *ptr;
}
T* operator->() const noexcept
{
return ptr;
}
template<typename U>
bool operator==(NotNull<U> other) const noexcept
{
return get() == other.get();
}
template<typename U>
bool operator!=(NotNull<U> other) const noexcept
{
return get() != other.get();
}
operator bool() const noexcept = delete;
T& operator[](int) = delete;
T& operator+(int) = delete;
T& operator-(int) = delete;
T* get() const noexcept
{
return ptr;
}
private:
T* ptr;
};
} // namespace Luau
namespace std
{
template<typename T>
struct hash<Luau::NotNull<T>>
{
size_t operator()(const Luau::NotNull<T>& p) const
{
return std::hash<T*>()(p.get());
}
};
} // namespace std

View file

@ -1,12 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Variant.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Symbol.h" #include "Luau/LValue.h"
#include "Luau/Variant.h"
#include <map>
#include <memory>
#include <vector> #include <vector>
namespace Luau namespace Luau
@ -15,34 +13,6 @@ namespace Luau
struct TypeVar; struct TypeVar;
using TypeId = const TypeVar*; using TypeId = const TypeVar*;
struct Field;
using LValue = Variant<Symbol, Field>;
struct Field
{
std::shared_ptr<LValue> parent; // TODO: Eventually use unique_ptr to enforce non-copyable trait.
std::string key;
};
std::optional<LValue> tryGetLValue(const class AstExpr& expr);
// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys.
std::pair<Symbol, std::vector<std::string>> getFullName(const LValue& lvalue);
std::string toString(const LValue& lvalue);
template<typename T>
const T* get(const LValue& lvalue)
{
return get_if<T>(&lvalue);
}
// Key is a stringified encoding of an LValue.
using RefinementMap = std::map<std::string, TypeId>;
void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f);
void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty);
struct TruthyPredicate; struct TruthyPredicate;
struct IsAPredicate; struct IsAPredicate;
struct TypeGuardPredicate; struct TypeGuardPredicate;

View file

@ -0,0 +1,15 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeVar.h"
namespace Luau
{
struct TypeArena;
struct Scope;
void quantify(TypeId ty, TypeLevel level);
TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope);
} // namespace Luau

View file

@ -2,12 +2,22 @@
#pragma once #pragma once
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Error.h"
#include <stdexcept> #include <stdexcept>
#include <exception>
namespace Luau namespace Luau
{ {
struct RecursionLimitException : public InternalCompilerError
{
RecursionLimitException()
: InternalCompilerError("Internal recursion counter limit exceeded")
{
}
};
struct RecursionCounter struct RecursionCounter
{ {
RecursionCounter(int* count) RecursionCounter(int* count)
@ -32,7 +42,9 @@ struct RecursionLimiter : RecursionCounter
: RecursionCounter(count) : RecursionCounter(count)
{ {
if (limit > 0 && *count > limit) if (limit > 0 && *count > limit)
throw std::runtime_error("Internal recursion counter limit exceeded"); {
throw RecursionLimitException();
}
} }
}; };

View file

@ -6,6 +6,7 @@
#include "Luau/Location.h" #include "Luau/Location.h"
#include <string> #include <string>
#include <vector>
namespace Luau namespace Luau
{ {
@ -17,12 +18,11 @@ struct AstLocal;
struct RequireTraceResult struct RequireTraceResult
{ {
DenseHashMap<const AstExpr*, ModuleName> exprs{0}; DenseHashMap<const AstExpr*, ModuleInfo> exprs{nullptr};
DenseHashMap<const AstExpr*, bool> optional{0};
std::vector<std::pair<ModuleName, Location>> requires; std::vector<std::pair<ModuleName, Location>> requireList;
}; };
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName); RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName);
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,84 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Location.h"
#include "Luau/NotNull.h"
#include "Luau/TypeVar.h"
#include <unordered_map>
#include <optional>
#include <memory>
namespace Luau
{
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct Binding
{
TypeId typeId;
Location location;
bool deprecated = false;
std::string deprecatedSuggestion;
std::optional<std::string> documentationSymbol;
};
struct Scope
{
explicit Scope(TypePackId returnType); // root scope
explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr.
const ScopePtr parent; // null for the root
// All the children of this scope.
std::vector<NotNull<Scope>> children;
std::unordered_map<Symbol, Binding> bindings;
TypePackId returnType;
std::optional<TypePackId> varargPack;
TypeLevel level;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
std::unordered_map<Name, TypeFun> privateTypeBindings;
std::unordered_map<Name, Location> typeAliasLocations;
std::unordered_map<Name, std::unordered_map<Name, TypeFun>> importedTypeBindings;
DenseHashSet<Name> builtinTypeNames{""};
void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun);
std::optional<TypeId> lookup(Symbol sym) const;
std::optional<TypeId> lookup(DefId def) const;
std::optional<std::pair<TypeId, Scope*>> lookupEx(Symbol sym);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
std::unordered_map<Name, TypePackId> privateTypePackBindings;
std::optional<TypePackId> lookupPack(const Name& name);
// WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2)
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const;
RefinementMap refinements;
DenseHashMap<const Def*, TypeId> dcrRefinements{nullptr};
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`
// we need that the generic type `T` in both cases is the same, so we use a cache.
std::unordered_map<Name, TypeId> typeAliasTypeParameters;
std::unordered_map<Name, TypePackId> typeAliasTypePackParameters;
};
// Returns true iff the left scope encloses the right scope. A Scope* equal to
// nullptr is considered to be the outermost-possible scope.
bool subsumesStrict(Scope* left, Scope* right);
// Returns true if the left scope encloses the right scope, or if they are the
// same scope. As in subsumesStrict(), nullptr is considered to be the
// outermost-possible scope.
bool subsumes(Scope* left, Scope* right);
} // namespace Luau

View file

@ -1,8 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Module.h" #include "Luau/TypeArena.h"
#include "Luau/ModuleResolver.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
@ -52,11 +51,11 @@
// `T`, and the type of `f` are in the same SCC, which is why `f` gets // `T`, and the type of `f` are in the same SCC, which is why `f` gets
// replaced. // replaced.
LUAU_FASTFLAG(DebugLuauTrackOwningArena)
namespace Luau namespace Luau
{ {
struct TxnLog;
enum class TarjanResult enum class TarjanResult
{ {
TooManyChildren, TooManyChildren,
@ -90,6 +89,11 @@ struct Tarjan
std::vector<int> lowlink; std::vector<int> lowlink;
int childCount = 0; int childCount = 0;
int childLimit = 0;
// This should never be null; ensure you initialize it before calling
// substitution methods.
const TxnLog* log = nullptr;
std::vector<TypeId> edgesTy; std::vector<TypeId> edgesTy;
std::vector<TypePackId> edgesTp; std::vector<TypePackId> edgesTp;
@ -97,9 +101,6 @@ struct Tarjan
// This is hot code, so we optimize recursion to a stack. // This is hot code, so we optimize recursion to a stack.
TarjanResult loop(); TarjanResult loop();
// Clear the state
void clear();
// Find or create the index for a vertex. // Find or create the index for a vertex.
// Return a boolean which is `true` if it's a freshly created index. // Return a boolean which is `true` if it's a freshly created index.
std::pair<int, bool> indexify(TypeId ty); std::pair<int, bool> indexify(TypeId ty);
@ -138,6 +139,8 @@ struct FindDirty : Tarjan
{ {
std::vector<bool> dirty; std::vector<bool> dirty;
void clearTarjan();
// Get/set the dirty bit for an index (grows the vector if needed) // Get/set the dirty bit for an index (grows the vector if needed)
bool getDirty(int index); bool getDirty(int index);
void setDirty(int index, bool d); void setDirty(int index, bool d);
@ -162,9 +165,21 @@ struct FindDirty : Tarjan
// and replaces them with clean ones. // and replaces them with clean ones.
struct Substitution : FindDirty struct Substitution : FindDirty
{ {
ModulePtr currentModule; protected:
Substitution(const TxnLog* log_, TypeArena* arena)
: arena(arena)
{
log = log_;
LUAU_ASSERT(log);
LUAU_ASSERT(arena);
}
public:
TypeArena* arena;
DenseHashMap<TypeId, TypeId> newTypes{nullptr}; DenseHashMap<TypeId, TypeId> newTypes{nullptr};
DenseHashMap<TypePackId, TypePackId> newPacks{nullptr}; DenseHashMap<TypePackId, TypePackId> newPacks{nullptr};
DenseHashSet<TypeId> replacedTypes{nullptr};
DenseHashSet<TypePackId> replacedTypePacks{nullptr};
std::optional<TypeId> substitute(TypeId ty); std::optional<TypeId> substitute(TypeId ty);
std::optional<TypePackId> substitute(TypePackId tp); std::optional<TypePackId> substitute(TypePackId tp);
@ -188,20 +203,13 @@ struct Substitution : FindDirty
template<typename T> template<typename T>
TypeId addType(const T& tv) TypeId addType(const T& tv)
{ {
TypeId allocated = currentModule->internalTypes.typeVars.allocate(tv); return arena->addType(tv);
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = &currentModule->internalTypes;
return allocated;
} }
template<typename T> template<typename T>
TypePackId addTypePack(const T& tp) TypePackId addTypePack(const T& tp)
{ {
TypePackId allocated = currentModule->internalTypes.typePacks.allocate(tp); return arena->addTypePack(TypePackVar{tp});
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = &currentModule->internalTypes;
return allocated;
} }
}; };

View file

@ -6,10 +6,11 @@
#include <string> #include <string>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau namespace Luau
{ {
// TODO Rename this to Name once the old type alias is gone.
struct Symbol struct Symbol
{ {
Symbol() Symbol()
@ -30,6 +31,9 @@ struct Symbol
{ {
} }
template<typename T>
Symbol(const T&) = delete;
AstLocal* local; AstLocal* local;
AstName global; AstName global;
@ -37,9 +41,12 @@ struct Symbol
{ {
if (local) if (local)
return local == rhs.local; return local == rhs.local;
if (global.value) else if (global.value)
return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity.
return false; else if (FFlag::DebugLuauDeferredConstraintResolution)
return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
else
return false;
} }
bool operator!=(const Symbol& rhs) const bool operator!=(const Symbol& rhs) const
@ -55,8 +62,8 @@ struct Symbol
return global < rhs.global; return global < rhs.global;
else if (local) else if (local)
return true; return true;
else
return false; return false;
} }
AstName astName() const AstName astName() const

View file

@ -0,0 +1,31 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include <string>
namespace Luau
{
struct TypeVar;
using TypeId = const TypeVar*;
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct ToDotOptions
{
bool showPointers = true; // Show pointer value in the node label
bool duplicatePrimitives = true; // Display primitive types and 'any' as separate nodes
};
std::string toDot(TypeId ty, const ToDotOptions& opts);
std::string toDot(TypePackId tp, const ToDotOptions& opts);
std::string toDot(TypeId ty);
std::string toDot(TypePackId tp);
void dumpDot(TypeId ty);
void dumpDot(TypePackId tp);
} // namespace Luau

View file

@ -2,12 +2,12 @@
#pragma once #pragma once
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/TypeVar.h"
#include <unordered_map>
#include <optional>
#include <memory> #include <memory>
#include <optional>
#include <string> #include <string>
#include <unordered_map>
#include <vector>
LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTableTypeMaximumStringifierLength)
LUAU_FASTINT(LuauTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength)
@ -15,6 +15,22 @@ LUAU_FASTINT(LuauTypeMaximumStringifierLength)
namespace Luau namespace Luau
{ {
class AstExpr;
struct Scope;
struct TypeVar;
using TypeId = const TypeVar*;
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct FunctionTypeVar;
struct Constraint;
struct Position;
struct Location;
struct ToStringNameMap struct ToStringNameMap
{ {
std::unordered_map<TypeId, std::string> typeVars; std::unordered_map<TypeId, std::string> typeVars;
@ -23,20 +39,23 @@ struct ToStringNameMap
struct ToStringOptions struct ToStringOptions
{ {
bool exhaustive = false; // If true, we produce complete output rather than comprehensible output bool exhaustive = false; // If true, we produce complete output rather than comprehensible output
bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable.
bool functionTypeArguments = false; // If true, output function type argument names when they are available bool functionTypeArguments = false; // If true, output function type argument names when they are available
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level.
bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self
bool DEPRECATED_indent = false; // TODO Deprecated field, prune when clipping flag FFlagLuauLineBreaksDeterminIndents
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
std::optional<ToStringNameMap> nameMap; ToStringNameMap nameMap;
std::shared_ptr<Scope> scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' std::shared_ptr<Scope> scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid'
std::vector<std::string> namedFunctionOverrideArgNames; // If present, named function argument names will be overridden
}; };
struct ToStringResult struct ToStringResult
{ {
std::string name; std::string name;
ToStringNameMap nameMap;
bool invalid = false; bool invalid = false;
bool error = false; bool error = false;
@ -44,11 +63,24 @@ struct ToStringResult
bool truncated = false; bool truncated = false;
}; };
ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts = {}); ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts);
ToStringResult toStringDetailed(TypePackId ty, const ToStringOptions& opts = {}); ToStringResult toStringDetailed(TypePackId ty, ToStringOptions& opts);
std::string toString(TypeId ty, const ToStringOptions& opts); std::string toString(TypeId ty, ToStringOptions& opts);
std::string toString(TypePackId ty, const ToStringOptions& opts); std::string toString(TypePackId ty, ToStringOptions& opts);
// These overloads are selected when a temporary ToStringOptions is passed. (eg
// via an initializer list)
inline std::string toString(TypePackId ty, ToStringOptions&& opts)
{
// Delegate to the overload (TypePackId, ToStringOptions&)
return toString(ty, opts);
}
inline std::string toString(TypeId ty, ToStringOptions&& opts)
{
// Delegate to the overload (TypeId, ToStringOptions&)
return toString(ty, opts);
}
// These are offered as overloads rather than a default parameter so that they can be easily invoked from within the MSVC debugger. // These are offered as overloads rather than a default parameter so that they can be easily invoked from within the MSVC debugger.
// You can use them in watch expressions! // You can use them in watch expressions!
@ -61,12 +93,54 @@ inline std::string toString(TypePackId ty)
return toString(ty, ToStringOptions{}); return toString(ty, ToStringOptions{});
} }
std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); std::string toString(const Constraint& c, ToStringOptions& opts);
std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {});
inline std::string toString(const Constraint& c, ToStringOptions&& opts)
{
return toString(c, opts);
}
inline std::string toString(const Constraint& c)
{
return toString(c, ToStringOptions{});
}
std::string toString(const TypeVar& tv, ToStringOptions& opts);
std::string toString(const TypePackVar& tp, ToStringOptions& opts);
inline std::string toString(const TypeVar& tv)
{
ToStringOptions opts;
return toString(tv, opts);
}
inline std::string toString(const TypePackVar& tp)
{
ToStringOptions opts;
return toString(tp, opts);
}
std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, ToStringOptions& opts);
inline std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv)
{
ToStringOptions opts;
return toStringNamedFunction(funcName, ftv, opts);
}
std::optional<std::string> getFunctionNameAsString(const AstExpr& expr);
// It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class
// These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression
void dump(TypeId ty); std::string dump(TypeId ty);
void dump(TypePackId ty); std::string dump(TypePackId ty);
std::string dump(const Constraint& c);
std::string dump(const std::shared_ptr<Scope>& scope, const char* name);
std::string generateName(size_t n);
std::string toString(const Position& position);
std::string toString(const Location& location);
} // namespace Luau } // namespace Luau

View file

@ -12,6 +12,7 @@ struct AstArray;
class AstStat; class AstStat;
bool containsFunctionCall(const AstStat& stat); bool containsFunctionCall(const AstStat& stat);
bool containsFunctionCallOrReturn(const AstStat& stat);
bool isFunction(const AstStat& stat); bool isFunction(const AstStat& stat);
void toposort(std::vector<AstStat*>& stats); void toposort(std::vector<AstStat*>& stats);

View file

@ -18,6 +18,7 @@ struct TranspileResult
std::string parseError; // Nonempty if the transpile failed std::string parseError; // Nonempty if the transpile failed
}; };
std::string toString(AstNode* node);
void dump(AstNode* node); void dump(AstNode* node);
// Never fails on a well-formed AST // Never fails on a well-formed AST
@ -25,6 +26,6 @@ std::string transpile(AstStatBlock& ast);
std::string transpileWithTypes(AstStatBlock& block); std::string transpileWithTypes(AstStatBlock& block);
// Only fails when parsing fails // Only fails when parsing fails
TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}); TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}, bool withTypes = false);
} // namespace Luau } // namespace Luau

View file

@ -2,17 +2,92 @@
#pragma once #pragma once
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
#include <memory>
#include <unordered_map>
namespace Luau namespace Luau
{ {
// Log of where what TypeIds we are rebinding and what they used to be using TypeOrPackId = const void*;
// Pending state for a TypeVar. Generated by a TxnLog and committed via
// TxnLog::commit.
struct PendingType
{
// The pending TypeVar state.
TypeVar pending;
explicit PendingType(TypeVar state)
: pending(std::move(state))
{
}
};
std::string toString(PendingType* pending);
std::string dump(PendingType* pending);
// Pending state for a TypePackVar. Generated by a TxnLog and committed via
// TxnLog::commit.
struct PendingTypePack
{
// The pending TypePackVar state.
TypePackVar pending;
explicit PendingTypePack(TypePackVar state)
: pending(std::move(state))
{
}
};
std::string toString(PendingTypePack* pending);
std::string dump(PendingTypePack* pending);
template<typename T>
T* getMutable(PendingType* pending)
{
// We use getMutable here because this state is intended to be mutated freely.
return getMutable<T>(&pending->pending);
}
template<typename T>
T* getMutable(PendingTypePack* pending)
{
// We use getMutable here because this state is intended to be mutated freely.
return getMutable<T>(&pending->pending);
}
// Log of what TypeIds we are rebinding, to be committed later.
struct TxnLog struct TxnLog
{ {
TxnLog() = default; TxnLog()
: typeVarChanges(nullptr)
, typePackChanges(nullptr)
, ownedSeen()
, sharedSeen(&ownedSeen)
{
}
explicit TxnLog(const std::vector<std::pair<TypeId, TypeId>>& seen) explicit TxnLog(TxnLog* parent)
: seen(seen) : typeVarChanges(nullptr)
, typePackChanges(nullptr)
, parent(parent)
{
if (parent)
{
sharedSeen = parent->sharedSeen;
}
else
{
sharedSeen = &ownedSeen;
}
}
explicit TxnLog(std::vector<std::pair<TypeOrPackId, TypeOrPackId>>* sharedSeen)
: typeVarChanges(nullptr)
, typePackChanges(nullptr)
, sharedSeen(sharedSeen)
{ {
} }
@ -22,25 +97,209 @@ struct TxnLog
TxnLog(TxnLog&&) = default; TxnLog(TxnLog&&) = default;
TxnLog& operator=(TxnLog&&) = default; TxnLog& operator=(TxnLog&&) = default;
void operator()(TypeId a); // Gets an empty TxnLog pointer. This is useful for constructs that
void operator()(TypePackId a); // take a TxnLog, like TypePackIterator - use the empty log if you
void operator()(TableTypeVar* a); // don't have a TxnLog to give it.
static const TxnLog* empty();
void rollback();
// Joins another TxnLog onto this one. You should use std::move to avoid
// copying the rhs TxnLog.
//
// If both logs talk about the same type, pack, or table, the rhs takes
// priority.
void concat(TxnLog rhs); void concat(TxnLog rhs);
void concatAsIntersections(TxnLog rhs, NotNull<TypeArena> arena);
void concatAsUnion(TxnLog rhs, NotNull<TypeArena> arena);
bool haveSeen(TypeId lhs, TypeId rhs); // Commits the TxnLog, rebinding all type pointers to their pending states.
// Clears the TxnLog afterwards.
void commit();
// Clears the TxnLog without committing any pending changes.
void clear();
// Computes an inverse of this TxnLog at the current time.
// This method should be called before commit is called in order to give an
// accurate result. Committing the inverse of a TxnLog will undo the changes
// made by commit, assuming the inverse log is accurate.
TxnLog inverse();
bool haveSeen(TypeId lhs, TypeId rhs) const;
void pushSeen(TypeId lhs, TypeId rhs); void pushSeen(TypeId lhs, TypeId rhs);
void popSeen(TypeId lhs, TypeId rhs); void popSeen(TypeId lhs, TypeId rhs);
bool haveSeen(TypePackId lhs, TypePackId rhs) const;
void pushSeen(TypePackId lhs, TypePackId rhs);
void popSeen(TypePackId lhs, TypePackId rhs);
// Queues a type for modification. The original type will not change until commit
// is called. Use pending to get the pending state.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingType* queue(TypeId ty);
// Queues a type pack for modification. The original type pack will not change
// until commit is called. Use pending to get the pending state.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingTypePack* queue(TypePackId tp);
// Returns the pending state of a type, or nullptr if there isn't any. It is important
// to note that this pending state is not transitive: the pending state may reference
// non-pending types freely, so you may need to call pending multiple times to view the
// entire pending state of a type graph.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingType* pending(TypeId ty) const;
// Returns the pending state of a type pack, or nullptr if there isn't any. It is
// important to note that this pending state is not transitive: the pending state may
// reference non-pending types freely, so you may need to call pending multiple times
// to view the entire pending state of a type graph.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingTypePack* pending(TypePackId tp) const;
// Queues a replacement of a type with another type.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingType* replace(TypeId ty, TypeVar replacement);
// Queues a replacement of a type pack with another type pack.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingTypePack* replace(TypePackId tp, TypePackVar replacement);
// Queues a replacement of a table type with another table type that is bound
// to a specific value.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingType* bindTable(TypeId ty, std::optional<TypeId> newBoundTo);
// Queues a replacement of a type with a level with a duplicate of that type
// with a new type level.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingType* changeLevel(TypeId ty, TypeLevel newLevel);
// Queues a replacement of a type pack with a level with a duplicate of that
// type pack with a new type level.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingTypePack* changeLevel(TypePackId tp, TypeLevel newLevel);
// Queues the replacement of a type's scope with the provided scope.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingType* changeScope(TypeId ty, NotNull<Scope> scope);
// Queues the replacement of a type pack's scope with the provided scope.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingTypePack* changeScope(TypePackId tp, NotNull<Scope> scope);
// Queues a replacement of a table type with another table type with a new
// indexer.
//
// The pointer returned lives until `commit` or `clear` is called.
PendingType* changeIndexer(TypeId ty, std::optional<TableIndexer> indexer);
// Returns the type level of the pending state of the type, or the level of that
// type, if no pending state exists. If the type doesn't have a notion of a level,
// returns nullopt. If the pending state doesn't have a notion of a level, but the
// original state does, returns nullopt.
std::optional<TypeLevel> getLevel(TypeId ty) const;
// Follows a type, accounting for pending type states. The returned type may have
// pending state; you should use `pending` or `get` to find out.
TypeId follow(TypeId ty) const;
// Follows a type pack, accounting for pending type states. The returned type pack
// may have pending state; you should use `pending` or `get` to find out.
TypePackId follow(TypePackId tp) const;
// Replaces a given type's state with a new variant. Returns the new pending state
// of that type.
//
// The pointer returned lives until `commit` or `clear` is called.
template<typename T>
PendingType* replace(TypeId ty, T replacement)
{
return replace(ty, TypeVar(replacement));
}
// Replaces a given type pack's state with a new variant. Returns the new
// pending state of that type pack.
//
// The pointer returned lives until `commit` or `clear` is called.
template<typename T>
PendingTypePack* replace(TypePackId tp, T replacement)
{
return replace(tp, TypePackVar(replacement));
}
// Returns T if a given type or type pack is this variant, respecting the
// log's pending state.
//
// Do not retain this pointer; it has the potential to be invalidated when
// commit or clear is called.
template<typename T, typename TID>
T* getMutable(TID ty) const
{
auto* pendingTy = pending(ty);
if (pendingTy)
return Luau::getMutable<T>(pendingTy);
return Luau::getMutable<T>(ty);
}
template<typename T, typename TID>
const T* get(TID ty) const
{
return this->getMutable<T>(ty);
}
// Returns whether a given type or type pack is a given state, respecting the
// log's pending state.
//
// This method will not assert if called on a BoundTypeVar or BoundTypePack.
template<typename T, typename TID>
bool is(TID ty) const
{
// We do not use getMutable here because this method can be called on
// BoundTypeVars, which triggers an assertion.
auto* pendingTy = pending(ty);
if (pendingTy)
return Luau::get_if<T>(&pendingTy->pending.ty) != nullptr;
return Luau::get_if<T>(&ty->ty) != nullptr;
}
std::pair<std::vector<TypeId>, std::vector<TypePackId>> getChanges() const;
private: private:
std::vector<std::pair<TypeId, TypeVar>> typeVarChanges; // unique_ptr is used to give us stable pointers across insertions into the
std::vector<std::pair<TypePackId, TypePackVar>> typePackChanges; // map. Otherwise, it would be really easy to accidentally invalidate the
std::vector<std::pair<TableTypeVar*, std::optional<TypeId>>> tableChanges; // pointers returned from queue/pending.
DenseHashMap<TypeId, std::unique_ptr<PendingType>> typeVarChanges;
DenseHashMap<TypePackId, std::unique_ptr<PendingTypePack>> typePackChanges;
TxnLog* parent = nullptr;
// Owned version of sharedSeen. This should not be accessed directly in
// TxnLogs; use sharedSeen instead. This field exists because in the tree
// of TxnLogs, the root must own its seen set. In all descendant TxnLogs,
// this is an empty vector.
std::vector<std::pair<TypeOrPackId, TypeOrPackId>> ownedSeen;
bool haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const;
void pushSeen(TypeOrPackId lhs, TypeOrPackId rhs);
void popSeen(TypeOrPackId lhs, TypeOrPackId rhs);
public: public:
std::vector<std::pair<TypeId, TypeId>> seen; // used to avoid infinite recursion when types are cyclic // Used to avoid infinite recursion when types are cyclic.
// Shared with all the descendent TxnLogs.
std::vector<std::pair<TypeOrPackId, TypeOrPackId>>* sharedSeen;
}; };
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,52 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypedAllocator.h"
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
#include <vector>
namespace Luau
{
struct TypeArena
{
TypedAllocator<TypeVar> typeVars;
TypedAllocator<TypePackVar> typePacks;
void clear();
template<typename T>
TypeId addType(T tv)
{
if constexpr (std::is_same_v<T, UnionTypeVar>)
LUAU_ASSERT(tv.options.size() >= 2);
return addTV(TypeVar(std::move(tv)));
}
TypeId addTV(TypeVar&& tv);
TypeId freshType(TypeLevel level);
TypeId freshType(Scope* scope);
TypeId freshType(Scope* scope, TypeLevel level);
TypePackId freshTypePack(Scope* scope);
TypePackId addTypePack(std::initializer_list<TypeId> types);
TypePackId addTypePack(std::vector<TypeId> types, std::optional<TypePackId> tail = {});
TypePackId addTypePack(TypePack pack);
TypePackId addTypePack(TypePackVar pack);
template<typename T>
TypePackId addTypePack(T tp)
{
return addTypePack(TypePackVar(std::move(tp)));
}
};
void freeze(TypeArena& arena);
void unfreeze(TypeArena& arena);
} // namespace Luau

View file

@ -0,0 +1,17 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Module.h"
#include "Luau/NotNull.h"
namespace Luau
{
struct DcrLogger;
struct SingletonTypes;
void check(NotNull<SingletonTypes> singletonTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module);
} // namespace Luau

View file

@ -1,21 +1,24 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Anyification.h"
#include "Luau/Predicate.h" #include "Luau/Predicate.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/Symbol.h" #include "Luau/Symbol.h"
#include "Luau/Parser.h"
#include "Luau/Substitution.h" #include "Luau/Substitution.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/Unifier.h" #include "Luau/Unifier.h"
#include "Luau/UnifierSharedState.h"
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau namespace Luau
{ {
@ -31,77 +34,39 @@ bool doesCallError(const AstExprCall* call);
bool hasBreak(AstStat* node); bool hasBreak(AstStat* node);
const AstStat* getFallthrough(const AstStat* node); const AstStat* getFallthrough(const AstStat* node);
struct UnifierOptions;
struct Unifier; struct Unifier;
// A substitution which replaces generic types in a given set by free types. struct GenericTypeDefinitions
struct ReplaceGenerics : Substitution
{ {
TypeLevel level; std::vector<GenericTypeDefinition> genericTypes;
std::vector<TypeId> generics; std::vector<GenericTypePackDefinition> genericPacks;
std::vector<TypePackId> genericPacks;
bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
}; };
// A substitution which replaces generic functions by monomorphic functions struct HashBoolNamePair
struct Instantiation : Substitution
{ {
TypeLevel level; size_t operator()(const std::pair<bool, Name>& pair) const;
ReplaceGenerics replaceGenerics;
bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
}; };
// A substitution which replaces free types by generic types. class TimeLimitError : public InternalCompilerError
struct Quantification : Substitution
{ {
TypeLevel level; public:
std::vector<TypeId> generics; explicit TimeLimitError(const std::string& moduleName)
std::vector<TypePackId> genericPacks; : InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName)
bool isDirty(TypeId ty) override; {
bool isDirty(TypePackId tp) override; }
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
};
// A substitution which replaces free types by any
struct Anyification : Substitution
{
TypeId anyType;
TypePackId anyTypePack;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
};
// A substitution which replaces the type parameters of a type function by arguments
struct ApplyTypeFunction : Substitution
{
TypeLevel level;
bool encounteredForwardedType;
std::unordered_map<TypeId, TypeId> arguments;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
}; };
// All TypeVars are retained via Environment::typeVars. All TypeIds // All TypeVars are retained via Environment::typeVars. All TypeIds
// within a program are borrowed pointers into this set. // within a program are borrowed pointers into this set.
struct TypeChecker struct TypeChecker
{ {
explicit TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler); explicit TypeChecker(ModuleResolver* resolver, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter* iceHandler);
TypeChecker(const TypeChecker&) = delete; TypeChecker(const TypeChecker&) = delete;
TypeChecker& operator=(const TypeChecker&) = delete; TypeChecker& operator=(const TypeChecker&) = delete;
ModulePtr check(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope = std::nullopt); ModulePtr check(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope = std::nullopt);
ModulePtr checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope = std::nullopt);
std::vector<std::pair<Location, ScopePtr>> getScopes() const; std::vector<std::pair<Location, ScopePtr>> getScopes() const;
@ -118,31 +83,37 @@ struct TypeChecker
void check(const ScopePtr& scope, const AstStatForIn& forin); void check(const ScopePtr& scope, const AstStatForIn& forin);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function);
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false); void check(const ScopePtr& scope, const AstStatTypeAlias& typealias);
void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0);
void prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); void checkBlock(const ScopePtr& scope, const AstStatBlock& statement);
void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement);
void checkBlockTypeAliases(const ScopePtr& scope, std::vector<AstStat*>& sorted); void checkBlockTypeAliases(const ScopePtr& scope, std::vector<AstStat*>& sorted);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional<TypeId> expectedType = std::nullopt); WithPredicate<TypeId> checkExpr(
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprLocal& expr); const ScopePtr& scope, const AstExpr& expr, std::optional<TypeId> expectedType = std::nullopt, bool forceSingleton = false);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprLocal& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprGlobal& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprCall& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprVarargs& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprCall& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprIndexName& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional<TypeId> expectedType = std::nullopt); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional<TypeId> expectedType = std::nullopt); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional<TypeId> expectedType = std::nullopt);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprUnary& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprUnary& expr);
TypeId checkRelationalOperation( TypeId checkRelationalOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {});
TypeId checkBinaryOperation( TypeId checkBinaryOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {});
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprBinary& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional<TypeId> expectedType = std::nullopt);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprError& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprError& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprIfElse& expr); WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprInterpString& expr);
TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector<std::pair<TypeId, TypeId>>& fieldTypes, TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector<std::pair<TypeId, TypeId>>& fieldTypes,
std::optional<TypeId> expectedType); std::optional<TypeId> expectedType);
@ -150,35 +121,38 @@ struct TypeChecker
// Returns the type of the lvalue. // Returns the type of the lvalue.
TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr); TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr);
// Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding). // Returns the type of the lvalue.
// Note: the binding may be null. TypeId checkLValueBinding(const ScopePtr& scope, const AstExpr& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); TypeId checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); TypeId checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr);
TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName); TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level);
std::pair<TypeId, ScopePtr> checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::pair<TypeId, ScopePtr> checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr,
std::optional<Location> originalNameLoc, std::optional<TypeId> expectedType); std::optional<Location> originalNameLoc, std::optional<TypeId> selfType, std::optional<TypeId> expectedType);
void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function);
void checkArgumentList( void checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId paramPack, TypePackId argPack,
const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector<Location>& argLocations); const std::vector<Location>& argLocations);
WithPredicate<TypePackId> checkExprPack(const ScopePtr& scope, const AstExpr& expr);
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr);
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr);
ExprResult<TypePackId> checkExprPack(const ScopePtr& scope, const AstExpr& expr);
ExprResult<TypePackId> checkExprPack(const ScopePtr& scope, const AstExprCall& expr);
std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall); std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall);
std::optional<ExprResult<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>& argLocations, const ExprResult<TypePackId>& argListResult, std::optional<WithPredicate<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<OverloadErrorEntry>& errors); TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors);
bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations, bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations,
const std::vector<OverloadErrorEntry>& errors); const std::vector<OverloadErrorEntry>& errors);
ExprResult<TypePackId> reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount, const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
const std::vector<OverloadErrorEntry>& errors); const std::vector<OverloadErrorEntry>& errors);
ExprResult<TypePackId> checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs, WithPredicate<TypePackId> checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs,
bool substituteFreeForNil = false, const std::vector<bool>& lhsAnnotations = {}, bool substituteFreeForNil = false, const std::vector<bool>& lhsAnnotations = {},
const std::vector<std::optional<TypeId>>& expectedTypes = {}); const std::vector<std::optional<TypeId>>& expectedTypes = {});
@ -193,51 +167,44 @@ struct TypeChecker
*/ */
TypeId anyIfNonstrict(TypeId ty) const; TypeId anyIfNonstrict(TypeId ty) const;
/** Attempt to unify the types left and right. Treat any failures as type errors /** Attempt to unify the types.
* in the final typecheck report. * Treat any failures as type errors in the final typecheck report.
*/ */
bool unify(TypeId left, TypeId right, const Location& location); bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
bool unify(TypePackId left, TypePackId right, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options);
bool unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location,
CountMismatch::Context ctx = CountMismatch::Context::Arg);
/** Attempt to unify the types left and right. /** Attempt to unify the types.
* If this fails, and the right type can be instantiated, do so and try unification again. * If this fails, and the subTy type can be instantiated, do so and try unification again.
*/ */
bool unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location); bool unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
void unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state); void unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, Unifier& state);
/** Attempt to unify left with right. /** Attempt to unify.
* If there are errors, undo everything and return the errors. * If there are errors, undo everything and return the errors.
* If there are no errors, commit and return an empty error vector. * If there are no errors, commit and return an empty error vector.
*/ */
ErrorVec tryUnify(TypeId left, TypeId right, const Location& location); template<typename Id>
ErrorVec tryUnify(TypePackId left, TypePackId right, const Location& location); ErrorVec tryUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location);
ErrorVec tryUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
ErrorVec tryUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location);
// Test whether the two type vars unify. Never commits the result. // Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId superTy, TypeId subTy, const Location& location); template<typename Id>
ErrorVec canUnify(TypePackId superTy, TypePackId subTy, const Location& location); ErrorVec canUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location);
ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location);
// Variant that takes a preexisting 'seen' set. We need this in certain cases to avoid infinitely recursing std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors);
// into cyclic types. std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors);
ErrorVec canUnify(const std::vector<std::pair<TypeId, TypeId>>& seen, TypeId left, TypeId right, const Location& location);
std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location);
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location);
std::optional<TypeId> getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); std::optional<TypeId> getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors);
std::optional<TypeId> getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors);
// Reduces the union to its simplest possible shape.
// (A | B) | B | C yields A | B | C
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types);
std::optional<TypeId> tryStripUnionFromNil(TypeId ty); std::optional<TypeId> tryStripUnionFromNil(TypeId ty);
TypeId stripFromNilAndReport(TypeId ty, const Location& location); TypeId stripFromNilAndReport(TypeId ty, const Location& location);
template<typename Id>
ErrorVec tryUnify_(Id left, Id right, const Location& location);
template<typename Id>
ErrorVec canUnify_(Id left, Id right, const Location& location);
public: public:
/* /*
* Convert monotype into a a polytype, by replacing any metavariables in descendant scopes * Convert monotype into a a polytype, by replacing any metavariables in descendant scopes
@ -258,49 +225,66 @@ public:
* {method: ({method: (<CYCLE>) -> a}) -> a} * {method: ({method: (<CYCLE>) -> a}) -> a}
* *
*/ */
TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location); TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log = TxnLog::empty());
// Removed by FFlag::LuauRankNTypes
TypePackId DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location);
// Replace any free types or type packs by `any`. // Replace any free types or type packs by `any`.
// This is used when exporting types from modules, to make sure free types don't leak. // This is used when exporting types from modules, to make sure free types don't leak.
TypeId anyify(const ScopePtr& scope, TypeId ty, Location location); TypeId anyify(const ScopePtr& scope, TypeId ty, Location location);
TypePackId anyify(const ScopePtr& scope, TypePackId ty, Location location); TypePackId anyify(const ScopePtr& scope, TypePackId ty, Location location);
TypePackId anyifyModuleReturnTypePackGenerics(TypePackId ty);
void reportError(const TypeError& error); void reportError(const TypeError& error);
void reportError(const Location& location, TypeErrorData error); void reportError(const Location& location, TypeErrorData error);
void reportErrors(const ErrorVec& errors); void reportErrors(const ErrorVec& errors);
[[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message, const Location& location);
[[noreturn]] void ice(const std::string& message); [[noreturn]] void ice(const std::string& message);
[[noreturn]] void throwTimeLimitError();
ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0);
ScopePtr childScope(const ScopePtr& parent, const Location& location, int subLevel = 0); ScopePtr childScope(const ScopePtr& parent, const Location& location);
// Wrapper for merge(l, r, toUnion) but without the lambda junk. // Wrapper for merge(l, r, toUnion) but without the lambda junk.
void merge(RefinementMap& l, const RefinementMap& r); void merge(RefinementMap& l, const RefinementMap& r);
// Produce an "emergency backup type" for recovery from type errors.
// This comes in two flavours, depening on whether or not we can make a good guess
// for an error recovery type.
TypeId errorRecoveryType(TypeId guess);
TypePackId errorRecoveryTypePack(TypePackId guess);
TypeId errorRecoveryType(const ScopePtr& scope);
TypePackId errorRecoveryTypePack(const ScopePtr& scope);
private: private:
void prepareErrorsForDisplay(ErrorVec& errVec); void prepareErrorsForDisplay(ErrorVec& errVec);
void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data); void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data);
void reportErrorCodeTooComplex(const Location& location); void reportErrorCodeTooComplex(const Location& location);
private: private:
Unifier mkUnifier(const Location& location); Unifier mkUnifier(const ScopePtr& scope, const Location& location);
Unifier mkUnifier(const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location);
// These functions are only safe to call when we are in the process of typechecking a module. // These functions are only safe to call when we are in the process of typechecking a module.
// Produce a new free type var. // Produce a new free type var.
TypeId freshType(const ScopePtr& scope); TypeId freshType(const ScopePtr& scope);
TypeId freshType(TypeLevel level); TypeId freshType(TypeLevel level);
TypeId DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric = false);
TypeId DEPRECATED_freshType(TypeLevel level, bool canBeGeneric = false);
// Returns nullopt if the predicate filters down the TypeId to 0 options. // Produce a new singleton type var.
std::optional<TypeId> filterMap(TypeId type, TypeIdPredicate predicate); TypeId singletonType(bool value);
TypeId singletonType(std::string value);
TypeId unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes = true); TypeIdPredicate mkTruthyPredicate(bool sense, TypeId emptySetTy);
// TODO: Return TypeId only.
std::optional<TypeId> filterMapImpl(TypeId type, TypeIdPredicate predicate);
std::pair<std::optional<TypeId>, bool> filterMap(TypeId type, TypeIdPredicate predicate);
public:
std::pair<std::optional<TypeId>, bool> pickTypesFromSense(TypeId type, bool sense, TypeId emptySetTy);
private:
TypeId unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes = true);
// ex // ex
// TypeId id = addType(FreeTypeVar()); // TypeId id = addType(FreeTypeVar());
@ -310,8 +294,6 @@ private:
return addTV(TypeVar(tv)); return addTV(TypeVar(tv));
} }
TypeId addType(const UnionTypeVar& utv);
TypeId addTV(TypeVar&& tv); TypeId addTV(TypeVar&& tv);
TypePackId addTypePack(TypePackVar&& tp); TypePackId addTypePack(TypePackVar&& tp);
@ -322,36 +304,38 @@ private:
TypePackId addTypePack(std::initializer_list<TypeId>&& ty); TypePackId addTypePack(std::initializer_list<TypeId>&& ty);
TypePackId freshTypePack(const ScopePtr& scope); TypePackId freshTypePack(const ScopePtr& scope);
TypePackId freshTypePack(TypeLevel level); TypePackId freshTypePack(TypeLevel level);
TypePackId DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric = false);
TypePackId DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric = false);
TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false); TypeId resolveType(const ScopePtr& scope, const AstType& annotation);
TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation);
TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams, const Location& location); TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams,
const std::vector<TypePackId>& typePackParams, const Location& location);
// Note: `scope` must be a fresh scope. // Note: `scope` must be a fresh scope.
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes( GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node,
const ScopePtr& scope, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames); const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericTypePack>& genericPackNames, bool useCache = false);
public: public:
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); void resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);
private: private:
void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate);
std::optional<TypeId> resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional<TypeId> resolveLValue(const ScopePtr& scope, const LValue& lvalue);
std::optional<TypeId> resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); std::optional<TypeId> resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue);
void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); void resolve(const PredicateVec& predicates, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false);
void resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); void resolve(const Predicate& predicate, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr);
void resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); void resolve(const TruthyPredicate& truthyP, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr);
void resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const OrPredicate& orP, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const TypeGuardPredicate& typeguardP, RefinementMap& refis, const ScopePtr& scope, bool sense);
void DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
bool isNonstrictMode() const; bool isNonstrictMode() const;
bool useConstrainedIntersections() const;
public: public:
/** Extract the types in a type pack, given the assumption that the pack must have some exact length. /** Extract the types in a type pack, given the assumption that the pack must have some exact length.
@ -371,14 +355,20 @@ public:
ModulePtr currentModule; ModulePtr currentModule;
ModuleName currentModuleName; ModuleName currentModuleName;
Instantiation instantiation;
Quantification quantification;
Anyification anyification;
ApplyTypeFunction applyTypeFunction;
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope; std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
NotNull<SingletonTypes> singletonTypes;
InternalErrorReporter* iceHandler; InternalErrorReporter* iceHandler;
UnifierSharedState unifierState;
Normalizer normalizer;
std::vector<RequireCycle> requireCycles;
// Type inference limits
std::optional<double> finishTime;
std::optional<int> instantiationChildLimit;
std::optional<int> unifierIterationLimit;
public: public:
const TypeId nilType; const TypeId nilType;
const TypeId numberType; const TypeId numberType;
@ -386,68 +376,37 @@ public:
const TypeId booleanType; const TypeId booleanType;
const TypeId threadType; const TypeId threadType;
const TypeId anyType; const TypeId anyType;
const TypeId unknownType;
const TypeId errorType; const TypeId neverType;
const TypeId optionalNumberType;
const TypePackId anyTypePack; const TypePackId anyTypePack;
const TypePackId errorTypePack; const TypePackId neverTypePack;
const TypePackId uninhabitableTypePack;
private: private:
int checkRecursionCount = 0; int checkRecursionCount = 0;
int recursionCount = 0; int recursionCount = 0;
/**
* We use this to avoid doing second-pass analysis of type aliases that are duplicates. We record a pair
* (exported, name) to properly deal with the case where the two duplicates do not have the same export status.
*/
DenseHashSet<std::pair<bool, Name>, HashBoolNamePair> duplicateTypeAliases;
/**
* A set of incorrect class definitions which is used to avoid a second-pass analysis.
*/
DenseHashSet<const AstStatDeclareClass*> incorrectClassDefinitions{nullptr};
std::vector<std::pair<TypeId, ScopePtr>> deferredQuantification;
}; };
struct Binding using PrintLineProc = void (*)(const std::string&);
{
TypeId typeId;
Location location;
bool deprecated = false;
std::string deprecatedSuggestion;
std::optional<std::string> documentationSymbol;
};
struct Scope extern PrintLineProc luauPrintLine;
{
explicit Scope(TypePackId returnType); // root scope
explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr.
const ScopePtr parent; // null for the root
std::unordered_map<Symbol, Binding> bindings;
TypePackId returnType;
bool breakOk = false;
std::optional<TypePackId> varargPack;
TypeLevel level;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
std::unordered_map<Name, TypeFun> privateTypeBindings;
std::unordered_map<Name, Location> typeAliasLocations;
std::unordered_map<Name, std::unordered_map<Name, TypeFun>> importedTypeBindings;
std::optional<TypeId> lookup(const Symbol& name);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
std::unordered_map<Name, TypePackId> privateTypePackBindings;
std::optional<TypePackId> lookupPack(const Name& name);
// WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2)
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true);
RefinementMap refinements;
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.
// For instance, in `type Tree<T> { data: T, children: Forest<T> } type Forest<T> = {Tree<T>}`
// we need that the generic type `T` in both cases is the same, so we use a cache.
std::unordered_map<Name, TypeId> typeAliasParameters;
};
// Unit test hook // Unit test hook
void setPrintLine(void (*pl)(const std::string& s)); void setPrintLine(PrintLineProc pl);
void resetPrintLine(); void resetPrintLine();
} // namespace Luau } // namespace Luau

View file

@ -8,8 +8,6 @@
#include <optional> #include <optional>
#include <set> #include <set>
LUAU_FASTFLAG(LuauAddMissingFollow)
namespace Luau namespace Luau
{ {
@ -17,14 +15,17 @@ struct TypeArena;
struct TypePack; struct TypePack;
struct VariadicTypePack; struct VariadicTypePack;
struct BlockedTypePack;
struct TypePackVar; struct TypePackVar;
struct TxnLog;
using TypePackId = const TypePackVar*; using TypePackId = const TypePackVar*;
using FreeTypePack = Unifiable::Free; using FreeTypePack = Unifiable::Free;
using BoundTypePack = Unifiable::Bound<TypePackId>; using BoundTypePack = Unifiable::Bound<TypePackId>;
using GenericTypePack = Unifiable::Generic; using GenericTypePack = Unifiable::Generic;
using TypePackVariant = Unifiable::Variant<TypePackId, TypePack, VariadicTypePack>; using TypePackVariant = Unifiable::Variant<TypePackId, TypePack, VariadicTypePack, BlockedTypePack>;
/* A TypePack is a rope-like string of TypeIds. We use this structure to encode /* A TypePack is a rope-like string of TypeIds. We use this structure to encode
* notions like packs of unknown length and packs of any length, as well as more * notions like packs of unknown length and packs of any length, as well as more
@ -40,6 +41,18 @@ struct TypePack
struct VariadicTypePack struct VariadicTypePack
{ {
TypeId ty; TypeId ty;
bool hidden = false; // if true, we don't display this when toString()ing a pack with this variadic as its tail.
};
/**
* Analogous to a BlockedTypeVar.
*/
struct BlockedTypePack
{
BlockedTypePack();
size_t index;
static size_t nextIndex;
}; };
struct TypePackVar struct TypePackVar
@ -47,16 +60,24 @@ struct TypePackVar
explicit TypePackVar(const TypePackVariant& ty); explicit TypePackVar(const TypePackVariant& ty);
explicit TypePackVar(TypePackVariant&& ty); explicit TypePackVar(TypePackVariant&& ty);
TypePackVar(TypePackVariant&& ty, bool persistent); TypePackVar(TypePackVariant&& ty, bool persistent);
bool operator==(const TypePackVar& rhs) const; bool operator==(const TypePackVar& rhs) const;
TypePackVar& operator=(TypePackVariant&& tp); TypePackVar& operator=(TypePackVariant&& tp);
TypePackVar& operator=(const TypePackVar& rhs);
// Re-assignes the content of the pack, but doesn't change the owning arena and can't make pack persistent.
void reassign(const TypePackVar& rhs)
{
ty = rhs.ty;
}
TypePackVariant ty; TypePackVariant ty;
bool persistent = false; bool persistent = false;
// Pointer to the type arena that allocated this type. // Pointer to the type arena that allocated this pack.
// Do not depend on the value of this under any circumstances. This is for
// debugging purposes only. This is only set in debug builds; it is nullptr
// in all other environments.
TypeArena* owningArena = nullptr; TypeArena* owningArena = nullptr;
}; };
@ -86,6 +107,7 @@ struct TypePackIterator
TypePackIterator() = default; TypePackIterator() = default;
explicit TypePackIterator(TypePackId tp); explicit TypePackIterator(TypePackId tp);
TypePackIterator(TypePackId tp, const TxnLog* log);
TypePackIterator& operator++(); TypePackIterator& operator++();
TypePackIterator operator++(int); TypePackIterator operator++(int);
@ -106,20 +128,25 @@ private:
TypePackId currentTypePack = nullptr; TypePackId currentTypePack = nullptr;
const TypePack* tp = nullptr; const TypePack* tp = nullptr;
size_t currentIndex = 0; size_t currentIndex = 0;
const TxnLog* log;
}; };
TypePackIterator begin(TypePackId tp); TypePackIterator begin(TypePackId tp);
TypePackIterator begin(TypePackId tp, const TxnLog* log);
TypePackIterator end(TypePackId tp); TypePackIterator end(TypePackId tp);
using SeenSet = std::set<std::pair<void*, void*>>; using SeenSet = std::set<std::pair<const void*, const void*>>;
bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs);
TypePackId follow(TypePackId tp); TypePackId follow(TypePackId tp);
TypePackId follow(TypePackId tp, std::function<TypePackId(TypePackId)> mapper);
size_t size(const TypePackId tp); size_t size(TypePackId tp, TxnLog* log = nullptr);
size_t size(const TypePack& tp); bool finite(TypePackId tp, TxnLog* log = nullptr);
std::optional<TypeId> first(TypePackId tp); size_t size(const TypePack& tp, TxnLog* log = nullptr);
std::optional<TypeId> first(TypePackId tp, bool ignoreHiddenVariadics = true);
TypePackVar* asMutable(TypePackId tp); TypePackVar* asMutable(TypePackId tp);
TypePack* asMutable(const TypePack* tp); TypePack* asMutable(const TypePack* tp);
@ -127,13 +154,10 @@ TypePack* asMutable(const TypePack* tp);
template<typename T> template<typename T>
const T* get(TypePackId tp) const T* get(TypePackId tp)
{ {
if (FFlag::LuauAddMissingFollow) LUAU_ASSERT(tp);
{
LUAU_ASSERT(tp);
if constexpr (!std::is_same_v<T, BoundTypePack>) if constexpr (!std::is_same_v<T, BoundTypePack>)
LUAU_ASSERT(get_if<BoundTypePack>(&tp->ty) == nullptr); LUAU_ASSERT(get_if<BoundTypePack>(&tp->ty) == nullptr);
}
return get_if<T>(&(tp->ty)); return get_if<T>(&(tp->ty));
} }
@ -141,13 +165,10 @@ const T* get(TypePackId tp)
template<typename T> template<typename T>
T* getMutable(TypePackId tp) T* getMutable(TypePackId tp)
{ {
if (FFlag::LuauAddMissingFollow) LUAU_ASSERT(tp);
{
LUAU_ASSERT(tp);
if constexpr (!std::is_same_v<T, BoundTypePack>) if constexpr (!std::is_same_v<T, BoundTypePack>)
LUAU_ASSERT(get_if<BoundTypePack>(&tp->ty) == nullptr); LUAU_ASSERT(get_if<BoundTypePack>(&tp->ty) == nullptr);
}
return get_if<T>(&(asMutable(tp)->ty)); return get_if<T>(&(asMutable(tp)->ty));
} }
@ -157,5 +178,16 @@ bool isEmpty(TypePackId tp);
/// Flattens out a type pack. Also returns a valid TypePackId tail if the type pack's full size is not known /// Flattens out a type pack. Also returns a valid TypePackId tail if the type pack's full size is not known
std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp); std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp);
std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp, const TxnLog& log);
/// Returs true if the type pack arose from a function that is declared to be variadic.
/// Returns *false* for function argument packs that are inferred to be safe to oversaturate!
bool isVariadic(TypePackId tp);
bool isVariadic(TypePackId tp, const TxnLog& log);
// Returns true if the TypePack is Generic or Variadic. Does not walk TypePacks!!
bool isVariadicTail(TypePackId tp, const TxnLog& log, bool includeHiddenVariadics = false);
bool containsNever(TypePackId tp);
} // namespace Luau } // namespace Luau

View file

@ -4,6 +4,7 @@
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
#include <memory> #include <memory>
#include <optional> #include <optional>
@ -11,9 +12,40 @@
namespace Luau namespace Luau
{ {
struct TxnLog;
struct TypeArena;
using ScopePtr = std::shared_ptr<struct Scope>; using ScopePtr = std::shared_ptr<struct Scope>;
std::optional<TypeId> findMetatableEntry(ErrorVec& errors, const ScopePtr& globalScope, TypeId type, std::string entry, Location location); std::optional<TypeId> findMetatableEntry(
std::optional<TypeId> findTablePropertyRespectingMeta(ErrorVec& errors, const ScopePtr& globalScope, TypeId ty, Name name, Location location); NotNull<SingletonTypes> singletonTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location);
std::optional<TypeId> findTablePropertyRespectingMeta(
NotNull<SingletonTypes> singletonTypes, ErrorVec& errors, TypeId ty, const std::string& name, Location location);
// Returns the minimum and maximum number of types the argument list can accept.
std::pair<size_t, std::optional<size_t>> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics = false);
// Extend the provided pack to at least `length` types.
// Returns a temporary TypePack that contains those types plus a tail.
TypePack extendTypePack(TypeArena& arena, NotNull<SingletonTypes> singletonTypes, TypePackId pack, size_t length);
/**
* Reduces a union by decomposing to the any/error type if it appears in the
* type list, and by merging child unions. Also strips out duplicate (by pointer
* identity) types.
* @param types the input type list to reduce.
* @returns the reduced type list.
*/
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types);
/**
* Tries to remove nil from a union type, if there's another option. T | nil
* reduces to T, but nil itself does not reduce.
* @param singletonTypes the singleton types to use
* @param arena the type arena to allocate the new type in, if necessary
* @param ty the type to remove nil from
* @returns a type with nil removed, or nil itself if that were the only option.
*/
TypeId stripNil(NotNull<SingletonTypes> singletonTypes, TypeArena& arena, TypeId ty);
} // namespace Luau } // namespace Luau

View file

@ -1,29 +1,36 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Ast.h"
#include "Luau/Common.h"
#include "Luau/Connective.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Predicate.h" #include "Luau/Predicate.h"
#include "Luau/Unifiable.h" #include "Luau/Unifiable.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/Common.h"
#include <set>
#include <string>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include <deque> #include <deque>
#include <map>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTableTypeMaximumStringifierLength)
LUAU_FASTINT(LuauTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength)
LUAU_FASTFLAG(LuauAddMissingFollow)
namespace Luau namespace Luau
{ {
struct TypeArena; struct TypeArena;
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
/** /**
* There are three kinds of type variables: * There are three kinds of type variables:
@ -83,6 +90,24 @@ using Tags = std::vector<std::string>;
using ModuleName = std::string; using ModuleName = std::string;
/** A TypeVar that cannot be computed.
*
* BlockedTypeVars essentially serve as a way to encode partial ordering on the
* constraint graph. Until a BlockedTypeVar is unblocked by its owning
* constraint, nothing at all can be said about it. Constraints that need to
* process a BlockedTypeVar cannot be dispatched.
*
* Whenever a BlockedTypeVar is added to the graph, we also record a constraint
* that will eventually unblock it.
*/
struct BlockedTypeVar
{
BlockedTypeVar();
int index;
static int nextIndex;
};
struct PrimitiveTypeVar struct PrimitiveTypeVar
{ {
enum Type enum Type
@ -92,6 +117,7 @@ struct PrimitiveTypeVar
Number, Number,
String, String,
Thread, Thread,
Function,
}; };
Type type; Type type;
@ -109,6 +135,95 @@ struct PrimitiveTypeVar
} }
}; };
// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
// Types for true and false
struct BooleanSingleton
{
bool value;
bool operator==(const BooleanSingleton& rhs) const
{
return value == rhs.value;
}
bool operator!=(const BooleanSingleton& rhs) const
{
return !(*this == rhs);
}
};
// Types for "foo", "bar" etc.
struct StringSingleton
{
std::string value;
bool operator==(const StringSingleton& rhs) const
{
return value == rhs.value;
}
bool operator!=(const StringSingleton& rhs) const
{
return !(*this == rhs);
}
};
// No type for float singletons, partly because === isn't any equalivalence on floats
// (NaN != NaN).
using SingletonVariant = Luau::Variant<BooleanSingleton, StringSingleton>;
struct SingletonTypeVar
{
explicit SingletonTypeVar(const SingletonVariant& variant)
: variant(variant)
{
}
explicit SingletonTypeVar(SingletonVariant&& variant)
: variant(std::move(variant))
{
}
// Default operator== is C++20.
bool operator==(const SingletonTypeVar& rhs) const
{
return variant == rhs.variant;
}
bool operator!=(const SingletonTypeVar& rhs) const
{
return !(*this == rhs);
}
SingletonVariant variant;
};
template<typename T>
const T* get(const SingletonTypeVar* stv)
{
if (stv)
return get_if<T>(&stv->variant);
else
return nullptr;
}
struct GenericTypeDefinition
{
TypeId ty;
std::optional<TypeId> defaultValue;
bool operator==(const GenericTypeDefinition& rhs) const;
};
struct GenericTypePackDefinition
{
TypePackId tp;
std::optional<TypePackId> defaultValue;
bool operator==(const GenericTypePackDefinition& rhs) const;
};
struct FunctionArgument struct FunctionArgument
{ {
Name name; Name name;
@ -127,42 +242,72 @@ struct FunctionDefinition
// TODO: Do we actually need this? We'll find out later if we can delete this. // TODO: Do we actually need this? We'll find out later if we can delete this.
// Does not exactly belong in TypeVar.h, but this is the only way to appease the compiler. // Does not exactly belong in TypeVar.h, but this is the only way to appease the compiler.
template<typename T> template<typename T>
struct ExprResult struct WithPredicate
{ {
T type; T type;
PredicateVec predicates; PredicateVec predicates;
}; };
using MagicFunction = std::function<std::optional<ExprResult<TypePackId>>( using MagicFunction = std::function<std::optional<WithPredicate<TypePackId>>(
struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, ExprResult<TypePackId>)>; struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)>;
struct MagicFunctionCallContext
{
NotNull<struct ConstraintSolver> solver;
const class AstExprCall* callSite;
TypePackId arguments;
TypePackId result;
};
using DcrMagicFunction = bool (*)(MagicFunctionCallContext);
struct MagicRefinementContext
{
ScopePtr scope;
NotNull<struct ConstraintGraphBuilder> cgb;
NotNull<const DataFlowGraph> dfg;
NotNull<ConnectiveArena> connectiveArena;
std::vector<ConnectiveId> argumentConnectives;
const class AstExprCall* callSite;
};
using DcrMagicRefinement = std::vector<ConnectiveId> (*)(const MagicRefinementContext&);
struct FunctionTypeVar struct FunctionTypeVar
{ {
// Global monomorphic function // Global monomorphic function
FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false); FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
// Global polymorphic function // Global polymorphic function
FunctionTypeVar(std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retType, FunctionTypeVar(std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retTypes,
std::optional<FunctionDefinition> defn = {}, bool hasSelf = false); std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
// Local monomorphic function // Local monomorphic function
FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false); FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
FunctionTypeVar(
TypeLevel level, Scope* scope, TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
// Local polymorphic function // Local polymorphic function
FunctionTypeVar(TypeLevel level, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retType, FunctionTypeVar(TypeLevel level, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retTypes,
std::optional<FunctionDefinition> defn = {}, bool hasSelf = false); std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
FunctionTypeVar(TypeLevel level, Scope* scope, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes,
TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
TypeLevel level; std::optional<FunctionDefinition> definition;
/// These should all be generic /// These should all be generic
std::vector<TypeId> generics; std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks; std::vector<TypePackId> genericPacks;
TypePackId argTypes;
std::vector<std::optional<FunctionArgument>> argNames; std::vector<std::optional<FunctionArgument>> argNames;
TypePackId retType;
std::optional<FunctionDefinition> definition;
MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr.
bool hasSelf;
Tags tags; Tags tags;
TypeLevel level;
Scope* scope = nullptr;
TypePackId argTypes;
TypePackId retTypes;
MagicFunction magicFunction = nullptr;
DcrMagicFunction dcrMagicFunction = nullptr; // Fired only while solving constraints
DcrMagicRefinement dcrMagicRefinement = nullptr; // Fired only while generating constraints
bool hasSelf;
bool hasNoGenerics = false;
}; };
enum class TableState enum class TableState
@ -212,26 +357,31 @@ struct TableTypeVar
using Props = std::map<Name, Property>; using Props = std::map<Name, Property>;
TableTypeVar() = default; TableTypeVar() = default;
explicit TableTypeVar(TableState state, TypeLevel level); explicit TableTypeVar(TableState state, TypeLevel level, Scope* scope = nullptr);
TableTypeVar(const Props& props, const std::optional<TableIndexer>& indexer, TypeLevel level, TableState state = TableState::Unsealed); TableTypeVar(const Props& props, const std::optional<TableIndexer>& indexer, TypeLevel level, TableState state);
TableTypeVar(const Props& props, const std::optional<TableIndexer>& indexer, TypeLevel level, Scope* scope, TableState state);
Props props; Props props;
std::optional<TableIndexer> indexer; std::optional<TableIndexer> indexer;
TableState state = TableState::Unsealed; TableState state = TableState::Unsealed;
TypeLevel level; TypeLevel level;
Scope* scope = nullptr;
std::optional<std::string> name; std::optional<std::string> name;
// Sometimes we throw a type on a name to make for nicer error messages, but without creating any entry in the type namespace // Sometimes we throw a type on a name to make for nicer error messages, but without creating any entry in the type namespace
// We need to know which is which when we stringify types. // We need to know which is which when we stringify types.
std::optional<std::string> syntheticName; std::optional<std::string> syntheticName;
std::map<Name, Location> methodDefinitionLocations;
std::vector<TypeId> instantiatedTypeParams; std::vector<TypeId> instantiatedTypeParams;
std::vector<TypePackId> instantiatedTypePackParams;
ModuleName definitionModuleName; ModuleName definitionModuleName;
std::optional<TypeId> boundTo; std::optional<TypeId> boundTo;
Tags tags; Tags tags;
// Methods of this table that have an untyped self will use the same shared self type.
std::optional<TypeId> selfTy;
}; };
// Represents a metatable attached to a table typevar. Somewhat analogous to a bound typevar. // Represents a metatable attached to a table typevar. Somewhat analogous to a bound typevar.
@ -269,23 +419,26 @@ struct ClassTypeVar
std::optional<TypeId> metatable; // metaclass? std::optional<TypeId> metatable; // metaclass?
Tags tags; Tags tags;
std::shared_ptr<ClassUserData> userData; std::shared_ptr<ClassUserData> userData;
ModuleName definitionModuleName;
ClassTypeVar( ClassTypeVar(Name name, Props props, std::optional<TypeId> parent, std::optional<TypeId> metatable, Tags tags,
Name name, Props props, std::optional<TypeId> parent, std::optional<TypeId> metatable, Tags tags, std::shared_ptr<ClassUserData> userData) std::shared_ptr<ClassUserData> userData, ModuleName definitionModuleName)
: name(name) : name(name)
, props(props) , props(props)
, parent(parent) , parent(parent)
, metatable(metatable) , metatable(metatable)
, tags(tags) , tags(tags)
, userData(userData) , userData(userData)
, definitionModuleName(definitionModuleName)
{ {
} }
}; };
struct TypeFun struct TypeFun
{ {
/// These should all be generic // These should all be generic
std::vector<TypeId> typeParams; std::vector<GenericTypeDefinition> typeParams;
std::vector<GenericTypePackDefinition> typePackParams;
/** The underlying type. /** The underlying type.
* *
@ -293,6 +446,48 @@ struct TypeFun
* You must first use TypeChecker::instantiateTypeFun to turn it into a real type. * You must first use TypeChecker::instantiateTypeFun to turn it into a real type.
*/ */
TypeId type; TypeId type;
TypeFun() = default;
explicit TypeFun(TypeId ty)
: type(ty)
{
}
TypeFun(std::vector<GenericTypeDefinition> typeParams, TypeId type)
: typeParams(std::move(typeParams))
, type(type)
{
}
TypeFun(std::vector<GenericTypeDefinition> typeParams, std::vector<GenericTypePackDefinition> typePackParams, TypeId type)
: typeParams(std::move(typeParams))
, typePackParams(std::move(typePackParams))
, type(type)
{
}
bool operator==(const TypeFun& rhs) const;
};
/** Represents a pending type alias instantiation.
*
* In order to afford (co)recursive type aliases, we need to reason about a
* partially-complete instantiation. This requires encoding more information in
* a type variable than a BlockedTypeVar affords, hence this. Each
* PendingExpansionTypeVar has a corresponding TypeAliasExpansionConstraint
* enqueued in the solver to convert it to an actual instantiated type
*/
struct PendingExpansionTypeVar
{
PendingExpansionTypeVar(std::optional<AstName> prefix, AstName name, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments);
std::optional<AstName> prefix;
AstName name;
std::vector<TypeId> typeArguments;
std::vector<TypePackId> packArguments;
size_t index;
static size_t nextIndex;
}; };
// Anything! All static checking is off. // Anything! All static checking is off.
@ -300,11 +495,13 @@ struct AnyTypeVar
{ {
}; };
// T | U
struct UnionTypeVar struct UnionTypeVar
{ {
std::vector<TypeId> options; std::vector<TypeId> options;
}; };
// T & U
struct IntersectionTypeVar struct IntersectionTypeVar
{ {
std::vector<TypeId> parts; std::vector<TypeId> parts;
@ -315,10 +512,27 @@ struct LazyTypeVar
std::function<TypeId()> thunk; std::function<TypeId()> thunk;
}; };
struct UnknownTypeVar
{
};
struct NeverTypeVar
{
};
// ~T
// TODO: Some simplification step that overwrites the type graph to make sure negation
// types disappear from the user's view, and (?) a debug flag to disable that
struct NegationTypeVar
{
TypeId ty;
};
using ErrorTypeVar = Unifiable::Error; using ErrorTypeVar = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, FunctionTypeVar, TableTypeVar, MetatableTypeVar, ClassTypeVar, AnyTypeVar, using TypeVariant =
UnionTypeVar, IntersectionTypeVar, LazyTypeVar>; Unifiable::Variant<TypeId, PrimitiveTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar,
MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar, NegationTypeVar>;
struct TypeVar final struct TypeVar final
{ {
@ -338,6 +552,13 @@ struct TypeVar final
{ {
} }
// Re-assignes the content of the type, but doesn't change the owning arena and can't make type persistent.
void reassign(const TypeVar& rhs)
{
ty = rhs.ty;
documentationSymbol = rhs.documentationSymbol;
}
TypeVariant ty; TypeVariant ty;
// Kludge: A persistent TypeVar is one that belongs to the global scope. // Kludge: A persistent TypeVar is one that belongs to the global scope.
@ -348,9 +569,6 @@ struct TypeVar final
std::optional<std::string> documentationSymbol; std::optional<std::string> documentationSymbol;
// Pointer to the type arena that allocated this type. // Pointer to the type arena that allocated this type.
// Do not depend on the value of this under any circumstances. This is for
// debugging purposes only. This is only set in debug builds; it is nullptr
// in all other environments.
TypeArena* owningArena = nullptr; TypeArena* owningArena = nullptr;
bool operator==(const TypeVar& rhs) const; bool operator==(const TypeVar& rhs) const;
@ -358,13 +576,16 @@ struct TypeVar final
TypeVar& operator=(const TypeVariant& rhs); TypeVar& operator=(const TypeVariant& rhs);
TypeVar& operator=(TypeVariant&& rhs); TypeVar& operator=(TypeVariant&& rhs);
TypeVar& operator=(const TypeVar& rhs);
}; };
using SeenSet = std::set<std::pair<void*, void*>>; using SeenSet = std::set<std::pair<const void*, const void*>>;
bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs); bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs);
// Follow BoundTypeVars until we get to something real // Follow BoundTypeVars until we get to something real
TypeId follow(TypeId t); TypeId follow(TypeId t);
TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper);
std::vector<TypeId> flattenIntersection(TypeId ty); std::vector<TypeId> flattenIntersection(TypeId ty);
@ -378,7 +599,10 @@ bool isOptional(TypeId ty);
bool isTableIntersection(TypeId ty); bool isTableIntersection(TypeId ty);
bool isOverloadedFunction(TypeId ty); bool isOverloadedFunction(TypeId ty);
std::optional<TypeId> getMetatable(TypeId type); // True when string is a subtype of ty
bool maybeString(TypeId ty);
std::optional<TypeId> getMetatable(TypeId type, NotNull<struct SingletonTypes> singletonTypes);
TableTypeVar* getMutableTableType(TypeId type); TableTypeVar* getMutableTableType(TypeId type);
const TableTypeVar* getTableType(TypeId type); const TableTypeVar* getTableType(TypeId type);
@ -386,83 +610,84 @@ const TableTypeVar* getTableType(TypeId type);
// Returns nullptr if the type has no name. // Returns nullptr if the type has no name.
const std::string* getName(TypeId type); const std::string* getName(TypeId type);
// Returns name of the module where type was defined if type has that information
std::optional<ModuleName> getDefinitionModuleName(TypeId type);
// Checks whether a union contains all types of another union. // Checks whether a union contains all types of another union.
bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub); bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub);
// Checks if a type conains generic type binders // Checks if a type contains generic type binders
bool isGeneric(const TypeId ty); bool isGeneric(const TypeId ty);
// Checks if a type may be instantiated to one containing generic type binders // Checks if a type may be instantiated to one containing generic type binders
bool maybeGeneric(const TypeId ty); bool maybeGeneric(const TypeId ty);
// Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton
bool maybeSingleton(TypeId ty);
// Checks if the length operator can be applied on the value of type
bool hasLength(TypeId ty, DenseHashSet<TypeId>& seen, int* recursionCount);
struct SingletonTypes struct SingletonTypes
{ {
const TypeId nilType = &nilType_;
const TypeId numberType = &numberType_;
const TypeId stringType = &stringType_;
const TypeId booleanType = &booleanType_;
const TypeId threadType = &threadType_;
const TypeId anyType = &anyType_;
const TypeId errorType = &errorType_;
SingletonTypes(); SingletonTypes();
~SingletonTypes();
SingletonTypes(const SingletonTypes&) = delete; SingletonTypes(const SingletonTypes&) = delete;
void operator=(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete;
TypeId errorRecoveryType(TypeId guess);
TypePackId errorRecoveryTypePack(TypePackId guess);
TypeId errorRecoveryType();
TypePackId errorRecoveryTypePack();
private: private:
std::unique_ptr<struct TypeArena> arena; std::unique_ptr<struct TypeArena> arena;
TypeVar nilType_; bool debugFreezeArena = false;
TypeVar numberType_;
TypeVar stringType_;
TypeVar booleanType_;
TypeVar threadType_;
TypeVar anyType_;
TypeVar errorType_;
TypeId makeStringMetatable(); TypeId makeStringMetatable();
};
extern SingletonTypes singletonTypes; public:
const TypeId nilType;
const TypeId numberType;
const TypeId stringType;
const TypeId booleanType;
const TypeId threadType;
const TypeId functionType;
const TypeId trueType;
const TypeId falseType;
const TypeId anyType;
const TypeId unknownType;
const TypeId neverType;
const TypeId errorType;
const TypeId falsyType; // No type binding!
const TypeId truthyType; // No type binding!
const TypePackId anyTypePack;
const TypePackId neverTypePack;
const TypePackId uninhabitableTypePack;
const TypePackId errorTypePack;
};
void persist(TypeId ty); void persist(TypeId ty);
void persist(TypePackId tp); void persist(TypePackId tp);
struct ToDotOptions
{
bool showPointers = true; // Show pointer value in the node label
bool duplicatePrimitives = true; // Display primitive types and 'any' as separate nodes
};
std::string toDot(TypeId ty, const ToDotOptions& opts);
std::string toDot(TypePackId tp, const ToDotOptions& opts);
std::string toDot(TypeId ty);
std::string toDot(TypePackId tp);
void dumpDot(TypeId ty);
void dumpDot(TypePackId tp);
const TypeLevel* getLevel(TypeId ty); const TypeLevel* getLevel(TypeId ty);
TypeLevel* getMutableLevel(TypeId ty); TypeLevel* getMutableLevel(TypeId ty);
std::optional<TypeLevel> getLevel(TypePackId tp);
const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name);
bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent);
bool hasGeneric(TypeId ty);
bool hasGeneric(TypePackId tp);
TypeVar* asMutable(TypeId ty); TypeVar* asMutable(TypeId ty);
template<typename T> template<typename T>
const T* get(TypeId tv) const T* get(TypeId tv)
{ {
if (FFlag::LuauAddMissingFollow) LUAU_ASSERT(tv);
{
LUAU_ASSERT(tv);
if constexpr (!std::is_same_v<T, BoundTypeVar>) if constexpr (!std::is_same_v<T, BoundTypeVar>)
LUAU_ASSERT(get_if<BoundTypeVar>(&tv->ty) == nullptr); LUAU_ASSERT(get_if<BoundTypeVar>(&tv->ty) == nullptr);
}
return get_if<T>(&tv->ty); return get_if<T>(&tv->ty);
} }
@ -470,23 +695,33 @@ const T* get(TypeId tv)
template<typename T> template<typename T>
T* getMutable(TypeId tv) T* getMutable(TypeId tv)
{ {
if (FFlag::LuauAddMissingFollow) LUAU_ASSERT(tv);
{
LUAU_ASSERT(tv);
if constexpr (!std::is_same_v<T, BoundTypeVar>) if constexpr (!std::is_same_v<T, BoundTypeVar>)
LUAU_ASSERT(get_if<BoundTypeVar>(&tv->ty) == nullptr); LUAU_ASSERT(get_if<BoundTypeVar>(&tv->ty) == nullptr);
}
return get_if<T>(&asMutable(tv)->ty); return get_if<T>(&asMutable(tv)->ty);
} }
/* Traverses the UnionTypeVar yielding each TypeId. const std::vector<TypeId>& getTypes(const UnionTypeVar* utv);
* If the iterator encounters a nested UnionTypeVar, it will instead yield each TypeId within. const std::vector<TypeId>& getTypes(const IntersectionTypeVar* itv);
*
* Beware: the iterator does not currently filter for unique TypeIds. This may change in the future. template<typename T>
struct TypeIterator;
using UnionTypeVarIterator = TypeIterator<UnionTypeVar>;
UnionTypeVarIterator begin(const UnionTypeVar* utv);
UnionTypeVarIterator end(const UnionTypeVar* utv);
using IntersectionTypeVarIterator = TypeIterator<IntersectionTypeVar>;
IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv);
IntersectionTypeVarIterator end(const IntersectionTypeVar* itv);
/* Traverses the type T yielding each TypeId.
* If the iterator encounters a nested type T, it will instead yield each TypeId within.
*/ */
struct UnionTypeVarIterator template<typename T>
struct TypeIterator
{ {
using value_type = Luau::TypeId; using value_type = Luau::TypeId;
using pointer = value_type*; using pointer = value_type*;
@ -494,38 +729,126 @@ struct UnionTypeVarIterator
using difference_type = size_t; using difference_type = size_t;
using iterator_category = std::input_iterator_tag; using iterator_category = std::input_iterator_tag;
explicit UnionTypeVarIterator(const UnionTypeVar* utv); explicit TypeIterator(const T* t)
{
LUAU_ASSERT(t);
UnionTypeVarIterator& operator++(); const std::vector<TypeId>& types = getTypes(t);
UnionTypeVarIterator operator++(int); if (!types.empty())
bool operator!=(const UnionTypeVarIterator& rhs); stack.push_front({t, 0});
bool operator==(const UnionTypeVarIterator& rhs);
const TypeId& operator*(); seen.insert(t);
descend();
}
friend UnionTypeVarIterator end(const UnionTypeVar* utv); TypeIterator<T>& operator++()
{
advance();
descend();
return *this;
}
TypeIterator<T> operator++(int)
{
TypeIterator<T> copy = *this;
++copy;
return copy;
}
bool operator==(const TypeIterator<T>& rhs) const
{
if (!stack.empty() && !rhs.stack.empty())
return stack.front() == rhs.stack.front();
return stack.empty() && rhs.stack.empty();
}
bool operator!=(const TypeIterator<T>& rhs) const
{
return !(*this == rhs);
}
const TypeId& operator*()
{
descend();
LUAU_ASSERT(!stack.empty());
auto [t, currentIndex] = stack.front();
LUAU_ASSERT(t);
const std::vector<TypeId>& types = getTypes(t);
LUAU_ASSERT(currentIndex < types.size());
const TypeId& ty = types[currentIndex];
LUAU_ASSERT(!get<T>(follow(ty)));
return ty;
}
// Normally, we'd have `begin` and `end` be a template but there's too much trouble
// with templates portability in this area, so not worth it. Thanks MSVC.
friend UnionTypeVarIterator end(const UnionTypeVar*);
friend IntersectionTypeVarIterator end(const IntersectionTypeVar*);
private: private:
UnionTypeVarIterator() = default; TypeIterator() = default;
// (UnionTypeVar* utv, size_t currentIndex) // (T* t, size_t currentIndex)
using SavedIterInfo = std::pair<const UnionTypeVar*, size_t>; using SavedIterInfo = std::pair<const T*, size_t>;
std::deque<SavedIterInfo> stack; std::deque<SavedIterInfo> stack;
std::unordered_set<const UnionTypeVar*> seen; // Only needed to protect the iterator from hanging the thread. std::unordered_set<const T*> seen; // Only needed to protect the iterator from hanging the thread.
void advance(); void advance()
void descend(); {
while (!stack.empty())
{
auto& [t, currentIndex] = stack.front();
++currentIndex;
const std::vector<TypeId>& types = getTypes(t);
if (currentIndex >= types.size())
stack.pop_front();
else
break;
}
}
void descend()
{
while (!stack.empty())
{
auto [current, currentIndex] = stack.front();
const std::vector<TypeId>& types = getTypes(current);
if (auto inner = get<T>(follow(types[currentIndex])))
{
// If we're about to descend into a cyclic type, we should skip over this.
// Ideally this should never happen, but alas it does from time to time. :(
if (seen.find(inner) != seen.end())
advance();
else
{
seen.insert(inner);
stack.push_front({inner, 0});
}
continue;
}
break;
}
}
}; };
UnionTypeVarIterator begin(const UnionTypeVar* utv);
UnionTypeVarIterator end(const UnionTypeVar* utv);
using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>; using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>;
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate); std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
// TEMP: Clip this prototype with FFlag::LuauStringMetatable void attachTag(TypeId ty, const std::string& tagName);
std::optional<ExprResult<TypePackId>> magicFunctionFormat( void attachTag(Property& prop, const std::string& tagName);
struct TypeChecker& typechecker, const std::shared_ptr<struct Scope>& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult);
bool hasTag(TypeId ty, const std::string& tagName);
bool hasTag(const Property& prop, const std::string& tagName);
bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work.
} // namespace Luau } // namespace Luau

View file

@ -10,7 +10,7 @@ namespace Luau
{ {
void* pagedAllocate(size_t size); void* pagedAllocate(size_t size);
void pagedDeallocate(void* ptr); void pagedDeallocate(void* ptr, size_t size);
void pagedFreeze(void* ptr, size_t size); void pagedFreeze(void* ptr, size_t size);
void pagedUnfreeze(void* ptr, size_t size); void pagedUnfreeze(void* ptr, size_t size);
@ -20,9 +20,15 @@ class TypedAllocator
public: public:
TypedAllocator() TypedAllocator()
{ {
appendBlock(); currentBlockSize = kBlockSize;
} }
TypedAllocator(const TypedAllocator&) = delete;
TypedAllocator& operator=(const TypedAllocator&) = delete;
TypedAllocator(TypedAllocator&&) = default;
TypedAllocator& operator=(TypedAllocator&&) = default;
~TypedAllocator() ~TypedAllocator()
{ {
if (frozen) if (frozen)
@ -59,12 +65,12 @@ public:
bool empty() const bool empty() const
{ {
return stuff.size() == 1 && currentBlockSize == 0; return stuff.empty();
} }
size_t size() const size_t size() const
{ {
return kBlockSize * (stuff.size() - 1) + currentBlockSize; return stuff.empty() ? 0 : kBlockSize * (stuff.size() - 1) + currentBlockSize;
} }
void clear() void clear()
@ -72,7 +78,8 @@ public:
if (frozen) if (frozen)
unfreeze(); unfreeze();
free(); free();
appendBlock();
currentBlockSize = kBlockSize;
} }
void freeze() void freeze()
@ -106,7 +113,7 @@ private:
for (size_t i = 0; i < blockSize; ++i) for (size_t i = 0; i < blockSize; ++i)
block[i].~T(); block[i].~T();
pagedDeallocate(block); pagedDeallocate(block, kBlockSizeBytes);
} }
stuff.clear(); stuff.clear();

View file

@ -8,6 +8,8 @@
namespace Luau namespace Luau
{ {
struct Scope;
/** /**
* The 'level' of a TypeVar is an indirect way to talk about the scope that it 'belongs' too. * The 'level' of a TypeVar is an indirect way to talk about the scope that it 'belongs' too.
* To start, read http://okmij.org/ftp/ML/generalization.html * To start, read http://okmij.org/ftp/ML/generalization.html
@ -24,7 +26,7 @@ struct TypeLevel
int level = 0; int level = 0;
int subLevel = 0; int subLevel = 0;
// Returns true if the typelevel "this" is "bigger" than rhs // Returns true if the level of "this" belongs to an equal or larger scope than that of rhs
bool subsumes(const TypeLevel& rhs) const bool subsumes(const TypeLevel& rhs) const
{ {
if (level < rhs.level) if (level < rhs.level)
@ -38,6 +40,15 @@ struct TypeLevel
return false; return false;
} }
// Returns true if the level of "this" belongs to a larger (not equal) scope than that of rhs
bool subsumesStrict(const TypeLevel& rhs) const
{
if (level == rhs.level && subLevel == rhs.subLevel)
return false;
else
return subsumes(rhs);
}
TypeLevel incr() const TypeLevel incr() const
{ {
TypeLevel result; TypeLevel result;
@ -47,6 +58,14 @@ struct TypeLevel
} }
}; };
inline TypeLevel max(const TypeLevel& a, const TypeLevel& b)
{
if (a.subsumes(b))
return b;
else
return a;
}
inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) inline TypeLevel min(const TypeLevel& a, const TypeLevel& b)
{ {
if (a.subsumes(b)) if (a.subsumes(b))
@ -55,7 +74,9 @@ inline TypeLevel min(const TypeLevel& a, const TypeLevel& b)
return b; return b;
} }
namespace Unifiable } // namespace Luau
namespace Luau::Unifiable
{ {
using Name = std::string; using Name = std::string;
@ -63,19 +84,19 @@ using Name = std::string;
struct Free struct Free
{ {
explicit Free(TypeLevel level); explicit Free(TypeLevel level);
Free(TypeLevel level, bool DEPRECATED_canBeGeneric); explicit Free(Scope* scope);
explicit Free(Scope* scope, TypeLevel level);
int index; int index;
TypeLevel level; TypeLevel level;
// Removed by FFlag::LuauRankNTypes Scope* scope = nullptr;
bool DEPRECATED_canBeGeneric = false;
// True if this free type variable is part of a mutually // True if this free type variable is part of a mutually
// recursive type alias whose definitions haven't been // recursive type alias whose definitions haven't been
// resolved yet. // resolved yet.
bool forwardedTypeAlias = false; bool forwardedTypeAlias = false;
private: private:
static int nextIndex; static int DEPRECATED_nextIndex;
}; };
template<typename Id> template<typename Id>
@ -95,19 +116,24 @@ struct Generic
Generic(); Generic();
explicit Generic(TypeLevel level); explicit Generic(TypeLevel level);
explicit Generic(const Name& name); explicit Generic(const Name& name);
explicit Generic(Scope* scope);
Generic(TypeLevel level, const Name& name); Generic(TypeLevel level, const Name& name);
Generic(Scope* scope, const Name& name);
int index; int index;
TypeLevel level; TypeLevel level;
Scope* scope = nullptr;
Name name; Name name;
bool explicitName; bool explicitName = false;
private: private:
static int nextIndex; static int DEPRECATED_nextIndex;
}; };
struct Error struct Error
{ {
// This constructor has to be public, since it's used in TypeVar and TypePack,
// but shouldn't be called directly. Please use errorRecoveryType() instead.
Error(); Error();
int index; int index;
@ -117,7 +143,6 @@ private:
}; };
template<typename Id, typename... Value> template<typename Id, typename... Value>
using Variant = Variant<Free, Bound<Id>, Generic, Error, Value...>; using Variant = Luau::Variant<Free, Bound<Id>, Generic, Error, Value...>;
} // namespace Unifiable } // namespace Luau::Unifiable
} // namespace Luau

View file

@ -3,9 +3,13 @@
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/ParseOptions.h"
#include "Luau/Scope.h"
#include "Luau/Substitution.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeArena.h"
#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header. #include "Luau/UnifierSharedState.h"
#include "Normalize.h"
#include <unordered_set> #include <unordered_set>
@ -18,81 +22,133 @@ enum Variance
Invariant Invariant
}; };
struct UnifierCounters // A substitution which replaces singleton types by their wider types
struct Widen : Substitution
{ {
int recursionCount = 0; Widen(TypeArena* arena, NotNull<SingletonTypes> singletonTypes)
int iterationCount = 0; : Substitution(TxnLog::empty(), arena)
, singletonTypes(singletonTypes)
{
}
NotNull<SingletonTypes> singletonTypes;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId ty) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId ty) override;
bool ignoreChildren(TypeId ty) override;
TypeId operator()(TypeId ty);
TypePackId operator()(TypePackId ty);
};
// TODO: Use this more widely.
struct UnifierOptions
{
bool isFunctionCall = false;
}; };
struct Unifier struct Unifier
{ {
TypeArena* const types; TypeArena* const types;
NotNull<SingletonTypes> singletonTypes;
NotNull<Normalizer> normalizer;
Mode mode; Mode mode;
ScopePtr globalScope; // sigh. Needed solely to get at string's metatable.
NotNull<Scope> scope; // const Scope maybe
TxnLog log; TxnLog log;
ErrorVec errors; ErrorVec errors;
Location location; Location location;
Variance variance = Covariant; Variance variance = Covariant;
bool normalize; // Normalize unions and intersections if necessary
bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels
CountMismatch::Context ctx = CountMismatch::Arg; CountMismatch::Context ctx = CountMismatch::Arg;
std::shared_ptr<UnifierCounters> counters; UnifierSharedState& sharedState;
InternalErrorReporter* iceHandler;
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler); Unifier(
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location, NotNull<Normalizer> normalizer, Mode mode, NotNull<Scope> scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr);
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters = nullptr);
// Test whether the two type vars unify. Never commits the result. // Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId superTy, TypeId subTy); ErrorVec canUnify(TypeId subTy, TypeId superTy);
ErrorVec canUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false);
/** Attempt to unify left with right. /** Attempt to unify.
* Populate the vector errors with any type errors that may arise. * Populate the vector errors with any type errors that may arise.
* Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt.
*/ */
void tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false);
private: private:
void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false);
void tryUnifyPrimitives(TypeId superTy, TypeId subTy); void tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId superTy);
void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall);
void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv);
void tryUnifyFreeTable(TypeId free, TypeId other); void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall);
void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection); void tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason,
void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); std::optional<TypeError> error = std::nullopt);
void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); void tryUnifyPrimitives(TypeId subTy, TypeId superTy);
void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); void tryUnifySingletons(TypeId subTy, TypeId superTy);
void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false);
void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false);
void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyTypeWithNegation(TypeId subTy, TypeId superTy);
void tryUnifyNegationWithType(TypeId subTy, TypeId superTy);
TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args);
TypeId widen(TypeId ty);
TypePackId widen(TypePackId tp);
TypeId deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> seen = {});
bool canCacheResult(TypeId subTy, TypeId superTy);
void cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount);
public: public:
void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); void tryUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false);
private: private:
void tryUnify_(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); void tryUnify_(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false);
void tryUnifyVariadics(TypePackId superTy, TypePackId subTy, bool reversed, int subOffset = 0); void tryUnifyVariadics(TypePackId subTy, TypePackId superTy, bool reversed, int subOffset = 0);
void tryUnifyWithAny(TypeId any, TypeId ty); void tryUnifyWithAny(TypeId subTy, TypeId anyTy);
void tryUnifyWithAny(TypePackId any, TypePackId ty); void tryUnifyWithAny(TypePackId subTy, TypePackId anyTp);
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name); std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name);
std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry);
TxnLog combineLogsIntoIntersection(std::vector<TxnLog> logs);
TxnLog combineLogsIntoUnion(std::vector<TxnLog> logs);
public: public:
// Report an "infinite type error" if the type "needle" already occurs within "haystack" // Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error"
void occursCheck(TypeId needle, TypeId haystack); bool occursCheck(TypeId needle, TypeId haystack);
void occursCheck(std::unordered_set<TypeId>& seen, TypeId needle, TypeId haystack); bool occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack);
void occursCheck(TypePackId needle, TypePackId haystack); bool occursCheck(TypePackId needle, TypePackId haystack);
void occursCheck(std::unordered_set<TypePackId>& seen, TypePackId needle, TypePackId haystack); bool occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack);
Unifier makeChildUnifier(); Unifier makeChildUnifier();
void reportError(TypeError err);
LUAU_NOINLINE void reportError(Location location, TypeErrorData data);
private: private:
bool isNonstrictMode() const; bool isNonstrictMode() const;
TypeMismatch::Context mismatchContext();
void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType); void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType);
void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType);
[[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message, const Location& location);
[[noreturn]] void ice(const std::string& message); [[noreturn]] void ice(const std::string& message);
// Available after regular type pack unification errors
std::optional<int> firstPackErrorPos;
}; };
void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp);
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,55 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/Error.h"
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
#include <utility>
namespace Luau
{
struct InternalErrorReporter;
struct TypeIdPairHash
{
size_t hashOne(Luau::TypeId key) const
{
return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9);
}
size_t operator()(const std::pair<Luau::TypeId, Luau::TypeId>& x) const
{
return hashOne(x.first) ^ (hashOne(x.second) << 1);
}
};
struct UnifierCounters
{
int recursionCount = 0;
int recursionLimit = 0;
int iterationCount = 0;
int iterationLimit = 0;
};
struct UnifierSharedState
{
UnifierSharedState(InternalErrorReporter* iceHandler)
: iceHandler(iceHandler)
{
}
InternalErrorReporter* iceHandler;
DenseHashMap<TypeId, bool> skipCacheForType{nullptr};
DenseHashSet<std::pair<TypeId, TypeId>, TypeIdPairHash> cachedUnify{{nullptr, nullptr}};
DenseHashMap<std::pair<TypeId, TypeId>, TypeErrorData, TypeIdPairHash> cachedUnifyError{{nullptr, nullptr}};
DenseHashSet<TypeId> tempSeenTy{nullptr};
DenseHashSet<TypePackId> tempSeenTp{nullptr};
UnifierCounters counters;
};
} // namespace Luau

View file

@ -2,45 +2,15 @@
#pragma once #pragma once
#include "Luau/Common.h" #include "Luau/Common.h"
#ifndef LUAU_USE_STD_VARIANT
#define LUAU_USE_STD_VARIANT 0
#endif
#if LUAU_USE_STD_VARIANT
#include <variant>
#else
#include <new> #include <new>
#include <type_traits> #include <type_traits>
#include <initializer_list> #include <initializer_list>
#include <stddef.h> #include <stddef.h>
#endif #include <utility>
namespace Luau namespace Luau
{ {
#if LUAU_USE_STD_VARIANT
template<typename... Ts>
using Variant = std::variant<Ts...>;
template<class Visitor, class Variant>
auto visit(Visitor&& vis, Variant&& var)
{
// This change resolves the ABI issues with std::variant on libc++; std::visit normally throws bad_variant_access
// but it requires an update to libc++.dylib which ships with macOS 10.14. To work around this, we assert on valueless
// variants since we will never generate them and call into a libc++ function that doesn't throw.
LUAU_ASSERT(!var.valueless_by_exception());
#ifdef __APPLE__
// See https://stackoverflow.com/a/53868971/503215
return std::__variant_detail::__visitation::__variant::__visit_value(vis, var);
#else
return std::visit(vis, var);
#endif
}
using std::get_if;
#else
template<typename... Ts> template<typename... Ts>
class Variant class Variant
{ {
@ -88,13 +58,15 @@ public:
constexpr int tid = getTypeId<T>(); constexpr int tid = getTypeId<T>();
typeId = tid; typeId = tid;
new (&storage) TT(value); new (&storage) TT(std::forward<T>(value));
} }
Variant(const Variant& other) Variant(const Variant& other)
{ {
static constexpr FnCopy table[sizeof...(Ts)] = {&fnCopy<Ts>...};
typeId = other.typeId; typeId = other.typeId;
tableCopy[typeId](&storage, &other.storage); table[typeId](&storage, &other.storage);
} }
Variant(Variant&& other) Variant(Variant&& other)
@ -126,6 +98,20 @@ public:
return *this; return *this;
} }
template<typename T, typename... Args>
T& emplace(Args&&... args)
{
using TT = std::decay_t<T>;
constexpr int tid = getTypeId<T>();
static_assert(tid >= 0, "unsupported T");
tableDtor[typeId](&storage);
typeId = tid;
new (&storage) TT{std::forward<Args>(args)...};
return *reinterpret_cast<T*>(&storage);
}
template<typename T> template<typename T>
const T* get_if() const const T* get_if() const
{ {
@ -208,7 +194,6 @@ private:
return *static_cast<const T*>(lhs) == *static_cast<const T*>(rhs); return *static_cast<const T*>(lhs) == *static_cast<const T*>(rhs);
} }
static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy<Ts>...};
static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...}; static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...};
static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...}; static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...};
@ -248,6 +233,8 @@ static void fnVisitV(Visitor& vis, std::conditional_t<std::is_const_v<T>, const
template<class Visitor, typename... Ts> template<class Visitor, typename... Ts>
auto visit(Visitor&& vis, const Variant<Ts...>& var) auto visit(Visitor&& vis, const Variant<Ts...>& var)
{ {
static_assert(std::conjunction_v<std::is_invocable<Visitor, Ts>...>, "visitor must accept every alternative as an argument");
using Result = std::invoke_result_t<Visitor, typename Variant<Ts...>::first_alternative>; using Result = std::invoke_result_t<Visitor, typename Variant<Ts...>::first_alternative>;
static_assert(std::conjunction_v<std::is_same<Result, std::invoke_result_t<Visitor, Ts>>...>, static_assert(std::conjunction_v<std::is_same<Result, std::invoke_result_t<Visitor, Ts>>...>,
"visitor result type must be consistent between alternatives"); "visitor result type must be consistent between alternatives");
@ -273,6 +260,8 @@ auto visit(Visitor&& vis, const Variant<Ts...>& var)
template<class Visitor, typename... Ts> template<class Visitor, typename... Ts>
auto visit(Visitor&& vis, Variant<Ts...>& var) auto visit(Visitor&& vis, Variant<Ts...>& var)
{ {
static_assert(std::conjunction_v<std::is_invocable<Visitor, Ts&>...>, "visitor must accept every alternative as an argument");
using Result = std::invoke_result_t<Visitor, typename Variant<Ts...>::first_alternative&>; using Result = std::invoke_result_t<Visitor, typename Variant<Ts...>::first_alternative&>;
static_assert(std::conjunction_v<std::is_same<Result, std::invoke_result_t<Visitor, Ts&>>...>, static_assert(std::conjunction_v<std::is_same<Result, std::invoke_result_t<Visitor, Ts&>>...>,
"visitor result type must be consistent between alternatives"); "visitor result type must be consistent between alternatives");
@ -294,7 +283,6 @@ auto visit(Visitor&& vis, Variant<Ts...>& var)
return res; return res;
} }
} }
#endif
template<class> template<class>
inline constexpr bool always_false_v = false; inline constexpr bool always_false_v = false;

View file

@ -1,8 +1,15 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/TypeVar.h" #include <unordered_set>
#include "Luau/DenseHash.h"
#include "Luau/RecursionCounter.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
LUAU_FASTINT(LuauVisitRecursionLimit)
LUAU_FASTFLAG(LuauCompleteVisitor);
namespace Luau namespace Luau
{ {
@ -32,169 +39,365 @@ inline bool hasSeen(std::unordered_set<void*>& seen, const void* tv)
return !seen.insert(ttv).second; return !seen.insert(ttv).second;
} }
inline bool hasSeen(DenseHashSet<void*>& seen, const void* tv)
{
void* ttv = const_cast<void*>(tv);
if (seen.contains(ttv))
return true;
seen.insert(ttv);
return false;
}
inline void unsee(std::unordered_set<void*>& seen, const void* tv) inline void unsee(std::unordered_set<void*>& seen, const void* tv)
{ {
void* ttv = const_cast<void*>(tv); void* ttv = const_cast<void*>(tv);
seen.erase(ttv); seen.erase(ttv);
} }
template<typename F> inline void unsee(DenseHashSet<void*>& seen, const void* tv)
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen);
template<typename F>
void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
{ {
if (visit_detail::hasSeen(seen, ty)) // When DenseHashSet is used for 'visitTypeVarOnce', where don't forget visited elements
{
f.cycle(ty);
return;
}
if (auto btv = get<BoundTypeVar>(ty))
{
if (apply(ty, *btv, seen, f))
visit(btv->boundTo, f, seen);
}
else if (auto ftv = get<FreeTypeVar>(ty))
apply(ty, *ftv, seen, f);
else if (auto gtv = get<GenericTypeVar>(ty))
apply(ty, *gtv, seen, f);
else if (auto etv = get<ErrorTypeVar>(ty))
apply(ty, *etv, seen, f);
else if (auto ptv = get<PrimitiveTypeVar>(ty))
apply(ty, *ptv, seen, f);
else if (auto ftv = get<FunctionTypeVar>(ty))
{
if (apply(ty, *ftv, seen, f))
{
visit(ftv->argTypes, f, seen);
visit(ftv->retType, f, seen);
}
}
else if (auto ttv = get<TableTypeVar>(ty))
{
if (apply(ty, *ttv, seen, f))
{
for (auto& [_name, prop] : ttv->props)
visit(prop.type, f, seen);
if (ttv->indexer)
{
visit(ttv->indexer->indexType, f, seen);
visit(ttv->indexer->indexResultType, f, seen);
}
}
}
else if (auto mtv = get<MetatableTypeVar>(ty))
{
if (apply(ty, *mtv, seen, f))
{
visit(mtv->table, f, seen);
visit(mtv->metatable, f, seen);
}
}
else if (auto ctv = get<ClassTypeVar>(ty))
{
if (apply(ty, *ctv, seen, f))
{
for (const auto& [name, prop] : ctv->props)
visit(prop.type, f, seen);
if (ctv->parent)
visit(*ctv->parent, f, seen);
if (ctv->metatable)
visit(*ctv->metatable, f, seen);
}
}
else if (auto atv = get<AnyTypeVar>(ty))
apply(ty, *atv, seen, f);
else if (auto utv = get<UnionTypeVar>(ty))
{
if (apply(ty, *utv, seen, f))
{
for (TypeId optTy : utv->options)
visit(optTy, f, seen);
}
}
else if (auto itv = get<IntersectionTypeVar>(ty))
{
if (apply(ty, *itv, seen, f))
{
for (TypeId partTy : itv->parts)
visit(partTy, f, seen);
}
}
visit_detail::unsee(seen, ty);
} }
template<typename F>
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen)
{
if (visit_detail::hasSeen(seen, tp))
{
f.cycle(tp);
return;
}
if (auto btv = get<BoundTypePack>(tp))
{
if (apply(tp, *btv, seen, f))
visit(btv->boundTo, f, seen);
}
else if (auto ftv = get<Unifiable::Free>(tp))
apply(tp, *ftv, seen, f);
else if (auto gtv = get<Unifiable::Generic>(tp))
apply(tp, *gtv, seen, f);
else if (auto etv = get<Unifiable::Error>(tp))
apply(tp, *etv, seen, f);
else if (auto pack = get<TypePack>(tp))
{
apply(tp, *pack, seen, f);
for (TypeId ty : pack->head)
visit(ty, f, seen);
if (pack->tail)
visit(*pack->tail, f, seen);
}
else if (auto pack = get<VariadicTypePack>(tp))
{
apply(tp, *pack, seen, f);
visit(pack->ty, f, seen);
}
visit_detail::unsee(seen, tp);
}
} // namespace visit_detail } // namespace visit_detail
template<typename TID, typename F> template<typename S>
void visitTypeVar(TID ty, F& f, std::unordered_set<void*>& seen) struct GenericTypeVarVisitor
{ {
visit_detail::visit(ty, f, seen); using Set = S;
}
template<typename TID, typename F> Set seen;
void visitTypeVar(TID ty, F& f) bool skipBoundTypes = false;
int recursionCounter = 0;
GenericTypeVarVisitor() = default;
explicit GenericTypeVarVisitor(Set seen, bool skipBoundTypes = false)
: seen(std::move(seen))
, skipBoundTypes(skipBoundTypes)
{
}
virtual void cycle(TypeId) {}
virtual void cycle(TypePackId) {}
virtual bool visit(TypeId ty)
{
return true;
}
virtual bool visit(TypeId ty, const BoundTypeVar& btv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const FreeTypeVar& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const GenericTypeVar& gtv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const ErrorTypeVar& etv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const FunctionTypeVar& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const TableTypeVar& ttv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const MetatableTypeVar& mtv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const ClassTypeVar& ctv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const AnyTypeVar& atv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const UnknownTypeVar& utv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const NeverTypeVar& ntv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const UnionTypeVar& utv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const IntersectionTypeVar& itv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const BlockedTypeVar& btv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const PendingExpansionTypeVar& petv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const SingletonTypeVar& stv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const NegationTypeVar& ntv)
{
return visit(ty);
}
virtual bool visit(TypePackId tp)
{
return true;
}
virtual bool visit(TypePackId tp, const BoundTypePack& btp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const FreeTypePack& ftp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const GenericTypePack& gtp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const Unifiable::Error& etp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const TypePack& pack)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const VariadicTypePack& vtp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const BlockedTypePack& btp)
{
return visit(tp);
}
void traverse(TypeId ty)
{
RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit};
if (visit_detail::hasSeen(seen, ty))
{
cycle(ty);
return;
}
if (auto btv = get<BoundTypeVar>(ty))
{
if (skipBoundTypes)
traverse(btv->boundTo);
else if (visit(ty, *btv))
traverse(btv->boundTo);
}
else if (auto ftv = get<FreeTypeVar>(ty))
visit(ty, *ftv);
else if (auto gtv = get<GenericTypeVar>(ty))
visit(ty, *gtv);
else if (auto etv = get<ErrorTypeVar>(ty))
visit(ty, *etv);
else if (auto ptv = get<PrimitiveTypeVar>(ty))
visit(ty, *ptv);
else if (auto ftv = get<FunctionTypeVar>(ty))
{
if (visit(ty, *ftv))
{
traverse(ftv->argTypes);
traverse(ftv->retTypes);
}
}
else if (auto ttv = get<TableTypeVar>(ty))
{
// Some visitors want to see bound tables, that's why we traverse the original type
if (skipBoundTypes && ttv->boundTo)
{
traverse(*ttv->boundTo);
}
else if (visit(ty, *ttv))
{
if (ttv->boundTo)
{
traverse(*ttv->boundTo);
}
else
{
for (auto& [_name, prop] : ttv->props)
traverse(prop.type);
if (ttv->indexer)
{
traverse(ttv->indexer->indexType);
traverse(ttv->indexer->indexResultType);
}
}
}
}
else if (auto mtv = get<MetatableTypeVar>(ty))
{
if (visit(ty, *mtv))
{
traverse(mtv->table);
traverse(mtv->metatable);
}
}
else if (auto ctv = get<ClassTypeVar>(ty))
{
if (visit(ty, *ctv))
{
for (const auto& [name, prop] : ctv->props)
traverse(prop.type);
if (ctv->parent)
traverse(*ctv->parent);
if (ctv->metatable)
traverse(*ctv->metatable);
}
}
else if (auto atv = get<AnyTypeVar>(ty))
visit(ty, *atv);
else if (auto utv = get<UnionTypeVar>(ty))
{
if (visit(ty, *utv))
{
for (TypeId optTy : utv->options)
traverse(optTy);
}
}
else if (auto itv = get<IntersectionTypeVar>(ty))
{
if (visit(ty, *itv))
{
for (TypeId partTy : itv->parts)
traverse(partTy);
}
}
else if (get<LazyTypeVar>(ty))
{
// Visiting into LazyTypeVar may necessarily cause infinite expansion, so we don't do that on purpose.
// Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassTypeVar
// that doesn't need to be expanded.
}
else if (auto stv = get<SingletonTypeVar>(ty))
visit(ty, *stv);
else if (auto btv = get<BlockedTypeVar>(ty))
visit(ty, *btv);
else if (auto utv = get<UnknownTypeVar>(ty))
visit(ty, *utv);
else if (auto ntv = get<NeverTypeVar>(ty))
visit(ty, *ntv);
else if (auto petv = get<PendingExpansionTypeVar>(ty))
{
if (visit(ty, *petv))
{
for (TypeId a : petv->typeArguments)
traverse(a);
for (TypePackId a : petv->packArguments)
traverse(a);
}
}
else if (auto ntv = get<NegationTypeVar>(ty))
visit(ty, *ntv);
else if (!FFlag::LuauCompleteVisitor)
return visit_detail::unsee(seen, ty);
else
LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypeId) is not exhaustive!");
visit_detail::unsee(seen, ty);
}
void traverse(TypePackId tp)
{
if (visit_detail::hasSeen(seen, tp))
{
cycle(tp);
return;
}
if (auto btv = get<BoundTypePack>(tp))
{
if (visit(tp, *btv))
traverse(btv->boundTo);
}
else if (auto ftv = get<Unifiable::Free>(tp))
visit(tp, *ftv);
else if (auto gtv = get<Unifiable::Generic>(tp))
visit(tp, *gtv);
else if (auto etv = get<Unifiable::Error>(tp))
visit(tp, *etv);
else if (auto pack = get<TypePack>(tp))
{
bool res = visit(tp, *pack);
if (res)
{
for (TypeId ty : pack->head)
traverse(ty);
if (pack->tail)
traverse(*pack->tail);
}
}
else if (auto pack = get<VariadicTypePack>(tp))
{
bool res = visit(tp, *pack);
if (res)
traverse(pack->ty);
}
else if (auto btp = get<BlockedTypePack>(tp))
visit(tp, *btp);
else
LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!");
visit_detail::unsee(seen, tp);
}
};
/** Visit each type under a given type. Skips over cycles and keeps recursion depth under control.
*
* The same type may be visited multiple times if there are multiple distinct paths to it. If this is undesirable, use
* TypeVarOnceVisitor.
*/
struct TypeVarVisitor : GenericTypeVarVisitor<std::unordered_set<void*>>
{ {
std::unordered_set<void*> seen; explicit TypeVarVisitor(bool skipBoundTypes = false)
visit_detail::visit(ty, f, seen); : GenericTypeVarVisitor{{}, skipBoundTypes}
} {
}
};
/// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it.
struct TypeVarOnceVisitor : GenericTypeVarVisitor<DenseHashSet<void*>>
{
explicit TypeVarOnceVisitor(bool skipBoundTypes = false)
: GenericTypeVarVisitor{DenseHashSet<void*>{nullptr}, skipBoundTypes}
{
}
};
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,90 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Anyification.h"
#include "Luau/Common.h"
#include "Luau/Normalize.h"
#include "Luau/TxnLog.h"
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau
{
Anyification::Anyification(TypeArena* arena, NotNull<Scope> scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter* iceHandler,
TypeId anyType, TypePackId anyTypePack)
: Substitution(TxnLog::empty(), arena)
, scope(scope)
, singletonTypes(singletonTypes)
, iceHandler(iceHandler)
, anyType(anyType)
, anyTypePack(anyTypePack)
{
}
Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, NotNull<SingletonTypes> singletonTypes, InternalErrorReporter* iceHandler,
TypeId anyType, TypePackId anyTypePack)
: Anyification(arena, NotNull{scope.get()}, singletonTypes, iceHandler, anyType, anyTypePack)
{
}
bool Anyification::isDirty(TypeId ty)
{
if (ty->persistent)
return false;
if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed);
else if (log->getMutable<FreeTypeVar>(ty))
return true;
else
return false;
}
bool Anyification::isDirty(TypePackId tp)
{
if (tp->persistent)
return false;
if (log->getMutable<FreeTypePack>(tp))
return true;
else
return false;
}
TypeId Anyification::clean(TypeId ty)
{
LUAU_ASSERT(isDirty(ty));
if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
{
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed};
clone.definitionModuleName = ttv->definitionModuleName;
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.tags = ttv->tags;
TypeId res = addType(std::move(clone));
return res;
}
else
return anyType;
}
TypePackId Anyification::clean(TypePackId tp)
{
LUAU_ASSERT(isDirty(tp));
return anyTypePack;
}
bool Anyification::ignoreChildren(TypeId ty)
{
if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassTypeVar>(ty))
return true;
return ty->persistent;
}
bool Anyification::ignoreChildren(TypePackId ty)
{
return ty->persistent;
}
} // namespace Luau

View file

@ -0,0 +1,64 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ApplyTypeFunction.h"
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau
{
bool ApplyTypeFunction::isDirty(TypeId ty)
{
if (typeArguments.count(ty))
return true;
else if (const FreeTypeVar* ftv = get<FreeTypeVar>(ty))
{
if (ftv->forwardedTypeAlias)
encounteredForwardedType = true;
return false;
}
else
return false;
}
bool ApplyTypeFunction::isDirty(TypePackId tp)
{
if (typePackArguments.count(tp))
return true;
else
return false;
}
bool ApplyTypeFunction::ignoreChildren(TypeId ty)
{
if (get<GenericTypeVar>(ty))
return true;
else if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassTypeVar>(ty))
return true;
else
return false;
}
bool ApplyTypeFunction::ignoreChildren(TypePackId tp)
{
if (get<GenericTypePack>(tp))
return true;
else
return false;
}
TypeId ApplyTypeFunction::clean(TypeId ty)
{
TypeId& arg = typeArguments[ty];
LUAU_ASSERT(arg);
return arg;
}
TypePackId ApplyTypeFunction::clean(TypePackId tp)
{
TypePackId& arg = typePackArguments[tp];
LUAU_ASSERT(arg);
return arg;
}
} // namespace Luau

View file

@ -1,8 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/JsonEncoder.h" #include "Luau/AstJsonEncoder.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/ParseResult.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include "Luau/Common.h"
#include <math.h>
namespace Luau namespace Luau
{ {
@ -74,6 +78,11 @@ struct AstJsonEncoder : public AstVisitor
writeRaw(std::string_view{&c, 1}); writeRaw(std::string_view{&c, 1});
} }
void writeType(std::string_view propValue)
{
write("type", propValue);
}
template<typename T> template<typename T>
void write(std::string_view propName, const T& value) void write(std::string_view propName, const T& value)
{ {
@ -96,9 +105,28 @@ struct AstJsonEncoder : public AstVisitor
void write(double d) void write(double d)
{ {
char b[256]; switch (fpclassify(d))
sprintf(b, "%g", d); {
writeRaw(b); case FP_INFINITE:
if (d < 0)
writeRaw("-Infinity");
else
writeRaw("Infinity");
break;
case FP_NAN:
writeRaw("NaN");
break;
case FP_NORMAL:
case FP_SUBNORMAL:
case FP_ZERO:
default:
char b[32];
snprintf(b, sizeof(b), "%.17g", d);
writeRaw(b);
break;
}
} }
void writeString(std::string_view sv) void writeString(std::string_view sv)
@ -110,8 +138,12 @@ struct AstJsonEncoder : public AstVisitor
{ {
if (c == '"') if (c == '"')
writeRaw("\\\""); writeRaw("\\\"");
else if (c == '\0') else if (c == '\\')
writeRaw("\\\0"); writeRaw("\\\\");
else if (c < ' ')
writeRaw(format("\\u%04x", c));
else if (c == '\n')
writeRaw("\\n");
else else
writeRaw(c); writeRaw(c);
} }
@ -147,10 +179,21 @@ struct AstJsonEncoder : public AstVisitor
{ {
writeRaw(std::to_string(i)); writeRaw(std::to_string(i));
} }
void write(std::nullptr_t)
{
writeRaw("null");
}
void write(std::string_view str) void write(std::string_view str)
{ {
writeString(str); writeString(str);
} }
void write(std::optional<AstName> name)
{
if (name)
write(*name);
else
writeRaw("null");
}
void write(AstName name) void write(AstName name)
{ {
writeString(name.value ? name.value : ""); writeString(name.value ? name.value : "");
@ -174,7 +217,17 @@ struct AstJsonEncoder : public AstVisitor
void write(AstLocal* local) void write(AstLocal* local)
{ {
write(local->name); writeRaw("{");
bool c = pushComma();
if (local->annotation != nullptr)
write("luauType", local->annotation);
else
write("luauType", nullptr);
write("name", local->name);
writeType("AstLocal");
write("location", local->location);
popComma(c);
writeRaw("}");
} }
void writeNode(AstNode* node) void writeNode(AstNode* node)
@ -187,7 +240,7 @@ struct AstJsonEncoder : public AstVisitor
{ {
writeRaw("{"); writeRaw("{");
bool c = pushComma(); bool c = pushComma();
write("type", name); writeType(name);
writeNode(node); writeNode(node);
f(); f();
popComma(c); popComma(c);
@ -261,7 +314,7 @@ struct AstJsonEncoder : public AstVisitor
if (comma) if (comma)
writeRaw(","); writeRaw(",");
else else
comma = false; comma = true;
write(a); write(a);
} }
@ -311,7 +364,7 @@ struct AstJsonEncoder : public AstVisitor
if (node->self) if (node->self)
PROP(self); PROP(self);
PROP(args); PROP(args);
if (node->hasReturnAnnotation) if (node->returnAnnotation)
PROP(returnAnnotation); PROP(returnAnnotation);
PROP(vararg); PROP(vararg);
PROP(varargLocation); PROP(varargLocation);
@ -325,10 +378,19 @@ struct AstJsonEncoder : public AstVisitor
}); });
} }
void write(const std::optional<AstTypeList>& typeList)
{
if (typeList)
write(*typeList);
else
writeRaw("null");
}
void write(const AstTypeList& typeList) void write(const AstTypeList& typeList)
{ {
writeRaw("{"); writeRaw("{");
bool c = pushComma(); bool c = pushComma();
writeType("AstTypeList");
write("types", typeList.types); write("types", typeList.types);
if (typeList.tailType) if (typeList.tailType)
write("tailType", typeList.tailType); write("tailType", typeList.tailType);
@ -336,6 +398,30 @@ struct AstJsonEncoder : public AstVisitor
writeRaw("}"); writeRaw("}");
} }
void write(const AstGenericType& genericType)
{
writeRaw("{");
bool c = pushComma();
writeType("AstGenericType");
write("name", genericType.name);
if (genericType.defaultValue)
write("luauType", genericType.defaultValue);
popComma(c);
writeRaw("}");
}
void write(const AstGenericTypePack& genericTypePack)
{
writeRaw("{");
bool c = pushComma();
writeType("AstGenericTypePack");
write("name", genericTypePack.name);
if (genericTypePack.defaultValue)
write("luauType", genericTypePack.defaultValue);
popComma(c);
writeRaw("}");
}
void write(AstExprTable::Item::Kind kind) void write(AstExprTable::Item::Kind kind)
{ {
switch (kind) switch (kind)
@ -352,35 +438,46 @@ struct AstJsonEncoder : public AstVisitor
void write(const AstExprTable::Item& item) void write(const AstExprTable::Item& item)
{ {
writeRaw("{"); writeRaw("{");
bool comma = pushComma(); bool c = pushComma();
writeType("AstExprTableItem");
write("kind", item.kind); write("kind", item.kind);
switch (item.kind) switch (item.kind)
{ {
case AstExprTable::Item::List: case AstExprTable::Item::List:
write(item.value); write("value", item.value);
break; break;
default: default:
write(item.key); write("key", item.key);
writeRaw(","); write("value", item.value);
write(item.value);
break; break;
} }
popComma(comma); popComma(c);
writeRaw("}"); writeRaw("}");
} }
void write(class AstExprIfElse* node)
{
writeNode(node, "AstExprIfElse", [&]() {
PROP(condition);
PROP(hasThen);
PROP(trueExpr);
PROP(hasElse);
PROP(falseExpr);
});
}
void write(class AstExprInterpString* node)
{
writeNode(node, "AstExprInterpString", [&]() {
PROP(strings);
PROP(expressions);
});
}
void write(class AstExprTable* node) void write(class AstExprTable* node)
{ {
writeNode(node, "AstExprTable", [&]() { writeNode(node, "AstExprTable", [&]() {
bool comma = false; PROP(items);
for (const auto& prop : node->items)
{
if (comma)
writeRaw(",");
else
comma = false;
write(prop);
}
}); });
} }
@ -389,11 +486,11 @@ struct AstJsonEncoder : public AstVisitor
switch (op) switch (op)
{ {
case AstExprUnary::Not: case AstExprUnary::Not:
return writeString("not"); return writeString("Not");
case AstExprUnary::Minus: case AstExprUnary::Minus:
return writeString("minus"); return writeString("Minus");
case AstExprUnary::Len: case AstExprUnary::Len:
return writeString("len"); return writeString("Len");
} }
} }
@ -492,14 +589,14 @@ struct AstJsonEncoder : public AstVisitor
PROP(thenbody); PROP(thenbody);
if (node->elsebody) if (node->elsebody)
PROP(elsebody); PROP(elsebody);
PROP(hasThen); write("hasThen", node->thenLocation.has_value());
PROP(hasEnd); PROP(hasEnd);
}); });
} }
void write(class AstStatWhile* node) void write(class AstStatWhile* node)
{ {
writeNode(node, "AtStatWhile", [&]() { writeNode(node, "AstStatWhile", [&]() {
PROP(condition); PROP(condition);
PROP(body); PROP(body);
PROP(hasDo); PROP(hasDo);
@ -612,6 +709,7 @@ struct AstJsonEncoder : public AstVisitor
writeNode(node, "AstStatTypeAlias", [&]() { writeNode(node, "AstStatTypeAlias", [&]() {
PROP(name); PROP(name);
PROP(generics); PROP(generics);
PROP(genericPacks);
PROP(type); PROP(type);
PROP(exported); PROP(exported);
}); });
@ -641,7 +739,8 @@ struct AstJsonEncoder : public AstVisitor
writeRaw("{"); writeRaw("{");
bool c = pushComma(); bool c = pushComma();
write("name", prop.name); write("name", prop.name);
write("type", prop.ty); writeType("AstDeclaredClassProp");
write("luauType", prop.ty);
popComma(c); popComma(c);
writeRaw("}"); writeRaw("}");
} }
@ -664,13 +763,21 @@ struct AstJsonEncoder : public AstVisitor
}); });
} }
void write(struct AstTypeOrPack node)
{
if (node.type)
write(node.type);
else
write(node.typePack);
}
void write(class AstTypeReference* node) void write(class AstTypeReference* node)
{ {
writeNode(node, "AstTypeReference", [&]() { writeNode(node, "AstTypeReference", [&]() {
if (node->hasPrefix) if (node->prefix)
PROP(prefix); PROP(prefix);
PROP(name); PROP(name);
PROP(generics); PROP(parameters);
}); });
} }
@ -680,8 +787,9 @@ struct AstJsonEncoder : public AstVisitor
bool c = pushComma(); bool c = pushComma();
write("name", prop.name); write("name", prop.name);
writeType("AstTableProp");
write("location", prop.location); write("location", prop.location);
write("type", prop.type); write("propType", prop.type);
popComma(c); popComma(c);
writeRaw("}"); writeRaw("}");
@ -695,6 +803,24 @@ struct AstJsonEncoder : public AstVisitor
}); });
} }
void write(struct AstTableIndexer* indexer)
{
if (indexer)
{
writeRaw("{");
bool c = pushComma();
write("location", indexer->location);
write("indexType", indexer->indexType);
write("resultType", indexer->resultType);
popComma(c);
writeRaw("}");
}
else
{
writeRaw("null");
}
}
void write(class AstTypeFunction* node) void write(class AstTypeFunction* node)
{ {
writeNode(node, "AstTypeFunction", [&]() { writeNode(node, "AstTypeFunction", [&]() {
@ -734,6 +860,13 @@ struct AstJsonEncoder : public AstVisitor
}); });
} }
void write(class AstTypePackExplicit* node)
{
writeNode(node, "AstTypePackExplicit", [&]() {
PROP(typeList);
});
}
void write(class AstTypePackVariadic* node) void write(class AstTypePackVariadic* node)
{ {
writeNode(node, "AstTypePackVariadic", [&]() { writeNode(node, "AstTypePackVariadic", [&]() {
@ -778,6 +911,18 @@ struct AstJsonEncoder : public AstVisitor
return false; return false;
} }
bool visit(class AstExprIfElse* node) override
{
write(node);
return false;
}
bool visit(class AstExprInterpString* node) override
{
write(node);
return false;
}
bool visit(class AstExprLocal* node) override bool visit(class AstExprLocal* node) override
{ {
write(node); write(node);
@ -1018,6 +1163,12 @@ struct AstJsonEncoder : public AstVisitor
return false; return false;
} }
bool visit(class AstTypePackExplicit* node) override
{
write(node);
return false;
}
bool visit(class AstTypePackVariadic* node) override bool visit(class AstTypePackVariadic* node) override
{ {
write(node); write(node);
@ -1029,6 +1180,41 @@ struct AstJsonEncoder : public AstVisitor
write(node); write(node);
return false; return false;
} }
void writeComments(std::vector<Comment> commentLocations)
{
bool commentComma = false;
for (Comment comment : commentLocations)
{
if (commentComma)
{
writeRaw(",");
}
else
{
commentComma = true;
}
writeRaw("{");
bool c = pushComma();
switch (comment.type)
{
case Lexeme::Comment:
writeType("Comment");
break;
case Lexeme::BlockComment:
writeType("BlockComment");
break;
case Lexeme::BrokenComment:
writeType("BrokenComment");
break;
default:
break;
}
write("location", comment.location);
popComma(c);
writeRaw("}");
}
}
}; };
std::string toJson(AstNode* node) std::string toJson(AstNode* node)
@ -1038,4 +1224,15 @@ std::string toJson(AstNode* node)
return encoder.str(); return encoder.str();
} }
std::string toJson(AstNode* node, const std::vector<Comment>& commentLocations)
{
AstJsonEncoder encoder;
encoder.writeRaw(R"({"root":)");
node->visit(&encoder);
encoder.writeRaw(R"(,"commentLocations":[)");
encoder.writeComments(commentLocations);
encoder.writeRaw("]}");
return encoder.str();
}
} // namespace Luau } // namespace Luau

View file

@ -2,6 +2,7 @@
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
@ -10,12 +11,123 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAG(LuauCompleteTableKeysBetter);
namespace Luau namespace Luau
{ {
namespace namespace
{ {
struct AutocompleteNodeFinder : public AstVisitor
{
const Position pos;
std::vector<AstNode*> ancestry;
explicit AutocompleteNodeFinder(Position pos, AstNode* root)
: pos(pos)
{
}
bool visit(AstExpr* expr) override
{
if (FFlag::LuauCompleteTableKeysBetter)
{
if (expr->location.begin <= pos && pos <= expr->location.end)
{
ancestry.push_back(expr);
return true;
}
return false;
}
else
{
if (expr->location.begin < pos && pos <= expr->location.end)
{
ancestry.push_back(expr);
return true;
}
return false;
}
}
bool visit(AstStat* stat) override
{
if (stat->location.begin < pos && pos <= stat->location.end)
{
ancestry.push_back(stat);
return true;
}
return false;
}
bool visit(AstType* type) override
{
if (type->location.begin < pos && pos <= type->location.end)
{
ancestry.push_back(type);
return true;
}
return false;
}
bool visit(AstTypeError* type) override
{
// For a missing type, match the whole range including the start position
if (type->isMissing && type->location.containsClosed(pos))
{
ancestry.push_back(type);
return true;
}
return false;
}
bool visit(class AstTypePack* typePack) override
{
return true;
}
bool visit(AstStatBlock* block) override
{
// If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite.
if (ancestry.empty())
{
ancestry.push_back(block);
return true;
}
// AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes.
// ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz}
if (!ancestry.empty() && ancestry.back()->is<AstExprIndexName>())
return false;
// Type annotation error might intersect the block statement when the function header is being written,
// annotation takes priority
if (!ancestry.empty() && ancestry.back()->is<AstTypeError>())
return false;
// If the cursor is at the end of an expression or type and simultaneously at the beginning of a block,
// the expression or type wins out.
// The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to
// be within the block.
if (block->location.begin == pos && !ancestry.empty())
{
if (ancestry.back()->asExpr() && !ancestry.back()->is<AstExprFunction>())
return false;
if (ancestry.back()->asType())
return false;
}
if (block->location.begin <= pos && pos <= block->location.end)
{
ancestry.push_back(block);
return true;
}
return false;
}
};
struct FindNode : public AstVisitor struct FindNode : public AstVisitor
{ {
const Position pos; const Position pos;
@ -70,9 +182,11 @@ struct FindFullAncestry final : public AstVisitor
{ {
std::vector<AstNode*> nodes; std::vector<AstNode*> nodes;
Position pos; Position pos;
Position documentEnd;
explicit FindFullAncestry(Position pos) explicit FindFullAncestry(Position pos, Position documentEnd)
: pos(pos) : pos(pos)
, documentEnd(documentEnd)
{ {
} }
@ -83,17 +197,38 @@ struct FindFullAncestry final : public AstVisitor
nodes.push_back(node); nodes.push_back(node);
return true; return true;
} }
// Edge case: If we ask for the node at the position that is the very end of the document
// return the innermost AST element that ends at that position.
if (node->location.end == documentEnd && pos >= documentEnd)
{
nodes.push_back(node);
return true;
}
return false; return false;
} }
}; };
} // namespace } // namespace
std::vector<AstNode*> findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos)
{
AutocompleteNodeFinder finder{pos, source.root};
source.root->visit(&finder);
return finder.ancestry;
}
std::vector<AstNode*> findAstAncestryOfPosition(const SourceModule& source, Position pos) std::vector<AstNode*> findAstAncestryOfPosition(const SourceModule& source, Position pos)
{ {
FindFullAncestry finder(pos); const Position end = source.root->location.end;
if (pos > end)
pos = end;
FindFullAncestry finder(pos, end);
source.root->visit(&finder); source.root->visit(&finder);
return std::move(finder.nodes); return finder.nodes;
} }
AstNode* findNodeAtPosition(const SourceModule& source, Position pos) AstNode* findNodeAtPosition(const SourceModule& source, Position pos)
@ -143,8 +278,8 @@ std::optional<TypeId> findTypeAtPosition(const Module& module, const SourceModul
{ {
if (auto expr = findExprAtPosition(sourceModule, pos)) if (auto expr = findExprAtPosition(sourceModule, pos))
{ {
if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) if (auto it = module.astTypes.find(expr))
return it->second; return *it;
} }
return std::nullopt; return std::nullopt;
@ -154,8 +289,8 @@ std::optional<TypeId> findExpectedTypeAtPosition(const Module& module, const Sou
{ {
if (auto expr = findExprAtPosition(sourceModule, pos)) if (auto expr = findExprAtPosition(sourceModule, pos))
{ {
if (auto it = module.astExpectedTypes.find(expr); it != module.astExpectedTypes.end()) if (auto it = module.astExpectedTypes.find(expr))
return it->second; return *it;
} }
return std::nullopt; return std::nullopt;
@ -192,7 +327,7 @@ std::optional<Binding> findBindingAtPosition(const Module& module, const SourceM
auto iter = currentScope->bindings.find(name); auto iter = currentScope->bindings.find(name);
if (iter != currentScope->bindings.end() && iter->second.location.begin <= pos) if (iter != currentScope->bindings.end() && iter->second.location.begin <= pos)
{ {
/* Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope */ // Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope
std::optional<AstStatLocal*> bindingStatement = findBindingLocalStatement(source, iter->second); std::optional<AstStatLocal*> bindingStatement = findBindingLocalStatement(source, iter->second);
if (!bindingStatement || !(*bindingStatement)->location.contains(pos)) if (!bindingStatement || !(*bindingStatement)->location.contains(pos))
return iter->second; return iter->second;
@ -305,6 +440,36 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos)
return findVisitor.result; return findVisitor.result;
} }
static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol(
const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional<DocumentationSymbol> documentationSymbol)
{
if (!documentationSymbol)
return std::nullopt;
// This might be an overloaded function.
if (get<IntersectionTypeVar>(follow(ty)))
{
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
{
matchingOverload = *it;
}
}
if (matchingOverload)
{
std::string overloadSymbol = *documentationSymbol + "/overload/";
// Default toString options are fine for this purpose.
overloadSymbol += toString(matchingOverload);
return overloadSymbol;
}
}
return documentationSymbol;
}
std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position)
{ {
std::vector<AstNode*> ancestry = findAstAncestryOfPosition(source, position); std::vector<AstNode*> ancestry = findAstAncestryOfPosition(source, position);
@ -313,54 +478,24 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
AstExpr* parentExpr = ancestry.size() >= 2 ? ancestry[ancestry.size() - 2]->asExpr() : nullptr; AstExpr* parentExpr = ancestry.size() >= 2 ? ancestry[ancestry.size() - 2]->asExpr() : nullptr;
if (std::optional<Binding> binding = findBindingAtPosition(module, source, position)) if (std::optional<Binding> binding = findBindingAtPosition(module, source, position))
{ return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol);
if (binding->documentationSymbol)
{
// This might be an overloaded function binding.
if (get<IntersectionTypeVar>(follow(binding->typeId)))
{
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{
if (auto it = module.astOverloadResolvedTypes.find(parentExpr); it != module.astOverloadResolvedTypes.end())
{
matchingOverload = it->second;
}
}
if (matchingOverload)
{
std::string overloadSymbol = *binding->documentationSymbol + "/overload/";
// Default toString options are fine for this purpose.
overloadSymbol += toString(matchingOverload);
return overloadSymbol;
}
}
}
return binding->documentationSymbol;
}
if (targetExpr) if (targetExpr)
{ {
if (AstExprIndexName* indexName = targetExpr->as<AstExprIndexName>()) if (AstExprIndexName* indexName = targetExpr->as<AstExprIndexName>())
{ {
if (auto it = module.astTypes.find(indexName->expr); it != module.astTypes.end()) if (auto it = module.astTypes.find(indexName->expr))
{ {
TypeId parentTy = follow(it->second); TypeId parentTy = follow(*it);
if (const TableTypeVar* ttv = get<TableTypeVar>(parentTy)) if (const TableTypeVar* ttv = get<TableTypeVar>(parentTy))
{ {
if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end())
{ return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
return propIt->second.documentationSymbol;
}
} }
else if (const ClassTypeVar* ctv = get<ClassTypeVar>(parentTy)) else if (const ClassTypeVar* ctv = get<ClassTypeVar>(parentTy))
{ {
if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end())
{ return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
return propIt->second.documentationSymbol;
}
} }
} }
} }

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

504
Analysis/src/Clone.cpp Normal file
View file

@ -0,0 +1,504 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Clone.h"
#include "Luau/RecursionCounter.h"
#include "Luau/TxnLog.h"
#include "Luau/TypePack.h"
#include "Luau/Unifiable.h"
LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing)
LUAU_FASTFLAG(LuauClonePublicInterfaceLess)
LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300)
namespace Luau
{
namespace
{
struct TypePackCloner;
/*
* Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set.
* They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage.
*/
struct TypeCloner
{
TypeCloner(TypeArena& dest, TypeId typeId, CloneState& cloneState)
: dest(dest)
, typeId(typeId)
, seenTypes(cloneState.seenTypes)
, seenTypePacks(cloneState.seenTypePacks)
, cloneState(cloneState)
{
}
TypeArena& dest;
TypeId typeId;
SeenTypes& seenTypes;
SeenTypePacks& seenTypePacks;
CloneState& cloneState;
template<typename T>
void defaultClone(const T& t);
void operator()(const Unifiable::Free& t);
void operator()(const Unifiable::Generic& t);
void operator()(const Unifiable::Bound<TypeId>& t);
void operator()(const Unifiable::Error& t);
void operator()(const BlockedTypeVar& t);
void operator()(const PendingExpansionTypeVar& t);
void operator()(const PrimitiveTypeVar& t);
void operator()(const SingletonTypeVar& t);
void operator()(const FunctionTypeVar& t);
void operator()(const TableTypeVar& t);
void operator()(const MetatableTypeVar& t);
void operator()(const ClassTypeVar& t);
void operator()(const AnyTypeVar& t);
void operator()(const UnionTypeVar& t);
void operator()(const IntersectionTypeVar& t);
void operator()(const LazyTypeVar& t);
void operator()(const UnknownTypeVar& t);
void operator()(const NeverTypeVar& t);
void operator()(const NegationTypeVar& t);
};
struct TypePackCloner
{
TypeArena& dest;
TypePackId typePackId;
SeenTypes& seenTypes;
SeenTypePacks& seenTypePacks;
CloneState& cloneState;
TypePackCloner(TypeArena& dest, TypePackId typePackId, CloneState& cloneState)
: dest(dest)
, typePackId(typePackId)
, seenTypes(cloneState.seenTypes)
, seenTypePacks(cloneState.seenTypePacks)
, cloneState(cloneState)
{
}
template<typename T>
void defaultClone(const T& t)
{
TypePackId cloned = dest.addTypePack(TypePackVar{t});
seenTypePacks[typePackId] = cloned;
}
void operator()(const Unifiable::Free& t)
{
defaultClone(t);
}
void operator()(const Unifiable::Generic& t)
{
defaultClone(t);
}
void operator()(const Unifiable::Error& t)
{
defaultClone(t);
}
void operator()(const BlockedTypePack& t)
{
defaultClone(t);
}
// While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter.
// We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer.
void operator()(const Unifiable::Bound<TypePackId>& t)
{
TypePackId cloned = clone(t.boundTo, dest, cloneState);
if (FFlag::DebugLuauCopyBeforeNormalizing)
cloned = dest.addTypePack(TypePackVar{BoundTypePack{cloned}});
seenTypePacks[typePackId] = cloned;
}
void operator()(const VariadicTypePack& t)
{
TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, cloneState), /*hidden*/ t.hidden}});
seenTypePacks[typePackId] = cloned;
}
void operator()(const TypePack& t)
{
TypePackId cloned = dest.addTypePack(TypePack{});
TypePack* destTp = getMutable<TypePack>(cloned);
LUAU_ASSERT(destTp != nullptr);
seenTypePacks[typePackId] = cloned;
for (TypeId ty : t.head)
destTp->head.push_back(clone(ty, dest, cloneState));
if (t.tail)
destTp->tail = clone(*t.tail, dest, cloneState);
}
};
template<typename T>
void TypeCloner::defaultClone(const T& t)
{
TypeId cloned = dest.addType(t);
seenTypes[typeId] = cloned;
}
void TypeCloner::operator()(const Unifiable::Free& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const Unifiable::Generic& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const Unifiable::Bound<TypeId>& t)
{
TypeId boundTo = clone(t.boundTo, dest, cloneState);
if (FFlag::DebugLuauCopyBeforeNormalizing)
boundTo = dest.addType(BoundTypeVar{boundTo});
seenTypes[typeId] = boundTo;
}
void TypeCloner::operator()(const Unifiable::Error& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const BlockedTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const PendingExpansionTypeVar& t)
{
TypeId res = dest.addType(PendingExpansionTypeVar{t.prefix, t.name, t.typeArguments, t.packArguments});
PendingExpansionTypeVar* petv = getMutable<PendingExpansionTypeVar>(res);
LUAU_ASSERT(petv);
seenTypes[typeId] = res;
std::vector<TypeId> typeArguments;
for (TypeId arg : t.typeArguments)
typeArguments.push_back(clone(arg, dest, cloneState));
std::vector<TypePackId> packArguments;
for (TypePackId arg : t.packArguments)
packArguments.push_back(clone(arg, dest, cloneState));
petv->typeArguments = std::move(typeArguments);
petv->packArguments = std::move(packArguments);
}
void TypeCloner::operator()(const PrimitiveTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const SingletonTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const FunctionTypeVar& t)
{
// FISHY: We always erase the scope when we clone things. clone() was
// originally written so that we could copy a module's type surface into an
// export arena. This probably dates to that.
TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(result);
LUAU_ASSERT(ftv != nullptr);
seenTypes[typeId] = result;
for (TypeId generic : t.generics)
ftv->generics.push_back(clone(generic, dest, cloneState));
for (TypePackId genericPack : t.genericPacks)
ftv->genericPacks.push_back(clone(genericPack, dest, cloneState));
ftv->tags = t.tags;
ftv->argTypes = clone(t.argTypes, dest, cloneState);
ftv->argNames = t.argNames;
ftv->retTypes = clone(t.retTypes, dest, cloneState);
ftv->hasNoGenerics = t.hasNoGenerics;
}
void TypeCloner::operator()(const TableTypeVar& t)
{
// If table is now bound to another one, we ignore the content of the original
if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo)
{
TypeId boundTo = clone(*t.boundTo, dest, cloneState);
seenTypes[typeId] = boundTo;
return;
}
TypeId result = dest.addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(result);
LUAU_ASSERT(ttv != nullptr);
*ttv = t;
seenTypes[typeId] = result;
ttv->level = TypeLevel{0, 0};
if (FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, cloneState);
for (const auto& [name, prop] : t.props)
ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags};
if (t.indexer)
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)};
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = clone(arg, dest, cloneState);
for (TypePackId& arg : ttv->instantiatedTypePackParams)
arg = clone(arg, dest, cloneState);
ttv->definitionModuleName = t.definitionModuleName;
ttv->tags = t.tags;
}
void TypeCloner::operator()(const MetatableTypeVar& t)
{
TypeId result = dest.addType(MetatableTypeVar{});
MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(result);
seenTypes[typeId] = result;
mtv->table = clone(t.table, dest, cloneState);
mtv->metatable = clone(t.metatable, dest, cloneState);
}
void TypeCloner::operator()(const ClassTypeVar& t)
{
TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData, t.definitionModuleName});
ClassTypeVar* ctv = getMutable<ClassTypeVar>(result);
seenTypes[typeId] = result;
for (const auto& [name, prop] : t.props)
ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags};
if (t.parent)
ctv->parent = clone(*t.parent, dest, cloneState);
if (t.metatable)
ctv->metatable = clone(*t.metatable, dest, cloneState);
}
void TypeCloner::operator()(const AnyTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const UnionTypeVar& t)
{
std::vector<TypeId> options;
options.reserve(t.options.size());
for (TypeId ty : t.options)
options.push_back(clone(ty, dest, cloneState));
TypeId result = dest.addType(UnionTypeVar{std::move(options)});
seenTypes[typeId] = result;
}
void TypeCloner::operator()(const IntersectionTypeVar& t)
{
TypeId result = dest.addType(IntersectionTypeVar{});
seenTypes[typeId] = result;
IntersectionTypeVar* option = getMutable<IntersectionTypeVar>(result);
LUAU_ASSERT(option != nullptr);
for (TypeId ty : t.parts)
option->parts.push_back(clone(ty, dest, cloneState));
}
void TypeCloner::operator()(const LazyTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const UnknownTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const NeverTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const NegationTypeVar& t)
{
TypeId result = dest.addType(AnyTypeVar{});
seenTypes[typeId] = result;
TypeId ty = clone(t.ty, dest, cloneState);
asMutable(result)->ty = NegationTypeVar{ty};
}
} // anonymous namespace
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
{
if (tp->persistent)
return tp;
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
TypePackId& res = cloneState.seenTypePacks[tp];
if (res == nullptr)
{
TypePackCloner cloner{dest, tp, cloneState};
Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into.
}
return res;
}
TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
{
if (typeId->persistent)
return typeId;
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
TypeId& res = cloneState.seenTypes[typeId];
if (res == nullptr)
{
TypeCloner cloner{dest, typeId, cloneState};
Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into.
// Persistent types are not being cloned and we get the original type back which might be read-only
if (!res->persistent)
{
asMutable(res)->documentationSymbol = typeId->documentationSymbol;
}
}
return res;
}
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
{
TypeFun result;
for (auto param : typeFun.typeParams)
{
TypeId ty = clone(param.ty, dest, cloneState);
std::optional<TypeId> defaultValue;
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, cloneState);
result.typeParams.push_back({ty, defaultValue});
}
for (auto param : typeFun.typePackParams)
{
TypePackId tp = clone(param.tp, dest, cloneState);
std::optional<TypePackId> defaultValue;
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, cloneState);
result.typePackParams.push_back({tp, defaultValue});
}
result.type = clone(typeFun.type, dest, cloneState);
return result;
}
TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone)
{
ty = log->follow(ty);
TypeId result = ty;
if (auto pty = log->pending(ty))
ty = &pty->pending;
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{
FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf};
clone.generics = ftv->generics;
clone.genericPacks = ftv->genericPacks;
clone.magicFunction = ftv->magicFunction;
clone.dcrMagicFunction = ftv->dcrMagicFunction;
clone.tags = ftv->tags;
clone.argNames = ftv->argNames;
result = dest.addType(std::move(clone));
}
else if (const TableTypeVar* ttv = get<TableTypeVar>(ty))
{
LUAU_ASSERT(!ttv->boundTo);
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state};
clone.definitionModuleName = ttv->definitionModuleName;
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.instantiatedTypeParams = ttv->instantiatedTypeParams;
clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams;
clone.tags = ttv->tags;
result = dest.addType(std::move(clone));
}
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{
MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable};
clone.syntheticName = mtv->syntheticName;
result = dest.addType(std::move(clone));
}
else if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
{
UnionTypeVar clone;
clone.options = utv->options;
result = dest.addType(std::move(clone));
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(ty))
{
IntersectionTypeVar clone;
clone.parts = itv->parts;
result = dest.addType(std::move(clone));
}
else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty))
{
PendingExpansionTypeVar clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments};
result = dest.addType(std::move(clone));
}
else if (const ClassTypeVar* ctv = get<ClassTypeVar>(ty); FFlag::LuauClonePublicInterfaceLess && ctv && alwaysClone)
{
ClassTypeVar clone{ctv->name, ctv->props, ctv->parent, ctv->metatable, ctv->tags, ctv->userData, ctv->definitionModuleName};
result = dest.addType(std::move(clone));
}
else if (FFlag::LuauClonePublicInterfaceLess && alwaysClone)
{
result = dest.addType(*ty);
}
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(ty))
{
result = dest.addType(NegationTypeVar{ntv->ty});
}
else
return result;
asMutable(result)->documentationSymbol = ty->documentationSymbol;
return result;
}
TypeId shallowClone(TypeId ty, NotNull<TypeArena> dest)
{
return shallowClone(ty, *dest, TxnLog::empty());
}
} // namespace Luau

View file

@ -1,18 +1,21 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Config.h" #include "Luau/Config.h"
#include "Luau/Parser.h" #include "Luau/Lexer.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
namespace LUAU_FASTFLAGVARIABLE(LuauEnableNonstrictByDefaultForLuauConfig, false)
namespace Luau
{ {
using Error = std::optional<std::string>; using Error = std::optional<std::string>;
} Config::Config()
: mode(FFlag::LuauEnableNonstrictByDefaultForLuauConfig ? Mode::Nonstrict : Mode::NoCheck)
namespace Luau
{ {
enabledLint.setDefaults();
}
static Error parseBoolean(bool& result, const std::string& value) static Error parseBoolean(bool& result, const std::string& value)
{ {

View file

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Connective.h"
namespace Luau
{
ConnectiveId ConnectiveArena::negation(ConnectiveId connective)
{
return NotNull{allocator.allocate(Negation{connective})};
}
ConnectiveId ConnectiveArena::conjunction(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Conjunction{lhs, rhs})};
}
ConnectiveId ConnectiveArena::disjunction(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Disjunction{lhs, rhs})};
}
ConnectiveId ConnectiveArena::equivalence(ConnectiveId lhs, ConnectiveId rhs)
{
return NotNull{allocator.allocate(Equivalence{lhs, rhs})};
}
ConnectiveId ConnectiveArena::proposition(DefId def, TypeId discriminantTy)
{
return NotNull{allocator.allocate(Proposition{def, discriminantTy})};
}
} // namespace Luau

View file

@ -0,0 +1,15 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Constraint.h"
namespace Luau
{
Constraint::Constraint(NotNull<Scope> scope, const Location& location, ConstraintV&& c)
: scope(scope)
, location(location)
, c(std::move(c))
{
}
} // namespace Luau

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,451 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/DataFlowGraph.h"
#include "Luau/Error.h"
LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau
{
std::optional<DefId> DataFlowGraph::getDef(const AstExpr* expr) const
{
// We need to skip through AstExprGroup because DFG doesn't try its best to transitively
while (auto group = expr->as<AstExprGroup>())
expr = group->expr;
if (auto def = astDefs.find(expr))
return NotNull{*def};
return std::nullopt;
}
std::optional<DefId> DataFlowGraph::getDef(const AstLocal* local) const
{
if (auto def = localDefs.find(local))
return NotNull{*def};
return std::nullopt;
}
std::optional<DefId> DataFlowGraph::getDef(const Symbol& symbol) const
{
if (symbol.local)
return getDef(symbol.local);
else
return std::nullopt;
}
DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalErrorReporter> handle)
{
LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution);
DataFlowGraphBuilder builder;
builder.handle = handle;
builder.visit(nullptr, block); // nullptr is the root DFG scope.
if (FFlag::DebugLuauFreezeArena)
builder.arena->allocator.freeze();
return std::move(builder.graph);
}
DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope)
{
return scopes.emplace_back(new DfgScope{scope}).get();
}
std::optional<DefId> DataFlowGraphBuilder::use(DfgScope* scope, Symbol symbol, AstExpr* e)
{
for (DfgScope* current = scope; current; current = current->parent)
{
if (auto def = current->bindings.find(symbol))
{
graph.astDefs[e] = *def;
return NotNull{*def};
}
}
return std::nullopt;
}
DefId DataFlowGraphBuilder::use(DefId def, AstExprIndexName* e)
{
auto& propertyDef = props[def][e->index.value];
if (!propertyDef)
propertyDef = arena->freshCell(def, e->index.value);
graph.astDefs[e] = propertyDef;
return NotNull{propertyDef};
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b)
{
DfgScope* child = childScope(scope);
return visitBlockWithoutChildScope(child, b);
}
void DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b)
{
for (AstStat* s : b->body)
visit(scope, s);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s)
{
if (auto b = s->as<AstStatBlock>())
return visit(scope, b);
else if (auto i = s->as<AstStatIf>())
return visit(scope, i);
else if (auto w = s->as<AstStatWhile>())
return visit(scope, w);
else if (auto r = s->as<AstStatRepeat>())
return visit(scope, r);
else if (auto b = s->as<AstStatBreak>())
return visit(scope, b);
else if (auto c = s->as<AstStatContinue>())
return visit(scope, c);
else if (auto r = s->as<AstStatReturn>())
return visit(scope, r);
else if (auto e = s->as<AstStatExpr>())
return visit(scope, e);
else if (auto l = s->as<AstStatLocal>())
return visit(scope, l);
else if (auto f = s->as<AstStatFor>())
return visit(scope, f);
else if (auto f = s->as<AstStatForIn>())
return visit(scope, f);
else if (auto a = s->as<AstStatAssign>())
return visit(scope, a);
else if (auto c = s->as<AstStatCompoundAssign>())
return visit(scope, c);
else if (auto f = s->as<AstStatFunction>())
return visit(scope, f);
else if (auto l = s->as<AstStatLocalFunction>())
return visit(scope, l);
else if (auto t = s->as<AstStatTypeAlias>())
return; // ok
else if (auto d = s->as<AstStatDeclareFunction>())
return; // ok
else if (auto d = s->as<AstStatDeclareGlobal>())
return; // ok
else if (auto d = s->as<AstStatDeclareFunction>())
return; // ok
else if (auto d = s->as<AstStatDeclareClass>())
return; // ok
else if (auto _ = s->as<AstStatError>())
return; // ok
else
handle->ice("Unknown AstStat in DataFlowGraphBuilder");
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i)
{
DfgScope* condScope = childScope(scope);
visitExpr(condScope, i->condition);
visit(condScope, i->thenbody);
if (i->elsebody)
visit(scope, i->elsebody);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w)
{
// TODO(controlflow): entry point has a back edge from exit point
DfgScope* whileScope = childScope(scope);
visitExpr(whileScope, w->condition);
visit(whileScope, w->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r)
{
// TODO(controlflow): entry point has a back edge from exit point
DfgScope* repeatScope = childScope(scope); // TODO: loop scope.
visitBlockWithoutChildScope(repeatScope, r->body);
visitExpr(repeatScope, r->condition);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b)
{
// TODO: Control flow analysis
return; // ok
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c)
{
// TODO: Control flow analysis
return; // ok
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r)
{
// TODO: Control flow analysis
for (AstExpr* e : r->list)
visitExpr(scope, e);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e)
{
visitExpr(scope, e->expr);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l)
{
// TODO: alias tracking
for (AstExpr* e : l->values)
visitExpr(scope, e);
for (AstLocal* local : l->vars)
{
DefId def = arena->freshCell();
graph.localDefs[local] = def;
scope->bindings[local] = def;
}
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f)
{
DfgScope* forScope = childScope(scope); // TODO: loop scope.
DefId def = arena->freshCell();
graph.localDefs[f->var] = def;
scope->bindings[f->var] = def;
// TODO(controlflow): entry point has a back edge from exit point
visit(forScope, f->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f)
{
DfgScope* forScope = childScope(scope); // TODO: loop scope.
for (AstLocal* local : f->vars)
{
DefId def = arena->freshCell();
graph.localDefs[local] = def;
forScope->bindings[local] = def;
}
// TODO(controlflow): entry point has a back edge from exit point
for (AstExpr* e : f->values)
visitExpr(forScope, e);
visit(forScope, f->body);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a)
{
for (AstExpr* r : a->values)
visitExpr(scope, r);
for (AstExpr* l : a->vars)
{
AstExpr* root = l;
bool isUpdatable = true;
while (true)
{
if (root->is<AstExprLocal>() || root->is<AstExprGlobal>())
break;
AstExprIndexName* indexName = root->as<AstExprIndexName>();
if (!indexName)
{
isUpdatable = false;
break;
}
root = indexName->expr;
}
if (isUpdatable)
{
// TODO global?
if (auto exprLocal = root->as<AstExprLocal>())
{
DefId def = arena->freshCell();
graph.astDefs[exprLocal] = def;
// Update the def in the scope that introduced the local. Not
// the current scope.
AstLocal* local = exprLocal->local;
DfgScope* s = scope;
while (s && !s->bindings.find(local))
s = s->parent;
LUAU_ASSERT(s && s->bindings.find(local));
s->bindings[local] = def;
}
}
visitExpr(scope, l); // TODO: they point to a new def!!
}
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c)
{
// TODO(typestates): The lhs is being read and written to. This might or might not be annoying.
visitExpr(scope, c->value);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f)
{
visitExpr(scope, f->name);
visitExpr(scope, f->func);
}
void DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l)
{
DefId def = arena->freshCell();
graph.localDefs[l->name] = def;
scope->bindings[l->name] = def;
visitExpr(scope, l->func);
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e)
{
if (auto g = e->as<AstExprGroup>())
return visitExpr(scope, g->expr);
else if (auto c = e->as<AstExprConstantNil>())
return {}; // ok
else if (auto c = e->as<AstExprConstantBool>())
return {}; // ok
else if (auto c = e->as<AstExprConstantNumber>())
return {}; // ok
else if (auto c = e->as<AstExprConstantString>())
return {}; // ok
else if (auto l = e->as<AstExprLocal>())
return visitExpr(scope, l);
else if (auto g = e->as<AstExprGlobal>())
return visitExpr(scope, g);
else if (auto v = e->as<AstExprVarargs>())
return {}; // ok
else if (auto c = e->as<AstExprCall>())
return visitExpr(scope, c);
else if (auto i = e->as<AstExprIndexName>())
return visitExpr(scope, i);
else if (auto i = e->as<AstExprIndexExpr>())
return visitExpr(scope, i);
else if (auto f = e->as<AstExprFunction>())
return visitExpr(scope, f);
else if (auto t = e->as<AstExprTable>())
return visitExpr(scope, t);
else if (auto u = e->as<AstExprUnary>())
return visitExpr(scope, u);
else if (auto b = e->as<AstExprBinary>())
return visitExpr(scope, b);
else if (auto t = e->as<AstExprTypeAssertion>())
return visitExpr(scope, t);
else if (auto i = e->as<AstExprIfElse>())
return visitExpr(scope, i);
else if (auto i = e->as<AstExprInterpString>())
return visitExpr(scope, i);
else if (auto _ = e->as<AstExprError>())
return {}; // ok
else
handle->ice("Unknown AstExpr in DataFlowGraphBuilder");
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l)
{
return {use(scope, l->local, l)};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g)
{
return {use(scope, g->name, g)};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c)
{
visitExpr(scope, c->func);
for (AstExpr* arg : c->args)
visitExpr(scope, arg);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i)
{
std::optional<DefId> def = visitExpr(scope, i->expr).def;
if (!def)
return {};
return {use(*def, i)};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i)
{
visitExpr(scope, i->expr);
visitExpr(scope, i->expr);
if (i->index->as<AstExprConstantString>())
{
// TODO: properties for the def
}
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f)
{
if (AstLocal* self = f->self)
{
DefId def = arena->freshCell();
graph.localDefs[self] = def;
scope->bindings[self] = def;
}
for (AstLocal* param : f->args)
{
DefId def = arena->freshCell();
graph.localDefs[param] = def;
scope->bindings[param] = def;
}
visit(scope, f->body);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t)
{
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u)
{
visitExpr(scope, u->expr);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b)
{
visitExpr(scope, b->left);
visitExpr(scope, b->right);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t)
{
ExpressionFlowGraph result = visitExpr(scope, t->expr);
// TODO: visit type
return result;
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i)
{
DfgScope* condScope = childScope(scope);
visitExpr(condScope, i->condition);
visitExpr(condScope, i->trueExpr);
visitExpr(scope, i->falseExpr);
return {};
}
ExpressionFlowGraph DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i)
{
for (AstExpr* e : i->expressions)
visitExpr(scope, e);
return {};
}
} // namespace Luau

396
Analysis/src/DcrLogger.cpp Normal file
View file

@ -0,0 +1,396 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/DcrLogger.h"
#include <algorithm>
#include "Luau/JsonEmitter.h"
namespace Luau
{
namespace Json
{
void write(JsonEmitter& emitter, const Location& location)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("beginLine", location.begin.line);
o.writePair("beginColumn", location.begin.column);
o.writePair("endLine", location.end.line);
o.writePair("endColumn", location.end.column);
o.finish();
}
void write(JsonEmitter& emitter, const ErrorSnapshot& snapshot)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("message", snapshot.message);
o.writePair("location", snapshot.location);
o.finish();
}
void write(JsonEmitter& emitter, const BindingSnapshot& snapshot)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("typeId", snapshot.typeId);
o.writePair("typeString", snapshot.typeString);
o.writePair("location", snapshot.location);
o.finish();
}
void write(JsonEmitter& emitter, const TypeBindingSnapshot& snapshot)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("typeId", snapshot.typeId);
o.writePair("typeString", snapshot.typeString);
o.finish();
}
void write(JsonEmitter& emitter, const ConstraintGenerationLog& log)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("source", log.source);
emitter.writeComma();
write(emitter, "constraintLocations");
emitter.writeRaw(":");
ObjectEmitter locationEmitter = emitter.writeObject();
for (const auto& [id, location] : log.constraintLocations)
{
locationEmitter.writePair(id, location);
}
locationEmitter.finish();
o.writePair("errors", log.errors);
o.finish();
}
void write(JsonEmitter& emitter, const ScopeSnapshot& snapshot)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("bindings", snapshot.bindings);
o.writePair("typeBindings", snapshot.typeBindings);
o.writePair("typePackBindings", snapshot.typePackBindings);
o.writePair("children", snapshot.children);
o.finish();
}
void write(JsonEmitter& emitter, const ConstraintBlockKind& kind)
{
switch (kind)
{
case ConstraintBlockKind::TypeId:
return write(emitter, "type");
case ConstraintBlockKind::TypePackId:
return write(emitter, "typePack");
case ConstraintBlockKind::ConstraintId:
return write(emitter, "constraint");
default:
LUAU_ASSERT(0);
}
}
void write(JsonEmitter& emitter, const ConstraintBlock& block)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("kind", block.kind);
o.writePair("stringification", block.stringification);
o.finish();
}
void write(JsonEmitter& emitter, const ConstraintSnapshot& snapshot)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("stringification", snapshot.stringification);
o.writePair("blocks", snapshot.blocks);
o.finish();
}
void write(JsonEmitter& emitter, const BoundarySnapshot& snapshot)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("rootScope", snapshot.rootScope);
o.writePair("constraints", snapshot.constraints);
o.finish();
}
void write(JsonEmitter& emitter, const StepSnapshot& snapshot)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("currentConstraint", snapshot.currentConstraint);
o.writePair("forced", snapshot.forced);
o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints);
o.writePair("rootScope", snapshot.rootScope);
o.finish();
}
void write(JsonEmitter& emitter, const TypeSolveLog& log)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("initialState", log.initialState);
o.writePair("stepStates", log.stepStates);
o.writePair("finalState", log.finalState);
o.finish();
}
void write(JsonEmitter& emitter, const TypeCheckLog& log)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("errors", log.errors);
o.finish();
}
} // namespace Json
static std::string toPointerId(NotNull<const Constraint> ptr)
{
return std::to_string(reinterpret_cast<size_t>(ptr.get()));
}
static ScopeSnapshot snapshotScope(const Scope* scope, ToStringOptions& opts)
{
std::unordered_map<Name, BindingSnapshot> bindings;
std::unordered_map<Name, TypeBindingSnapshot> typeBindings;
std::unordered_map<Name, TypeBindingSnapshot> typePackBindings;
std::vector<ScopeSnapshot> children;
for (const auto& [name, binding] : scope->bindings)
{
std::string id = std::to_string(reinterpret_cast<size_t>(binding.typeId));
ToStringResult result = toStringDetailed(binding.typeId, opts);
bindings[name.c_str()] = BindingSnapshot{
id,
result.name,
binding.location,
};
}
for (const auto& [name, tf] : scope->exportedTypeBindings)
{
std::string id = std::to_string(reinterpret_cast<size_t>(tf.type));
typeBindings[name] = TypeBindingSnapshot{
id,
toString(tf.type, opts),
};
}
for (const auto& [name, tf] : scope->privateTypeBindings)
{
std::string id = std::to_string(reinterpret_cast<size_t>(tf.type));
typeBindings[name] = TypeBindingSnapshot{
id,
toString(tf.type, opts),
};
}
for (const auto& [name, tp] : scope->privateTypePackBindings)
{
std::string id = std::to_string(reinterpret_cast<size_t>(tp));
typePackBindings[name] = TypeBindingSnapshot{
id,
toString(tp, opts),
};
}
for (const auto& child : scope->children)
{
children.push_back(snapshotScope(child.get(), opts));
}
return ScopeSnapshot{
bindings,
typeBindings,
typePackBindings,
children,
};
}
std::string DcrLogger::compileOutput()
{
Json::JsonEmitter emitter;
Json::ObjectEmitter o = emitter.writeObject();
o.writePair("generation", generationLog);
o.writePair("solve", solveLog);
o.writePair("check", checkLog);
o.finish();
return emitter.str();
}
void DcrLogger::captureSource(std::string source)
{
generationLog.source = std::move(source);
}
void DcrLogger::captureGenerationError(const TypeError& error)
{
std::string stringifiedError = toString(error);
generationLog.errors.push_back(ErrorSnapshot{
/* message */ stringifiedError,
/* location */ error.location,
});
}
void DcrLogger::captureConstraintLocation(NotNull<const Constraint> constraint, Location location)
{
std::string id = toPointerId(constraint);
generationLog.constraintLocations[id] = location;
}
void DcrLogger::pushBlock(NotNull<const Constraint> constraint, TypeId block)
{
constraintBlocks[constraint].push_back(block);
}
void DcrLogger::pushBlock(NotNull<const Constraint> constraint, TypePackId block)
{
constraintBlocks[constraint].push_back(block);
}
void DcrLogger::pushBlock(NotNull<const Constraint> constraint, NotNull<const Constraint> block)
{
constraintBlocks[constraint].push_back(block);
}
void DcrLogger::popBlock(TypeId block)
{
for (auto& [_, list] : constraintBlocks)
{
list.erase(std::remove(list.begin(), list.end(), block), list.end());
}
}
void DcrLogger::popBlock(TypePackId block)
{
for (auto& [_, list] : constraintBlocks)
{
list.erase(std::remove(list.begin(), list.end(), block), list.end());
}
}
void DcrLogger::popBlock(NotNull<const Constraint> block)
{
for (auto& [_, list] : constraintBlocks)
{
list.erase(std::remove(list.begin(), list.end(), block), list.end());
}
}
void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
solveLog.initialState.rootScope = snapshotScope(rootScope, opts);
solveLog.initialState.constraints.clear();
for (NotNull<const Constraint> c : unsolvedConstraints)
{
std::string id = toPointerId(c);
solveLog.initialState.constraints[id] = {
toString(*c.get(), opts),
snapshotBlocks(c),
};
}
}
StepSnapshot DcrLogger::prepareStepSnapshot(
const Scope* rootScope, NotNull<const Constraint> current, bool force, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts);
std::string currentId = toPointerId(current);
std::unordered_map<std::string, ConstraintSnapshot> constraints;
for (NotNull<const Constraint> c : unsolvedConstraints)
{
std::string id = toPointerId(c);
constraints[id] = {
toString(*c.get(), opts),
snapshotBlocks(c),
};
}
return StepSnapshot{
currentId,
force,
constraints,
scopeSnapshot,
};
}
void DcrLogger::commitStepSnapshot(StepSnapshot snapshot)
{
solveLog.stepStates.push_back(std::move(snapshot));
}
void DcrLogger::captureFinalSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
solveLog.finalState.rootScope = snapshotScope(rootScope, opts);
solveLog.finalState.constraints.clear();
for (NotNull<const Constraint> c : unsolvedConstraints)
{
std::string id = toPointerId(c);
solveLog.finalState.constraints[id] = {
toString(*c.get(), opts),
snapshotBlocks(c),
};
}
}
void DcrLogger::captureTypeCheckError(const TypeError& error)
{
std::string stringifiedError = toString(error);
checkLog.errors.push_back(ErrorSnapshot{
/* message */ stringifiedError,
/* location */ error.location,
});
}
std::vector<ConstraintBlock> DcrLogger::snapshotBlocks(NotNull<const Constraint> c)
{
auto it = constraintBlocks.find(c);
if (it == constraintBlocks.end())
{
return {};
}
std::vector<ConstraintBlock> snapshot;
for (const ConstraintBlockTarget& target : it->second)
{
if (const TypeId* ty = get_if<TypeId>(&target))
{
snapshot.push_back({
ConstraintBlockKind::TypeId,
toString(*ty, opts),
});
}
else if (const TypePackId* tp = get_if<TypePackId>(&target))
{
snapshot.push_back({
ConstraintBlockKind::TypePackId,
toString(*tp, opts),
});
}
else if (const NotNull<const Constraint>* c = get_if<NotNull<const Constraint>>(&target))
{
snapshot.push_back({
ConstraintBlockKind::ConstraintId,
toString(*(c->get()), opts),
});
}
else
{
LUAU_ASSERT(0);
}
}
return snapshot;
}
} // namespace Luau

17
Analysis/src/Def.cpp Normal file
View file

@ -0,0 +1,17 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Def.h"
namespace Luau
{
DefId DefArena::freshCell()
{
return NotNull{allocator.allocate<Def>(Def{Cell{std::nullopt}})};
}
DefId DefArena::freshCell(DefId parent, const std::string& prop)
{
return NotNull{allocator.allocate<Def>(Def{Cell{FieldMetadata{parent, prop}}})};
}
} // namespace Luau

View file

@ -1,8 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BuiltinDefinitions.h" #include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAG(LuauParseGenericFunctions) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAG(LuauOptionalNextKey)
namespace Luau namespace Luau
{ {
@ -10,59 +10,65 @@ namespace Luau
static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC( static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC(
declare bit32: { declare bit32: {
-- band, bor, bxor, and btest are declared in C++ band: (...number) -> number,
rrotate: (number, number) -> number, bor: (...number) -> number,
lrotate: (number, number) -> number, bxor: (...number) -> number,
lshift: (number, number) -> number, btest: (number, ...number) -> boolean,
arshift: (number, number) -> number, rrotate: (x: number, disp: number) -> number,
rshift: (number, number) -> number, lrotate: (x: number, disp: number) -> number,
bnot: (number) -> number, lshift: (x: number, disp: number) -> number,
extract: (number, number, number?) -> number, arshift: (x: number, disp: number) -> number,
replace: (number, number, number, number?) -> number, rshift: (x: number, disp: number) -> number,
bnot: (x: number) -> number,
extract: (n: number, field: number, width: number?) -> number,
replace: (n: number, v: number, field: number, width: number?) -> number,
countlz: (n: number) -> number,
countrz: (n: number) -> number,
} }
declare math: { declare math: {
frexp: (number) -> (number, number), frexp: (n: number) -> (number, number),
ldexp: (number, number) -> number, ldexp: (s: number, e: number) -> number,
fmod: (number, number) -> number, fmod: (x: number, y: number) -> number,
modf: (number) -> (number, number), modf: (n: number) -> (number, number),
pow: (number, number) -> number, pow: (x: number, y: number) -> number,
exp: (number) -> number, exp: (n: number) -> number,
ceil: (number) -> number, ceil: (n: number) -> number,
floor: (number) -> number, floor: (n: number) -> number,
abs: (number) -> number, abs: (n: number) -> number,
sqrt: (number) -> number, sqrt: (n: number) -> number,
log: (number, number?) -> number, log: (n: number, base: number?) -> number,
log10: (number) -> number, log10: (n: number) -> number,
rad: (number) -> number, rad: (n: number) -> number,
deg: (number) -> number, deg: (n: number) -> number,
sin: (number) -> number, sin: (n: number) -> number,
cos: (number) -> number, cos: (n: number) -> number,
tan: (number) -> number, tan: (n: number) -> number,
sinh: (number) -> number, sinh: (n: number) -> number,
cosh: (number) -> number, cosh: (n: number) -> number,
tanh: (number) -> number, tanh: (n: number) -> number,
atan: (number) -> number, atan: (n: number) -> number,
acos: (number) -> number, acos: (n: number) -> number,
asin: (number) -> number, asin: (n: number) -> number,
atan2: (number, number) -> number, atan2: (y: number, x: number) -> number,
-- min and max are declared in C++. min: (number, ...number) -> number,
max: (number, ...number) -> number,
pi: number, pi: number,
huge: number, huge: number,
randomseed: (number) -> (), randomseed: (seed: number) -> (),
random: (number?, number?) -> number, random: (number?, number?) -> number,
sign: (number) -> number, sign: (n: number) -> number,
clamp: (number, number, number) -> number, clamp: (n: number, min: number, max: number) -> number,
noise: (number, number?, number?) -> number, noise: (x: number, y: number?, z: number?) -> number,
round: (number) -> number, round: (n: number) -> number,
} }
type DateTypeArg = { type DateTypeArg = {
@ -88,151 +94,126 @@ type DateTypeResult = {
} }
declare os: { declare os: {
time: (DateTypeArg?) -> number, time: (time: DateTypeArg?) -> number,
date: (string?, number?) -> DateTypeResult | string, date: (formatString: string?, time: number?) -> DateTypeResult | string,
difftime: (DateTypeResult | number, DateTypeResult | number) -> number, difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number, clock: () -> number,
} }
declare function require(target: any): any declare function require(target: any): any
declare function getfenv(target: any?): { [string]: any } declare function getfenv(target: any): { [string]: any }
declare _G: any declare _G: any
declare _VERSION: string declare _VERSION: string
declare function gcinfo(): number declare function gcinfo(): number
declare function print<T...>(...: T...)
declare function type<T>(value: T): string
declare function typeof<T>(value: T): string
-- `assert` has a magic function attached that will give more detailed type information
declare function assert<T>(value: T, errorMessage: string?): T
declare function tostring<T>(value: T): string
declare function tonumber<T>(value: T, radix: number?): number?
declare function rawequal<T1, T2>(a: T1, b: T2): boolean
declare function rawget<K, V>(tab: {[K]: V}, k: K): V
declare function rawset<K, V>(tab: {[K]: V}, k: K, v: V): {[K]: V}
declare function rawlen<K, V>(obj: {[K]: V} | string): number
declare function setfenv<T..., R...>(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)?
-- TODO: place ipairs definition here with removal of FFlagLuauOptionalNextKey
declare function pcall<A..., R...>(f: (A...) -> R..., ...: A...): (boolean, R...)
-- FIXME: The actual type of `xpcall` is:
-- <E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...)
-- Since we can't represent the return value, we use (boolean, R1...).
declare function xpcall<E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...)
-- `select` has a magic function attached to provide more detailed type information
declare function select<A...>(i: string | number, ...: A...): ...any
-- FIXME: This type is not entirely correct - `loadstring` returns a function or
-- (nil, string).
declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?)
declare function newproxy(mt: boolean?): any
declare coroutine: {
create: <A..., R...>(f: (A...) -> R...) -> thread,
resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
running: () -> thread,
status: (co: thread) -> "dead" | "running" | "normal" | "suspended",
-- FIXME: This technically returns a function, but we can't represent this yet.
wrap: <A..., R...>(f: (A...) -> R...) -> any,
yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean,
close: (co: thread) -> (boolean, any)
}
declare table: {
concat: <V>(t: {V}, sep: string?, i: number?, j: number?) -> string,
insert: (<V>(t: {V}, value: V) -> ()) & (<V>(t: {V}, pos: number, value: V) -> ()),
maxn: <V>(t: {V}) -> number,
remove: <V>(t: {V}, number?) -> V?,
sort: <V>(t: {V}, comp: ((V, V) -> boolean)?) -> (),
create: <V>(count: number, value: V?) -> {V},
find: <V>(haystack: {V}, needle: V, init: number?) -> number?,
unpack: <V>(list: {V}, i: number?, j: number?) -> ...V,
pack: <V>(...V) -> { n: number, [number]: V },
getn: <V>(t: {V}) -> number,
foreach: <K, V>(t: {[K]: V}, f: (K, V) -> ()) -> (),
foreachi: <V>({V}, (number, V) -> ()) -> (),
move: <V>(src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V},
clear: <K, V>(table: {[K]: V}) -> (),
isfrozen: <K, V>(t: {[K]: V}) -> boolean,
}
declare debug: {
info: (<R...>(thread: thread, level: number, options: string) -> R...) & (<R...>(level: number, options: string) -> R...) & (<A..., R1..., R2...>(func: (A...) -> R1..., options: string) -> R2...),
traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string),
}
declare utf8: {
char: (...number) -> string,
charpattern: string,
codes: (str: string) -> ((string, number) -> (number, number), string, number),
codepoint: (str: string, i: number?, j: number?) -> ...number,
len: (s: string, i: number?, j: number?) -> (number?, number?),
offset: (s: string, n: number?, i: number?) -> number,
}
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
)BUILTIN_SRC"; )BUILTIN_SRC";
std::string getBuiltinDefinitionSource() std::string getBuiltinDefinitionSource()
{ {
std::string src = kBuiltinDefinitionLuaSrc;
if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) std::string result = kBuiltinDefinitionLuaSrc;
{
src += R"(
declare function print<T...>(...: T...)
declare function type<T>(value: T): string if (FFlag::LuauUnknownAndNeverType)
declare function typeof<T>(value: T): string result += "declare function error<T>(message: T, level: number?): never\n";
else
-- `assert` has a magic function attached that will give more detailed type information result += "declare function error<T>(message: T, level: number?)\n";
declare function assert<T>(value: T, errorMessage: string?): T
declare function error<T>(message: T, level: number?) if (FFlag::LuauOptionalNextKey)
result += "declare function ipairs<V>(tab: {V}): (({V}, number) -> (number?, V), {V}, number)\n";
else
result += "declare function ipairs<V>(tab: {V}): (({V}, number) -> (number, V), {V}, number)\n";
declare function tostring<T>(value: T): string return result;
declare function tonumber<T>(value: T, radix: number?): number
declare function rawequal<T1, T2>(a: T1, b: T2): boolean
declare function rawget<K, V>(tab: {[K]: V}, k: K): V
declare function rawset<K, V>(tab: {[K]: V}, k: K, v: V): {[K]: V}
declare function setfenv<T..., R...>(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)?
declare function ipairs<V>(tab: {V}): (({V}, number) -> (number, V), {V}, number)
declare function pcall<A..., R...>(f: (A...) -> R..., ...: A...): (boolean, R...)
-- FIXME: The actual type of `xpcall` is:
-- <E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...)
-- Since we can't represent the return value, we use (boolean, R1...).
declare function xpcall<E, A..., R1..., R2...>(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...)
-- `select` has a magic function attached to provide more detailed type information
declare function select<A...>(i: string | number, ...: A...): ...any
-- FIXME: This type is not entirely correct - `loadstring` returns a function or
-- (nil, string).
declare function loadstring<A...>(src: string, chunkname: string?): (((A...) -> any)?, string?)
-- a userdata object is "roughly" the same as a sealed empty table
-- except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too.
-- another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT
-- setmetatable.
-- FIXME: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`.
declare function newproxy(mt: boolean?): {}
declare coroutine: {
create: <A..., R...>((A...) -> R...) -> thread,
resume: <A..., R...>(thread, A...) -> (boolean, R...),
running: () -> thread,
status: (thread) -> string,
-- FIXME: This technically returns a function, but we can't represent this yet.
wrap: <A..., R...>((A...) -> R...) -> any,
yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean,
}
declare table: {
concat: <V>({V}, string?, number?, number?) -> string,
insert: (<V>({V}, V) -> ()) & (<V>({V}, number, V) -> ()),
maxn: <V>({V}) -> number,
remove: <V>({V}, number?) -> V?,
sort: <V>({V}, ((V, V) -> boolean)?) -> (),
create: <V>(number, V?) -> {V},
find: <V>({V}, V, number?) -> number?,
unpack: <V>({V}, number?, number?) -> ...V,
pack: <V>(...V) -> { n: number, [number]: V },
getn: <V>({V}) -> number,
foreach: <K, V>({[K]: V}, (K, V) -> ()) -> (),
foreachi: <V>({V}, (number, V) -> ()) -> (),
move: <V>({V}, number, number, number, {V}?) -> (),
clear: <K, V>({[K]: V}) -> (),
freeze: <K, V>({[K]: V}) -> {[K]: V},
isfrozen: <K, V>({[K]: V}) -> boolean,
}
declare debug: {
info: (<R...>(thread, number, string) -> R...) & (<R...>(number, string) -> R...) & (<A..., R1..., R2...>((A...) -> R1..., string) -> R2...),
traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string),
}
declare utf8: {
char: (number, ...number) -> string,
charpattern: string,
codes: (string) -> ((string, number) -> (number, number), string, number),
-- FIXME
codepoint: (string, number?, number?) -> (number, ...number),
len: (string, number?, number?) -> (number?, number?),
offset: (string, number?, number?) -> number,
nfdnormalize: (string) -> string,
nfcnormalize: (string) -> string,
graphemes: (string, number?, number?) -> (() -> (number, number)),
}
declare string: {
byte: (string, number?, number?) -> ...number,
char: (number, ...number) -> string,
find: (string, string, number?, boolean?) -> (number?, number?),
-- `string.format` has a magic function attached that will provide more type information for literal format strings.
format: <A...>(string, A...) -> string,
gmatch: (string, string) -> () -> (...string),
-- gsub is defined in C++ because we don't have syntax for describing a generic table.
len: (string) -> number,
lower: (string) -> string,
match: (string, string, number?) -> string?,
rep: (string, number) -> string,
reverse: (string) -> string,
sub: (string, number, number?) -> string,
upper: (string) -> string,
split: (string, string, string?) -> {string},
pack: <A...>(string, A...) -> string,
packsize: (string) -> number,
unpack: <R...>(string, string, number?) -> R...,
}
-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype.
declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
)";
}
return src;
} }
} // namespace Luau } // namespace Luau

View file

@ -1,23 +1,35 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Module.h" #include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/FileResolver.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include <stdexcept> #include <stdexcept>
#include <type_traits>
LUAU_FASTFLAG(LuauFasterStringifier) LUAU_FASTFLAGVARIABLE(LuauTypeMismatchInvarianceInError, false)
static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) static std::string wrongNumberOfArgsString(
size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
{ {
std::string s = "expects " + std::to_string(expectedCount) + " "; std::string s = "expects ";
if (isTypeArgs) if (isVariadic)
s += "type "; s += "at least ";
s += std::to_string(expectedCount) + " ";
if (maximumCount && expectedCount != *maximumCount)
s += "to " + std::to_string(*maximumCount) + " ";
if (argPrefix)
s += std::string(argPrefix) + " ";
s += "argument"; s += "argument";
if (expectedCount != 1) if ((maximumCount ? *maximumCount : expectedCount) != 1)
s += "s"; s += "s";
s += ", but "; s += ", but ";
@ -46,10 +58,60 @@ namespace Luau
struct ErrorConverter struct ErrorConverter
{ {
FileResolver* fileResolver = nullptr;
std::string operator()(const Luau::TypeMismatch& tm) const std::string operator()(const Luau::TypeMismatch& tm) const
{ {
ToStringOptions opts; std::string givenTypeName = Luau::toString(tm.givenType);
return "Type '" + Luau::toString(tm.givenType, opts) + "' could not be converted into '" + Luau::toString(tm.wantedType, opts) + "'"; std::string wantedTypeName = Luau::toString(tm.wantedType);
std::string result;
if (givenTypeName == wantedTypeName)
{
if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType))
{
if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType))
{
if (fileResolver != nullptr)
{
std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule);
std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule);
result = "Type '" + givenTypeName + "' from '" + givenModuleName + "' could not be converted into '" + wantedTypeName +
"' from '" + wantedModuleName + "'";
}
else
{
result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName +
"' from '" + *wantedDefinitionModule + "'";
}
}
}
}
if (result.empty())
result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'";
if (tm.error)
{
result += "\ncaused by:\n ";
if (!tm.reason.empty())
result += tm.reason + " ";
result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver});
}
else if (!tm.reason.empty())
{
result += "; " + tm.reason;
}
else if (FFlag::LuauTypeMismatchInvarianceInError && tm.context == TypeMismatch::InvariantContext)
{
result += " in an invariant context";
}
return result;
} }
std::string operator()(const Luau::UnknownSymbol& e) const std::string operator()(const Luau::UnknownSymbol& e) const
@ -60,8 +122,6 @@ struct ErrorConverter
return "Unknown global '" + e.name + "'"; return "Unknown global '" + e.name + "'";
case UnknownSymbol::Type: case UnknownSymbol::Type:
return "Unknown type '" + e.name + "'"; return "Unknown type '" + e.name + "'";
case UnknownSymbol::Generic:
return "Unknown generic '" + e.name + "'";
} }
LUAU_ASSERT(!"Unexpected context for UnknownSymbol"); LUAU_ASSERT(!"Unexpected context for UnknownSymbol");
@ -107,28 +167,38 @@ struct ErrorConverter
std::string operator()(const Luau::DuplicateTypeDefinition& e) const std::string operator()(const Luau::DuplicateTypeDefinition& e) const
{ {
return "Redefinition of type '" + e.name + "', previously defined at line " + std::to_string(e.previousLocation.begin.line + 1); std::string s = "Redefinition of type '" + e.name + "'";
if (e.previousLocation)
s += ", previously defined at line " + std::to_string(e.previousLocation->begin.line + 1);
return s;
} }
std::string operator()(const Luau::CountMismatch& e) const std::string operator()(const Luau::CountMismatch& e) const
{ {
const std::string expectedS = e.expected == 1 ? "" : "s";
const std::string actualS = e.actual == 1 ? "" : "s";
const std::string actualVerb = e.actual == 1 ? "is" : "are";
switch (e.context) switch (e.context)
{ {
case CountMismatch::Return: case CountMismatch::Return:
{ return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " +
const std::string expectedS = e.expected == 1 ? "" : "s"; actualVerb + " returned here";
const std::string actualS = e.actual == 1 ? "is" : "are"; case CountMismatch::FunctionResult:
return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + actualS + // It is alright if right hand side produces more values than the
" returned here"; // left hand side accepts. In this context consider only the opposite case.
} return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " +
case CountMismatch::Result: actualVerb + " required here";
if (e.expected > e.actual) case CountMismatch::ExprListResult:
return "Function returns " + std::to_string(e.expected) + " values but there are only " + std::to_string(e.expected) + return "Expression list has " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " +
" values to unpack them into."; actualVerb + " required here";
else
return "Function only returns " + std::to_string(e.expected) + " values. " + std::to_string(e.actual) + " are required here";
case CountMismatch::Arg: case CountMismatch::Arg:
return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); if (!e.function.empty())
return "Argument count mismatch. Function '" + e.function + "' " +
wrongNumberOfArgsString(e.expected, e.maximum, e.actual, /*argPrefix*/ nullptr, e.isVariadic);
else
return "Argument count mismatch. Function " +
wrongNumberOfArgsString(e.expected, e.maximum, e.actual, /*argPrefix*/ nullptr, e.isVariadic);
} }
LUAU_ASSERT(!"Unknown context"); LUAU_ASSERT(!"Unknown context");
@ -142,15 +212,7 @@ struct ErrorConverter
std::string operator()(const Luau::FunctionRequiresSelf& e) const std::string operator()(const Luau::FunctionRequiresSelf& e) const
{ {
if (e.requiredExtraNils) return "This function must be called with self. Did you mean to use a colon instead of a dot?";
{
const char* plural = e.requiredExtraNils == 1 ? "" : "s";
return format("This function was declared to accept self, but you did not pass enough arguments. Use a colon instead of a dot or "
"pass %i extra nil%s to suppress this warning",
e.requiredExtraNils, plural);
}
else
return "This function must be called with self. Did you mean to use a colon instead of a dot?";
} }
std::string operator()(const Luau::OccursCheckFailed&) const std::string operator()(const Luau::OccursCheckFailed&) const
@ -160,34 +222,53 @@ struct ErrorConverter
std::string operator()(const Luau::UnknownRequire& e) const std::string operator()(const Luau::UnknownRequire& e) const
{ {
return "Unknown require: " + e.modulePath; if (e.modulePath.empty())
return "Unknown require: unsupported path";
else
return "Unknown require: " + e.modulePath;
} }
std::string operator()(const Luau::IncorrectGenericParameterCount& e) const std::string operator()(const Luau::IncorrectGenericParameterCount& e) const
{ {
std::string name = e.name; std::string name = e.name;
if (!e.typeFun.typeParams.empty()) if (!e.typeFun.typeParams.empty() || !e.typeFun.typePackParams.empty())
{ {
name += "<"; name += "<";
bool first = true; bool first = true;
for (TypeId t : e.typeFun.typeParams) for (auto param : e.typeFun.typeParams)
{ {
if (first) if (first)
first = false; first = false;
else else
name += ", "; name += ", ";
name += toString(t); name += toString(param.ty);
} }
for (auto param : e.typeFun.typePackParams)
{
if (first)
first = false;
else
name += ", ";
name += toString(param.tp);
}
name += ">"; name += ">";
} }
return "Generic type '" + name + "' " + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); if (e.typeFun.typeParams.size() != e.actualParameters)
return "Generic type '" + name + "' " +
wrongNumberOfArgsString(e.typeFun.typeParams.size(), std::nullopt, e.actualParameters, "type", !e.typeFun.typePackParams.empty());
return "Generic type '" + name + "' " +
wrongNumberOfArgsString(e.typeFun.typePackParams.size(), std::nullopt, e.actualPackParameters, "type pack", /*isVariadic*/ false);
} }
std::string operator()(const Luau::SyntaxError& e) const std::string operator()(const Luau::SyntaxError& e) const
{ {
return "Syntax error: " + e.message; return e.message;
} }
std::string operator()(const Luau::CodeTooComplex&) const std::string operator()(const Luau::CodeTooComplex&) const
@ -234,6 +315,11 @@ struct ErrorConverter
return e.message; return e.message;
} }
std::string operator()(const Luau::InternalError& e) const
{
return e.message;
}
std::string operator()(const Luau::CannotCallNonFunction& e) const std::string operator()(const Luau::CannotCallNonFunction& e) const
{ {
return "Cannot call non-function " + toString(e.ty); return "Cannot call non-function " + toString(e.ty);
@ -374,6 +460,26 @@ struct ErrorConverter
return ss + " in the type '" + toString(e.type) + "'"; return ss + " in the type '" + toString(e.type) + "'";
} }
std::string operator()(const TypesAreUnrelated& e) const
{
return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated";
}
std::string operator()(const NormalizationTooComplex&) const
{
return "Code is too complex to typecheck! Consider simplifying the code around this area";
}
std::string operator()(const TypePackMismatch& e) const
{
return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'";
}
std::string operator()(const DynamicPropertyLookupOnClassesUnsafe& e) const
{
return "Attempting a dynamic property access on type '" + Luau::toString(e.ty) + "' is unsafe and may cause exceptions at runtime";
}
}; };
struct InvalidNameChecker struct InvalidNameChecker
@ -400,9 +506,60 @@ struct InvalidNameChecker
} }
}; };
TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType)
: wantedType(wantedType)
, givenType(givenType)
{
}
TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason)
: wantedType(wantedType)
, givenType(givenType)
, reason(reason)
{
}
TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional<TypeError> error)
: wantedType(wantedType)
, givenType(givenType)
, reason(reason)
, error(error ? std::make_shared<TypeError>(std::move(*error)) : nullptr)
{
}
TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, TypeMismatch::Context context)
: wantedType(wantedType)
, givenType(givenType)
, context(context)
{
}
TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeMismatch::Context context)
: wantedType(wantedType)
, givenType(givenType)
, context(context)
, reason(reason)
{
}
TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional<TypeError> error, TypeMismatch::Context context)
: wantedType(wantedType)
, givenType(givenType)
, context(context)
, reason(reason)
, error(error ? std::make_shared<TypeError>(std::move(*error)) : nullptr)
{
}
bool TypeMismatch::operator==(const TypeMismatch& rhs) const bool TypeMismatch::operator==(const TypeMismatch& rhs) const
{ {
return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType; if (!!error != !!rhs.error)
return false;
if (error && !(*error == *rhs.error))
return false;
return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType && reason == rhs.reason && context == rhs.context;
} }
bool UnknownSymbol::operator==(const UnknownSymbol& rhs) const bool UnknownSymbol::operator==(const UnknownSymbol& rhs) const
@ -437,7 +594,7 @@ bool DuplicateTypeDefinition::operator==(const DuplicateTypeDefinition& rhs) con
bool CountMismatch::operator==(const CountMismatch& rhs) const bool CountMismatch::operator==(const CountMismatch& rhs) const
{ {
return expected == rhs.expected && actual == rhs.actual && context == rhs.context; return expected == rhs.expected && maximum == rhs.maximum && actual == rhs.actual && context == rhs.context && function == rhs.function;
} }
bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const
@ -447,7 +604,7 @@ bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const
bool FunctionRequiresSelf::operator==(const FunctionRequiresSelf& e) const bool FunctionRequiresSelf::operator==(const FunctionRequiresSelf& e) const
{ {
return requiredExtraNils == e.requiredExtraNils; return true;
} }
bool OccursCheckFailed::operator==(const OccursCheckFailed&) const bool OccursCheckFailed::operator==(const OccursCheckFailed&) const
@ -471,9 +628,20 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC
if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size()) if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size())
return false; return false;
if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size())
return false;
for (size_t i = 0; i < typeFun.typeParams.size(); ++i) for (size_t i = 0; i < typeFun.typeParams.size(); ++i)
if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) {
if (typeFun.typeParams[i].ty != rhs.typeFun.typeParams[i].ty)
return false; return false;
}
for (size_t i = 0; i < typeFun.typePackParams.size(); ++i)
{
if (typeFun.typePackParams[i].tp != rhs.typeFun.typePackParams[i].tp)
return false;
}
return true; return true;
} }
@ -504,6 +672,11 @@ bool GenericError::operator==(const GenericError& rhs) const
return message == rhs.message; return message == rhs.message;
} }
bool InternalError::operator==(const InternalError& rhs) const
{
return message == rhs.message;
}
bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const
{ {
return ty == rhs.ty; return ty == rhs.ty;
@ -526,7 +699,17 @@ bool FunctionExitsWithoutReturning::operator==(const FunctionExitsWithoutReturni
int TypeError::code() const int TypeError::code() const
{ {
return 1000 + int(data.index()); return minCode() + int(data.index());
}
int TypeError::minCode()
{
return 1000;
}
TypeErrorSummary TypeError::summary() const
{
return TypeErrorSummary{location, moduleName, code()};
} }
bool TypeError::operator==(const TypeError& rhs) const bool TypeError::operator==(const TypeError& rhs) const
@ -584,9 +767,29 @@ bool MissingUnionProperty::operator==(const MissingUnionProperty& rhs) const
return *type == *rhs.type && key == rhs.key; return *type == *rhs.type && key == rhs.key;
} }
bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const
{
return left == rhs.left && right == rhs.right;
}
bool TypePackMismatch::operator==(const TypePackMismatch& rhs) const
{
return *wantedTp == *rhs.wantedTp && *givenTp == *rhs.givenTp;
}
bool DynamicPropertyLookupOnClassesUnsafe::operator==(const DynamicPropertyLookupOnClassesUnsafe& rhs) const
{
return ty == rhs.ty;
}
std::string toString(const TypeError& error) std::string toString(const TypeError& error)
{ {
ErrorConverter converter; return toString(error, TypeErrorToStringOptions{});
}
std::string toString(const TypeError& error, TypeErrorToStringOptions options)
{
ErrorConverter converter{options.fileResolver};
return Luau::visit(converter, error.data); return Luau::visit(converter, error.data);
} }
@ -595,130 +798,158 @@ bool containsParseErrorName(const TypeError& error)
return Luau::visit(InvalidNameChecker{}, error.data); return Luau::visit(InvalidNameChecker{}, error.data);
} }
void copyErrors(ErrorVec& errors, struct TypeArena& destArena) template<typename T>
void copyError(T& e, TypeArena& destArena, CloneState cloneState)
{ {
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
auto clone = [&](auto&& ty) { auto clone = [&](auto&& ty) {
return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks); return ::Luau::clone(ty, destArena, cloneState);
}; };
auto visitErrorData = [&](auto&& e) { auto visitErrorData = [&](auto&& e) {
using T = std::decay_t<decltype(e)>; copyError(e, destArena, cloneState);
};
if constexpr (false) if constexpr (false)
{ {
} }
else if constexpr (std::is_same_v<T, TypeMismatch>) else if constexpr (std::is_same_v<T, TypeMismatch>)
{ {
e.wantedType = clone(e.wantedType); e.wantedType = clone(e.wantedType);
e.givenType = clone(e.givenType); e.givenType = clone(e.givenType);
}
else if constexpr (std::is_same_v<T, UnknownSymbol>)
{
}
else if constexpr (std::is_same_v<T, UnknownProperty>)
{
e.table = clone(e.table);
}
else if constexpr (std::is_same_v<T, NotATable>)
{
e.ty = clone(e.ty);
}
else if constexpr (std::is_same_v<T, CannotExtendTable>)
{
e.tableType = clone(e.tableType);
}
else if constexpr (std::is_same_v<T, OnlyTablesCanHaveMethods>)
{
e.tableType = clone(e.tableType);
}
else if constexpr (std::is_same_v<T, DuplicateTypeDefinition>)
{
}
else if constexpr (std::is_same_v<T, CountMismatch>)
{
}
else if constexpr (std::is_same_v<T, FunctionDoesNotTakeSelf>)
{
}
else if constexpr (std::is_same_v<T, FunctionRequiresSelf>)
{
}
else if constexpr (std::is_same_v<T, OccursCheckFailed>)
{
}
else if constexpr (std::is_same_v<T, UnknownRequire>)
{
}
else if constexpr (std::is_same_v<T, IncorrectGenericParameterCount>)
{
e.typeFun = clone(e.typeFun);
}
else if constexpr (std::is_same_v<T, SyntaxError>)
{
}
else if constexpr (std::is_same_v<T, CodeTooComplex>)
{
}
else if constexpr (std::is_same_v<T, UnificationTooComplex>)
{
}
else if constexpr (std::is_same_v<T, UnknownPropButFoundLikeProp>)
{
e.table = clone(e.table);
}
else if constexpr (std::is_same_v<T, GenericError>)
{
}
else if constexpr (std::is_same_v<T, CannotCallNonFunction>)
{
e.ty = clone(e.ty);
}
else if constexpr (std::is_same_v<T, ExtraInformation>)
{
}
else if constexpr (std::is_same_v<T, DeprecatedApiUsed>)
{
}
else if constexpr (std::is_same_v<T, ModuleHasCyclicDependency>)
{
}
else if constexpr (std::is_same_v<T, IllegalRequire>)
{
}
else if constexpr (std::is_same_v<T, FunctionExitsWithoutReturning>)
{
e.expectedReturnType = clone(e.expectedReturnType);
}
else if constexpr (std::is_same_v<T, DuplicateGenericParameter>)
{
}
else if constexpr (std::is_same_v<T, CannotInferBinaryOperation>)
{
}
else if constexpr (std::is_same_v<T, MissingProperties>)
{
e.superType = clone(e.superType);
e.subType = clone(e.subType);
}
else if constexpr (std::is_same_v<T, SwappedGenericTypeParameter>)
{
}
else if constexpr (std::is_same_v<T, OptionalValueAccess>)
{
e.optional = clone(e.optional);
}
else if constexpr (std::is_same_v<T, MissingUnionProperty>)
{
e.type = clone(e.type);
for (auto& ty : e.missing) if (e.error)
ty = clone(ty); visit(visitErrorData, e.error->data);
} }
else else if constexpr (std::is_same_v<T, UnknownSymbol>)
static_assert(always_false_v<T>, "Non-exhaustive type switch"); {
}
else if constexpr (std::is_same_v<T, UnknownProperty>)
{
e.table = clone(e.table);
}
else if constexpr (std::is_same_v<T, NotATable>)
{
e.ty = clone(e.ty);
}
else if constexpr (std::is_same_v<T, CannotExtendTable>)
{
e.tableType = clone(e.tableType);
}
else if constexpr (std::is_same_v<T, OnlyTablesCanHaveMethods>)
{
e.tableType = clone(e.tableType);
}
else if constexpr (std::is_same_v<T, DuplicateTypeDefinition>)
{
}
else if constexpr (std::is_same_v<T, CountMismatch>)
{
}
else if constexpr (std::is_same_v<T, FunctionDoesNotTakeSelf>)
{
}
else if constexpr (std::is_same_v<T, FunctionRequiresSelf>)
{
}
else if constexpr (std::is_same_v<T, OccursCheckFailed>)
{
}
else if constexpr (std::is_same_v<T, UnknownRequire>)
{
}
else if constexpr (std::is_same_v<T, IncorrectGenericParameterCount>)
{
e.typeFun = clone(e.typeFun);
}
else if constexpr (std::is_same_v<T, SyntaxError>)
{
}
else if constexpr (std::is_same_v<T, CodeTooComplex>)
{
}
else if constexpr (std::is_same_v<T, UnificationTooComplex>)
{
}
else if constexpr (std::is_same_v<T, UnknownPropButFoundLikeProp>)
{
e.table = clone(e.table);
}
else if constexpr (std::is_same_v<T, GenericError>)
{
}
else if constexpr (std::is_same_v<T, InternalError>)
{
}
else if constexpr (std::is_same_v<T, CannotCallNonFunction>)
{
e.ty = clone(e.ty);
}
else if constexpr (std::is_same_v<T, ExtraInformation>)
{
}
else if constexpr (std::is_same_v<T, DeprecatedApiUsed>)
{
}
else if constexpr (std::is_same_v<T, ModuleHasCyclicDependency>)
{
}
else if constexpr (std::is_same_v<T, IllegalRequire>)
{
}
else if constexpr (std::is_same_v<T, FunctionExitsWithoutReturning>)
{
e.expectedReturnType = clone(e.expectedReturnType);
}
else if constexpr (std::is_same_v<T, DuplicateGenericParameter>)
{
}
else if constexpr (std::is_same_v<T, CannotInferBinaryOperation>)
{
}
else if constexpr (std::is_same_v<T, MissingProperties>)
{
e.superType = clone(e.superType);
e.subType = clone(e.subType);
}
else if constexpr (std::is_same_v<T, SwappedGenericTypeParameter>)
{
}
else if constexpr (std::is_same_v<T, OptionalValueAccess>)
{
e.optional = clone(e.optional);
}
else if constexpr (std::is_same_v<T, MissingUnionProperty>)
{
e.type = clone(e.type);
for (auto& ty : e.missing)
ty = clone(ty);
}
else if constexpr (std::is_same_v<T, TypesAreUnrelated>)
{
e.left = clone(e.left);
e.right = clone(e.right);
}
else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
{
}
else if constexpr (std::is_same_v<T, TypePackMismatch>)
{
e.wantedTp = clone(e.wantedTp);
e.givenTp = clone(e.givenTp);
}
else if constexpr (std::is_same_v<T, DynamicPropertyLookupOnClassesUnsafe>)
e.ty = clone(e.ty);
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}
void copyErrors(ErrorVec& errors, TypeArena& destArena)
{
CloneState cloneState;
auto visitErrorData = [&](auto&& e) {
copyError(e, destArena, cloneState);
}; };
LUAU_ASSERT(!destArena.typeVars.isFrozen()); LUAU_ASSERT(!destArena.typeVars.isFrozen());
@ -730,7 +961,7 @@ void copyErrors(ErrorVec& errors, struct TypeArena& destArena)
void InternalErrorReporter::ice(const std::string& message, const Location& location) void InternalErrorReporter::ice(const std::string& message, const Location& location)
{ {
std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message); InternalCompilerError error(message, moduleName, location);
if (onInternalError) if (onInternalError)
onInternalError(error.what()); onInternalError(error.what());
@ -740,7 +971,7 @@ void InternalErrorReporter::ice(const std::string& message, const Location& loca
void InternalErrorReporter::ice(const std::string& message) void InternalErrorReporter::ice(const std::string& message)
{ {
std::runtime_error error("Internal error in " + moduleName + ": " + message); InternalCompilerError error(message, moduleName);
if (onInternalError) if (onInternalError)
onInternalError(error.what()); onInternalError(error.what());
@ -748,4 +979,9 @@ void InternalErrorReporter::ice(const std::string& message)
throw error; throw error;
} }
const char* InternalCompilerError::what() const throw()
{
return this->message.data();
}
} // namespace Luau } // namespace Luau

View file

@ -1,39 +1,53 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Frontend.h" #include "Luau/Frontend.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/Config.h" #include "Luau/Config.h"
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/DcrLogger.h"
#include "Luau/FileResolver.h" #include "Luau/FileResolver.h"
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include "Luau/TimeTrace.h"
#include "Luau/TypeChecker2.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include "Luau/Common.h"
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
#include <stdexcept> #include <stdexcept>
#include <string>
LUAU_FASTINT(LuauTypeInferIterationLimit)
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100)
LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAG(DebugLuauLogSolverToJson);
LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false)
namespace Luau namespace Luau
{ {
std::optional<Mode> parseMode(const std::vector<std::string>& hotcomments) std::optional<Mode> parseMode(const std::vector<HotComment>& hotcomments)
{ {
for (const std::string& hc : hotcomments) for (const HotComment& hc : hotcomments)
{ {
if (hc == "nocheck") if (!hc.header)
continue;
if (hc.content == "nocheck")
return Mode::NoCheck; return Mode::NoCheck;
if (hc == "nonstrict") if (hc.content == "nonstrict")
return Mode::Nonstrict; return Mode::Nonstrict;
if (hc == "strict") if (hc.content == "strict")
return Mode::Strict; return Mode::Strict;
} }
@ -67,8 +81,70 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName)
} }
} }
LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName)
{
if (!FFlag::DebugLuauDeferredConstraintResolution)
return Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, source, packageName);
LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend");
Luau::Allocator allocator;
Luau::AstNameTable names(allocator);
ParseOptions options;
options.allowDeclarationSyntax = true;
Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options);
if (parseResult.errors.size() > 0)
return LoadDefinitionFileResult{false, parseResult, nullptr};
Luau::SourceModule module;
module.root = parseResult.root;
module.mode = Mode::Definition;
ModulePtr checkedModule = check(module, Mode::Definition, globalScope, {});
if (checkedModule->errors.size() > 0)
return LoadDefinitionFileResult{false, parseResult, checkedModule};
CloneState cloneState;
std::vector<TypeId> typesToPersist;
typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size());
for (const auto& [name, ty] : checkedModule->declaredGlobals)
{
TypeId globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
typesToPersist.push_back(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
globalScope->exportedTypeBindings[name] = globalTy;
typesToPersist.push_back(globalTy.type);
}
for (TypeId ty : typesToPersist)
{
persist(ty);
}
return LoadDefinitionFileResult{true, parseResult, checkedModule};
}
LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName) LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName)
{ {
LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend");
Luau::Allocator allocator; Luau::Allocator allocator;
Luau::AstNameTable names(allocator); Luau::AstNameTable names(allocator);
@ -89,29 +165,34 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
if (checkedModule->errors.size() > 0) if (checkedModule->errors.size() > 0)
return LoadDefinitionFileResult{false, parseResult, checkedModule}; return LoadDefinitionFileResult{false, parseResult, checkedModule};
SeenTypes seenTypes; CloneState cloneState;
SeenTypePacks seenTypePacks;
std::vector<TypeId> typesToPersist;
typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->getModuleScope()->exportedTypeBindings.size());
for (const auto& [name, ty] : checkedModule->declaredGlobals) for (const auto& [name, ty] : checkedModule->declaredGlobals)
{ {
TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name; std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol); generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
if (FFlag::LuauPersistDefinitionFileTypes) typesToPersist.push_back(globalTy);
persist(globalTy);
} }
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{ {
TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name; std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol); generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy; targetScope->exportedTypeBindings[name] = globalTy;
if (FFlag::LuauPersistDefinitionFileTypes) typesToPersist.push_back(globalTy.type);
persist(globalTy.type); }
for (TypeId ty : typesToPersist)
{
persist(ty);
} }
return LoadDefinitionFileResult{true, parseResult, checkedModule}; return LoadDefinitionFileResult{true, parseResult, checkedModule};
@ -208,7 +289,7 @@ ErrorVec accumulateErrors(
continue; continue;
const SourceNode& sourceNode = it->second; const SourceNode& sourceNode = it->second;
queue.insert(queue.end(), sourceNode.requires.begin(), sourceNode.requires.end()); queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end());
// FIXME: If a module has a syntax error, we won't be able to re-report it here. // FIXME: If a module has a syntax error, we won't be able to re-report it here.
// The solution is probably to move errors from Module to SourceNode // The solution is probably to move errors from Module to SourceNode
@ -231,18 +312,12 @@ ErrorVec accumulateErrors(
return result; return result;
} }
struct RequireCycle
{
Location location;
std::vector<ModuleName> path; // one of the paths for a require() to go all the way back to the originating module
};
// Given a source node (start), find all requires that start a transitive dependency path that ends back at start // Given a source node (start), find all requires that start a transitive dependency path that ends back at start
// For each such path, record the full path and the location of the require in the starting module. // For each such path, record the full path and the location of the require in the starting module.
// Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V)
// However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true) // However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true)
std::vector<RequireCycle> getRequireCycles( std::vector<RequireCycle> getRequireCycles(
const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false) const FileResolver* resolver, const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
{ {
std::vector<RequireCycle> result; std::vector<RequireCycle> result;
@ -276,9 +351,9 @@ std::vector<RequireCycle> getRequireCycles(
if (top == start) if (top == start)
{ {
for (const SourceNode* node : path) for (const SourceNode* node : path)
cycle.push_back(node->name); cycle.push_back(resolver->getHumanReadableModuleName(node->name));
cycle.push_back(top->name); cycle.push_back(resolver->getHumanReadableModuleName(top->name));
break; break;
} }
} }
@ -333,13 +408,15 @@ double getTimestamp()
} // namespace } // namespace
Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options) Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options)
: fileResolver(fileResolver) : singletonTypes(NotNull{&singletonTypes_})
, fileResolver(fileResolver)
, moduleResolver(this) , moduleResolver(this)
, moduleResolverForAutocomplete(this) , moduleResolverForAutocomplete(this)
, typeChecker(&moduleResolver, &iceHandler) , typeChecker(&moduleResolver, singletonTypes, &iceHandler)
, typeCheckerForAutocomplete(&moduleResolverForAutocomplete, &iceHandler) , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, singletonTypes, &iceHandler)
, configResolver(configResolver) , configResolver(configResolver)
, options(options) , options(options)
, globalScope(typeChecker.globalScope)
{ {
} }
@ -348,33 +425,49 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend)
{ {
} }
CheckResult Frontend::check(const ModuleName& name) CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOptions> optionOverride)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
FrontendOptions frontendOptions = optionOverride.value_or(options);
CheckResult checkResult; CheckResult checkResult;
auto it = sourceNodes.find(name); auto it = sourceNodes.find(name);
if (it != sourceNodes.end() && !it->second.dirty) if (it != sourceNodes.end() && !it->second.hasDirtyModule(frontendOptions.forAutocomplete))
{ {
// No recheck required. // No recheck required.
auto it2 = moduleResolver.modules.find(name); if (frontendOptions.forAutocomplete)
if (it2 == moduleResolver.modules.end() || it2->second == nullptr) {
throw std::runtime_error("Frontend::modules does not have data for " + name); auto it2 = moduleResolverForAutocomplete.modules.find(name);
if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr)
throw InternalCompilerError("Frontend::modules does not have data for " + name, name);
}
else
{
auto it2 = moduleResolver.modules.find(name);
if (it2 == moduleResolver.modules.end() || it2->second == nullptr)
throw InternalCompilerError("Frontend::modules does not have data for " + name, name);
}
return CheckResult{accumulateErrors(sourceNodes, moduleResolver.modules, name)}; return CheckResult{
accumulateErrors(sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)};
} }
std::vector<ModuleName> buildQueue; std::vector<ModuleName> buildQueue;
bool cycleDetected = parseGraph(buildQueue, checkResult, name); bool cycleDetected = parseGraph(buildQueue, checkResult, name, frontendOptions.forAutocomplete);
// Keep track of which AST nodes we've reported cycles in // Keep track of which AST nodes we've reported cycles in
std::unordered_set<AstNode*> reportedCycles; std::unordered_set<AstNode*> reportedCycles;
double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0;
for (const ModuleName& moduleName : buildQueue) for (const ModuleName& moduleName : buildQueue)
{ {
LUAU_ASSERT(sourceNodes.count(moduleName)); LUAU_ASSERT(sourceNodes.count(moduleName));
SourceNode& sourceNode = sourceNodes[moduleName]; SourceNode& sourceNode = sourceNodes[moduleName];
if (!sourceNode.dirty) if (!sourceNode.hasDirtyModule(frontendOptions.forAutocomplete))
continue; continue;
LUAU_ASSERT(sourceModules.count(moduleName)); LUAU_ASSERT(sourceModules.count(moduleName));
@ -384,7 +477,7 @@ CheckResult Frontend::check(const ModuleName& name)
Mode mode = sourceModule.mode.value_or(config.mode); Mode mode = sourceModule.mode.value_or(config.mode);
ScopePtr environmentScope = getModuleEnvironment(sourceModule, config); ScopePtr environmentScope = getModuleEnvironment(sourceModule, config, frontendOptions.forAutocomplete);
double timestamp = getTimestamp(); double timestamp = getTimestamp();
@ -395,49 +488,75 @@ CheckResult Frontend::check(const ModuleName& name)
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
// all correct programs must be acyclic so this code triggers rarely // all correct programs must be acyclic so this code triggers rarely
if (cycleDetected) if (cycleDetected)
requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck); requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck);
// This is used by the type checker to replace the resulting type of cyclic modules with any // This is used by the type checker to replace the resulting type of cyclic modules with any
sourceModule.cyclic = !requireCycles.empty(); sourceModule.cyclic = !requireCycles.empty();
ModulePtr module = typeChecker.check(sourceModule, mode, environmentScope); if (frontendOptions.forAutocomplete)
// If we're typechecking twice, we do so.
// The second typecheck is always in strict mode with DM awareness
// to provide better typen information for IDE features.
if (options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel)
{ {
ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); // The autocomplete typecheck is always in strict mode with DM awareness
// to provide better type information for IDE features
typeCheckerForAutocomplete.requireCycles = requireCycles;
if (autocompleteTimeLimit != 0.0)
typeCheckerForAutocomplete.finishTime = TimeTrace::getClock() + autocompleteTimeLimit;
else
typeCheckerForAutocomplete.finishTime = std::nullopt;
// TODO: This is a dirty ad hoc solution for autocomplete timeouts
// We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit
// so that we'll have type information for the whole file at lower quality instead of a full abort in the middle
if (FInt::LuauTarjanChildLimit > 0)
typeCheckerForAutocomplete.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt;
if (FInt::LuauTypeInferIterationLimit > 0)
typeCheckerForAutocomplete.unifierIterationLimit =
std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt;
ModulePtr moduleForAutocomplete = FFlag::DebugLuauDeferredConstraintResolution
? check(sourceModule, mode, environmentScope, requireCycles, /*forAutocomplete*/ true)
: typeCheckerForAutocomplete.check(sourceModule, Mode::Strict, environmentScope);
moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete;
double duration = getTimestamp() - timestamp;
if (moduleForAutocomplete->timeout)
{
checkResult.timeoutHits.push_back(moduleName);
sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
}
else if (duration < autocompleteTimeLimit / 2.0)
{
sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0);
}
stats.timeCheck += duration;
stats.filesStrict += 1;
sourceNode.dirtyModuleForAutocomplete = false;
continue;
} }
else if (options.retainFullTypeGraphs && options.typecheckTwice && mode != Mode::Strict)
{
ModulePtr strictModule = typeChecker.check(sourceModule, Mode::Strict, environmentScope);
module->astTypes.clear();
module->astOriginalCallTypes.clear();
module->astExpectedTypes.clear();
SeenTypes seenTypes; typeChecker.requireCycles = requireCycles;
SeenTypePacks seenTypePacks;
for (const auto& [expr, strictTy] : strictModule->astTypes) ModulePtr module = FFlag::DebugLuauDeferredConstraintResolution ? check(sourceModule, mode, environmentScope, requireCycles)
module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); : typeChecker.check(sourceModule, mode, environmentScope);
for (const auto& [expr, strictTy] : strictModule->astOriginalCallTypes)
module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks);
for (const auto& [expr, strictTy] : strictModule->astExpectedTypes)
module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks);
}
stats.timeCheck += getTimestamp() - timestamp; stats.timeCheck += getTimestamp() - timestamp;
stats.filesStrict += mode == Mode::Strict; stats.filesStrict += mode == Mode::Strict;
stats.filesNonstrict += mode == Mode::Nonstrict; stats.filesNonstrict += mode == Mode::Nonstrict;
if (module == nullptr) if (module == nullptr)
throw std::runtime_error("Frontend::check produced a nullptr module for " + moduleName); throw InternalCompilerError("Frontend::check produced a nullptr module for " + moduleName, moduleName);
if (!options.retainFullTypeGraphs) if (!frontendOptions.retainFullTypeGraphs)
{ {
// copyErrors needs to allocate into interfaceTypes as it copies // copyErrors needs to allocate into interfaceTypes as it copies
// types out of internalTypes, so we unfreeze it here. // types out of internalTypes, so we unfreeze it here.
@ -449,6 +568,9 @@ CheckResult Frontend::check(const ModuleName& name)
module->astTypes.clear(); module->astTypes.clear();
module->astExpectedTypes.clear(); module->astExpectedTypes.clear();
module->astOriginalCallTypes.clear(); module->astOriginalCallTypes.clear();
module->astResolvedTypes.clear();
module->astResolvedTypePacks.clear();
module->scopes.resize(1);
} }
if (mode != Mode::NoCheck) if (mode != Mode::NoCheck)
@ -471,14 +593,17 @@ CheckResult Frontend::check(const ModuleName& name)
checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end()); checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end());
moduleResolver.modules[moduleName] = std::move(module); moduleResolver.modules[moduleName] = std::move(module);
sourceNode.dirty = false; sourceNode.dirtyModule = false;
} }
return checkResult; return checkResult;
} }
bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& checkResult, const ModuleName& root) bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend");
LUAU_TIMETRACE_ARGUMENT("root", root.c_str());
// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search // https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
enum Mark enum Mark
{ {
@ -536,7 +661,7 @@ bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& chec
path.push_back(top); path.push_back(top);
// push children // push children
for (const ModuleName& dep : top->requires) for (const ModuleName& dep : top->requireSet)
{ {
auto it = sourceNodes.find(dep); auto it = sourceNodes.find(dep);
if (it != sourceNodes.end()) if (it != sourceNodes.end())
@ -545,7 +670,7 @@ bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& chec
// this relies on the fact that markDirty marks reverse-dependencies dirty as well // this relies on the fact that markDirty marks reverse-dependencies dirty as well
// thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need
// to be built, *and* can't form a cycle with any nodes we did process. // to be built, *and* can't form a cycle with any nodes we did process.
if (!it->second.dirty) if (!it->second.hasDirtyModule(forAutocomplete))
continue; continue;
// note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization
@ -572,9 +697,13 @@ bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& chec
return cyclic; return cyclic;
} }
ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config) ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete)
{ {
ScopePtr result = typeChecker.globalScope; ScopePtr result;
if (forAutocomplete)
result = typeCheckerForAutocomplete.globalScope;
else
result = typeChecker.globalScope;
if (module.environmentName) if (module.environmentName)
result = getEnvironmentScope(*module.environmentName); result = getEnvironmentScope(*module.environmentName);
@ -597,6 +726,9 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config
LintResult Frontend::lint(const ModuleName& name, std::optional<Luau::LintOptions> enabledLintWarnings) LintResult Frontend::lint(const ModuleName& name, std::optional<Luau::LintOptions> enabledLintWarnings)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
CheckResult checkResult; CheckResult checkResult;
auto [_sourceNode, sourceModule] = getSourceNode(checkResult, name); auto [_sourceNode, sourceModule] = getSourceNode(checkResult, name);
@ -606,52 +738,17 @@ LintResult Frontend::lint(const ModuleName& name, std::optional<Luau::LintOption
return lint(*sourceModule, enabledLintWarnings); return lint(*sourceModule, enabledLintWarnings);
} }
std::pair<SourceModule, LintResult> Frontend::lintFragment(std::string_view source, std::optional<Luau::LintOptions> enabledLintWarnings)
{
const Config& config = configResolver->getConfig("");
SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions);
Luau::LintOptions lintOptions = enabledLintWarnings.value_or(config.enabledLint);
lintOptions.warningMask &= sourceModule.ignoreLints;
double timestamp = getTimestamp();
std::vector<LintWarning> warnings =
Luau::lint(sourceModule.root, *sourceModule.names.get(), typeChecker.globalScope, nullptr, enabledLintWarnings.value_or(config.enabledLint));
stats.timeLint += getTimestamp() - timestamp;
return {std::move(sourceModule), classifyLints(warnings, config)};
}
CheckResult Frontend::check(const SourceModule& module)
{
const Config& config = configResolver->getConfig(module.name);
Mode mode = module.mode.value_or(config.mode);
double timestamp = getTimestamp();
ModulePtr checkedModule = typeChecker.check(module, mode);
stats.timeCheck += getTimestamp() - timestamp;
stats.filesStrict += mode == Mode::Strict;
stats.filesNonstrict += mode == Mode::Nonstrict;
if (checkedModule == nullptr)
throw std::runtime_error("Frontend::check produced a nullptr module for module " + module.name);
moduleResolver.modules[module.name] = checkedModule;
return CheckResult{checkedModule->errors};
}
LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings) LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");
LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str());
const Config& config = configResolver->getConfig(module.name); const Config& config = configResolver->getConfig(module.name);
uint64_t ignoreLints = LintWarning::parseMask(module.hotcomments);
LintOptions options = enabledLintWarnings.value_or(config.enabledLint); LintOptions options = enabledLintWarnings.value_or(config.enabledLint);
options.warningMask &= ~module.ignoreLints; options.warningMask &= ~ignoreLints;
Mode mode = module.mode.value_or(config.mode); Mode mode = module.mode.value_or(config.mode);
if (mode != Mode::NoCheck) if (mode != Mode::NoCheck)
@ -670,17 +767,17 @@ LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOp
double timestamp = getTimestamp(); double timestamp = getTimestamp();
std::vector<LintWarning> warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), options); std::vector<LintWarning> warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), module.hotcomments, options);
stats.timeLint += getTimestamp() - timestamp; stats.timeLint += getTimestamp() - timestamp;
return classifyLints(warnings, config); return classifyLints(warnings, config);
} }
bool Frontend::isDirty(const ModuleName& name) const bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
{ {
auto it = sourceNodes.find(name); auto it = sourceNodes.find(name);
return it == sourceNodes.end() || it->second.dirty; return it == sourceNodes.end() || it->second.hasDirtyModule(forAutocomplete);
} }
/* /*
@ -691,13 +788,13 @@ bool Frontend::isDirty(const ModuleName& name) const
*/ */
void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty) void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty)
{ {
if (!moduleResolver.modules.count(name)) if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name))
return; return;
std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps; std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps;
for (const auto& module : sourceNodes) for (const auto& module : sourceNodes)
{ {
for (const auto& dep : module.second.requires) for (const auto& dep : module.second.requireSet)
reverseDeps[dep].push_back(module.first); reverseDeps[dep].push_back(module.first);
} }
@ -714,17 +811,19 @@ void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* marked
if (markedDirty) if (markedDirty)
markedDirty->push_back(next); markedDirty->push_back(next);
if (sourceNode.dirty) if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete)
continue; continue;
sourceNode.dirty = true; sourceNode.dirtySourceModule = true;
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
if (0 == reverseDeps.count(name)) if (0 == reverseDeps.count(next))
continue; continue;
sourceModules.erase(name); sourceModules.erase(next);
const std::vector<ModuleName>& dependents = reverseDeps[name]; const std::vector<ModuleName>& dependents = reverseDeps[next];
queue.insert(queue.end(), dependents.begin(), dependents.end()); queue.insert(queue.end(), dependents.begin(), dependents.end());
} }
} }
@ -743,11 +842,95 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons
return const_cast<Frontend*>(this)->getSourceModule(moduleName); return const_cast<Frontend*>(this)->getSourceModule(moduleName);
} }
ScopePtr Frontend::getGlobalScope()
{
if (!globalScope)
{
globalScope = typeChecker.globalScope;
}
return globalScope;
}
ModulePtr Frontend::check(
const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope, std::vector<RequireCycle> requireCycles, bool forAutocomplete)
{
ModulePtr result = std::make_shared<Module>();
std::unique_ptr<DcrLogger> logger;
if (FFlag::DebugLuauLogSolverToJson)
{
logger = std::make_unique<DcrLogger>();
std::optional<SourceCode> source = fileResolver->readSource(sourceModule.name);
if (source)
{
logger->captureSource(source->source);
}
}
DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&iceHandler});
const NotNull<ModuleResolver> mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver};
const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope};
Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}};
ConstraintGraphBuilder cgb{
sourceModule.name,
result,
&result->internalTypes,
mr,
singletonTypes,
NotNull(&iceHandler),
globalScope,
logger.get(),
NotNull{&dfg},
};
cgb.visit(sourceModule.root);
result->errors = std::move(cgb.errors);
ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, NotNull(&moduleResolver),
requireCycles, logger.get()};
if (options.randomizeConstraintResolutionSeed)
cs.randomize(*options.randomizeConstraintResolutionSeed);
cs.run();
for (TypeError& e : cs.errors)
result->errors.emplace_back(std::move(e));
result->scopes = std::move(cgb.scopes);
result->astTypes = std::move(cgb.astTypes);
result->astTypePacks = std::move(cgb.astTypePacks);
result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes);
result->astOverloadResolvedTypes = std::move(cgb.astOverloadResolvedTypes);
result->astResolvedTypes = std::move(cgb.astResolvedTypes);
result->astResolvedTypePacks = std::move(cgb.astResolvedTypePacks);
result->type = sourceModule.type;
Luau::check(singletonTypes, logger.get(), sourceModule, result.get());
if (FFlag::DebugLuauLogSolverToJson)
{
std::string output = logger->compileOutput();
printf("%s\n", output.c_str());
}
result->clonePublicInterface(singletonTypes, iceHandler);
return result;
}
// Read AST into sourceModules if necessary. Trace require()s. Report parse errors. // Read AST into sourceModules if necessary. Trace require()s. Report parse errors.
std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
auto it = sourceNodes.find(name); auto it = sourceNodes.find(name);
if (it != sourceNodes.end() && !it->second.dirty) if (it != sourceNodes.end() && !it->second.hasDirtySourceModule())
{ {
auto moduleIt = sourceModules.find(name); auto moduleIt = sourceModules.find(name);
if (moduleIt != sourceModules.end()) if (moduleIt != sourceModules.end())
@ -778,8 +961,8 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& check
SourceModule result = parse(name, source->source, opts); SourceModule result = parse(name, source->source, opts);
result.type = source->type; result.type = source->type;
RequireTraceResult& requireTrace = requires[name]; RequireTraceResult& require = requireTrace[name];
requireTrace = traceRequires(fileResolver, result.root, name); require = traceRequires(fileResolver, result.root, name);
SourceNode& sourceNode = sourceNodes[name]; SourceNode& sourceNode = sourceNodes[name];
SourceModule& sourceModule = sourceModules[name]; SourceModule& sourceModule = sourceModules[name];
@ -788,14 +971,20 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& check
sourceModule.environmentName = environmentName; sourceModule.environmentName = environmentName;
sourceNode.name = name; sourceNode.name = name;
sourceNode.requires.clear(); sourceNode.requireSet.clear();
sourceNode.requireLocations.clear(); sourceNode.requireLocations.clear();
sourceNode.dirty = true; sourceNode.dirtySourceModule = false;
for (const auto& [moduleName, location] : requireTrace.requires) if (it == sourceNodes.end())
sourceNode.requires.insert(moduleName); {
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
}
sourceNode.requireLocations = requireTrace.requires; for (const auto& [moduleName, location] : require.requireList)
sourceNode.requireSet.insert(moduleName);
sourceNode.requireLocations = require.requireList;
return {&sourceNode, &sourceModule}; return {&sourceNode, &sourceModule};
} }
@ -815,15 +1004,18 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& check
*/ */
SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions) SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
SourceModule sourceModule; SourceModule sourceModule;
double timestamp = getTimestamp(); double timestamp = getTimestamp();
auto parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); Luau::ParseResult parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions);
stats.timeParse += getTimestamp() - timestamp; stats.timeParse += getTimestamp() - timestamp;
stats.files++; stats.files++;
stats.lines += std::count(src.begin(), src.end(), '\n') + (src.size() && src.back() != '\n'); stats.lines += parseResult.lines;
if (!parseResult.errors.empty()) if (!parseResult.errors.empty())
sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end()); sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end());
@ -832,7 +1024,6 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const
{ {
sourceModule.root = parseResult.root; sourceModule.root = parseResult.root;
sourceModule.mode = parseMode(parseResult.hotcomments); sourceModule.mode = parseMode(parseResult.hotcomments);
sourceModule.ignoreLints = LintWarning::parseMask(parseResult.hotcomments);
} }
else else
{ {
@ -841,43 +1032,36 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const
} }
sourceModule.name = name; sourceModule.name = name;
if (parseOptions.captureComments) if (parseOptions.captureComments)
{
sourceModule.commentLocations = std::move(parseResult.commentLocations); sourceModule.commentLocations = std::move(parseResult.commentLocations);
sourceModule.hotcomments = std::move(parseResult.hotcomments);
}
return sourceModule; return sourceModule;
} }
std::optional<ModuleInfo> FrontendModuleResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) std::optional<ModuleInfo> FrontendModuleResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr)
{ {
// FIXME I think this can be pushed into the FileResolver. // FIXME I think this can be pushed into the FileResolver.
auto it = frontend->requires.find(currentModuleName); auto it = frontend->requireTrace.find(currentModuleName);
if (it == frontend->requires.end()) if (it == frontend->requireTrace.end())
{ {
// CLI-43699 // CLI-43699
// If we can't find the current module name, that's because we bypassed the frontend's initializer // If we can't find the current module name, that's because we bypassed the frontend's initializer
// and called typeChecker.check directly. (This is done by autocompleteSource, for example). // and called typeChecker.check directly.
// In that case, requires will always fail. // In that case, requires will always fail.
if (FFlag::LuauResolveModuleNameWithoutACurrentModule) return std::nullopt;
return std::nullopt;
else
throw std::runtime_error("Frontend::resolveModuleName: Unknown currentModuleName '" + currentModuleName + "'");
} }
const auto& exprs = it->second.exprs; const auto& exprs = it->second.exprs;
const ModuleName* relativeName = exprs.find(&pathExpr); const ModuleInfo* info = exprs.find(&pathExpr);
if (!relativeName || relativeName->empty()) if (!info)
return std::nullopt; return std::nullopt;
if (FFlag::LuauTraceRequireLookupChild) return *info;
{
const bool* optional = it->second.optional.find(&pathExpr);
return {{*relativeName, optional ? *optional : false}};
}
else
{
return {{*relativeName, false}};
}
} }
const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const
@ -891,12 +1075,12 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName)
bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const
{ {
return frontend->fileResolver->moduleExists(moduleName); return frontend->sourceNodes.count(moduleName) != 0;
} }
std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const
{ {
return frontend->fileResolver->getHumanReadableModuleName_(moduleName).value_or(moduleName); return frontend->fileResolver->getHumanReadableModuleName(moduleName);
} }
ScopePtr Frontend::addEnvironment(const std::string& environmentName) ScopePtr Frontend::addEnvironment(const std::string& environmentName)
@ -961,7 +1145,7 @@ void Frontend::clear()
sourceModules.clear(); sourceModules.clear();
moduleResolver.modules.clear(); moduleResolver.modules.clear();
moduleResolverForAutocomplete.modules.clear(); moduleResolverForAutocomplete.modules.clear();
requires.clear(); requireTrace.clear();
} }
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,131 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Common.h"
#include "Luau/Instantiation.h"
#include "Luau/TxnLog.h"
#include "Luau/TypeArena.h"
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau
{
bool Instantiation::isDirty(TypeId ty)
{
if (const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty))
{
if (ftv->hasNoGenerics)
return false;
return true;
}
else
{
return false;
}
}
bool Instantiation::isDirty(TypePackId tp)
{
return false;
}
bool Instantiation::ignoreChildren(TypeId ty)
{
if (log->getMutable<FunctionTypeVar>(ty))
return true;
else if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassTypeVar>(ty))
return true;
else
return false;
}
TypeId Instantiation::clean(TypeId ty)
{
const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv);
FunctionTypeVar clone = FunctionTypeVar{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf};
clone.magicFunction = ftv->magicFunction;
clone.dcrMagicFunction = ftv->dcrMagicFunction;
clone.tags = ftv->tags;
clone.argNames = ftv->argNames;
TypeId result = addType(std::move(clone));
// Annoyingly, we have to do this even if there are no generics,
// to replace any generic tables.
ReplaceGenerics replaceGenerics{log, arena, level, scope, ftv->generics, ftv->genericPacks};
// TODO: What to do if this returns nullopt?
// We don't have access to the error-reporting machinery
result = replaceGenerics.substitute(result).value_or(result);
asMutable(result)->documentationSymbol = ty->documentationSymbol;
return result;
}
TypePackId Instantiation::clean(TypePackId tp)
{
LUAU_ASSERT(false);
return tp;
}
bool ReplaceGenerics::ignoreChildren(TypeId ty)
{
if (const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty))
{
if (ftv->hasNoGenerics)
return true;
// We aren't recursing in the case of a generic function which
// binds the same generics. This can happen if, for example, there's recursive types.
// If T = <a>(a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'.
// It's OK to use vector equality here, since we always generate fresh generics
// whenever we quantify, so the vectors overlap if and only if they are equal.
return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks);
}
else if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassTypeVar>(ty))
return true;
else
{
return false;
}
}
bool ReplaceGenerics::isDirty(TypeId ty)
{
if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
return ttv->state == TableState::Generic;
else if (log->getMutable<GenericTypeVar>(ty))
return std::find(generics.begin(), generics.end(), ty) != generics.end();
else
return false;
}
bool ReplaceGenerics::isDirty(TypePackId tp)
{
if (log->getMutable<GenericTypePack>(tp))
return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end();
else
return false;
}
TypeId ReplaceGenerics::clean(TypeId ty)
{
LUAU_ASSERT(isDirty(ty));
if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
{
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, scope, TableState::Free};
clone.definitionModuleName = ttv->definitionModuleName;
return addType(std::move(clone));
}
else
return addType(FreeTypeVar{scope, level});
}
TypePackId ReplaceGenerics::clean(TypePackId tp)
{
LUAU_ASSERT(isDirty(tp));
return addTypePack(TypePackVar(FreeTypePack{level}));
}
} // namespace Luau

View file

@ -23,232 +23,191 @@ std::ostream& operator<<(std::ostream& stream, const AstName& name)
return stream << "<empty>"; return stream << "<empty>";
} }
std::ostream& operator<<(std::ostream& stream, const TypeMismatch& tm) template<typename T>
static void errorToString(std::ostream& stream, const T& err)
{ {
return stream << "TypeMismatch { " << toString(tm.wantedType) << ", " << toString(tm.givenType) << " }"; if constexpr (false)
}
std::ostream& operator<<(std::ostream& stream, const TypeError& error)
{
return stream << "TypeError { \"" << error.moduleName << "\", " << error.location << ", " << error.data << " }";
}
std::ostream& operator<<(std::ostream& stream, const UnknownSymbol& error)
{
return stream << "UnknownSymbol { " << error.name << " , context " << error.context << " }";
}
std::ostream& operator<<(std::ostream& stream, const UnknownProperty& error)
{
return stream << "UnknownProperty { " << toString(error.table) << ", key = " << error.key << " }";
}
std::ostream& operator<<(std::ostream& stream, const NotATable& ge)
{
return stream << "NotATable { " << toString(ge.ty) << " }";
}
std::ostream& operator<<(std::ostream& stream, const CannotExtendTable& error)
{
return stream << "CannotExtendTable { " << toString(error.tableType) << ", context " << error.context << ", prop \"" << error.prop << "\" }";
}
std::ostream& operator<<(std::ostream& stream, const OnlyTablesCanHaveMethods& error)
{
return stream << "OnlyTablesCanHaveMethods { " << toString(error.tableType) << " }";
}
std::ostream& operator<<(std::ostream& stream, const DuplicateTypeDefinition& error)
{
return stream << "DuplicateTypeDefinition { " << error.name << " }";
}
std::ostream& operator<<(std::ostream& stream, const CountMismatch& error)
{
return stream << "CountMismatch { expected " << error.expected << ", got " << error.actual << ", context " << error.context << " }";
}
std::ostream& operator<<(std::ostream& stream, const FunctionDoesNotTakeSelf&)
{
return stream << "FunctionDoesNotTakeSelf { }";
}
std::ostream& operator<<(std::ostream& stream, const FunctionRequiresSelf& error)
{
return stream << "FunctionRequiresSelf { extraNils " << error.requiredExtraNils << " }";
}
std::ostream& operator<<(std::ostream& stream, const OccursCheckFailed&)
{
return stream << "OccursCheckFailed { }";
}
std::ostream& operator<<(std::ostream& stream, const UnknownRequire& error)
{
return stream << "UnknownRequire { " << error.modulePath << " }";
}
std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCount& error)
{
stream << "IncorrectGenericParameterCount { name = " << error.name;
if (!error.typeFun.typeParams.empty())
{ {
stream << "<"; }
else if constexpr (std::is_same_v<T, TypeMismatch>)
stream << "TypeMismatch { " << toString(err.wantedType) << ", " << toString(err.givenType) << " }";
else if constexpr (std::is_same_v<T, UnknownSymbol>)
stream << "UnknownSymbol { " << err.name << " , context " << err.context << " }";
else if constexpr (std::is_same_v<T, UnknownProperty>)
stream << "UnknownProperty { " << toString(err.table) << ", key = " << err.key << " }";
else if constexpr (std::is_same_v<T, NotATable>)
stream << "NotATable { " << toString(err.ty) << " }";
else if constexpr (std::is_same_v<T, CannotExtendTable>)
stream << "CannotExtendTable { " << toString(err.tableType) << ", context " << err.context << ", prop \"" << err.prop << "\" }";
else if constexpr (std::is_same_v<T, OnlyTablesCanHaveMethods>)
stream << "OnlyTablesCanHaveMethods { " << toString(err.tableType) << " }";
else if constexpr (std::is_same_v<T, DuplicateTypeDefinition>)
stream << "DuplicateTypeDefinition { " << err.name << " }";
else if constexpr (std::is_same_v<T, CountMismatch>)
stream << "CountMismatch { expected " << err.expected << ", got " << err.actual << ", context " << err.context << " }";
else if constexpr (std::is_same_v<T, FunctionDoesNotTakeSelf>)
stream << "FunctionDoesNotTakeSelf { }";
else if constexpr (std::is_same_v<T, FunctionRequiresSelf>)
stream << "FunctionRequiresSelf { }";
else if constexpr (std::is_same_v<T, OccursCheckFailed>)
stream << "OccursCheckFailed { }";
else if constexpr (std::is_same_v<T, UnknownRequire>)
stream << "UnknownRequire { " << err.modulePath << " }";
else if constexpr (std::is_same_v<T, IncorrectGenericParameterCount>)
{
stream << "IncorrectGenericParameterCount { name = " << err.name;
if (!err.typeFun.typeParams.empty() || !err.typeFun.typePackParams.empty())
{
stream << "<";
bool first = true;
for (auto param : err.typeFun.typeParams)
{
if (first)
first = false;
else
stream << ", ";
stream << toString(param.ty);
}
for (auto param : err.typeFun.typePackParams)
{
if (first)
first = false;
else
stream << ", ";
stream << toString(param.tp);
}
stream << ">";
}
stream << ", typeFun = " << toString(err.typeFun.type) << ", actualCount = " << err.actualParameters << " }";
}
else if constexpr (std::is_same_v<T, SyntaxError>)
stream << "SyntaxError { " << err.message << " }";
else if constexpr (std::is_same_v<T, CodeTooComplex>)
stream << "CodeTooComplex {}";
else if constexpr (std::is_same_v<T, UnificationTooComplex>)
stream << "UnificationTooComplex {}";
else if constexpr (std::is_same_v<T, UnknownPropButFoundLikeProp>)
{
stream << "UnknownPropButFoundLikeProp { key = '" << err.key << "', suggested = { ";
bool first = true; bool first = true;
for (TypeId t : error.typeFun.typeParams) for (Name name : err.candidates)
{ {
if (first) if (first)
first = false; first = false;
else else
stream << ", "; stream << ", ";
stream << toString(t); stream << "'" << name << "'";
} }
stream << ">";
}
stream << ", typeFun = " << toString(error.typeFun.type) << ", actualCount = " << error.actualParameters << " }"; stream << " }, table = " << toString(err.table) << " } ";
}
else if constexpr (std::is_same_v<T, GenericError>)
stream << "GenericError { " << err.message << " }";
else if constexpr (std::is_same_v<T, InternalError>)
stream << "InternalError { " << err.message << " }";
else if constexpr (std::is_same_v<T, CannotCallNonFunction>)
stream << "CannotCallNonFunction { " << toString(err.ty) << " }";
else if constexpr (std::is_same_v<T, ExtraInformation>)
stream << "ExtraInformation { " << err.message << " }";
else if constexpr (std::is_same_v<T, DeprecatedApiUsed>)
stream << "DeprecatedApiUsed { " << err.symbol << ", useInstead = " << err.useInstead << " }";
else if constexpr (std::is_same_v<T, ModuleHasCyclicDependency>)
{
stream << "ModuleHasCyclicDependency {";
bool first = true;
for (const ModuleName& name : err.cycle)
{
if (first)
first = false;
else
stream << ", ";
stream << name;
}
stream << "}";
}
else if constexpr (std::is_same_v<T, IllegalRequire>)
stream << "IllegalRequire { " << err.moduleName << ", reason = " << err.reason << " }";
else if constexpr (std::is_same_v<T, FunctionExitsWithoutReturning>)
stream << "FunctionExitsWithoutReturning {" << toString(err.expectedReturnType) << "}";
else if constexpr (std::is_same_v<T, DuplicateGenericParameter>)
stream << "DuplicateGenericParameter { " + err.parameterName + " }";
else if constexpr (std::is_same_v<T, CannotInferBinaryOperation>)
stream << "CannotInferBinaryOperation { op = " + toString(err.op) + ", suggested = '" +
(err.suggestedToAnnotate ? *err.suggestedToAnnotate : "") + "', kind "
<< err.kind << "}";
else if constexpr (std::is_same_v<T, MissingProperties>)
{
stream << "MissingProperties { superType = '" << toString(err.superType) << "', subType = '" << toString(err.subType) << "', properties = { ";
bool first = true;
for (Name name : err.properties)
{
if (first)
first = false;
else
stream << ", ";
stream << "'" << name << "'";
}
stream << " }, context " << err.context << " } ";
}
else if constexpr (std::is_same_v<T, SwappedGenericTypeParameter>)
stream << "SwappedGenericTypeParameter { name = '" + err.name + "', kind = " + std::to_string(err.kind) + " }";
else if constexpr (std::is_same_v<T, OptionalValueAccess>)
stream << "OptionalValueAccess { optional = '" + toString(err.optional) + "' }";
else if constexpr (std::is_same_v<T, MissingUnionProperty>)
{
stream << "MissingUnionProperty { type = '" + toString(err.type) + "', missing = { ";
bool first = true;
for (auto ty : err.missing)
{
if (first)
first = false;
else
stream << ", ";
stream << "'" << toString(ty) << "'";
}
stream << " }, key = '" + err.key + "' }";
}
else if constexpr (std::is_same_v<T, TypesAreUnrelated>)
stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }";
else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
stream << "NormalizationTooComplex { }";
else if constexpr (std::is_same_v<T, TypePackMismatch>)
stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }";
else if constexpr (std::is_same_v<T, DynamicPropertyLookupOnClassesUnsafe>)
stream << "DynamicPropertyLookupOnClassesUnsafe { " << toString(err.ty) << " }";
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}
std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data)
{
auto cb = [&](const auto& e) {
return errorToString(stream, e);
};
visit(cb, data);
return stream; return stream;
} }
std::ostream& operator<<(std::ostream& stream, const SyntaxError& ge) std::ostream& operator<<(std::ostream& stream, const TypeError& error)
{ {
return stream << "SyntaxError { " << ge.message << " }"; return stream << "TypeError { \"" << error.moduleName << "\", " << error.location << ", " << error.data << " }";
}
std::ostream& operator<<(std::ostream& stream, const CodeTooComplex&)
{
return stream << "CodeTooComplex {}";
}
std::ostream& operator<<(std::ostream& stream, const UnificationTooComplex&)
{
return stream << "UnificationTooComplex {}";
}
std::ostream& operator<<(std::ostream& stream, const UnknownPropButFoundLikeProp& e)
{
stream << "UnknownPropButFoundLikeProp { key = '" << e.key << "', suggested = { ";
bool first = true;
for (Name name : e.candidates)
{
if (first)
first = false;
else
stream << ", ";
stream << "'" << name << "'";
}
return stream << " }, table = " << toString(e.table) << " } ";
}
std::ostream& operator<<(std::ostream& stream, const GenericError& ge)
{
return stream << "GenericError { " << ge.message << " }";
}
std::ostream& operator<<(std::ostream& stream, const CannotCallNonFunction& e)
{
return stream << "CannotCallNonFunction { " << toString(e.ty) << " }";
}
std::ostream& operator<<(std::ostream& stream, const FunctionExitsWithoutReturning& error)
{
return stream << "FunctionExitsWithoutReturning {" << toString(error.expectedReturnType) << "}";
}
std::ostream& operator<<(std::ostream& stream, const ExtraInformation& e)
{
return stream << "ExtraInformation { " << e.message << " }";
}
std::ostream& operator<<(std::ostream& stream, const DeprecatedApiUsed& e)
{
return stream << "DeprecatedApiUsed { " << e.symbol << ", useInstead = " << e.useInstead << " }";
}
std::ostream& operator<<(std::ostream& stream, const ModuleHasCyclicDependency& e)
{
stream << "ModuleHasCyclicDependency {";
bool first = true;
for (const ModuleName& name : e.cycle)
{
if (first)
first = false;
else
stream << ", ";
stream << name;
}
return stream << "}";
}
std::ostream& operator<<(std::ostream& stream, const IllegalRequire& e)
{
return stream << "IllegalRequire { " << e.moduleName << ", reason = " << e.reason << " }";
}
std::ostream& operator<<(std::ostream& stream, const MissingProperties& e)
{
stream << "MissingProperties { superType = '" << toString(e.superType) << "', subType = '" << toString(e.subType) << "', properties = { ";
bool first = true;
for (Name name : e.properties)
{
if (first)
first = false;
else
stream << ", ";
stream << "'" << name << "'";
}
return stream << " }, context " << e.context << " } ";
}
std::ostream& operator<<(std::ostream& stream, const DuplicateGenericParameter& error)
{
return stream << "DuplicateGenericParameter { " + error.parameterName + " }";
}
std::ostream& operator<<(std::ostream& stream, const CannotInferBinaryOperation& error)
{
return stream << "CannotInferBinaryOperation { op = " + toString(error.op) + ", suggested = '" +
(error.suggestedToAnnotate ? *error.suggestedToAnnotate : "") + "', kind "
<< error.kind << "}";
}
std::ostream& operator<<(std::ostream& stream, const SwappedGenericTypeParameter& error)
{
return stream << "SwappedGenericTypeParameter { name = '" + error.name + "', kind = " + std::to_string(error.kind) + " }";
}
std::ostream& operator<<(std::ostream& stream, const OptionalValueAccess& error)
{
return stream << "OptionalValueAccess { optional = '" + toString(error.optional) + "' }";
}
std::ostream& operator<<(std::ostream& stream, const MissingUnionProperty& error)
{
stream << "MissingUnionProperty { type = '" + toString(error.type) + "', missing = { ";
bool first = true;
for (auto ty : error.missing)
{
if (first)
first = false;
else
stream << ", ";
stream << "'" << toString(ty) << "'";
}
return stream << " }, key = '" + error.key + "' }";
} }
std::ostream& operator<<(std::ostream& stream, const TableState& tv) std::ostream& operator<<(std::ostream& stream, const TableState& tv)
@ -266,15 +225,4 @@ std::ostream& operator<<(std::ostream& stream, const TypePackVar& tv)
return stream << toString(tv); return stream << toString(tv);
} }
std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted)
{
Luau::visit(
[&](const auto& a) {
lhs << a;
},
ted);
return lhs;
}
} // namespace Luau } // namespace Luau

View file

@ -0,0 +1,222 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/JsonEmitter.h"
#include "Luau/StringUtils.h"
#include <string.h>
namespace Luau::Json
{
static constexpr int CHUNK_SIZE = 1024;
ObjectEmitter::ObjectEmitter(NotNull<JsonEmitter> emitter)
: emitter(emitter)
, finished(false)
{
comma = emitter->pushComma();
emitter->writeRaw('{');
}
ObjectEmitter::~ObjectEmitter()
{
finish();
}
void ObjectEmitter::finish()
{
if (finished)
return;
emitter->writeRaw('}');
emitter->popComma(comma);
finished = true;
}
ArrayEmitter::ArrayEmitter(NotNull<JsonEmitter> emitter)
: emitter(emitter)
, finished(false)
{
comma = emitter->pushComma();
emitter->writeRaw('[');
}
ArrayEmitter::~ArrayEmitter()
{
finish();
}
void ArrayEmitter::finish()
{
if (finished)
return;
emitter->writeRaw(']');
emitter->popComma(comma);
finished = true;
}
JsonEmitter::JsonEmitter()
{
newChunk();
}
std::string JsonEmitter::str()
{
return join(chunks, "");
}
bool JsonEmitter::pushComma()
{
bool current = comma;
comma = false;
return current;
}
void JsonEmitter::popComma(bool c)
{
comma = c;
}
void JsonEmitter::writeRaw(std::string_view sv)
{
if (sv.size() > CHUNK_SIZE)
{
chunks.emplace_back(sv);
newChunk();
return;
}
auto& chunk = chunks.back();
if (chunk.size() + sv.size() < CHUNK_SIZE)
{
chunk.append(sv.data(), sv.size());
return;
}
size_t prefix = CHUNK_SIZE - chunk.size();
chunk.append(sv.data(), prefix);
newChunk();
chunks.back().append(sv.data() + prefix, sv.size() - prefix);
}
void JsonEmitter::writeRaw(char c)
{
writeRaw(std::string_view{&c, 1});
}
void write(JsonEmitter& emitter, bool b)
{
if (b)
emitter.writeRaw("true");
else
emitter.writeRaw("false");
}
void write(JsonEmitter& emitter, double d)
{
emitter.writeRaw(std::to_string(d));
}
void write(JsonEmitter& emitter, int i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, long i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, long long i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, unsigned int i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, unsigned long i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, unsigned long long i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, std::string_view sv)
{
emitter.writeRaw('\"');
for (char c : sv)
{
if (c == '"')
emitter.writeRaw("\\\"");
else if (c == '\\')
emitter.writeRaw("\\\\");
else if (c == '\n')
emitter.writeRaw("\\n");
else if (c < ' ')
emitter.writeRaw(format("\\u%04x", c));
else
emitter.writeRaw(c);
}
emitter.writeRaw('\"');
}
void write(JsonEmitter& emitter, char c)
{
write(emitter, std::string_view{&c, 1});
}
void write(JsonEmitter& emitter, const char* str)
{
write(emitter, std::string_view{str, strlen(str)});
}
void write(JsonEmitter& emitter, const std::string& str)
{
write(emitter, std::string_view{str});
}
void write(JsonEmitter& emitter, std::nullptr_t)
{
emitter.writeRaw("null");
}
void write(JsonEmitter& emitter, std::nullopt_t)
{
emitter.writeRaw("null");
}
void JsonEmitter::writeComma()
{
if (comma)
writeRaw(',');
else
comma = true;
}
ObjectEmitter JsonEmitter::writeObject()
{
return ObjectEmitter{NotNull(this)};
}
ArrayEmitter JsonEmitter::writeArray()
{
return ArrayEmitter{NotNull(this)};
}
void JsonEmitter::newChunk()
{
chunks.emplace_back();
chunks.back().reserve(CHUNK_SIZE);
}
} // namespace Luau::Json

107
Analysis/src/LValue.cpp Normal file
View file

@ -0,0 +1,107 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/LValue.h"
#include "Luau/Ast.h"
#include <vector>
namespace Luau
{
bool Field::operator==(const Field& rhs) const
{
LUAU_ASSERT(parent && rhs.parent);
return key == rhs.key && (parent == rhs.parent || *parent == *rhs.parent);
}
bool Field::operator!=(const Field& rhs) const
{
return !(*this == rhs);
}
size_t LValueHasher::operator()(const LValue& lvalue) const
{
// Most likely doesn't produce high quality hashes, but we're probably ok enough with it.
// When an evidence is shown that operator==(LValue) is used more often than it should, we can have a look at improving the hash quality.
size_t acc = 0;
size_t offset = 0;
const LValue* current = &lvalue;
while (current)
{
if (auto field = get<Field>(*current))
acc ^= (std::hash<std::string>{}(field->key) << 1) >> ++offset;
else if (auto symbol = get<Symbol>(*current))
acc ^= std::hash<Symbol>{}(*symbol) << 1;
else
LUAU_ASSERT(!"Hash not accumulated for this new LValue alternative.");
current = baseof(*current);
}
return acc;
}
const LValue* baseof(const LValue& lvalue)
{
if (auto field = get<Field>(lvalue))
return field->parent.get();
auto symbol = get<Symbol>(lvalue);
LUAU_ASSERT(symbol);
return nullptr; // Base of root is null.
}
std::optional<LValue> tryGetLValue(const AstExpr& node)
{
const AstExpr* expr = &node;
while (auto e = expr->as<AstExprGroup>())
expr = e->expr;
if (auto local = expr->as<AstExprLocal>())
return Symbol{local->local};
else if (auto global = expr->as<AstExprGlobal>())
return Symbol{global->name};
else if (auto indexname = expr->as<AstExprIndexName>())
{
if (auto lvalue = tryGetLValue(*indexname->expr))
return Field{std::make_shared<LValue>(*lvalue), indexname->index.value};
}
else if (auto indexexpr = expr->as<AstExprIndexExpr>())
{
if (auto lvalue = tryGetLValue(*indexexpr->expr))
if (auto string = indexexpr->index->as<AstExprConstantString>())
return Field{std::make_shared<LValue>(*lvalue), std::string(string->value.data, string->value.size)};
}
return std::nullopt;
}
Symbol getBaseSymbol(const LValue& lvalue)
{
const LValue* current = &lvalue;
while (auto field = get<Field>(*current))
current = baseof(*current);
const Symbol* symbol = get<Symbol>(*current);
LUAU_ASSERT(symbol);
return *symbol;
}
void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f)
{
for (const auto& [k, a] : r)
{
if (auto it = l.find(k); it != l.end())
l[k] = f(it->second, a);
else
l[k] = a;
}
}
void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty)
{
refis[lvalue] = ty;
}
} // namespace Luau

View file

@ -3,6 +3,7 @@
#include "Luau/AstQuery.h" #include "Luau/AstQuery.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include "Luau/Common.h" #include "Luau/Common.h"
@ -11,7 +12,8 @@
#include <math.h> #include <math.h>
#include <limits.h> #include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false) LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4)
LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false)
namespace Luau namespace Luau
{ {
@ -44,6 +46,10 @@ static const char* kWarningNames[] = {
"DeprecatedApi", "DeprecatedApi",
"TableOperations", "TableOperations",
"DuplicateCondition", "DuplicateCondition",
"MisleadingAndOr",
"CommentDirective",
"IntegerParsing",
"ComparisonPrecedence",
}; };
// clang-format on // clang-format on
@ -85,10 +91,10 @@ struct LintContext
return std::nullopt; return std::nullopt;
auto it = module->astTypes.find(expr); auto it = module->astTypes.find(expr);
if (it == module->astTypes.end()) if (!it)
return std::nullopt; return std::nullopt;
return it->second; return *it;
} }
}; };
@ -198,6 +204,26 @@ static bool similar(AstExpr* lhs, AstExpr* rhs)
return true; return true;
} }
CASE(AstExprIfElse) return similar(le->condition, re->condition) && similar(le->trueExpr, re->trueExpr) && similar(le->falseExpr, re->falseExpr);
CASE(AstExprInterpString)
{
if (le->strings.size != re->strings.size)
return false;
if (le->expressions.size != re->expressions.size)
return false;
for (size_t i = 0; i < le->strings.size; ++i)
if (le->strings.data[i].size != re->strings.data[i].size ||
memcmp(le->strings.data[i].data, re->strings.data[i].data, le->strings.data[i].size) != 0)
return false;
for (size_t i = 0; i < le->expressions.size; ++i)
if (!similar(le->expressions.data[i], re->expressions.data[i]))
return false;
return true;
}
else else
{ {
LUAU_ASSERT(!"Unknown expression type"); LUAU_ASSERT(!"Unknown expression type");
@ -229,6 +255,20 @@ public:
} }
private: private:
struct FunctionInfo
{
explicit FunctionInfo(AstExprFunction* ast)
: ast(ast)
, dominatedGlobals({})
, conditionalExecution(false)
{
}
AstExprFunction* ast;
DenseHashSet<AstName> dominatedGlobals;
bool conditionalExecution;
};
struct Global struct Global
{ {
AstExprGlobal* firstRef = nullptr; AstExprGlobal* firstRef = nullptr;
@ -237,6 +277,9 @@ private:
bool assigned = false; bool assigned = false;
bool builtin = false; bool builtin = false;
bool definedInModuleScope = false;
bool definedAsFunction = false;
bool readBeforeWritten = false;
std::optional<const char*> deprecated; std::optional<const char*> deprecated;
}; };
@ -244,7 +287,8 @@ private:
DenseHashMap<AstName, Global> globals; DenseHashMap<AstName, Global> globals;
std::vector<AstExprGlobal*> globalRefs; std::vector<AstExprGlobal*> globalRefs;
std::vector<AstExprFunction*> functionStack; std::vector<FunctionInfo> functionStack;
LintGlobalLocal() LintGlobalLocal()
: globals(AstName()) : globals(AstName())
@ -262,9 +306,9 @@ private:
emitWarning(*context, LintWarning::Code_UnknownGlobal, gv->location, "Unknown global '%s'", gv->name.value); emitWarning(*context, LintWarning::Code_UnknownGlobal, gv->location, "Unknown global '%s'", gv->name.value);
else if (g->deprecated) else if (g->deprecated)
{ {
if (*g->deprecated) if (const char* replacement = *g->deprecated; replacement && strlen(replacement))
emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead",
gv->name.value, *g->deprecated); gv->name.value, replacement);
else else
emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value);
} }
@ -287,12 +331,18 @@ private:
"Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local", "Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local",
g.firstRef->name.value, top->location.begin.line + 1); g.firstRef->name.value, top->location.begin.line + 1);
} }
else if (FFlag::LuauLintGlobalNeverReadBeforeWritten && g.assigned && !g.readBeforeWritten && !g.definedInModuleScope &&
g.firstRef->name != context->placeholder)
{
emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location,
"Global '%s' is never read before being written. Consider changing it to local", g.firstRef->name.value);
}
} }
} }
bool visit(AstExprFunction* node) override bool visit(AstExprFunction* node) override
{ {
functionStack.push_back(node); functionStack.emplace_back(node);
node->body->visit(this); node->body->visit(this);
@ -303,6 +353,11 @@ private:
bool visit(AstExprGlobal* node) override bool visit(AstExprGlobal* node) override
{ {
if (FFlag::LuauLintGlobalNeverReadBeforeWritten && !functionStack.empty() && !functionStack.back().dominatedGlobals.contains(node->name))
{
Global& g = globals[node->name];
g.readBeforeWritten = true;
}
trackGlobalRef(node); trackGlobalRef(node);
if (node->name == context->placeholder) if (node->name == context->placeholder)
@ -331,6 +386,21 @@ private:
{ {
Global& g = globals[gv->name]; Global& g = globals[gv->name];
if (FFlag::LuauLintGlobalNeverReadBeforeWritten)
{
if (functionStack.empty())
{
g.definedInModuleScope = true;
}
else
{
if (!functionStack.back().conditionalExecution)
{
functionStack.back().dominatedGlobals.insert(gv->name);
}
}
}
if (g.builtin) if (g.builtin)
emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location,
"Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value);
@ -365,7 +435,14 @@ private:
emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location,
"Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value);
else else
{
g.assigned = true; g.assigned = true;
if (FFlag::LuauLintGlobalNeverReadBeforeWritten)
{
g.definedAsFunction = true;
g.definedInModuleScope = functionStack.empty();
}
}
trackGlobalRef(gv); trackGlobalRef(gv);
} }
@ -373,6 +450,98 @@ private:
return true; return true;
} }
class HoldConditionalExecution
{
public:
HoldConditionalExecution(LintGlobalLocal& p)
: p(p)
{
if (!p.functionStack.empty() && !p.functionStack.back().conditionalExecution)
{
resetToFalse = true;
p.functionStack.back().conditionalExecution = true;
}
}
~HoldConditionalExecution()
{
if (resetToFalse)
p.functionStack.back().conditionalExecution = false;
}
private:
bool resetToFalse = false;
LintGlobalLocal& p;
};
bool visit(AstStatIf* node) override
{
if (!FFlag::LuauLintGlobalNeverReadBeforeWritten)
return true;
HoldConditionalExecution ce(*this);
node->condition->visit(this);
node->thenbody->visit(this);
if (node->elsebody)
node->elsebody->visit(this);
return false;
}
bool visit(AstStatWhile* node) override
{
if (!FFlag::LuauLintGlobalNeverReadBeforeWritten)
return true;
HoldConditionalExecution ce(*this);
node->condition->visit(this);
node->body->visit(this);
return false;
}
bool visit(AstStatRepeat* node) override
{
if (!FFlag::LuauLintGlobalNeverReadBeforeWritten)
return true;
HoldConditionalExecution ce(*this);
node->condition->visit(this);
node->body->visit(this);
return false;
}
bool visit(AstStatFor* node) override
{
if (!FFlag::LuauLintGlobalNeverReadBeforeWritten)
return true;
HoldConditionalExecution ce(*this);
node->from->visit(this);
node->to->visit(this);
if (node->step)
node->step->visit(this);
node->body->visit(this);
return false;
}
bool visit(AstStatForIn* node) override
{
if (!FFlag::LuauLintGlobalNeverReadBeforeWritten)
return true;
HoldConditionalExecution ce(*this);
for (AstExpr* expr : node->values)
expr->visit(this);
node->body->visit(this);
return false;
}
void trackGlobalRef(AstExprGlobal* node) void trackGlobalRef(AstExprGlobal* node)
{ {
Global& g = globals[node->name]; Global& g = globals[node->name];
@ -386,7 +555,12 @@ private:
// to reduce the cost of tracking we only track this for user globals // to reduce the cost of tracking we only track this for user globals
if (!g.builtin) if (!g.builtin)
{ {
g.functionRef = functionStack; g.functionRef.clear();
g.functionRef.reserve(functionStack.size());
for (const FunctionInfo& entry : functionStack)
{
g.functionRef.push_back(entry.ast);
}
} }
} }
else else
@ -397,7 +571,7 @@ private:
// we need to find a common prefix between all uses of a global // we need to find a common prefix between all uses of a global
size_t prefix = 0; size_t prefix = 0;
while (prefix < g.functionRef.size() && prefix < functionStack.size() && g.functionRef[prefix] == functionStack[prefix]) while (prefix < g.functionRef.size() && prefix < functionStack.size() && g.functionRef[prefix] == functionStack[prefix].ast)
prefix++; prefix++;
g.functionRef.resize(prefix); g.functionRef.resize(prefix);
@ -732,13 +906,13 @@ private:
bool visit(AstTypeReference* node) override bool visit(AstTypeReference* node) override
{ {
if (!node->hasPrefix) if (!node->prefix)
return true; return true;
if (!imports.contains(node->prefix)) if (!imports.contains(*node->prefix))
return true; return true;
AstLocal* astLocal = imports[node->prefix]; AstLocal* astLocal = imports[*node->prefix];
Local& local = locals[astLocal]; Local& local = locals[astLocal];
LUAU_ASSERT(local.import); LUAU_ASSERT(local.import);
local.used = true; local.used = true;
@ -982,25 +1156,12 @@ private:
enum TypeKind enum TypeKind
{ {
Kind_Invalid, Kind_Unknown,
Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata. Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata.
Kind_Vector, // For 'vector' but only used when type is used Kind_Vector, // 'vector' but only used when type is used
Kind_Userdata, // custom userdata type - Vector3/etc. Kind_Userdata, // custom userdata type
Kind_Class, // custom userdata type that reflects Roblox Instance-derived hierarchy - Part/etc.
Kind_Enum, // custom userdata type referring to an enum item of enum classes, e.g. Enum.NormalId.Back/Enum.Axis.X/etc.
}; };
bool containsPropName(TypeId ty, const std::string& propName)
{
if (auto ctv = get<ClassTypeVar>(ty))
return lookupClassProp(ctv, propName) != nullptr;
if (auto ttv = get<TableTypeVar>(ty))
return ttv->props.find(propName) != ttv->props.end();
return false;
}
TypeKind getTypeKind(const std::string& name) TypeKind getTypeKind(const std::string& name)
{ {
if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" || if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" ||
@ -1011,12 +1172,9 @@ private:
return Kind_Vector; return Kind_Vector;
if (std::optional<TypeFun> maybeTy = context->scope->lookupType(name)) if (std::optional<TypeFun> maybeTy = context->scope->lookupType(name))
// Kind_Userdata is probably not 100% precise but is close enough return Kind_Userdata;
return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata;
else if (std::optional<TypeFun> maybeTy = context->scope->lookupImportedType("Enum", name))
return Kind_Enum;
return Kind_Invalid; return Kind_Unknown;
} }
void validateType(AstExprConstantString* expr, std::initializer_list<TypeKind> expected, const char* expectedString) void validateType(AstExprConstantString* expr, std::initializer_list<TypeKind> expected, const char* expectedString)
@ -1024,7 +1182,7 @@ private:
std::string name(expr->value.data, expr->value.size); std::string name(expr->value.data, expr->value.size);
TypeKind kind = getTypeKind(name); TypeKind kind = getTypeKind(name);
if (kind == Kind_Invalid) if (kind == Kind_Unknown)
{ {
emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s'", name.c_str()); emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s'", name.c_str());
return; return;
@ -1034,61 +1192,11 @@ private:
{ {
if (kind == ek) if (kind == ek)
return; return;
// as a special case, Instance and EnumItem are both a userdata type (as returned by typeof) and a class type
if (ek == Kind_Userdata && (name == "Instance" || name == "EnumItem"))
return;
} }
emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s' (expected %s)", name.c_str(), expectedString); emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s' (expected %s)", name.c_str(), expectedString);
} }
bool acceptsClassName(AstName method)
{
return method.value[0] == 'F' && (method == "FindFirstChildOfClass" || method == "FindFirstChildWhichIsA" ||
method == "FindFirstAncestorOfClass" || method == "FindFirstAncestorWhichIsA");
}
bool visit(AstExprCall* node) override
{
if (AstExprIndexName* index = node->func->as<AstExprIndexName>())
{
AstExprConstantString* arg0 = node->args.size > 0 ? node->args.data[0]->as<AstExprConstantString>() : NULL;
if (arg0)
{
if (node->self && index->index == "IsA" && node->args.size == 1)
{
validateType(arg0, {Kind_Class, Kind_Enum}, "class or enum type");
}
else if (node->self && (index->index == "GetService" || index->index == "FindService") && node->args.size == 1)
{
AstExprGlobal* g = index->expr->as<AstExprGlobal>();
if (g && (g->name == "game" || g->name == "Game"))
{
validateType(arg0, {Kind_Class}, "class type");
}
}
else if (node->self && acceptsClassName(index->index) && node->args.size == 1)
{
validateType(arg0, {Kind_Class}, "class type");
}
else if (!node->self && index->index == "new" && node->args.size <= 2)
{
AstExprGlobal* g = index->expr->as<AstExprGlobal>();
if (g && g->name == "Instance")
{
validateType(arg0, {Kind_Class}, "class type");
}
}
}
}
return true;
}
bool visit(AstExprBinary* node) override bool visit(AstExprBinary* node) override
{ {
if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq) if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq)
@ -1108,10 +1216,7 @@ private:
if (g && g->name == "type") if (g && g->name == "type")
{ {
if (FFlag::LuauLinterUnknownTypeVectorAware) validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type");
validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type");
else
validateType(arg, {Kind_Primitive}, "primitive type");
} }
else if (g && g->name == "typeof") else if (g && g->name == "typeof")
{ {
@ -1349,7 +1454,7 @@ private:
const char* checkStringFormat(const char* data, size_t size) const char* checkStringFormat(const char* data, size_t size)
{ {
const char* flags = "-+ #0"; const char* flags = "-+ #0";
const char* options = "cdiouxXeEfgGqs"; const char* options = "cdiouxXeEfgGqs*";
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
{ {
@ -2044,18 +2149,28 @@ private:
const Property* prop = lookupClassProp(cty, node->index.value); const Property* prop = lookupClassProp(cty, node->index.value);
if (prop && prop->deprecated) if (prop && prop->deprecated)
{ report(node->location, *prop, cty->name.c_str(), node->index.value);
if (!prop->deprecatedSuggestion.empty()) }
emitWarning(*context, LintWarning::Code_DeprecatedApi, node->location, "Member '%s.%s' is deprecated, use '%s' instead", else if (const TableTypeVar* tty = get<TableTypeVar>(follow(*ty)))
cty->name.c_str(), node->index.value, prop->deprecatedSuggestion.c_str()); {
else auto prop = tty->props.find(node->index.value);
emitWarning(*context, LintWarning::Code_DeprecatedApi, node->location, "Member '%s.%s' is deprecated", cty->name.c_str(),
node->index.value); if (prop != tty->props.end() && prop->second.deprecated)
} report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value);
} }
return true; return true;
} }
void report(const Location& location, const Property& prop, const char* container, const char* field)
{
std::string suggestion = prop.deprecatedSuggestion.empty() ? "" : format(", use '%s' instead", prop.deprecatedSuggestion.c_str());
if (container)
emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s.%s' is deprecated%s", container, field, suggestion.c_str());
else
emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s' is deprecated%s", field, suggestion.c_str());
}
}; };
class LintTableOperations : AstVisitor class LintTableOperations : AstVisitor
@ -2144,6 +2259,32 @@ private:
"wrap it in parentheses to silence"); "wrap it in parentheses to silence");
} }
if (func->index == "move" && node->args.size >= 4)
{
// table.move(t, 0, _, _)
if (isConstant(args[1], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
// table.move(t, _, _, 0)
else if (isConstant(args[3], 0.0))
emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location,
"table.move uses index 0 but arrays are 1-based; did you mean 1 instead?");
}
if (func->index == "create" && node->args.size == 2)
{
// table.create(n, {...})
if (args[1]->is<AstExprTable>())
emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location,
"table.create with a table literal will reuse the same object for all elements; consider using a for loop instead");
// table.create(n, {...} :: ?)
if (AstExprTypeAssertion* as = args[1]->as<AstExprTypeAssertion>(); as && as->expr->is<AstExprTable>())
emitWarning(*context, LintWarning::Code_TableOperations, as->expr->location,
"table.create with a table literal will reuse the same object for all elements; consider using a for loop instead");
}
return true; return true;
} }
@ -2162,7 +2303,7 @@ private:
size_t getReturnCount(TypeId ty) size_t getReturnCount(TypeId ty)
{ {
if (auto ftv = get<FunctionTypeVar>(ty)) if (auto ftv = get<FunctionTypeVar>(ty))
return size(ftv->retType); return size(ftv->retTypes);
if (auto itv = get<IntersectionTypeVar>(ty)) if (auto itv = get<IntersectionTypeVar>(ty))
{ {
@ -2171,7 +2312,7 @@ private:
for (TypeId part : itv->parts) for (TypeId part : itv->parts)
if (auto ftv = get<FunctionTypeVar>(follow(part))) if (auto ftv = get<FunctionTypeVar>(follow(part)))
result = std::max(result, size(ftv->retType)); result = std::max(result, size(ftv->retTypes));
return result; return result;
} }
@ -2235,6 +2376,39 @@ private:
return false; return false;
} }
bool visit(AstExprIfElse* expr) override
{
if (!expr->falseExpr->is<AstExprIfElse>())
return true;
// if..elseif chain detected, we need to unroll it
std::vector<AstExpr*> conditions;
conditions.reserve(2);
AstExprIfElse* head = expr;
while (head)
{
head->condition->visit(this);
head->trueExpr->visit(this);
conditions.push_back(head->condition);
if (head->falseExpr->is<AstExprIfElse>())
{
head = head->falseExpr->as<AstExprIfElse>();
continue;
}
head->falseExpr->visit(this);
break;
}
detectDuplicates(conditions);
// block recursive visits so that we only analyze each chain once
return false;
}
bool visit(AstExprBinary* expr) override bool visit(AstExprBinary* expr) override
{ {
if (expr->op != AstExprBinary::And && expr->op != AstExprBinary::Or) if (expr->op != AstExprBinary::And && expr->op != AstExprBinary::Or)
@ -2396,6 +2570,153 @@ private:
} }
}; };
class LintMisleadingAndOr : AstVisitor
{
public:
LUAU_NOINLINE static void process(LintContext& context)
{
LintMisleadingAndOr pass;
pass.context = &context;
context.root->visit(&pass);
}
private:
LintContext* context;
bool visit(AstExprBinary* node) override
{
if (node->op != AstExprBinary::Or)
return true;
AstExprBinary* and_ = node->left->as<AstExprBinary>();
if (!and_ || and_->op != AstExprBinary::And)
return true;
const char* alt = nullptr;
if (and_->right->is<AstExprConstantNil>())
alt = "nil";
else if (AstExprConstantBool* c = and_->right->as<AstExprConstantBool>(); c && c->value == false)
alt = "false";
if (alt)
emitWarning(*context, LintWarning::Code_MisleadingAndOr, node->location,
"The and-or expression always evaluates to the second alternative because the first alternative is %s; consider using if-then-else "
"expression instead",
alt);
return true;
}
};
class LintIntegerParsing : AstVisitor
{
public:
LUAU_NOINLINE static void process(LintContext& context)
{
LintIntegerParsing pass;
pass.context = &context;
context.root->visit(&pass);
}
private:
LintContext* context;
bool visit(AstExprConstantNumber* node) override
{
switch (node->parseResult)
{
case ConstantNumberParseResult::Ok:
case ConstantNumberParseResult::Malformed:
break;
case ConstantNumberParseResult::BinOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Binary number literal exceeded available precision and has been truncated to 2^64");
break;
case ConstantNumberParseResult::HexOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Hexadecimal number literal exceeded available precision and has been truncated to 2^64");
break;
case ConstantNumberParseResult::DoublePrefix:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Hexadecimal number literal has a double prefix, which will fail to parse in the future; remove the extra 0x to fix");
break;
}
return true;
}
};
class LintComparisonPrecedence : AstVisitor
{
public:
LUAU_NOINLINE static void process(LintContext& context)
{
LintComparisonPrecedence pass;
pass.context = &context;
context.root->visit(&pass);
}
private:
LintContext* context;
static bool isEquality(AstExprBinary::Op op)
{
return op == AstExprBinary::CompareNe || op == AstExprBinary::CompareEq;
}
static bool isComparison(AstExprBinary::Op op)
{
return op == AstExprBinary::CompareNe || op == AstExprBinary::CompareEq || op == AstExprBinary::CompareLt || op == AstExprBinary::CompareLe ||
op == AstExprBinary::CompareGt || op == AstExprBinary::CompareGe;
}
static bool isNot(AstExpr* node)
{
AstExprUnary* expr = node->as<AstExprUnary>();
return expr && expr->op == AstExprUnary::Not;
}
bool visit(AstExprBinary* node) override
{
if (!isComparison(node->op))
return true;
// not X == Y; we silence this for not X == not Y as it's likely an intentional boolean comparison
if (isNot(node->left) && !isNot(node->right))
{
std::string op = toString(node->op);
if (isEquality(node->op))
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location,
"not X %s Y is equivalent to (not X) %s Y; consider using X %s Y, or add parentheses to silence", op.c_str(), op.c_str(),
node->op == AstExprBinary::CompareEq ? "~=" : "==");
else
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location,
"not X %s Y is equivalent to (not X) %s Y; add parentheses to silence", op.c_str(), op.c_str());
}
else if (AstExprBinary* left = node->left->as<AstExprBinary>(); left && isComparison(left->op))
{
std::string lop = toString(left->op);
std::string rop = toString(node->op);
if (isEquality(left->op) || isEquality(node->op))
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location,
"X %s Y %s Z is equivalent to (X %s Y) %s Z; add parentheses to silence", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str());
else
emitWarning(*context, LintWarning::Code_ComparisonPrecedence, node->location,
"X %s Y %s Z is equivalent to (X %s Y) %s Z; did you mean X %s Y and Y %s Z?", lop.c_str(), rop.c_str(), lop.c_str(), rop.c_str(),
lop.c_str(), rop.c_str());
}
return true;
}
};
static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names, const ScopePtr& env) static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names, const ScopePtr& env)
{ {
ScopePtr current = env; ScopePtr current = env;
@ -2421,13 +2742,124 @@ static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names,
} }
} }
static const char* fuzzyMatch(std::string_view str, const char** array, size_t size)
{
if (FInt::LuauSuggestionDistance == 0)
return nullptr;
size_t bestDistance = FInt::LuauSuggestionDistance;
size_t bestMatch = size;
for (size_t i = 0; i < size; ++i)
{
size_t ed = editDistance(str, array[i]);
if (ed <= bestDistance)
{
bestDistance = ed;
bestMatch = i;
}
}
return bestMatch < size ? array[bestMatch] : nullptr;
}
static void lintComments(LintContext& context, const std::vector<HotComment>& hotcomments)
{
bool seenMode = false;
for (const HotComment& hc : hotcomments)
{
// We reserve --!<space> for various informational (non-directive) comments
if (hc.content.empty() || hc.content[0] == ' ' || hc.content[0] == '\t')
continue;
if (!hc.header)
{
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"Comment directive is ignored because it is placed after the first non-comment token");
}
else
{
size_t space = hc.content.find_first_of(" \t");
std::string_view first = std::string_view(hc.content).substr(0, space);
if (first == "nolint")
{
size_t notspace = hc.content.find_first_not_of(" \t", space);
if (space == std::string::npos || notspace == std::string::npos)
{
// disables all lints
}
else if (LintWarning::parseName(hc.content.c_str() + notspace) == LintWarning::Code_Unknown)
{
const char* rule = hc.content.c_str() + notspace;
// skip Unknown
if (const char* suggestion = fuzzyMatch(rule, kWarningNames + 1, LintWarning::Code__Count - 1))
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"nolint directive refers to unknown lint rule '%s'; did you mean '%s'?", rule, suggestion);
else
emitWarning(
context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule);
}
}
else if (first == "nocheck" || first == "nonstrict" || first == "strict")
{
if (space != std::string::npos)
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"Comment directive with the type checking mode has extra symbols at the end of the line");
else if (seenMode)
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"Comment directive with the type checking mode has already been used");
else
seenMode = true;
}
else if (first == "optimize")
{
size_t notspace = hc.content.find_first_not_of(" \t", space);
if (space == std::string::npos || notspace == std::string::npos)
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "optimize directive requires an optimization level");
else
{
const char* level = hc.content.c_str() + notspace;
if (strcmp(level, "0") && strcmp(level, "1") && strcmp(level, "2"))
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"optimize directive uses unknown optimization level '%s', 0..2 expected", level);
}
}
else
{
static const char* kHotComments[] = {
"nolint",
"nocheck",
"nonstrict",
"strict",
"optimize",
};
if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments)))
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'; did you mean '%s'?",
int(first.size()), first.data(), suggestion);
else
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()),
first.data());
}
}
}
}
void LintOptions::setDefaults() void LintOptions::setDefaults()
{ {
// By default, we enable all warnings // By default, we enable all warnings
warningMask = ~0ull; warningMask = ~0ull;
} }
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const LintOptions& options) std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module,
const std::vector<HotComment>& hotcomments, const LintOptions& options)
{ {
LintContext context; LintContext context;
@ -2500,6 +2932,18 @@ std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const Sc
if (context.warningEnabled(LintWarning::Code_DuplicateLocal)) if (context.warningEnabled(LintWarning::Code_DuplicateLocal))
LintDuplicateLocal::process(context); LintDuplicateLocal::process(context);
if (context.warningEnabled(LintWarning::Code_MisleadingAndOr))
LintMisleadingAndOr::process(context);
if (context.warningEnabled(LintWarning::Code_CommentDirective))
lintComments(context, hotcomments);
if (context.warningEnabled(LintWarning::Code_IntegerParsing))
LintIntegerParsing::process(context);
if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence))
LintComparisonPrecedence::process(context);
std::sort(context.result.begin(), context.result.end(), WarningComparator()); std::sort(context.result.begin(), context.result.end(), WarningComparator());
return context.result; return context.result;
@ -2521,23 +2965,30 @@ LintWarning::Code LintWarning::parseName(const char* name)
return Code_Unknown; return Code_Unknown;
} }
uint64_t LintWarning::parseMask(const std::vector<std::string>& hotcomments) uint64_t LintWarning::parseMask(const std::vector<HotComment>& hotcomments)
{ {
uint64_t result = 0; uint64_t result = 0;
for (const std::string& hc : hotcomments) for (const HotComment& hc : hotcomments)
{ {
if (hc.compare(0, 6, "nolint") != 0) if (!hc.header)
continue; continue;
std::string::size_type name = hc.find_first_not_of(" \t", 6); if (hc.content.compare(0, 6, "nolint") != 0)
continue;
size_t name = hc.content.find_first_not_of(" \t", 6);
// --!nolint disables everything // --!nolint disables everything
if (name == std::string::npos) if (name == std::string::npos)
return ~0ull; return ~0ull;
// --!nolint needs to be followed by a whitespace character
if (name == 6)
continue;
// --!nolint name disables the specific lint // --!nolint name disables the specific lint
LintWarning::Code code = LintWarning::parseName(hc.c_str() + name); LintWarning::Code code = LintWarning::parseName(hc.content.c_str() + name);
if (code != LintWarning::Code_Unknown) if (code != LintWarning::Code_Unknown)
result |= 1ull << int(code); result |= 1ull << int(code);

View file

@ -1,18 +1,24 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/Normalize.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/VisitTypeVar.h" #include "Luau/VisitTypeVar.h"
#include "Luau/Common.h"
#include <algorithm> #include <algorithm>
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false);
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauSubstitutionReentrant);
LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution);
LUAU_FASTFLAG(LuauSubstitutionFixMissingFields);
namespace Luau namespace Luau
{ {
@ -21,7 +27,7 @@ static bool contains(Position pos, Comment comment)
{ {
if (comment.location.contains(pos)) if (comment.location.contains(pos))
return true; return true;
else if (FFlag::LuauCaptureBrokenCommentSpans && comment.type == Lexeme::BrokenComment && else if (comment.type == Lexeme::BrokenComment &&
comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end
return true; return true;
else if (comment.type == Lexeme::Comment && comment.location.end == pos) else if (comment.type == Lexeme::Comment && comment.location.end == pos)
@ -52,411 +58,178 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos)
return contains(pos, *iter); return contains(pos, *iter);
} }
void TypeArena::clear() struct ClonePublicInterface : Substitution
{ {
typeVars.clear(); NotNull<SingletonTypes> singletonTypes;
typePacks.clear(); NotNull<Module> module;
}
TypeId TypeArena::addTV(TypeVar&& tv) ClonePublicInterface(const TxnLog* log, NotNull<SingletonTypes> singletonTypes, Module* module)
{ : Substitution(log, &module->interfaceTypes)
TypeId allocated = typeVars.allocate(std::move(tv)); , singletonTypes(singletonTypes)
, module(module)
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(TypeLevel level)
{
TypeId allocated = typeVars.allocate(FreeTypeVar{level});
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(std::initializer_list<TypeId> types)
{
TypePackId allocated = typePacks.allocate(TypePack{std::move(types)});
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(std::vector<TypeId> types)
{
TypePackId allocated = typePacks.allocate(TypePack{std::move(types)});
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(TypePack tp)
{
TypePackId allocated = typePacks.allocate(std::move(tp));
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(TypePackVar tp)
{
TypePackId allocated = typePacks.allocate(std::move(tp));
if (FFlag::DebugLuauTrackOwningArena)
asMutable(allocated)->owningArena = this;
return allocated;
}
using SeenTypes = std::unordered_map<TypeId, TypeId>;
using SeenTypePacks = std::unordered_map<TypePackId, TypePackId>;
TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType);
TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType);
namespace
{
struct TypePackCloner;
/*
* Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set.
* They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage.
*/
struct TypeCloner
{
TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks)
: dest(dest)
, typeId(typeId)
, seenTypes(seenTypes)
, seenTypePacks(seenTypePacks)
{ {
LUAU_ASSERT(module);
} }
TypeArena& dest; bool isDirty(TypeId ty) override
TypeId typeId;
SeenTypes& seenTypes;
SeenTypePacks& seenTypePacks;
bool* encounteredFreeType = nullptr;
template<typename T>
void defaultClone(const T& t);
void operator()(const Unifiable::Free& t);
void operator()(const Unifiable::Generic& t);
void operator()(const Unifiable::Bound<TypeId>& t);
void operator()(const Unifiable::Error& t);
void operator()(const PrimitiveTypeVar& t);
void operator()(const FunctionTypeVar& t);
void operator()(const TableTypeVar& t);
void operator()(const MetatableTypeVar& t);
void operator()(const ClassTypeVar& t);
void operator()(const AnyTypeVar& t);
void operator()(const UnionTypeVar& t);
void operator()(const IntersectionTypeVar& t);
void operator()(const LazyTypeVar& t);
};
struct TypePackCloner
{
TypeArena& dest;
TypePackId typePackId;
SeenTypes& seenTypes;
SeenTypePacks& seenTypePacks;
bool* encounteredFreeType = nullptr;
TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks)
: dest(dest)
, typePackId(typePackId)
, seenTypes(seenTypes)
, seenTypePacks(seenTypePacks)
{ {
if (ty->owningArena == &module->internalTypes)
return true;
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
return ftv->level.level != 0;
if (const TableTypeVar* ttv = get<TableTypeVar>(ty))
return ttv->level.level != 0;
return false;
} }
template<typename T> bool isDirty(TypePackId tp) override
void defaultClone(const T& t)
{ {
TypePackId cloned = dest.typePacks.allocate(t); return tp->owningArena == &module->internalTypes;
seenTypePacks[typePackId] = cloned;
} }
void operator()(const Unifiable::Free& t) TypeId clean(TypeId ty) override
{ {
if (encounteredFreeType) TypeId result = clone(ty);
*encounteredFreeType = true;
seenTypePacks[typePackId] = dest.typePacks.allocate(TypePackVar{Unifiable::Error{}}); if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(result))
ftv->level = TypeLevel{0, 0};
else if (TableTypeVar* ttv = getMutable<TableTypeVar>(result))
ttv->level = TypeLevel{0, 0};
return result;
} }
void operator()(const Unifiable::Generic& t) TypePackId clean(TypePackId tp) override
{ {
defaultClone(t); return clone(tp);
}
void operator()(const Unifiable::Error& t)
{
defaultClone(t);
} }
// While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. TypeId cloneType(TypeId ty)
// We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer.
void operator()(const Unifiable::Bound<TypePackId>& t)
{ {
TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields);
seenTypePacks[typePackId] = cloned;
}
void operator()(const VariadicTypePack& t) std::optional<TypeId> result = substitute(ty);
{ if (result)
TypePackId cloned = dest.typePacks.allocate(VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)});
seenTypePacks[typePackId] = cloned;
}
void operator()(const TypePack& t)
{
TypePackId cloned = dest.typePacks.allocate(TypePack{});
TypePack* destTp = getMutable<TypePack>(cloned);
LUAU_ASSERT(destTp != nullptr);
seenTypePacks[typePackId] = cloned;
for (TypeId ty : t.head)
destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType));
if (t.tail)
destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, encounteredFreeType);
}
};
template<typename T>
void TypeCloner::defaultClone(const T& t)
{
TypeId cloned = dest.typeVars.allocate(t);
seenTypes[typeId] = cloned;
}
void TypeCloner::operator()(const Unifiable::Free& t)
{
if (encounteredFreeType)
*encounteredFreeType = true;
seenTypes[typeId] = dest.typeVars.allocate(ErrorTypeVar{});
}
void TypeCloner::operator()(const Unifiable::Generic& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const Unifiable::Bound<TypeId>& t)
{
TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
seenTypes[typeId] = boundTo;
}
void TypeCloner::operator()(const Unifiable::Error& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const PrimitiveTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const FunctionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(result);
LUAU_ASSERT(ftv != nullptr);
seenTypes[typeId] = result;
for (TypeId generic : t.generics)
ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, encounteredFreeType));
for (TypePackId genericPack : t.genericPacks)
ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, encounteredFreeType));
if (FFlag::LuauSecondTypecheckKnowsTheDataModel)
ftv->tags = t.tags;
ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, encounteredFreeType);
ftv->argNames = t.argNames;
ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, encounteredFreeType);
}
void TypeCloner::operator()(const TableTypeVar& t)
{
TypeId result = dest.typeVars.allocate(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(result);
LUAU_ASSERT(ttv != nullptr);
*ttv = t;
seenTypes[typeId] = result;
ttv->level = TypeLevel{0, 0};
for (const auto& [name, prop] : t.props)
{
if (FFlag::LuauSecondTypecheckKnowsTheDataModel)
ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags};
else
ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location};
}
if (t.indexer)
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType),
clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)};
if (t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = (clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType));
if (ttv->state == TableState::Free)
{
if (!t.boundTo)
{ {
if (encounteredFreeType) return *result;
*encounteredFreeType = true; }
else
{
module->errors.push_back(TypeError{module->scopes[0].first, UnificationTooComplex{}});
return singletonTypes->errorRecoveryType();
}
}
TypePackId cloneTypePack(TypePackId tp)
{
LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields);
std::optional<TypePackId> result = substitute(tp);
if (result)
{
return *result;
}
else
{
module->errors.push_back(TypeError{module->scopes[0].first, UnificationTooComplex{}});
return singletonTypes->errorRecoveryTypePack();
}
}
TypeFun cloneTypeFun(const TypeFun& tf)
{
LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields);
std::vector<GenericTypeDefinition> typeParams;
std::vector<GenericTypePackDefinition> typePackParams;
for (GenericTypeDefinition typeParam : tf.typeParams)
{
TypeId ty = cloneType(typeParam.ty);
std::optional<TypeId> defaultValue;
if (typeParam.defaultValue)
defaultValue = cloneType(*typeParam.defaultValue);
typeParams.push_back(GenericTypeDefinition{ty, defaultValue});
} }
ttv->state = TableState::Sealed; for (GenericTypePackDefinition typePackParam : tf.typePackParams)
{
TypePackId tp = cloneTypePack(typePackParam.tp);
std::optional<TypePackId> defaultValue;
if (typePackParam.defaultValue)
defaultValue = cloneTypePack(*typePackParam.defaultValue);
typePackParams.push_back(GenericTypePackDefinition{tp, defaultValue});
}
TypeId type = cloneType(tf.type);
return TypeFun{typeParams, typePackParams, type};
} }
};
ttv->definitionModuleName = t.definitionModuleName; Module::~Module()
ttv->methodDefinitionLocations = t.methodDefinitionLocations; {
ttv->tags = t.tags; unfreeze(interfaceTypes);
unfreeze(internalTypes);
} }
void TypeCloner::operator()(const MetatableTypeVar& t) void Module::clonePublicInterface(NotNull<SingletonTypes> singletonTypes, InternalErrorReporter& ice)
{ {
TypeId result = dest.typeVars.allocate(MetatableTypeVar{}); LUAU_ASSERT(interfaceTypes.typeVars.empty());
MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(result); LUAU_ASSERT(interfaceTypes.typePacks.empty());
seenTypes[typeId] = result;
mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, encounteredFreeType); CloneState cloneState;
mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType);
}
void TypeCloner::operator()(const ClassTypeVar& t) ScopePtr moduleScope = getModuleScope();
{
TypeId result = dest.typeVars.allocate(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData});
ClassTypeVar* ctv = getMutable<ClassTypeVar>(result);
seenTypes[typeId] = result; TypePackId returnType = moduleScope->returnType;
std::optional<TypePackId> varargPack = FFlag::DebugLuauDeferredConstraintResolution ? std::nullopt : moduleScope->varargPack;
std::unordered_map<Name, TypeFun>* exportedTypeBindings = &moduleScope->exportedTypeBindings;
for (const auto& [name, prop] : t.props) TxnLog log;
if (FFlag::LuauSecondTypecheckKnowsTheDataModel) ClonePublicInterface clonePublicInterface{&log, singletonTypes, this};
ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags};
if (FFlag::LuauClonePublicInterfaceLess)
returnType = clonePublicInterface.cloneTypePack(returnType);
else
returnType = clone(returnType, interfaceTypes, cloneState);
moduleScope->returnType = returnType;
if (varargPack)
{
if (FFlag::LuauClonePublicInterfaceLess)
varargPack = clonePublicInterface.cloneTypePack(*varargPack);
else else
ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location}; varargPack = clone(*varargPack, interfaceTypes, cloneState);
moduleScope->varargPack = varargPack;
if (t.parent)
ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, encounteredFreeType);
if (t.metatable)
ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType);
}
void TypeCloner::operator()(const AnyTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const UnionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(UnionTypeVar{});
seenTypes[typeId] = result;
UnionTypeVar* option = getMutable<UnionTypeVar>(result);
LUAU_ASSERT(option != nullptr);
for (TypeId ty : t.options)
option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType));
}
void TypeCloner::operator()(const IntersectionTypeVar& t)
{
TypeId result = dest.typeVars.allocate(IntersectionTypeVar{});
seenTypes[typeId] = result;
IntersectionTypeVar* option = getMutable<IntersectionTypeVar>(result);
LUAU_ASSERT(option != nullptr);
for (TypeId ty : t.parts)
option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType));
}
void TypeCloner::operator()(const LazyTypeVar& t)
{
defaultClone(t);
}
} // anonymous namespace
TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType)
{
if (tp->persistent)
return tp;
TypePackId& res = seenTypePacks[tp];
if (res == nullptr)
{
TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks};
cloner.encounteredFreeType = encounteredFreeType;
Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into.
} }
if (FFlag::DebugLuauTrackOwningArena) if (exportedTypeBindings)
asMutable(res)->owningArena = &dest;
return res;
}
TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType)
{
if (typeId->persistent)
return typeId;
TypeId& res = seenTypes[typeId];
if (res == nullptr)
{ {
TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks}; for (auto& [name, tf] : *exportedTypeBindings)
cloner.encounteredFreeType = encounteredFreeType; {
Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. if (FFlag::LuauClonePublicInterfaceLess)
asMutable(res)->documentationSymbol = typeId->documentationSymbol; tf = clonePublicInterface.cloneTypeFun(tf);
else
tf = clone(tf, interfaceTypes, cloneState);
}
} }
if (FFlag::DebugLuauTrackOwningArena) for (auto& [name, ty] : declaredGlobals)
asMutable(res)->owningArena = &dest; {
if (FFlag::LuauClonePublicInterfaceLess)
ty = clonePublicInterface.cloneType(ty);
else
ty = clone(ty, interfaceTypes, cloneState);
}
return res; freeze(internalTypes);
} freeze(interfaceTypes);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType)
{
TypeFun result;
for (TypeId param : typeFun.typeParams)
result.typeParams.push_back(clone(param, dest, seenTypes, seenTypePacks, encounteredFreeType));
result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType);
return result;
} }
ScopePtr Module::getModuleScope() const ScopePtr Module::getModuleScope() const
@ -465,57 +238,4 @@ ScopePtr Module::getModuleScope() const
return scopes.front().second; return scopes.front().second;
} }
void freeze(TypeArena& arena)
{
if (!FFlag::DebugLuauFreezeArena)
return;
arena.typeVars.freeze();
arena.typePacks.freeze();
}
void unfreeze(TypeArena& arena)
{
if (!FFlag::DebugLuauFreezeArena)
return;
arena.typeVars.unfreeze();
arena.typePacks.unfreeze();
}
Module::~Module()
{
unfreeze(interfaceTypes);
unfreeze(internalTypes);
}
bool Module::clonePublicInterface()
{
LUAU_ASSERT(interfaceTypes.typeVars.empty());
LUAU_ASSERT(interfaceTypes.typePacks.empty());
bool encounteredFreeType = false;
SeenTypePacks seenTypePacks;
SeenTypes seenTypes;
ScopePtr moduleScope = getModuleScope();
moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType);
if (moduleScope->varargPack)
moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType);
for (auto& pair : moduleScope->exportedTypeBindings)
pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType);
for (TypeId ty : moduleScope->returnType)
if (get<GenericTypeVar>(follow(ty)))
*asMutable(ty) = AnyTypeVar{};
freeze(internalTypes);
freeze(interfaceTypes);
return encounteredFreeType;
}
} // namespace Luau } // namespace Luau

2194
Analysis/src/Normalize.cpp Normal file

File diff suppressed because it is too large Load diff

View file

@ -1,93 +0,0 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Predicate.h"
#include "Luau/Ast.h"
LUAU_FASTFLAG(LuauOrPredicate)
namespace Luau
{
std::optional<LValue> tryGetLValue(const AstExpr& node)
{
const AstExpr* expr = &node;
while (auto e = expr->as<AstExprGroup>())
expr = e->expr;
if (auto local = expr->as<AstExprLocal>())
return Symbol{local->local};
else if (auto global = expr->as<AstExprGlobal>())
return Symbol{global->name};
else if (auto indexname = expr->as<AstExprIndexName>())
{
if (auto lvalue = tryGetLValue(*indexname->expr))
return Field{std::make_shared<LValue>(*lvalue), indexname->index.value};
}
else if (auto indexexpr = expr->as<AstExprIndexExpr>())
{
if (auto lvalue = tryGetLValue(*indexexpr->expr))
if (auto string = indexexpr->expr->as<AstExprConstantString>())
return Field{std::make_shared<LValue>(*lvalue), std::string(string->value.data, string->value.size)};
}
return std::nullopt;
}
std::pair<Symbol, std::vector<std::string>> getFullName(const LValue& lvalue)
{
const LValue* current = &lvalue;
std::vector<std::string> keys;
while (auto field = get<Field>(*current))
{
keys.push_back(field->key);
current = field->parent.get();
if (!current)
LUAU_ASSERT(!"LValue root is a Field?");
}
const Symbol* symbol = get<Symbol>(*current);
return {*symbol, std::vector<std::string>(keys.rbegin(), keys.rend())};
}
std::string toString(const LValue& lvalue)
{
auto [symbol, keys] = getFullName(lvalue);
std::string s = toString(symbol);
for (std::string key : keys)
s += "." + key;
return s;
}
void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f)
{
LUAU_ASSERT(FFlag::LuauOrPredicate);
auto itL = l.begin();
auto itR = r.begin();
while (itL != l.end() && itR != r.end())
{
const auto& [k, a] = *itR;
if (itL->first == k)
{
l[k] = f(itL->second, a);
++itL;
++itR;
}
else if (itL->first > k)
{
l[k] = a;
++itR;
}
else
++itL;
}
l.insert(itR, r.end());
}
void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty)
{
refis[toString(lvalue)] = ty;
}
} // namespace Luau

259
Analysis/src/Quantify.cpp Normal file
View file

@ -0,0 +1,259 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Quantify.h"
#include "Luau/Scope.h"
#include "Luau/Substitution.h"
#include "Luau/TxnLog.h"
#include "Luau/TypeVar.h"
#include "Luau/VisitTypeVar.h"
LUAU_FASTFLAG(DebugLuauSharedSelf)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau
{
struct Quantifier final : TypeVarOnceVisitor
{
TypeLevel level;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
Scope* scope = nullptr;
bool seenGenericType = false;
bool seenMutableType = false;
explicit Quantifier(TypeLevel level)
: level(level)
{
LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution);
}
/// @return true if outer encloses inner
bool subsumes(Scope* outer, Scope* inner)
{
while (inner)
{
if (inner == outer)
return true;
inner = inner->parent.get();
}
return false;
}
bool visit(TypeId ty, const FreeTypeVar& ftv) override
{
seenMutableType = true;
if (!level.subsumes(ftv.level))
return false;
*asMutable(ty) = GenericTypeVar{level};
generics.push_back(ty);
return false;
}
bool visit(TypeId ty, const TableTypeVar&) override
{
LUAU_ASSERT(getMutable<TableTypeVar>(ty));
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
if (ttv.state == TableState::Generic)
seenGenericType = true;
if (ttv.state == TableState::Free)
seenMutableType = true;
if (!level.subsumes(ttv.level))
{
if (ttv.state == TableState::Unsealed)
seenMutableType = true;
return false;
}
if (ttv.state == TableState::Free)
{
ttv.state = TableState::Generic;
seenGenericType = true;
}
else if (ttv.state == TableState::Unsealed)
ttv.state = TableState::Sealed;
ttv.level = level;
return true;
}
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
seenMutableType = true;
if (!level.subsumes(ftp.level))
return false;
*asMutable(tp) = GenericTypePack{level};
genericPacks.push_back(tp);
return true;
}
};
void quantify(TypeId ty, TypeLevel level)
{
if (FFlag::DebugLuauSharedSelf)
{
ty = follow(ty);
if (auto ttv = getTableType(ty); ttv && ttv->selfTy)
{
Quantifier selfQ{level};
selfQ.traverse(*ttv->selfTy);
Quantifier q{level};
q.traverse(ty);
for (const auto& [_, prop] : ttv->props)
{
auto ftv = getMutable<FunctionTypeVar>(follow(prop.type));
if (!ftv || !ftv->hasSelf)
continue;
if (Luau::first(ftv->argTypes) == ttv->selfTy)
{
ftv->generics.insert(ftv->generics.end(), selfQ.generics.begin(), selfQ.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), selfQ.genericPacks.begin(), selfQ.genericPacks.end());
}
}
}
else if (auto ftv = getMutable<FunctionTypeVar>(ty))
{
Quantifier q{level};
q.traverse(ty);
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType)
ftv->hasNoGenerics = true;
}
}
else
{
Quantifier q{level};
q.traverse(ty);
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv);
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
}
}
struct PureQuantifier : Substitution
{
Scope* scope;
std::vector<TypeId> insertedGenerics;
std::vector<TypePackId> insertedGenericPacks;
PureQuantifier(TypeArena* arena, Scope* scope)
: Substitution(TxnLog::empty(), arena)
, scope(scope)
{
}
bool isDirty(TypeId ty) override
{
LUAU_ASSERT(ty == follow(ty));
if (auto ftv = get<FreeTypeVar>(ty))
{
return subsumes(scope, ftv->scope);
}
else if (auto ttv = get<TableTypeVar>(ty))
{
return ttv->state == TableState::Free && subsumes(scope, ttv->scope);
}
return false;
}
bool isDirty(TypePackId tp) override
{
if (auto ftp = get<FreeTypePack>(tp))
{
return subsumes(scope, ftp->scope);
}
return false;
}
TypeId clean(TypeId ty) override
{
if (auto ftv = get<FreeTypeVar>(ty))
{
TypeId result = arena->addType(GenericTypeVar{scope});
insertedGenerics.push_back(result);
return result;
}
else if (auto ttv = get<TableTypeVar>(ty))
{
TypeId result = arena->addType(TableTypeVar{});
TableTypeVar* resultTable = getMutable<TableTypeVar>(result);
LUAU_ASSERT(resultTable);
*resultTable = *ttv;
resultTable->level = TypeLevel{};
resultTable->scope = scope;
resultTable->state = TableState::Generic;
return result;
}
return ty;
}
TypePackId clean(TypePackId tp) override
{
if (auto ftp = get<FreeTypePack>(tp))
{
TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{}});
insertedGenericPacks.push_back(result);
return result;
}
return tp;
}
bool ignoreChildren(TypeId ty) override
{
if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassTypeVar>(ty))
return true;
return ty->persistent;
}
bool ignoreChildren(TypePackId ty) override
{
return ty->persistent;
}
};
TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope)
{
PureQuantifier quantifier{arena, scope};
std::optional<TypeId> result = quantifier.substitute(ty);
LUAU_ASSERT(result);
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(*result);
LUAU_ASSERT(ftv);
ftv->scope = scope;
ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end());
ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty();
return *result;
}
} // namespace Luau

View file

@ -4,187 +4,163 @@
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Module.h" #include "Luau/Module.h"
LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false)
namespace Luau namespace Luau
{ {
namespace
{
struct RequireTracer : AstVisitor struct RequireTracer : AstVisitor
{ {
explicit RequireTracer(FileResolver* fileResolver, ModuleName currentModuleName) RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName)
: fileResolver(fileResolver) : result(result)
, currentModuleName(std::move(currentModuleName)) , fileResolver(fileResolver)
, currentModuleName(currentModuleName)
, locals(nullptr)
{ {
} }
FileResolver* const fileResolver;
ModuleName currentModuleName;
DenseHashMap<AstLocal*, ModuleName> locals{0};
RequireTraceResult result;
std::optional<ModuleName> fromAstFragment(AstExpr* expr)
{
if (auto g = expr->as<AstExprGlobal>(); g && g->name == "script")
return currentModuleName;
return fileResolver->fromAstFragment(expr);
}
bool visit(AstStatLocal* stat) override
{
for (size_t i = 0; i < stat->vars.size; ++i)
{
AstLocal* local = stat->vars.data[i];
if (local->annotation)
{
if (AstTypeTypeof* ann = local->annotation->as<AstTypeTypeof>())
ann->expr->visit(this);
}
if (i < stat->values.size)
{
AstExpr* expr = stat->values.data[i];
expr->visit(this);
const ModuleName* name = result.exprs.find(expr);
if (name)
locals[local] = *name;
}
}
return false;
}
bool visit(AstExprGlobal* global) override
{
std::optional<ModuleName> name = fromAstFragment(global);
if (name)
result.exprs[global] = *name;
return false;
}
bool visit(AstExprLocal* local) override
{
const ModuleName* name = locals.find(local->local);
if (name)
result.exprs[local] = *name;
return false;
}
bool visit(AstExprIndexName* indexName) override
{
indexName->expr->visit(this);
const ModuleName* name = result.exprs.find(indexName->expr);
if (name)
{
if (indexName->index == "parent" || indexName->index == "Parent")
{
if (auto parent = fileResolver->getParentModuleName(*name))
result.exprs[indexName] = *parent;
}
else
result.exprs[indexName] = fileResolver->concat(*name, indexName->index.value);
}
return false;
}
bool visit(AstExprIndexExpr* indexExpr) override
{
indexExpr->expr->visit(this);
const ModuleName* name = result.exprs.find(indexExpr->expr);
const AstExprConstantString* str = indexExpr->index->as<AstExprConstantString>();
if (name && str)
{
result.exprs[indexExpr] = fileResolver->concat(*name, std::string_view(str->value.data, str->value.size));
}
indexExpr->index->visit(this);
return false;
}
bool visit(AstExprTypeAssertion* expr) override bool visit(AstExprTypeAssertion* expr) override
{ {
// suppress `require() :: any`
return false; return false;
} }
// If we see game:GetService("StringLiteral") or Game:GetService("StringLiteral"), then rewrite to game.StringLiteral. bool visit(AstExprCall* expr) override
// Else traverse arguments and trace requires to them.
bool visit(AstExprCall* call) override
{ {
for (AstExpr* arg : call->args) AstExprGlobal* global = expr->func->as<AstExprGlobal>();
arg->visit(this);
call->func->visit(this); if (global && global->name == "require" && expr->args.size >= 1)
requireCalls.push_back(expr);
AstExprGlobal* globalName = call->func->as<AstExprGlobal>(); return true;
if (globalName && globalName->name == "require" && call->args.size >= 1)
{
if (const ModuleName* moduleName = result.exprs.find(call->args.data[0]))
result.requires.push_back({*moduleName, call->location});
return false;
}
AstExprIndexName* indexName = call->func->as<AstExprIndexName>();
if (!indexName)
return false;
std::optional<ModuleName> rootName = fromAstFragment(indexName->expr);
if (FFlag::LuauTraceRequireLookupChild && !rootName)
{
if (const ModuleName* moduleName = result.exprs.find(indexName->expr))
rootName = *moduleName;
}
if (!rootName)
return false;
bool supportedLookup = indexName->index == "GetService" ||
(FFlag::LuauTraceRequireLookupChild && (indexName->index == "FindFirstChild" || indexName->index == "WaitForChild"));
if (!supportedLookup)
return false;
if (call->args.size != 1)
return false;
AstExprConstantString* name = call->args.data[0]->as<AstExprConstantString>();
if (!name)
return false;
std::string_view v{name->value.data, name->value.size};
if (v.end() != std::find(v.begin(), v.end(), '/'))
return false;
result.exprs[call] = fileResolver->concat(*rootName, v);
// 'WaitForChild' can be used on modules that are not awailable at the typecheck time, but will be awailable at runtime
// If we fail to find such module, we will not report an UnknownRequire error
if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild")
result.optional[call] = true;
return false;
} }
bool visit(AstStatLocal* stat) override
{
for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i)
{
AstLocal* local = stat->vars.data[i];
AstExpr* expr = stat->values.data[i];
// track initializing expression to be able to trace modules through locals
locals[local] = expr;
}
return true;
}
bool visit(AstStatAssign* stat) override
{
for (size_t i = 0; i < stat->vars.size; ++i)
{
// locals that are assigned don't have a known expression
if (AstExprLocal* expr = stat->vars.data[i]->as<AstExprLocal>())
locals[expr->local] = nullptr;
}
return true;
}
bool visit(AstType* node) override
{
// allow resolving require inside `typeof` annotations
return true;
}
AstExpr* getDependent(AstExpr* node)
{
if (AstExprLocal* expr = node->as<AstExprLocal>())
return locals[expr->local];
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
return expr->expr;
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
return expr->expr;
else if (AstExprCall* expr = node->as<AstExprCall>(); expr && expr->self)
return expr->func->as<AstExprIndexName>()->expr;
else
return nullptr;
}
void process()
{
ModuleInfo moduleContext{currentModuleName};
// seed worklist with require arguments
work.reserve(requireCalls.size());
for (AstExprCall* require : requireCalls)
work.push_back(require->args.data[0]);
// push all dependent expressions to the work stack; note that the vector is modified during traversal
for (size_t i = 0; i < work.size(); ++i)
if (AstExpr* dep = getDependent(work[i]))
work.push_back(dep);
// resolve all expressions to a module info
for (size_t i = work.size(); i > 0; --i)
{
AstExpr* expr = work[i - 1];
// when multiple expressions depend on the same one we push it to work queue multiple times
if (result.exprs.contains(expr))
continue;
std::optional<ModuleInfo> info;
if (AstExpr* dep = getDependent(expr))
{
const ModuleInfo* context = result.exprs.find(dep);
// locals just inherit their dependent context, no resolution required
if (expr->is<AstExprLocal>())
info = context ? std::optional<ModuleInfo>(*context) : std::nullopt;
else
info = fileResolver->resolveModule(context, expr);
}
else
{
info = fileResolver->resolveModule(&moduleContext, expr);
}
if (info)
result.exprs[expr] = std::move(*info);
}
// resolve all requires according to their argument
result.requireList.reserve(requireCalls.size());
for (AstExprCall* require : requireCalls)
{
AstExpr* arg = require->args.data[0];
if (const ModuleInfo* info = result.exprs.find(arg))
{
result.requireList.push_back({info->name, require->location});
ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info!
result.exprs[require] = std::move(infoCopy);
}
else
{
result.exprs[require] = {}; // mark require as unresolved
}
}
}
RequireTraceResult& result;
FileResolver* fileResolver;
ModuleName currentModuleName;
DenseHashMap<AstLocal*, AstExpr*> locals;
std::vector<AstExpr*> work;
std::vector<AstExprCall*> requireCalls;
}; };
} // anonymous namespace RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName)
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName)
{ {
RequireTracer tracer{fileResolver, std::move(currentModuleName)}; RequireTraceResult result;
RequireTracer tracer{result, fileResolver, currentModuleName};
root->visit(&tracer); root->visit(&tracer);
return tracer.result; tracer.process();
return result;
} }
} // namespace Luau } // namespace Luau

170
Analysis/src/Scope.cpp Normal file
View file

@ -0,0 +1,170 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Scope.h"
namespace Luau
{
Scope::Scope(TypePackId returnType)
: parent(nullptr)
, returnType(returnType)
, level(TypeLevel())
{
}
Scope::Scope(const ScopePtr& parent, int subLevel)
: parent(parent)
, returnType(parent->returnType)
, level(parent->level.incr())
{
level = level.incr();
level.subLevel = subLevel;
}
void Scope::addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun)
{
exportedTypeBindings[name] = tyFun;
builtinTypeNames.insert(name);
}
std::optional<TypeId> Scope::lookup(Symbol sym) const
{
auto r = const_cast<Scope*>(this)->lookupEx(sym);
if (r)
return r->first;
else
return std::nullopt;
}
std::optional<std::pair<TypeId, Scope*>> Scope::lookupEx(Symbol sym)
{
Scope* s = this;
while (true)
{
auto it = s->bindings.find(sym);
if (it != s->bindings.end())
return std::pair{it->second.typeId, s};
if (s->parent)
s = s->parent.get();
else
return std::nullopt;
}
}
// TODO: We might kill Scope::lookup(Symbol) once data flow is fully fleshed out with type states and control flow analysis.
std::optional<TypeId> Scope::lookup(DefId def) const
{
for (const Scope* current = this; current; current = current->parent.get())
{
if (auto ty = current->dcrRefinements.find(def))
return *ty;
}
return std::nullopt;
}
std::optional<TypeFun> Scope::lookupType(const Name& name)
{
const Scope* scope = this;
while (true)
{
auto it = scope->exportedTypeBindings.find(name);
if (it != scope->exportedTypeBindings.end())
return it->second;
it = scope->privateTypeBindings.find(name);
if (it != scope->privateTypeBindings.end())
return it->second;
if (scope->parent)
scope = scope->parent.get();
else
return std::nullopt;
}
}
std::optional<TypeFun> Scope::lookupImportedType(const Name& moduleAlias, const Name& name)
{
const Scope* scope = this;
while (scope)
{
auto it = scope->importedTypeBindings.find(moduleAlias);
if (it == scope->importedTypeBindings.end())
{
scope = scope->parent.get();
continue;
}
auto it2 = it->second.find(name);
if (it2 == it->second.end())
{
scope = scope->parent.get();
continue;
}
return it2->second;
}
return std::nullopt;
}
std::optional<TypePackId> Scope::lookupPack(const Name& name)
{
const Scope* scope = this;
while (true)
{
auto it = scope->privateTypePackBindings.find(name);
if (it != scope->privateTypePackBindings.end())
return it->second;
if (scope->parent)
scope = scope->parent.get();
else
return std::nullopt;
}
}
std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) const
{
const Scope* scope = this;
while (scope)
{
for (const auto& [n, binding] : scope->bindings)
{
if (n.local && n.local->name == name.c_str())
return binding;
else if (n.global.value && n.global == name.c_str())
return binding;
}
scope = scope->parent.get();
if (!traverseScopeChain)
break;
}
return std::nullopt;
}
bool subsumesStrict(Scope* left, Scope* right)
{
while (right)
{
if (right->parent.get() == left)
return true;
right = right->parent.get();
}
return false;
}
bool subsumes(Scope* left, Scope* right)
{
return left == right || subsumesStrict(left, right);
}
} // namespace Luau

View file

@ -2,28 +2,44 @@
#include "Luau/Substitution.h" #include "Luau/Substitution.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/Clone.h"
#include "Luau/TxnLog.h"
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <stdexcept>
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 0) LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false)
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauClonePublicInterfaceLess)
LUAU_FASTFLAG(LuauRankNTypes) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000)
LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false)
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false)
namespace Luau namespace Luau
{ {
void Tarjan::visitChildren(TypeId ty, int index) void Tarjan::visitChildren(TypeId ty, int index)
{ {
ty = follow(ty); LUAU_ASSERT(ty == log->follow(ty));
if (FFlag::LuauRankNTypes && ignoreChildren(ty)) if (ignoreChildren(ty))
return; return;
if (auto pty = log->pending(ty))
ty = &pty->pending;
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty)) if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{ {
if (FFlag::LuauSubstitutionFixMissingFields)
{
for (TypeId generic : ftv->generics)
visitChild(generic);
for (TypePackId genericPack : ftv->genericPacks)
visitChild(genericPack);
}
visitChild(ftv->argTypes); visitChild(ftv->argTypes);
visitChild(ftv->retType); visitChild(ftv->retTypes);
} }
else if (const TableTypeVar* ttv = get<TableTypeVar>(ty)) else if (const TableTypeVar* ttv = get<TableTypeVar>(ty))
{ {
@ -35,8 +51,12 @@ void Tarjan::visitChildren(TypeId ty, int index)
visitChild(ttv->indexer->indexType); visitChild(ttv->indexer->indexType);
visitChild(ttv->indexer->indexResultType); visitChild(ttv->indexer->indexResultType);
} }
for (TypeId itp : ttv->instantiatedTypeParams) for (TypeId itp : ttv->instantiatedTypeParams)
visitChild(itp); visitChild(itp);
for (TypePackId itp : ttv->instantiatedTypePackParams)
visitChild(itp);
} }
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty)) else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{ {
@ -53,15 +73,41 @@ void Tarjan::visitChildren(TypeId ty, int index)
for (TypeId part : itv->parts) for (TypeId part : itv->parts)
visitChild(part); visitChild(part);
} }
else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty))
{
for (TypeId a : petv->typeArguments)
visitChild(a);
for (TypePackId a : petv->packArguments)
visitChild(a);
}
else if (const ClassTypeVar* ctv = get<ClassTypeVar>(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv)
{
for (auto [name, prop] : ctv->props)
visitChild(prop.type);
if (ctv->parent)
visitChild(*ctv->parent);
if (ctv->metatable)
visitChild(*ctv->metatable);
}
else if (const NegationTypeVar* ntv = get<NegationTypeVar>(ty))
{
visitChild(ntv->ty);
}
} }
void Tarjan::visitChildren(TypePackId tp, int index) void Tarjan::visitChildren(TypePackId tp, int index)
{ {
tp = follow(tp); LUAU_ASSERT(tp == log->follow(tp));
if (FFlag::LuauRankNTypes && ignoreChildren(tp)) if (ignoreChildren(tp))
return; return;
if (auto ptp = log->pending(tp))
tp = &ptp->pending;
if (const TypePack* tpp = get<TypePack>(tp)) if (const TypePack* tpp = get<TypePack>(tp))
{ {
for (TypeId tv : tpp->head) for (TypeId tv : tpp->head)
@ -77,7 +123,7 @@ void Tarjan::visitChildren(TypePackId tp, int index)
std::pair<int, bool> Tarjan::indexify(TypeId ty) std::pair<int, bool> Tarjan::indexify(TypeId ty)
{ {
ty = follow(ty); ty = log->follow(ty);
bool fresh = !typeToIndex.contains(ty); bool fresh = !typeToIndex.contains(ty);
int& index = typeToIndex[ty]; int& index = typeToIndex[ty];
@ -95,7 +141,7 @@ std::pair<int, bool> Tarjan::indexify(TypeId ty)
std::pair<int, bool> Tarjan::indexify(TypePackId tp) std::pair<int, bool> Tarjan::indexify(TypePackId tp)
{ {
tp = follow(tp); tp = log->follow(tp);
bool fresh = !packToIndex.contains(tp); bool fresh = !packToIndex.contains(tp);
int& index = packToIndex[tp]; int& index = packToIndex[tp];
@ -113,7 +159,7 @@ std::pair<int, bool> Tarjan::indexify(TypePackId tp)
void Tarjan::visitChild(TypeId ty) void Tarjan::visitChild(TypeId ty)
{ {
ty = follow(ty); ty = log->follow(ty);
edgesTy.push_back(ty); edgesTy.push_back(ty);
edgesTp.push_back(nullptr); edgesTp.push_back(nullptr);
@ -121,7 +167,7 @@ void Tarjan::visitChild(TypeId ty)
void Tarjan::visitChild(TypePackId tp) void Tarjan::visitChild(TypePackId tp)
{ {
tp = follow(tp); tp = log->follow(tp);
edgesTy.push_back(nullptr); edgesTy.push_back(nullptr);
edgesTp.push_back(tp); edgesTp.push_back(tp);
@ -138,7 +184,7 @@ TarjanResult Tarjan::loop()
if (currEdge == -1) if (currEdge == -1)
{ {
++childCount; ++childCount;
if (FInt::LuauTarjanChildLimit > 0 && FInt::LuauTarjanChildLimit < childCount) if (childLimit > 0 && (FFlag::LuauUnknownAndNeverType ? childLimit <= childCount : childLimit < childCount))
return TarjanResult::TooManyChildren; return TarjanResult::TooManyChildren;
stack.push_back(index); stack.push_back(index);
@ -223,27 +269,14 @@ TarjanResult Tarjan::loop()
return TarjanResult::Ok; return TarjanResult::Ok;
} }
void Tarjan::clear()
{
typeToIndex.clear();
indexToType.clear();
packToIndex.clear();
indexToPack.clear();
lowlink.clear();
stack.clear();
onStack.clear();
edgesTy.clear();
edgesTp.clear();
worklist.clear();
}
TarjanResult Tarjan::visitRoot(TypeId ty) TarjanResult Tarjan::visitRoot(TypeId ty)
{ {
childCount = 0; childCount = 0;
ty = follow(ty); if (childLimit == 0)
childLimit = FInt::LuauTarjanChildLimit;
ty = log->follow(ty);
clear();
auto [index, fresh] = indexify(ty); auto [index, fresh] = indexify(ty);
worklist.push_back({index, -1, -1}); worklist.push_back({index, -1, -1});
return loop(); return loop();
@ -252,14 +285,34 @@ TarjanResult Tarjan::visitRoot(TypeId ty)
TarjanResult Tarjan::visitRoot(TypePackId tp) TarjanResult Tarjan::visitRoot(TypePackId tp)
{ {
childCount = 0; childCount = 0;
tp = follow(tp); if (childLimit == 0)
childLimit = FInt::LuauTarjanChildLimit;
tp = log->follow(tp);
clear();
auto [index, fresh] = indexify(tp); auto [index, fresh] = indexify(tp);
worklist.push_back({index, -1, -1}); worklist.push_back({index, -1, -1});
return loop(); return loop();
} }
void FindDirty::clearTarjan()
{
dirty.clear();
typeToIndex.clear();
packToIndex.clear();
indexToType.clear();
indexToPack.clear();
stack.clear();
onStack.clear();
lowlink.clear();
edgesTy.clear();
edgesTp.clear();
worklist.clear();
}
bool FindDirty::getDirty(int index) bool FindDirty::getDirty(int index)
{ {
if (dirty.size() <= size_t(index)) if (dirty.size() <= size_t(index))
@ -311,107 +364,122 @@ void FindDirty::visitSCC(int index)
TarjanResult FindDirty::findDirty(TypeId ty) TarjanResult FindDirty::findDirty(TypeId ty)
{ {
dirty.clear();
return visitRoot(ty); return visitRoot(ty);
} }
TarjanResult FindDirty::findDirty(TypePackId tp) TarjanResult FindDirty::findDirty(TypePackId tp)
{ {
dirty.clear();
return visitRoot(tp); return visitRoot(tp);
} }
std::optional<TypeId> Substitution::substitute(TypeId ty) std::optional<TypeId> Substitution::substitute(TypeId ty)
{ {
ty = follow(ty); ty = log->follow(ty);
newTypes.clear();
newPacks.clear(); // clear algorithm state for reentrancy
if (FFlag::LuauSubstitutionReentrant)
clearTarjan();
auto result = findDirty(ty); auto result = findDirty(ty);
if (result != TarjanResult::Ok) if (result != TarjanResult::Ok)
return std::nullopt; return std::nullopt;
for (auto [oldTy, newTy] : newTypes) for (auto [oldTy, newTy] : newTypes)
replaceChildren(newTy); {
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy))
{
replaceChildren(newTy);
replacedTypes.insert(newTy);
}
}
else
{
if (!ignoreChildren(oldTy))
replaceChildren(newTy);
}
}
for (auto [oldTp, newTp] : newPacks) for (auto [oldTp, newTp] : newPacks)
replaceChildren(newTp); {
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp))
{
replaceChildren(newTp);
replacedTypePacks.insert(newTp);
}
}
else
{
if (!ignoreChildren(oldTp))
replaceChildren(newTp);
}
}
TypeId newTy = replace(ty); TypeId newTy = replace(ty);
return newTy; return newTy;
} }
std::optional<TypePackId> Substitution::substitute(TypePackId tp) std::optional<TypePackId> Substitution::substitute(TypePackId tp)
{ {
tp = follow(tp); tp = log->follow(tp);
newTypes.clear();
newPacks.clear(); // clear algorithm state for reentrancy
if (FFlag::LuauSubstitutionReentrant)
clearTarjan();
auto result = findDirty(tp); auto result = findDirty(tp);
if (result != TarjanResult::Ok) if (result != TarjanResult::Ok)
return std::nullopt; return std::nullopt;
for (auto [oldTy, newTy] : newTypes) for (auto [oldTy, newTy] : newTypes)
replaceChildren(newTy); {
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy))
{
replaceChildren(newTy);
replacedTypes.insert(newTy);
}
}
else
{
if (!ignoreChildren(oldTy))
replaceChildren(newTy);
}
}
for (auto [oldTp, newTp] : newPacks) for (auto [oldTp, newTp] : newPacks)
replaceChildren(newTp); {
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp))
{
replaceChildren(newTp);
replacedTypePacks.insert(newTp);
}
}
else
{
if (!ignoreChildren(oldTp))
replaceChildren(newTp);
}
}
TypePackId newTp = replace(tp); TypePackId newTp = replace(tp);
return newTp; return newTp;
} }
TypeId Substitution::clone(TypeId ty) TypeId Substitution::clone(TypeId ty)
{ {
ty = follow(ty); return shallowClone(ty, *arena, log, /* alwaysClone */ FFlag::LuauClonePublicInterfaceLess);
TypeId result = ty;
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{
FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf};
clone.generics = ftv->generics;
clone.genericPacks = ftv->genericPacks;
clone.magicFunction = ftv->magicFunction;
clone.tags = ftv->tags;
clone.argNames = ftv->argNames;
result = addType(std::move(clone));
}
else if (const TableTypeVar* ttv = get<TableTypeVar>(ty))
{
LUAU_ASSERT(!ttv->boundTo);
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state};
clone.methodDefinitionLocations = ttv->methodDefinitionLocations;
clone.definitionModuleName = ttv->definitionModuleName;
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.instantiatedTypeParams = ttv->instantiatedTypeParams;
if (FFlag::LuauSecondTypecheckKnowsTheDataModel)
clone.tags = ttv->tags;
result = addType(std::move(clone));
}
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{
MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable};
clone.syntheticName = mtv->syntheticName;
result = addType(std::move(clone));
}
else if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
{
UnionTypeVar clone;
clone.options = utv->options;
result = addType(std::move(clone));
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(ty))
{
IntersectionTypeVar clone;
clone.parts = itv->parts;
result = addType(std::move(clone));
}
asMutable(result)->documentationSymbol = ty->documentationSymbol;
return result;
} }
TypePackId Substitution::clone(TypePackId tp) TypePackId Substitution::clone(TypePackId tp)
{ {
tp = follow(tp); tp = log->follow(tp);
if (auto ptp = log->pending(tp))
tp = &ptp->pending;
if (const TypePack* tpp = get<TypePack>(tp)) if (const TypePack* tpp = get<TypePack>(tp))
{ {
TypePack clone; TypePack clone;
@ -423,33 +491,48 @@ TypePackId Substitution::clone(TypePackId tp)
{ {
VariadicTypePack clone; VariadicTypePack clone;
clone.ty = vtp->ty; clone.ty = vtp->ty;
if (FFlag::LuauSubstitutionFixMissingFields)
clone.hidden = vtp->hidden;
return addTypePack(std::move(clone)); return addTypePack(std::move(clone));
} }
else if (FFlag::LuauClonePublicInterfaceLess)
{
return addTypePack(*tp);
}
else else
return tp; return tp;
} }
void Substitution::foundDirty(TypeId ty) void Substitution::foundDirty(TypeId ty)
{ {
ty = follow(ty); ty = log->follow(ty);
if (FFlag::LuauSubstitutionReentrant && newTypes.contains(ty))
return;
if (isDirty(ty)) if (isDirty(ty))
newTypes[ty] = clean(ty); newTypes[ty] = follow(clean(ty));
else else
newTypes[ty] = clone(ty); newTypes[ty] = follow(clone(ty));
} }
void Substitution::foundDirty(TypePackId tp) void Substitution::foundDirty(TypePackId tp)
{ {
tp = follow(tp); tp = log->follow(tp);
if (FFlag::LuauSubstitutionReentrant && newPacks.contains(tp))
return;
if (isDirty(tp)) if (isDirty(tp))
newPacks[tp] = clean(tp); newPacks[tp] = follow(clean(tp));
else else
newPacks[tp] = clone(tp); newPacks[tp] = follow(clone(tp));
} }
TypeId Substitution::replace(TypeId ty) TypeId Substitution::replace(TypeId ty)
{ {
ty = follow(ty); ty = log->follow(ty);
if (TypeId* prevTy = newTypes.find(ty)) if (TypeId* prevTy = newTypes.find(ty))
return *prevTy; return *prevTy;
else else
@ -458,7 +541,8 @@ TypeId Substitution::replace(TypeId ty)
TypePackId Substitution::replace(TypePackId tp) TypePackId Substitution::replace(TypePackId tp)
{ {
tp = follow(tp); tp = log->follow(tp);
if (TypePackId* prevTp = newPacks.find(tp)) if (TypePackId* prevTp = newPacks.find(tp))
return *prevTp; return *prevTp;
else else
@ -467,15 +551,26 @@ TypePackId Substitution::replace(TypePackId tp)
void Substitution::replaceChildren(TypeId ty) void Substitution::replaceChildren(TypeId ty)
{ {
ty = follow(ty); LUAU_ASSERT(ty == log->follow(ty));
if (FFlag::LuauRankNTypes && ignoreChildren(ty)) if (ignoreChildren(ty))
return;
if (ty->owningArena != arena)
return; return;
if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty)) if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty))
{ {
if (FFlag::LuauSubstitutionFixMissingFields)
{
for (TypeId& generic : ftv->generics)
generic = replace(generic);
for (TypePackId& genericPack : ftv->genericPacks)
genericPack = replace(genericPack);
}
ftv->argTypes = replace(ftv->argTypes); ftv->argTypes = replace(ftv->argTypes);
ftv->retType = replace(ftv->retType); ftv->retTypes = replace(ftv->retTypes);
} }
else if (TableTypeVar* ttv = getMutable<TableTypeVar>(ty)) else if (TableTypeVar* ttv = getMutable<TableTypeVar>(ty))
{ {
@ -487,8 +582,12 @@ void Substitution::replaceChildren(TypeId ty)
ttv->indexer->indexType = replace(ttv->indexer->indexType); ttv->indexer->indexType = replace(ttv->indexer->indexType);
ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType); ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType);
} }
for (TypeId& itp : ttv->instantiatedTypeParams) for (TypeId& itp : ttv->instantiatedTypeParams)
itp = replace(itp); itp = replace(itp);
for (TypePackId& itp : ttv->instantiatedTypePackParams)
itp = replace(itp);
} }
else if (MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(ty)) else if (MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(ty))
{ {
@ -505,13 +604,39 @@ void Substitution::replaceChildren(TypeId ty)
for (TypeId& part : itv->parts) for (TypeId& part : itv->parts)
part = replace(part); part = replace(part);
} }
else if (PendingExpansionTypeVar* petv = getMutable<PendingExpansionTypeVar>(ty))
{
for (TypeId& a : petv->typeArguments)
a = replace(a);
for (TypePackId& a : petv->packArguments)
a = replace(a);
}
else if (ClassTypeVar* ctv = getMutable<ClassTypeVar>(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv)
{
for (auto& [name, prop] : ctv->props)
prop.type = replace(prop.type);
if (ctv->parent)
ctv->parent = replace(*ctv->parent);
if (ctv->metatable)
ctv->metatable = replace(*ctv->metatable);
}
else if (NegationTypeVar* ntv = getMutable<NegationTypeVar>(ty))
{
ntv->ty = replace(ntv->ty);
}
} }
void Substitution::replaceChildren(TypePackId tp) void Substitution::replaceChildren(TypePackId tp)
{ {
tp = follow(tp); LUAU_ASSERT(tp == log->follow(tp));
if (FFlag::LuauRankNTypes && ignoreChildren(tp)) if (ignoreChildren(tp))
return;
if (tp->owningArena != arena)
return; return;
if (TypePack* tpp = getMutable<TypePack>(tp)) if (TypePack* tpp = getMutable<TypePack>(tp))

400
Analysis/src/ToDot.cpp Normal file
View file

@ -0,0 +1,400 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ToDot.h"
#include "Luau/ToString.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
#include "Luau/StringUtils.h"
#include <unordered_map>
#include <unordered_set>
namespace Luau
{
namespace
{
struct StateDot
{
StateDot(ToDotOptions opts)
: opts(opts)
{
}
ToDotOptions opts;
std::unordered_set<TypeId> seenTy;
std::unordered_set<TypePackId> seenTp;
std::unordered_map<TypeId, int> tyToIndex;
std::unordered_map<TypePackId, int> tpToIndex;
int nextIndex = 1;
std::string result;
bool canDuplicatePrimitive(TypeId ty);
void visitChildren(TypeId ty, int index);
void visitChildren(TypePackId ty, int index);
void visitChild(TypeId ty, int parentIndex, const char* linkName = nullptr);
void visitChild(TypePackId tp, int parentIndex, const char* linkName = nullptr);
void startNode(int index);
void finishNode();
void startNodeLabel();
void finishNodeLabel(TypeId ty);
void finishNodeLabel(TypePackId tp);
};
bool StateDot::canDuplicatePrimitive(TypeId ty)
{
if (get<BoundTypeVar>(ty))
return false;
return get<PrimitiveTypeVar>(ty) || get<AnyTypeVar>(ty);
}
void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName)
{
if (!tyToIndex.count(ty) || (opts.duplicatePrimitives && canDuplicatePrimitive(ty)))
tyToIndex[ty] = nextIndex++;
int index = tyToIndex[ty];
if (parentIndex != 0)
{
if (linkName)
formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, index, linkName);
else
formatAppend(result, "n%d -> n%d;\n", parentIndex, index);
}
if (opts.duplicatePrimitives && canDuplicatePrimitive(ty))
{
if (get<PrimitiveTypeVar>(ty))
formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str());
else if (get<AnyTypeVar>(ty))
formatAppend(result, "n%d [label=\"any\"];\n", index);
}
else
{
visitChildren(ty, index);
}
}
void StateDot::visitChild(TypePackId tp, int parentIndex, const char* linkName)
{
if (!tpToIndex.count(tp))
tpToIndex[tp] = nextIndex++;
if (parentIndex != 0)
{
if (linkName)
formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, tpToIndex[tp], linkName);
else
formatAppend(result, "n%d -> n%d;\n", parentIndex, tpToIndex[tp]);
}
visitChildren(tp, tpToIndex[tp]);
}
void StateDot::startNode(int index)
{
formatAppend(result, "n%d [", index);
}
void StateDot::finishNode()
{
formatAppend(result, "];\n");
}
void StateDot::startNodeLabel()
{
formatAppend(result, "label=\"");
}
void StateDot::finishNodeLabel(TypeId ty)
{
if (opts.showPointers)
formatAppend(result, "\n0x%p", ty);
// additional common attributes can be added here as well
result += "\"";
}
void StateDot::finishNodeLabel(TypePackId tp)
{
if (opts.showPointers)
formatAppend(result, "\n0x%p", tp);
// additional common attributes can be added here as well
result += "\"";
}
void StateDot::visitChildren(TypeId ty, int index)
{
if (seenTy.count(ty))
return;
seenTy.insert(ty);
startNode(index);
startNodeLabel();
if (const BoundTypeVar* btv = get<BoundTypeVar>(ty))
{
formatAppend(result, "BoundTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
visitChild(btv->boundTo, index);
}
else if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{
formatAppend(result, "FunctionTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
visitChild(ftv->argTypes, index, "arg");
visitChild(ftv->retTypes, index, "ret");
}
else if (const TableTypeVar* ttv = get<TableTypeVar>(ty))
{
if (ttv->name)
formatAppend(result, "TableTypeVar %s", ttv->name->c_str());
else if (ttv->syntheticName)
formatAppend(result, "TableTypeVar %s", ttv->syntheticName->c_str());
else
formatAppend(result, "TableTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
if (ttv->boundTo)
return visitChild(*ttv->boundTo, index, "boundTo");
for (const auto& [name, prop] : ttv->props)
visitChild(prop.type, index, name.c_str());
if (ttv->indexer)
{
visitChild(ttv->indexer->indexType, index, "[index]");
visitChild(ttv->indexer->indexResultType, index, "[value]");
}
for (TypeId itp : ttv->instantiatedTypeParams)
visitChild(itp, index, "typeParam");
for (TypePackId itp : ttv->instantiatedTypePackParams)
visitChild(itp, index, "typePackParam");
}
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{
formatAppend(result, "MetatableTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
visitChild(mtv->table, index, "table");
visitChild(mtv->metatable, index, "metatable");
}
else if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
{
formatAppend(result, "UnionTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
for (TypeId opt : utv->options)
visitChild(opt, index);
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(ty))
{
formatAppend(result, "IntersectionTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
for (TypeId part : itv->parts)
visitChild(part, index);
}
else if (const GenericTypeVar* gtv = get<GenericTypeVar>(ty))
{
if (gtv->explicitName)
formatAppend(result, "GenericTypeVar %s", gtv->name.c_str());
else
formatAppend(result, "GenericTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
}
else if (const FreeTypeVar* ftv = get<FreeTypeVar>(ty))
{
formatAppend(result, "FreeTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
}
else if (get<AnyTypeVar>(ty))
{
formatAppend(result, "AnyTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
}
else if (get<PrimitiveTypeVar>(ty))
{
formatAppend(result, "PrimitiveTypeVar %s", toString(ty).c_str());
finishNodeLabel(ty);
finishNode();
}
else if (get<ErrorTypeVar>(ty))
{
formatAppend(result, "ErrorTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
}
else if (const ClassTypeVar* ctv = get<ClassTypeVar>(ty))
{
formatAppend(result, "ClassTypeVar %s", ctv->name.c_str());
finishNodeLabel(ty);
finishNode();
for (const auto& [name, prop] : ctv->props)
visitChild(prop.type, index, name.c_str());
if (ctv->parent)
visitChild(*ctv->parent, index, "[parent]");
if (ctv->metatable)
visitChild(*ctv->metatable, index, "[metatable]");
}
else if (const SingletonTypeVar* stv = get<SingletonTypeVar>(ty))
{
std::string res;
if (const StringSingleton* ss = get<StringSingleton>(stv))
{
// Don't put in quotes anywhere. If it's outside of the call to escape,
// then it's invalid syntax. If it's inside, then escaping is super noisy.
res = "string: " + escape(ss->value);
}
else if (const BooleanSingleton* bs = get<BooleanSingleton>(stv))
{
res = "boolean: ";
res += bs->value ? "true" : "false";
}
else
LUAU_ASSERT(!"unknown singleton type");
formatAppend(result, "SingletonTypeVar %s", res.c_str());
finishNodeLabel(ty);
finishNode();
}
else
{
LUAU_ASSERT(!"unknown type kind");
finishNodeLabel(ty);
finishNode();
}
}
void StateDot::visitChildren(TypePackId tp, int index)
{
if (seenTp.count(tp))
return;
seenTp.insert(tp);
startNode(index);
startNodeLabel();
if (const BoundTypePack* btp = get<BoundTypePack>(tp))
{
formatAppend(result, "BoundTypePack %d", index);
finishNodeLabel(tp);
finishNode();
visitChild(btp->boundTo, index);
}
else if (const TypePack* tpp = get<TypePack>(tp))
{
formatAppend(result, "TypePack %d", index);
finishNodeLabel(tp);
finishNode();
for (TypeId tv : tpp->head)
visitChild(tv, index);
if (tpp->tail)
visitChild(*tpp->tail, index, "tail");
}
else if (const VariadicTypePack* vtp = get<VariadicTypePack>(tp))
{
formatAppend(result, "VariadicTypePack %s%d", vtp->hidden ? "hidden " : "", index);
finishNodeLabel(tp);
finishNode();
visitChild(vtp->ty, index);
}
else if (const FreeTypePack* ftp = get<FreeTypePack>(tp))
{
formatAppend(result, "FreeTypePack %d", index);
finishNodeLabel(tp);
finishNode();
}
else if (const GenericTypePack* gtp = get<GenericTypePack>(tp))
{
if (gtp->explicitName)
formatAppend(result, "GenericTypePack %s", gtp->name.c_str());
else
formatAppend(result, "GenericTypePack %d", index);
finishNodeLabel(tp);
finishNode();
}
else if (get<Unifiable::Error>(tp))
{
formatAppend(result, "ErrorTypePack %d", index);
finishNodeLabel(tp);
finishNode();
}
else
{
LUAU_ASSERT(!"unknown type pack kind");
finishNodeLabel(tp);
finishNode();
}
}
} // namespace
std::string toDot(TypeId ty, const ToDotOptions& opts)
{
StateDot state{opts};
state.result = "digraph graphname {\n";
state.visitChild(ty, 0);
state.result += "}";
return state.result;
}
std::string toDot(TypePackId tp, const ToDotOptions& opts)
{
StateDot state{opts};
state.result = "digraph graphname {\n";
state.visitChild(tp, 0);
state.result += "}";
return state.result;
}
std::string toDot(TypeId ty)
{
return toDot(ty, {});
}
std::string toDot(TypePackId tp)
{
return toDot(tp, {});
}
void dumpDot(TypeId ty)
{
printf("%s\n", toDot(ty).c_str());
}
void dumpDot(TypePackId tp)
{
printf("%s\n", toDot(tp).c_str());
}
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TopoSortStatements.h" #include "Luau/TopoSortStatements.h"
#include "Luau/Error.h"
/* Decide the order in which we typecheck Lua statements in a block. /* Decide the order in which we typecheck Lua statements in a block.
* *
* Algorithm: * Algorithm:
@ -26,9 +27,10 @@
* 3. Cyclic dependencies can be resolved by picking an arbitrary statement to check first. * 3. Cyclic dependencies can be resolved by picking an arbitrary statement to check first.
*/ */
#include "Luau/Parser.h" #include "Luau/Ast.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/StringUtils.h"
#include <algorithm> #include <algorithm>
#include <deque> #include <deque>
@ -148,7 +150,7 @@ Identifier mkName(const AstStatFunction& function)
auto name = mkName(*function.name); auto name = mkName(*function.name);
LUAU_ASSERT(bool(name)); LUAU_ASSERT(bool(name));
if (!name) if (!name)
throw std::runtime_error("Internal error: Function declaration has a bad name"); throw InternalCompilerError("Internal error: Function declaration has a bad name");
return *name; return *name;
} }
@ -214,6 +216,7 @@ struct ArcCollector : public AstVisitor
} }
} }
// Adds a dependency from the current node to the named node.
void add(const Identifier& name) void add(const Identifier& name)
{ {
Node** it = map.find(name); Node** it = map.find(name);
@ -253,7 +256,7 @@ struct ArcCollector : public AstVisitor
{ {
auto name = mkName(*node->name); auto name = mkName(*node->name);
if (!name) if (!name)
throw std::runtime_error("Internal error: AstStatFunction has a bad name"); throw InternalCompilerError("Internal error: AstStatFunction has a bad name");
add(*name); add(*name);
return true; return true;
@ -298,8 +301,15 @@ struct ArcCollector : public AstVisitor
struct ContainsFunctionCall : public AstVisitor struct ContainsFunctionCall : public AstVisitor
{ {
bool alsoReturn = false;
bool result = false; bool result = false;
ContainsFunctionCall() = default;
explicit ContainsFunctionCall(bool alsoReturn)
: alsoReturn(alsoReturn)
{
}
bool visit(AstExpr*) override bool visit(AstExpr*) override
{ {
return !result; // short circuit if result is true return !result; // short circuit if result is true
@ -318,6 +328,17 @@ struct ContainsFunctionCall : public AstVisitor
return false; return false;
} }
bool visit(AstStatReturn* stat) override
{
if (alsoReturn)
{
result = true;
return false;
}
else
return AstVisitor::visit(stat);
}
bool visit(AstExprFunction*) override bool visit(AstExprFunction*) override
{ {
return false; return false;
@ -479,6 +500,13 @@ bool containsFunctionCall(const AstStat& stat)
return cfc.result; return cfc.result;
} }
bool containsFunctionCallOrReturn(const AstStat& stat)
{
detail::ContainsFunctionCall cfc{true};
const_cast<AstStat&>(stat).visit(&cfc);
return cfc.result;
}
bool isFunction(const AstStat& stat) bool isFunction(const AstStat& stat)
{ {
return stat.is<AstStatFunction>() || stat.is<AstStatLocalFunction>(); return stat.is<AstStatFunction>() || stat.is<AstStatLocalFunction>();

View file

@ -10,65 +10,8 @@
#include <limits> #include <limits>
#include <math.h> #include <math.h>
LUAU_FASTFLAG(LuauGenericFunctions)
namespace namespace
{ {
std::string escape(std::string_view s)
{
std::string r;
r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting
for (uint8_t c : s)
{
if (c >= ' ' && c != '\\' && c != '\'' && c != '\"')
r += c;
else
{
r += '\\';
switch (c)
{
case '\a':
r += 'a';
break;
case '\b':
r += 'b';
break;
case '\f':
r += 'f';
break;
case '\n':
r += 'n';
break;
case '\r':
r += 'r';
break;
case '\t':
r += 't';
break;
case '\v':
r += 'v';
break;
case '\'':
r += '\'';
break;
case '\"':
r += '\"';
break;
case '\\':
r += '\\';
break;
default:
Luau::formatAppend(r, "%03u", c);
}
}
}
return r;
}
bool isIdentifierStartChar(char c) bool isIdentifierStartChar(char c)
{ {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_'; return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_';
@ -96,9 +39,6 @@ struct Writer
{ {
virtual ~Writer() {} virtual ~Writer() {}
virtual void begin() {}
virtual void end() {}
virtual void advance(const Position&) = 0; virtual void advance(const Position&) = 0;
virtual void newline() = 0; virtual void newline() = 0;
virtual void space() = 0; virtual void space() = 0;
@ -130,6 +70,7 @@ struct StringWriter : Writer
if (pos.column < newPos.column) if (pos.column < newPos.column)
write(std::string(newPos.column - pos.column, ' ')); write(std::string(newPos.column - pos.column, ' '));
} }
void maybeSpace(const Position& newPos, int reserve) override void maybeSpace(const Position& newPos, int reserve) override
{ {
if (pos.column + reserve < newPos.column) if (pos.column + reserve < newPos.column)
@ -264,26 +205,25 @@ struct Printer
} }
} }
void visualizeWithSelf(AstExpr& expr, bool self) void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg)
{ {
if (!self) advance(annotation.location.begin);
return visualize(expr); if (const AstTypePackVariadic* variadicTp = annotation.as<AstTypePackVariadic>())
AstExprIndexName* func = expr.as<AstExprIndexName>();
LUAU_ASSERT(func);
visualize(*func->expr);
writer.symbol(":");
advance(func->indexLocation.begin);
writer.identifier(func->index.value);
}
void visualizeTypePackAnnotation(const AstTypePack& annotation)
{
if (const AstTypePackVariadic* variadic = annotation.as<AstTypePackVariadic>())
{ {
if (!forVarArg)
writer.symbol("...");
visualizeTypeAnnotation(*variadicTp->variadicType);
}
else if (const AstTypePackGeneric* genericTp = annotation.as<AstTypePackGeneric>())
{
writer.symbol(genericTp->genericName.value);
writer.symbol("..."); writer.symbol("...");
visualizeTypeAnnotation(*variadic->variadicType); }
else if (const AstTypePackExplicit* explicitTp = annotation.as<AstTypePackExplicit>())
{
LUAU_ASSERT(!forVarArg);
visualizeTypeList(explicitTp->typeList, true);
} }
else else
{ {
@ -307,7 +247,7 @@ struct Printer
// Only variadic tail // Only variadic tail
if (list.types.size == 0) if (list.types.size == 0)
{ {
visualizeTypePackAnnotation(*list.tailType); visualizeTypePackAnnotation(*list.tailType, false);
} }
else else
{ {
@ -335,7 +275,7 @@ struct Printer
if (list.tailType) if (list.tailType)
{ {
writer.symbol(","); writer.symbol(",");
visualizeTypePackAnnotation(*list.tailType); visualizeTypePackAnnotation(*list.tailType, false);
} }
writer.symbol(")"); writer.symbol(")");
@ -412,7 +352,7 @@ struct Printer
} }
else if (const auto& a = expr.as<AstExprCall>()) else if (const auto& a = expr.as<AstExprCall>())
{ {
visualizeWithSelf(*a->func, a->self); visualize(*a->func);
writer.symbol("("); writer.symbol("(");
bool first = true; bool first = true;
@ -431,7 +371,7 @@ struct Printer
else if (const auto& a = expr.as<AstExprIndexName>()) else if (const auto& a = expr.as<AstExprIndexName>())
{ {
visualize(*a->expr); visualize(*a->expr);
writer.symbol("."); writer.symbol(std::string(1, a->op));
writer.write(a->index.value); writer.write(a->index.value);
} }
else if (const auto& a = expr.as<AstExprIndexExpr>()) else if (const auto& a = expr.as<AstExprIndexExpr>())
@ -532,6 +472,7 @@ struct Printer
case AstExprBinary::CompareLt: case AstExprBinary::CompareLt:
case AstExprBinary::CompareGt: case AstExprBinary::CompareGt:
writer.maybeSpace(a->right->location.begin, 2); writer.maybeSpace(a->right->location.begin, 2);
writer.symbol(toString(a->op));
break; break;
case AstExprBinary::Concat: case AstExprBinary::Concat:
case AstExprBinary::CompareNe: case AstExprBinary::CompareNe:
@ -540,19 +481,57 @@ struct Printer
case AstExprBinary::CompareGe: case AstExprBinary::CompareGe:
case AstExprBinary::Or: case AstExprBinary::Or:
writer.maybeSpace(a->right->location.begin, 3); writer.maybeSpace(a->right->location.begin, 3);
writer.keyword(toString(a->op));
break; break;
case AstExprBinary::And: case AstExprBinary::And:
writer.maybeSpace(a->right->location.begin, 4); writer.maybeSpace(a->right->location.begin, 4);
writer.keyword(toString(a->op));
break; break;
} }
writer.symbol(toString(a->op));
visualize(*a->right); visualize(*a->right);
} }
else if (const auto& a = expr.as<AstExprTypeAssertion>()) else if (const auto& a = expr.as<AstExprTypeAssertion>())
{ {
visualize(*a->expr); visualize(*a->expr);
if (writeTypes)
{
writer.maybeSpace(a->annotation->location.begin, 2);
writer.symbol("::");
visualizeTypeAnnotation(*a->annotation);
}
}
else if (const auto& a = expr.as<AstExprIfElse>())
{
writer.keyword("if");
visualize(*a->condition);
writer.keyword("then");
visualize(*a->trueExpr);
writer.keyword("else");
visualize(*a->falseExpr);
}
else if (const auto& a = expr.as<AstExprInterpString>())
{
writer.symbol("`");
size_t index = 0;
for (const auto& string : a->strings)
{
writer.write(escape(std::string_view(string.data, string.size), /* escapeForInterpString = */ true));
if (index < a->expressions.size)
{
writer.symbol("{");
visualize(*a->expressions.data[index]);
writer.symbol("}");
}
index++;
}
writer.symbol("`");
} }
else if (const auto& a = expr.as<AstExprError>()) else if (const auto& a = expr.as<AstExprError>())
{ {
@ -759,24 +738,31 @@ struct Printer
switch (a->op) switch (a->op)
{ {
case AstExprBinary::Add: case AstExprBinary::Add:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("+="); writer.symbol("+=");
break; break;
case AstExprBinary::Sub: case AstExprBinary::Sub:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("-="); writer.symbol("-=");
break; break;
case AstExprBinary::Mul: case AstExprBinary::Mul:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("*="); writer.symbol("*=");
break; break;
case AstExprBinary::Div: case AstExprBinary::Div:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("/="); writer.symbol("/=");
break; break;
case AstExprBinary::Mod: case AstExprBinary::Mod:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("%="); writer.symbol("%=");
break; break;
case AstExprBinary::Pow: case AstExprBinary::Pow:
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("^="); writer.symbol("^=");
break; break;
case AstExprBinary::Concat: case AstExprBinary::Concat:
writer.maybeSpace(a->value->location.begin, 3);
writer.symbol("..="); writer.symbol("..=");
break; break;
default: default:
@ -788,7 +774,7 @@ struct Printer
else if (const auto& a = program.as<AstStatFunction>()) else if (const auto& a = program.as<AstStatFunction>())
{ {
writer.keyword("function"); writer.keyword("function");
visualizeWithSelf(*a->name, a->func->self != nullptr); visualize(*a->name);
visualizeFunctionBody(*a->func); visualizeFunctionBody(*a->func);
} }
else if (const auto& a = program.as<AstStatLocalFunction>()) else if (const auto& a = program.as<AstStatLocalFunction>())
@ -807,7 +793,7 @@ struct Printer
writer.keyword("type"); writer.keyword("type");
writer.identifier(a->name.value); writer.identifier(a->name.value);
if (a->generics.size > 0) if (a->generics.size > 0 || a->genericPacks.size > 0)
{ {
writer.symbol("<"); writer.symbol("<");
CommaSeparatorInserter comma(writer); CommaSeparatorInserter comma(writer);
@ -815,8 +801,34 @@ struct Printer
for (auto o : a->generics) for (auto o : a->generics)
{ {
comma(); comma();
writer.identifier(o.value);
writer.advance(o.location.begin);
writer.identifier(o.name.value);
if (o.defaultValue)
{
writer.maybeSpace(o.defaultValue->location.begin, 2);
writer.symbol("=");
visualizeTypeAnnotation(*o.defaultValue);
}
} }
for (auto o : a->genericPacks)
{
comma();
writer.advance(o.location.begin);
writer.identifier(o.name.value);
writer.symbol("...");
if (o.defaultValue)
{
writer.maybeSpace(o.defaultValue->location.begin, 2);
writer.symbol("=");
visualizeTypePackAnnotation(*o.defaultValue, false);
}
}
writer.symbol(">"); writer.symbol(">");
} }
writer.maybeSpace(a->type->location.begin, 2); writer.maybeSpace(a->type->location.begin, 2);
@ -853,19 +865,23 @@ struct Printer
void visualizeFunctionBody(AstExprFunction& func) void visualizeFunctionBody(AstExprFunction& func)
{ {
if (FFlag::LuauGenericFunctions && (func.generics.size > 0 || func.genericPacks.size > 0)) if (func.generics.size > 0 || func.genericPacks.size > 0)
{ {
CommaSeparatorInserter comma(writer); CommaSeparatorInserter comma(writer);
writer.symbol("<"); writer.symbol("<");
for (const auto& o : func.generics) for (const auto& o : func.generics)
{ {
comma(); comma();
writer.identifier(o.value);
writer.advance(o.location.begin);
writer.identifier(o.name.value);
} }
for (const auto& o : func.genericPacks) for (const auto& o : func.genericPacks)
{ {
comma(); comma();
writer.identifier(o.value);
writer.advance(o.location.begin);
writer.identifier(o.name.value);
writer.symbol("..."); writer.symbol("...");
} }
writer.symbol(">"); writer.symbol(">");
@ -892,23 +908,24 @@ struct Printer
if (func.vararg) if (func.vararg)
{ {
comma(); comma();
advance(func.varargLocation.begin);
writer.symbol("..."); writer.symbol("...");
if (func.varargAnnotation) if (func.varargAnnotation)
{ {
writer.symbol(":"); writer.symbol(":");
visualizeTypePackAnnotation(*func.varargAnnotation); visualizeTypePackAnnotation(*func.varargAnnotation, true);
} }
} }
writer.symbol(")"); writer.symbol(")");
if (writeTypes && func.hasReturnAnnotation) if (writeTypes && func.returnAnnotation)
{ {
writer.symbol(":"); writer.symbol(":");
writer.space(); writer.space();
visualizeTypeList(func.returnAnnotation, false); visualizeTypeList(*func.returnAnnotation, false);
} }
visualizeBlock(*func.body); visualizeBlock(*func.body);
@ -959,34 +976,49 @@ struct Printer
advance(typeAnnotation.location.begin); advance(typeAnnotation.location.begin);
if (const auto& a = typeAnnotation.as<AstTypeReference>()) if (const auto& a = typeAnnotation.as<AstTypeReference>())
{ {
if (a->prefix)
{
writer.write(a->prefix->value);
writer.symbol(".");
}
writer.write(a->name.value); writer.write(a->name.value);
if (a->generics.size > 0) if (a->parameters.size > 0 || a->hasParameterList)
{ {
CommaSeparatorInserter comma(writer); CommaSeparatorInserter comma(writer);
writer.symbol("<"); writer.symbol("<");
for (auto o : a->generics) for (auto o : a->parameters)
{ {
comma(); comma();
visualizeTypeAnnotation(*o);
if (o.type)
visualizeTypeAnnotation(*o.type);
else
visualizeTypePackAnnotation(*o.typePack, false);
} }
writer.symbol(">"); writer.symbol(">");
} }
} }
else if (const auto& a = typeAnnotation.as<AstTypeFunction>()) else if (const auto& a = typeAnnotation.as<AstTypeFunction>())
{ {
if (FFlag::LuauGenericFunctions && (a->generics.size > 0 || a->genericPacks.size > 0)) if (a->generics.size > 0 || a->genericPacks.size > 0)
{ {
CommaSeparatorInserter comma(writer); CommaSeparatorInserter comma(writer);
writer.symbol("<"); writer.symbol("<");
for (const auto& o : a->generics) for (const auto& o : a->generics)
{ {
comma(); comma();
writer.identifier(o.value);
writer.advance(o.location.begin);
writer.identifier(o.name.value);
} }
for (const auto& o : a->genericPacks) for (const auto& o : a->genericPacks)
{ {
comma(); comma();
writer.identifier(o.value);
writer.advance(o.location.begin);
writer.identifier(o.name.value);
writer.symbol("..."); writer.symbol("...");
} }
writer.symbol(">"); writer.symbol(">");
@ -1001,31 +1033,42 @@ struct Printer
} }
else if (const auto& a = typeAnnotation.as<AstTypeTable>()) else if (const auto& a = typeAnnotation.as<AstTypeTable>())
{ {
CommaSeparatorInserter comma(writer); AstTypeReference* indexType = a->indexer ? a->indexer->indexType->as<AstTypeReference>() : nullptr;
writer.symbol("{"); if (a->props.size == 0 && indexType && indexType->name == "number")
for (std::size_t i = 0; i < a->props.size; ++i)
{ {
comma(); writer.symbol("{");
advance(a->props.data[i].location.begin);
writer.identifier(a->props.data[i].name.value);
if (a->props.data[i].type)
{
writer.symbol(":");
visualizeTypeAnnotation(*a->props.data[i].type);
}
}
if (a->indexer)
{
comma();
writer.symbol("[");
visualizeTypeAnnotation(*a->indexer->indexType);
writer.symbol("]");
writer.symbol(":");
visualizeTypeAnnotation(*a->indexer->resultType); visualizeTypeAnnotation(*a->indexer->resultType);
writer.symbol("}");
}
else
{
CommaSeparatorInserter comma(writer);
writer.symbol("{");
for (std::size_t i = 0; i < a->props.size; ++i)
{
comma();
advance(a->props.data[i].location.begin);
writer.identifier(a->props.data[i].name.value);
if (a->props.data[i].type)
{
writer.symbol(":");
visualizeTypeAnnotation(*a->props.data[i].type);
}
}
if (a->indexer)
{
comma();
writer.symbol("[");
visualizeTypeAnnotation(*a->indexer->indexType);
writer.symbol("]");
writer.symbol(":");
visualizeTypeAnnotation(*a->indexer->resultType);
}
writer.symbol("}");
} }
writer.symbol("}");
} }
else if (auto a = typeAnnotation.as<AstTypeTypeof>()) else if (auto a = typeAnnotation.as<AstTypeTypeof>())
{ {
@ -1049,7 +1092,16 @@ struct Printer
auto rta = r->as<AstTypeReference>(); auto rta = r->as<AstTypeReference>();
if (rta && rta->name == "nil") if (rta && rta->name == "nil")
{ {
bool wrap = l->as<AstTypeIntersection>() || l->as<AstTypeFunction>();
if (wrap)
writer.symbol("(");
visualizeTypeAnnotation(*l); visualizeTypeAnnotation(*l);
if (wrap)
writer.symbol(")");
writer.symbol("?"); writer.symbol("?");
return; return;
} }
@ -1063,7 +1115,15 @@ struct Printer
writer.symbol("|"); writer.symbol("|");
} }
bool wrap = a->types.data[i]->as<AstTypeIntersection>() || a->types.data[i]->as<AstTypeFunction>();
if (wrap)
writer.symbol("(");
visualizeTypeAnnotation(*a->types.data[i]); visualizeTypeAnnotation(*a->types.data[i]);
if (wrap)
writer.symbol(")");
} }
} }
else if (const auto& a = typeAnnotation.as<AstTypeIntersection>()) else if (const auto& a = typeAnnotation.as<AstTypeIntersection>())
@ -1076,9 +1136,25 @@ struct Printer
writer.symbol("&"); writer.symbol("&");
} }
bool wrap = a->types.data[i]->as<AstTypeUnion>() || a->types.data[i]->as<AstTypeFunction>();
if (wrap)
writer.symbol("(");
visualizeTypeAnnotation(*a->types.data[i]); visualizeTypeAnnotation(*a->types.data[i]);
if (wrap)
writer.symbol(")");
} }
} }
else if (const auto& a = typeAnnotation.as<AstTypeSingletonBool>())
{
writer.keyword(a->value ? "true" : "false");
}
else if (const auto& a = typeAnnotation.as<AstTypeSingletonString>())
{
writer.string(std::string_view(a->value.data, a->value.size));
}
else if (typeAnnotation.is<AstTypeError>()) else if (typeAnnotation.is<AstTypeError>())
{ {
writer.symbol("%error-type%"); writer.symbol("%error-type%");
@ -1090,31 +1166,27 @@ struct Printer
} }
}; };
void dump(AstNode* node) std::string toString(AstNode* node)
{ {
StringWriter writer; StringWriter writer;
writer.pos = node->location.begin;
Printer printer(writer); Printer printer(writer);
printer.writeTypes = true; printer.writeTypes = true;
if (auto statNode = dynamic_cast<AstStat*>(node)) if (auto statNode = dynamic_cast<AstStat*>(node))
{
printer.visualize(*statNode); printer.visualize(*statNode);
printf("%s\n", writer.str().c_str());
}
else if (auto exprNode = dynamic_cast<AstExpr*>(node)) else if (auto exprNode = dynamic_cast<AstExpr*>(node))
{
printer.visualize(*exprNode); printer.visualize(*exprNode);
printf("%s\n", writer.str().c_str());
}
else if (auto typeNode = dynamic_cast<AstType*>(node)) else if (auto typeNode = dynamic_cast<AstType*>(node))
{
printer.visualizeTypeAnnotation(*typeNode); printer.visualizeTypeAnnotation(*typeNode);
printf("%s\n", writer.str().c_str());
} return writer.str();
else }
{
printf("Can't dump this node\n"); void dump(AstNode* node)
} {
printf("%s\n", toString(node).c_str());
} }
std::string transpile(AstStatBlock& block) std::string transpile(AstStatBlock& block)
@ -1123,6 +1195,7 @@ std::string transpile(AstStatBlock& block)
Printer(writer).visualizeBlock(block); Printer(writer).visualizeBlock(block);
return writer.str(); return writer.str();
} }
std::string transpileWithTypes(AstStatBlock& block) std::string transpileWithTypes(AstStatBlock& block)
{ {
StringWriter writer; StringWriter writer;
@ -1132,7 +1205,7 @@ std::string transpileWithTypes(AstStatBlock& block)
return writer.str(); return writer.str();
} }
TranspileResult transpile(std::string_view source, ParseOptions options) TranspileResult transpile(std::string_view source, ParseOptions options, bool withTypes)
{ {
auto allocator = Allocator{}; auto allocator = Allocator{};
auto names = AstNameTable{allocator}; auto names = AstNameTable{allocator};
@ -1150,6 +1223,9 @@ TranspileResult transpile(std::string_view source, ParseOptions options)
if (!parseResult.root) if (!parseResult.root)
return TranspileResult{"", {}, "Internal error: Parser yielded empty parse tree"}; return TranspileResult{"", {}, "Internal error: Parser yielded empty parse tree"};
if (withTypes)
return TranspileResult{transpileWithTypes(*parseResult.root)};
return TranspileResult{transpile(*parseResult.root)}; return TranspileResult{transpile(*parseResult.root)};
} }

View file

@ -1,72 +1,426 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/ToString.h"
#include "Luau/TypeArena.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include <algorithm> #include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauUnknownAndNeverType)
namespace Luau namespace Luau
{ {
void TxnLog::operator()(TypeId a) const std::string nullPendingResult = "<nullptr>";
std::string toString(PendingType* pending)
{ {
typeVarChanges.emplace_back(a, *a); if (pending == nullptr)
return nullPendingResult;
return toString(pending->pending);
} }
void TxnLog::operator()(TypePackId a) std::string dump(PendingType* pending)
{ {
typePackChanges.emplace_back(a, *a); if (pending == nullptr)
{
printf("%s\n", nullPendingResult.c_str());
return nullPendingResult;
}
ToStringOptions opts;
opts.exhaustive = true;
opts.functionTypeArguments = true;
std::string result = toString(pending->pending, opts);
printf("%s\n", result.c_str());
return result;
} }
void TxnLog::operator()(TableTypeVar* a) std::string toString(PendingTypePack* pending)
{ {
tableChanges.emplace_back(a, a->boundTo); if (pending == nullptr)
return nullPendingResult;
return toString(pending->pending);
} }
void TxnLog::rollback() std::string dump(PendingTypePack* pending)
{ {
for (auto it = typeVarChanges.rbegin(); it != typeVarChanges.rend(); ++it) if (pending == nullptr)
std::swap(*asMutable(it->first), it->second); {
printf("%s\n", nullPendingResult.c_str());
return nullPendingResult;
}
for (auto it = typePackChanges.rbegin(); it != typePackChanges.rend(); ++it) ToStringOptions opts;
std::swap(*asMutable(it->first), it->second); opts.exhaustive = true;
opts.functionTypeArguments = true;
std::string result = toString(pending->pending, opts);
printf("%s\n", result.c_str());
return result;
}
for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) static const TxnLog emptyLog;
std::swap(it->first->boundTo, it->second);
const TxnLog* TxnLog::empty()
{
return &emptyLog;
} }
void TxnLog::concat(TxnLog rhs) void TxnLog::concat(TxnLog rhs)
{ {
typeVarChanges.insert(typeVarChanges.end(), rhs.typeVarChanges.begin(), rhs.typeVarChanges.end()); for (auto& [ty, rep] : rhs.typeVarChanges)
rhs.typeVarChanges.clear(); typeVarChanges[ty] = std::move(rep);
typePackChanges.insert(typePackChanges.end(), rhs.typePackChanges.begin(), rhs.typePackChanges.end()); for (auto& [tp, rep] : rhs.typePackChanges)
rhs.typePackChanges.clear(); typePackChanges[tp] = std::move(rep);
tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end());
rhs.tableChanges.clear();
seen.swap(rhs.seen);
rhs.seen.clear();
} }
bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) void TxnLog::concatAsIntersections(TxnLog rhs, NotNull<TypeArena> arena)
{ {
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); for (auto& [ty, rightRep] : rhs.typeVarChanges)
return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair)); {
if (auto leftRep = typeVarChanges.find(ty))
{
TypeId leftTy = arena->addType((*leftRep)->pending);
TypeId rightTy = arena->addType(rightRep->pending);
typeVarChanges[ty]->pending.ty = IntersectionTypeVar{{leftTy, rightTy}};
}
else
typeVarChanges[ty] = std::move(rightRep);
}
for (auto& [tp, rep] : rhs.typePackChanges)
typePackChanges[tp] = std::move(rep);
}
void TxnLog::concatAsUnion(TxnLog rhs, NotNull<TypeArena> arena)
{
for (auto& [ty, rightRep] : rhs.typeVarChanges)
{
if (auto leftRep = typeVarChanges.find(ty))
{
TypeId leftTy = arena->addType((*leftRep)->pending);
TypeId rightTy = arena->addType(rightRep->pending);
typeVarChanges[ty]->pending.ty = UnionTypeVar{{leftTy, rightTy}};
}
else
typeVarChanges[ty] = std::move(rightRep);
}
for (auto& [tp, rep] : rhs.typePackChanges)
typePackChanges[tp] = std::move(rep);
}
void TxnLog::commit()
{
for (auto& [ty, rep] : typeVarChanges)
asMutable(ty)->reassign(rep.get()->pending);
for (auto& [tp, rep] : typePackChanges)
asMutable(tp)->reassign(rep.get()->pending);
clear();
}
void TxnLog::clear()
{
typeVarChanges.clear();
typePackChanges.clear();
}
TxnLog TxnLog::inverse()
{
TxnLog inversed(sharedSeen);
for (auto& [ty, _rep] : typeVarChanges)
inversed.typeVarChanges[ty] = std::make_unique<PendingType>(*ty);
for (auto& [tp, _rep] : typePackChanges)
inversed.typePackChanges[tp] = std::make_unique<PendingTypePack>(*tp);
return inversed;
}
bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) const
{
return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs);
} }
void TxnLog::pushSeen(TypeId lhs, TypeId rhs) void TxnLog::pushSeen(TypeId lhs, TypeId rhs)
{ {
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs);
seen.push_back(sortedPair);
} }
void TxnLog::popSeen(TypeId lhs, TypeId rhs) void TxnLog::popSeen(TypeId lhs, TypeId rhs)
{ {
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs);
LUAU_ASSERT(sortedPair == seen.back()); }
seen.pop_back();
bool TxnLog::haveSeen(TypePackId lhs, TypePackId rhs) const
{
return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs);
}
void TxnLog::pushSeen(TypePackId lhs, TypePackId rhs)
{
pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs);
}
void TxnLog::popSeen(TypePackId lhs, TypePackId rhs)
{
popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs);
}
bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const
{
const std::pair<TypeOrPackId, TypeOrPackId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair))
{
return true;
}
return false;
}
void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs)
{
const std::pair<TypeOrPackId, TypeOrPackId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
sharedSeen->push_back(sortedPair);
}
void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs)
{
const std::pair<TypeOrPackId, TypeOrPackId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
LUAU_ASSERT(sortedPair == sharedSeen->back());
sharedSeen->pop_back();
}
PendingType* TxnLog::queue(TypeId ty)
{
LUAU_ASSERT(!ty->persistent);
// Explicitly don't look in ancestors. If we have discovered something new
// about this type, we don't want to mutate the parent's state.
auto& pending = typeVarChanges[ty];
if (!pending)
{
pending = std::make_unique<PendingType>(*ty);
pending->pending.owningArena = nullptr;
}
return pending.get();
}
PendingTypePack* TxnLog::queue(TypePackId tp)
{
LUAU_ASSERT(!tp->persistent);
// Explicitly don't look in ancestors. If we have discovered something new
// about this type, we don't want to mutate the parent's state.
auto& pending = typePackChanges[tp];
if (!pending)
{
pending = std::make_unique<PendingTypePack>(*tp);
pending->pending.owningArena = nullptr;
}
return pending.get();
}
PendingType* TxnLog::pending(TypeId ty) const
{
// This function will technically work if `this` is nullptr, but this
// indicates a bug, so we explicitly assert.
LUAU_ASSERT(static_cast<const void*>(this) != nullptr);
for (const TxnLog* current = this; current; current = current->parent)
{
if (auto it = current->typeVarChanges.find(ty))
return it->get();
}
return nullptr;
}
PendingTypePack* TxnLog::pending(TypePackId tp) const
{
// This function will technically work if `this` is nullptr, but this
// indicates a bug, so we explicitly assert.
LUAU_ASSERT(static_cast<const void*>(this) != nullptr);
for (const TxnLog* current = this; current; current = current->parent)
{
if (auto it = current->typePackChanges.find(tp))
return it->get();
}
return nullptr;
}
PendingType* TxnLog::replace(TypeId ty, TypeVar replacement)
{
PendingType* newTy = queue(ty);
newTy->pending.reassign(replacement);
return newTy;
}
PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement)
{
PendingTypePack* newTp = queue(tp);
newTp->pending.reassign(replacement);
return newTp;
}
PendingType* TxnLog::bindTable(TypeId ty, std::optional<TypeId> newBoundTo)
{
LUAU_ASSERT(get<TableTypeVar>(ty));
PendingType* newTy = queue(ty);
if (TableTypeVar* ttv = Luau::getMutable<TableTypeVar>(newTy))
ttv->boundTo = newBoundTo;
return newTy;
}
PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel)
{
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty));
PendingType* newTy = queue(ty);
if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy))
{
ftv->level = newLevel;
}
else if (TableTypeVar* ttv = Luau::getMutable<TableTypeVar>(newTy))
{
LUAU_ASSERT(ttv->state == TableState::Free || ttv->state == TableState::Generic);
ttv->level = newLevel;
}
else if (FunctionTypeVar* ftv = Luau::getMutable<FunctionTypeVar>(newTy))
{
ftv->level = newLevel;
}
return newTy;
}
PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel)
{
LUAU_ASSERT(get<FreeTypePack>(tp));
PendingTypePack* newTp = queue(tp);
if (FreeTypePack* ftp = Luau::getMutable<FreeTypePack>(newTp))
{
ftp->level = newLevel;
}
return newTp;
}
PendingType* TxnLog::changeScope(TypeId ty, NotNull<Scope> newScope)
{
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty));
PendingType* newTy = queue(ty);
if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy))
{
ftv->scope = newScope;
}
else if (TableTypeVar* ttv = Luau::getMutable<TableTypeVar>(newTy))
{
LUAU_ASSERT(ttv->state == TableState::Free || ttv->state == TableState::Generic);
ttv->scope = newScope;
}
else if (FunctionTypeVar* ftv = Luau::getMutable<FunctionTypeVar>(newTy))
{
ftv->scope = newScope;
}
return newTy;
}
PendingTypePack* TxnLog::changeScope(TypePackId tp, NotNull<Scope> newScope)
{
LUAU_ASSERT(get<FreeTypePack>(tp));
PendingTypePack* newTp = queue(tp);
if (FreeTypePack* ftp = Luau::getMutable<FreeTypePack>(newTp))
{
ftp->scope = newScope;
}
return newTp;
}
PendingType* TxnLog::changeIndexer(TypeId ty, std::optional<TableIndexer> indexer)
{
LUAU_ASSERT(get<TableTypeVar>(ty));
PendingType* newTy = queue(ty);
if (TableTypeVar* ttv = Luau::getMutable<TableTypeVar>(newTy))
{
ttv->indexer = indexer;
}
return newTy;
}
std::optional<TypeLevel> TxnLog::getLevel(TypeId ty) const
{
if (FreeTypeVar* ftv = getMutable<FreeTypeVar>(ty))
return ftv->level;
else if (TableTypeVar* ttv = getMutable<TableTypeVar>(ty); ttv && (ttv->state == TableState::Free || ttv->state == TableState::Generic))
return ttv->level;
else if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty))
return ftv->level;
return std::nullopt;
}
TypeId TxnLog::follow(TypeId ty) const
{
return Luau::follow(ty, [this](TypeId ty) {
PendingType* state = this->pending(ty);
if (state == nullptr)
return ty;
// Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants
// that normally apply. This is safe because follow will only call get<>
// on the returned pointer.
return const_cast<const TypeVar*>(&state->pending);
});
}
TypePackId TxnLog::follow(TypePackId tp) const
{
return Luau::follow(tp, [this](TypePackId tp) {
PendingTypePack* state = this->pending(tp);
if (state == nullptr)
return tp;
// Ugly: Fabricate a TypePackId that doesn't adhere to most of the
// invariants that normally apply. This is safe because follow will
// only call get<> on the returned pointer.
return const_cast<const TypePackVar*>(&state->pending);
});
}
std::pair<std::vector<TypeId>, std::vector<TypePackId>> TxnLog::getChanges() const
{
std::pair<std::vector<TypeId>, std::vector<TypePackId>> result;
for (const auto& [typeId, _newState] : typeVarChanges)
result.first.push_back(typeId);
for (const auto& [typePackId, _newState] : typePackChanges)
result.second.push_back(typePackId);
return result;
} }
} // namespace Luau } // namespace Luau

115
Analysis/src/TypeArena.cpp Normal file
View file

@ -0,0 +1,115 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeArena.h"
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false);
namespace Luau
{
void TypeArena::clear()
{
typeVars.clear();
typePacks.clear();
}
TypeId TypeArena::addTV(TypeVar&& tv)
{
TypeId allocated = typeVars.allocate(std::move(tv));
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(TypeLevel level)
{
TypeId allocated = typeVars.allocate(FreeTypeVar{level});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(Scope* scope)
{
TypeId allocated = typeVars.allocate(FreeTypeVar{scope});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(Scope* scope, TypeLevel level)
{
TypeId allocated = typeVars.allocate(FreeTypeVar{scope, level});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::freshTypePack(Scope* scope)
{
TypePackId allocated = typePacks.allocate(FreeTypePack{scope});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(std::initializer_list<TypeId> types)
{
TypePackId allocated = typePacks.allocate(TypePack{std::move(types)});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(std::vector<TypeId> types, std::optional<TypePackId> tail)
{
TypePackId allocated = typePacks.allocate(TypePack{std::move(types), tail});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(TypePack tp)
{
TypePackId allocated = typePacks.allocate(std::move(tp));
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(TypePackVar tp)
{
TypePackId allocated = typePacks.allocate(std::move(tp));
asMutable(allocated)->owningArena = this;
return allocated;
}
void freeze(TypeArena& arena)
{
if (!FFlag::DebugLuauFreezeArena)
return;
arena.typeVars.freeze();
arena.typePacks.freeze();
}
void unfreeze(TypeArena& arena)
{
if (!FFlag::DebugLuauFreezeArena)
return;
arena.typeVars.unfreeze();
arena.typePacks.unfreeze();
}
} // namespace Luau

View file

@ -3,16 +3,15 @@
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/Parser.h"
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include <string> #include <string>
LUAU_FASTFLAG(LuauGenericFunctions)
static char* allocateString(Luau::Allocator& allocator, std::string_view contents) static char* allocateString(Luau::Allocator& allocator, std::string_view contents)
{ {
char* result = (char*)allocator.allocate(contents.size() + 1); char* result = (char*)allocator.allocate(contents.size() + 1);
@ -31,15 +30,31 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data
return result; return result;
} }
using SyntheticNames = std::unordered_map<const void*, char*>;
namespace Luau namespace Luau
{ {
static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen)
{
size_t s = syntheticNames->size();
char*& n = (*syntheticNames)[&gen];
if (!n)
{
std::string str = gen.explicitName ? gen.name : generateName(s);
n = static_cast<char*>(allocator->allocate(str.size() + 1));
strcpy(n, str.c_str());
}
return n;
}
class TypeRehydrationVisitor class TypeRehydrationVisitor
{ {
mutable std::map<void*, int> seen; std::map<void*, int> seen;
mutable int count = 0; int count = 0;
bool hasSeen(const void* tv) const bool hasSeen(const void* tv)
{ {
void* ttv = const_cast<void*>(tv); void* ttv = const_cast<void*>(tv);
auto it = seen.find(ttv); auto it = seen.find(ttv);
@ -51,13 +66,16 @@ class TypeRehydrationVisitor
} }
public: public:
TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions()) TypeRehydrationVisitor(Allocator* alloc, SyntheticNames* syntheticNames, const TypeRehydrationOptions& options = TypeRehydrationOptions())
: allocator(alloc) : allocator(alloc)
, syntheticNames(syntheticNames)
, options(options) , options(options)
{ {
} }
AstType* operator()(const PrimitiveTypeVar& ptv) const AstTypePack* rehydrate(TypePackId tp);
AstType* operator()(const PrimitiveTypeVar& ptv)
{ {
switch (ptv.type) switch (ptv.type)
{ {
@ -75,26 +93,57 @@ public:
return nullptr; return nullptr;
} }
} }
AstType* operator()(const AnyTypeVar&) const
AstType* operator()(const BlockedTypeVar& btv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*blocked*"));
}
AstType* operator()(const PendingExpansionTypeVar& petv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*pending-expansion*"));
}
AstType* operator()(const SingletonTypeVar& stv)
{
if (const BooleanSingleton* bs = get<BooleanSingleton>(&stv))
return allocator->alloc<AstTypeSingletonBool>(Location(), bs->value);
else if (const StringSingleton* ss = get<StringSingleton>(&stv))
{
AstArray<char> value;
value.data = const_cast<char*>(ss->value.c_str());
value.size = strlen(value.data);
return allocator->alloc<AstTypeSingletonString>(Location(), value);
}
else
return nullptr;
}
AstType* operator()(const AnyTypeVar&)
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any")); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"));
} }
AstType* operator()(const TableTypeVar& ttv) const AstType* operator()(const TableTypeVar& ttv)
{ {
RecursionCounter counter(&count); RecursionCounter counter(&count);
if (ttv.name && options.bannedNames.find(*ttv.name) == options.bannedNames.end()) if (ttv.name && options.bannedNames.find(*ttv.name) == options.bannedNames.end())
{ {
AstArray<AstType*> generics; AstArray<AstTypeOrPack> parameters;
generics.size = ttv.instantiatedTypeParams.size(); parameters.size = ttv.instantiatedTypeParams.size();
generics.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * generics.size)); parameters.data = static_cast<AstTypeOrPack*>(allocator->allocate(sizeof(AstTypeOrPack) * parameters.size));
for (size_t i = 0; i < ttv.instantiatedTypeParams.size(); ++i) for (size_t i = 0; i < ttv.instantiatedTypeParams.size(); ++i)
{ {
generics.data[i] = Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty); parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}};
} }
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(ttv.name->c_str()), generics); for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i)
{
parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])};
}
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters);
} }
if (hasSeen(&ttv)) if (hasSeen(&ttv))
@ -133,12 +182,12 @@ public:
return allocator->alloc<AstTypeTable>(Location(), props, indexer); return allocator->alloc<AstTypeTable>(Location(), props, indexer);
} }
AstType* operator()(const MetatableTypeVar& mtv) const AstType* operator()(const MetatableTypeVar& mtv)
{ {
return Luau::visit(*this, mtv.table->ty); return Luau::visit(*this, mtv.table->ty);
} }
AstType* operator()(const ClassTypeVar& ctv) const AstType* operator()(const ClassTypeVar& ctv)
{ {
RecursionCounter counter(&count); RecursionCounter counter(&count);
@ -165,47 +214,31 @@ public:
return allocator->alloc<AstTypeTable>(Location(), props); return allocator->alloc<AstTypeTable>(Location(), props);
} }
AstType* operator()(const FunctionTypeVar& ftv) const AstType* operator()(const FunctionTypeVar& ftv)
{ {
RecursionCounter counter(&count); RecursionCounter counter(&count);
if (hasSeen(&ftv)) if (hasSeen(&ftv))
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Cycle>")); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Cycle>"));
AstArray<AstName> generics; AstArray<AstGenericType> generics;
if (FFlag::LuauGenericFunctions) generics.size = ftv.generics.size();
generics.data = static_cast<AstGenericType*>(allocator->allocate(sizeof(AstGenericType) * generics.size));
size_t numGenerics = 0;
for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it)
{ {
generics.size = ftv.generics.size(); if (auto gtv = get<GenericTypeVar>(*it))
generics.data = static_cast<AstName*>(allocator->allocate(sizeof(AstName) * generics.size)); generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr};
size_t i = 0;
for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it)
{
if (auto gtv = get<GenericTypeVar>(*it))
generics.data[i++] = AstName(gtv->name.c_str());
}
}
else
{
generics.size = 0;
generics.data = nullptr;
} }
AstArray<AstName> genericPacks; AstArray<AstGenericTypePack> genericPacks;
if (FFlag::LuauGenericFunctions) genericPacks.size = ftv.genericPacks.size();
genericPacks.data = static_cast<AstGenericTypePack*>(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size));
size_t numGenericPacks = 0;
for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it)
{ {
genericPacks.size = ftv.genericPacks.size(); if (auto gtv = get<GenericTypeVar>(*it))
genericPacks.data = static_cast<AstName*>(allocator->allocate(sizeof(AstName) * genericPacks.size)); genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr};
size_t i = 0;
for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it)
{
if (auto gtv = get<GenericTypeVar>(*it))
genericPacks.data[i++] = AstName(gtv->name.c_str());
}
}
else
{
generics.size = 0;
generics.data = nullptr;
} }
AstArray<AstType*> argTypes; AstArray<AstType*> argTypes;
@ -221,13 +254,7 @@ public:
AstTypePack* argTailAnnotation = nullptr; AstTypePack* argTailAnnotation = nullptr;
if (argTail) if (argTail)
{ argTailAnnotation = rehydrate(*argTail);
TypePackId tail = *argTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
{
argTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
}
}
AstArray<std::optional<AstArgumentName>> argNames; AstArray<std::optional<AstArgumentName>> argNames;
argNames.size = ftv.argNames.size(); argNames.size = ftv.argNames.size();
@ -235,14 +262,16 @@ public:
size_t i = 0; size_t i = 0;
for (const auto& el : ftv.argNames) for (const auto& el : ftv.argNames)
{ {
std::optional<AstArgumentName>* arg = &argNames.data[i++];
if (el) if (el)
argNames.data[i++] = {AstName(el->name.c_str()), el->location}; new (arg) std::optional<AstArgumentName>(AstArgumentName(AstName(el->name.c_str()), el->location));
else else
argNames.data[i++] = {}; new (arg) std::optional<AstArgumentName>();
} }
AstArray<AstType*> returnTypes; AstArray<AstType*> returnTypes;
const auto& [retVector, retTail] = flatten(ftv.retType); const auto& [retVector, retTail] = flatten(ftv.retTypes);
returnTypes.size = retVector.size(); returnTypes.size = retVector.size();
returnTypes.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * returnTypes.size)); returnTypes.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * returnTypes.size));
for (size_t i = 0; i < returnTypes.size; ++i) for (size_t i = 0; i < returnTypes.size; ++i)
@ -254,34 +283,28 @@ public:
AstTypePack* retTailAnnotation = nullptr; AstTypePack* retTailAnnotation = nullptr;
if (retTail) if (retTail)
{ retTailAnnotation = rehydrate(*retTail);
TypePackId tail = *retTail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tail))
{
retTailAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*this, vtp->ty->ty));
}
}
return allocator->alloc<AstTypeFunction>( return allocator->alloc<AstTypeFunction>(
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation});
} }
AstType* operator()(const Unifiable::Error&) const AstType* operator()(const Unifiable::Error&)
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("Unifiable<Error>")); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("Unifiable<Error>"));
} }
AstType* operator()(const GenericTypeVar& gtv) const AstType* operator()(const GenericTypeVar& gtv)
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(gtv.name.c_str())); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)));
} }
AstType* operator()(const Unifiable::Bound<TypeId>& bound) const AstType* operator()(const Unifiable::Bound<TypeId>& bound)
{ {
return Luau::visit(*this, bound.boundTo->ty); return Luau::visit(*this, bound.boundTo->ty);
} }
AstType* operator()(Unifiable::Free ftv) const AstType* operator()(const FreeTypeVar& ftv)
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free")); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"));
} }
AstType* operator()(const UnionTypeVar& uv) const AstType* operator()(const UnionTypeVar& uv)
{ {
AstArray<AstType*> unionTypes; AstArray<AstType*> unionTypes;
unionTypes.size = uv.options.size(); unionTypes.size = uv.options.size();
@ -292,7 +315,7 @@ public:
} }
return allocator->alloc<AstTypeUnion>(Location(), unionTypes); return allocator->alloc<AstTypeUnion>(Location(), unionTypes);
} }
AstType* operator()(const IntersectionTypeVar& uv) const AstType* operator()(const IntersectionTypeVar& uv)
{ {
AstArray<AstType*> intersectionTypes; AstArray<AstType*> intersectionTypes;
intersectionTypes.size = uv.parts.size(); intersectionTypes.size = uv.parts.size();
@ -303,16 +326,105 @@ public:
} }
return allocator->alloc<AstTypeIntersection>(Location(), intersectionTypes); return allocator->alloc<AstTypeIntersection>(Location(), intersectionTypes);
} }
AstType* operator()(const LazyTypeVar& ltv) const AstType* operator()(const LazyTypeVar& ltv)
{ {
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Lazy?>")); return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Lazy?>"));
} }
AstType* operator()(const UnknownTypeVar& ttv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"unknown"});
}
AstType* operator()(const NeverTypeVar& ttv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"never"});
}
AstType* operator()(const NegationTypeVar& ntv)
{
// FIXME: do the same thing we do with ErrorTypeVar
throw InternalCompilerError("Cannot convert NegationTypeVar into AstNode");
}
private: private:
Allocator* allocator; Allocator* allocator;
SyntheticNames* syntheticNames;
const TypeRehydrationOptions& options; const TypeRehydrationOptions& options;
}; };
class TypePackRehydrationVisitor
{
public:
TypePackRehydrationVisitor(Allocator* allocator, SyntheticNames* syntheticNames, TypeRehydrationVisitor* typeVisitor)
: allocator(allocator)
, syntheticNames(syntheticNames)
, typeVisitor(typeVisitor)
{
LUAU_ASSERT(allocator);
LUAU_ASSERT(syntheticNames);
LUAU_ASSERT(typeVisitor);
}
AstTypePack* operator()(const BoundTypePack& btp) const
{
return Luau::visit(*this, btp.boundTo->ty);
}
AstTypePack* operator()(const BlockedTypePack& btp) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("*blocked*"));
}
AstTypePack* operator()(const TypePack& tp) const
{
AstArray<AstType*> head;
head.size = tp.head.size();
head.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * tp.head.size()));
for (size_t i = 0; i < tp.head.size(); i++)
head.data[i] = Luau::visit(*typeVisitor, tp.head[i]->ty);
AstTypePack* tail = nullptr;
if (tp.tail)
tail = Luau::visit(*this, (*tp.tail)->ty);
return allocator->alloc<AstTypePackExplicit>(Location(), AstTypeList{head, tail});
}
AstTypePack* operator()(const VariadicTypePack& vtp) const
{
if (vtp.hidden)
return nullptr;
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*typeVisitor, vtp.ty->ty));
}
AstTypePack* operator()(const GenericTypePack& gtp) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName(getName(allocator, syntheticNames, gtp)));
}
AstTypePack* operator()(const FreeTypePack& gtp) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("free"));
}
AstTypePack* operator()(const Unifiable::Error&) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName("Unifiable<Error>"));
}
private:
Allocator* allocator;
SyntheticNames* syntheticNames;
TypeRehydrationVisitor* typeVisitor;
};
AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp)
{
TypePackRehydrationVisitor tprv(allocator, syntheticNames, this);
return Luau::visit(tprv, tp->ty);
}
class TypeAttacher : public AstVisitor class TypeAttacher : public AstVisitor
{ {
public: public:
@ -344,7 +456,7 @@ public:
{ {
if (!type) if (!type)
return nullptr; return nullptr;
return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty); return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), (*type)->ty);
} }
AstArray<Luau::AstType*> typeAstPack(TypePackId type) AstArray<Luau::AstType*> typeAstPack(TypePackId type)
@ -356,7 +468,7 @@ public:
result.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * v.size())); result.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * v.size()));
for (size_t i = 0; i < v.size(); ++i) for (size_t i = 0; i < v.size(); ++i)
{ {
result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty); result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), v[i]->ty);
} }
return result; return result;
} }
@ -385,6 +497,20 @@ public:
{ {
return visitLocal(al->local); return visitLocal(al->local);
} }
virtual bool visit(AstStatFor* stat) override
{
visitLocal(stat->var);
return true;
}
virtual bool visit(AstStatForIn* stat) override
{
for (size_t i = 0; i < stat->vars.size; ++i)
visitLocal(stat->vars.data[i]);
return true;
}
virtual bool visit(AstExprFunction* fn) override virtual bool visit(AstExprFunction* fn) override
{ {
// TODO: add generics if the inferred type of the function is generic CLI-39908 // TODO: add generics if the inferred type of the function is generic CLI-39908
@ -394,22 +520,17 @@ public:
visitLocal(arg); visitLocal(arg);
} }
if (!fn->hasReturnAnnotation) if (!fn->returnAnnotation)
{ {
if (auto result = getScope(fn->body->location)) if (auto result = getScope(fn->body->location))
{ {
TypePackId ret = result->returnType; TypePackId ret = result->returnType;
fn->hasReturnAnnotation = true;
AstTypePack* variadicAnnotation = nullptr; AstTypePack* variadicAnnotation = nullptr;
const auto& [v, tail] = flatten(ret); const auto& [v, tail] = flatten(ret);
if (tail) if (tail)
{ variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail);
TypePackId tailPack = *tail;
if (const VariadicTypePack* vtp = get<VariadicTypePack>(tailPack))
variadicAnnotation = allocator->alloc<AstTypePackVariadic>(Location(), typeAst(vtp->ty));
}
fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation}; fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation};
} }
@ -421,6 +542,7 @@ public:
private: private:
Module& module; Module& module;
Allocator* allocator; Allocator* allocator;
SyntheticNames syntheticNames;
}; };
void attachTypeData(SourceModule& source, Module& result) void attachTypeData(SourceModule& source, Module& result)
@ -431,7 +553,8 @@ void attachTypeData(SourceModule& source, Module& result)
AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options) AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options)
{ {
return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty); SyntheticNames syntheticNames;
return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames, options), type->ty);
} }
} // namespace Luau } // namespace Luau

File diff suppressed because it is too large Load diff

Some files were not shown because too many files have changed in this diff Show more