Merge branch 'master' of https://github.com/Roblox/luau into definition-module-name

This commit is contained in:
JohnnyMorganz 2023-11-18 12:13:22 +01:00
commit 6c820b278e
563 changed files with 61780 additions and 28789 deletions

View file

@ -1,5 +1,5 @@
blank_issues_enabled: false
contact_links:
- name: Questions
url: https://github.com/Roblox/luau/discussions
url: https://github.com/luau-lang/luau/discussions
about: Please use GitHub Discussions if you have questions or need help.

6
.github/codecov.yml vendored
View file

@ -1 +1,7 @@
comment: false
coverage:
status:
patch: false
project:
default:
informational: true

View file

@ -1,185 +0,0 @@
name: benchmark-dev
on:
push:
branches:
- master
paths-ignore:
- "docs/**"
- "papers/**"
- "rfcs/**"
- "*.md"
jobs:
windows:
name: windows-${{matrix.arch}}
strategy:
fail-fast: false
matrix:
os: [windows-latest]
arch: [Win32, x64]
bench:
- {
script: "run-benchmarks",
timeout: 12,
title: "Luau Benchmarks",
}
benchResultsRepo:
- { name: "luau-lang/benchmark-data", branch: "main" }
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Luau repository
uses: actions/checkout@v3
- name: Build Luau
shell: bash # necessary for fail-fast
run: |
mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
cmake --build . --target Luau.Repl.CLI --config Release
cmake --build . --target Luau.Analyze.CLI --config Release
- name: Move build files to root
run: |
move build/Release/* .
- uses: actions/setup-python@v3
with:
python-version: "3.9"
architecture: "x64"
- name: Install python dependencies
run: |
python -m pip install requests
python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose
- name: Run benchmark
run: |
python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt
- name: Push benchmark results
id: pushBenchmarkAttempt1
continue-on-error: true
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})"
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 2)
id: pushBenchmarkAttempt2
continue-on-error: true
if: steps.pushBenchmarkAttempt1.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})"
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 3)
id: pushBenchmarkAttempt3
continue-on-error: true
if: steps.pushBenchmarkAttempt2.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})"
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
unix:
name: ${{matrix.os}}
strategy:
fail-fast: false
matrix:
os: [ubuntu-20.04, macos-latest]
bench:
- {
script: "run-benchmarks",
timeout: 12,
title: "Luau Benchmarks",
}
benchResultsRepo:
- { name: "luau-lang/benchmark-data", branch: "main" }
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Luau repository
uses: actions/checkout@v3
- name: Build Luau
run: make config=release luau luau-analyze
- uses: actions/setup-python@v3
with:
python-version: "3.9"
architecture: "x64"
- name: Install python dependencies
run: |
python -m pip install requests
python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose
- name: Run benchmark
run: |
python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt
- name: Push benchmark results
id: pushBenchmarkAttempt1
continue-on-error: true
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: ${{ matrix.bench.title }}
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 2)
id: pushBenchmarkAttempt2
continue-on-error: true
if: steps.pushBenchmarkAttempt1.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: ${{ matrix.bench.title }}
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 3)
id: pushBenchmarkAttempt3
continue-on-error: true
if: steps.pushBenchmarkAttempt2.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: ${{ matrix.bench.title }}
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"

View file

@ -25,6 +25,7 @@ jobs:
- name: Install valgrind
run: |
sudo apt-get update
sudo apt-get install valgrind
- name: Build Luau (gcc)
@ -41,7 +42,7 @@ jobs:
- name: Build Luau (clang)
run: |
make config=profile clean
CXX=clang++ make config=profile luau luau-analyze
CXX=clang++ make config=profile luau luau-analyze luau-compile
- name: Run benchmark (bench-gcc)
run: |
@ -62,22 +63,24 @@ jobs:
}
valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-nonstrict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-strict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=DebugLuauDeferredConstraintResolution bench/other/LuauPolyfillMap.lua 2>&1 | filter map-dcr | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/regex.lua 2>&1 | filter regex-nonstrict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/regex.lua 2>&1 | filter regex-strict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=DebugLuauDeferredConstraintResolution bench/other/regex.lua 2>&1 | filter regex-dcr | tee -a analyze-output.txt
- name: Run benchmark (compile)
run: |
filter() {
awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau --compile"}'
awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau-compile"}'
}
valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2-codegen | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau-compile --codegennull -O2 bench/other/regex.lua 2>&1 | filter regex-O2-codegen | tee -a compile-output.txt
- name: Checkout benchmark results
uses: actions/checkout@v3
@ -122,7 +125,7 @@ jobs:
- name: Store results (compile)
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: luau --compile
name: luau-compile
tool: "benchmarkluau"
output-file-path: ./compile-output.txt
external-data-json-path: ./gh-pages/compile.json

View file

@ -42,9 +42,10 @@ jobs:
./luau-tests -ts=Conformance --codegen -O2 --fflags=true
- name: make cli
run: |
make -j2 config=sanitize werror=1 luau luau-analyze # match config with tests to improve build time
make -j2 config=sanitize werror=1 luau luau-analyze luau-compile # match config with tests to improve build time
./luau tests/conformance/assert.lua
./luau-analyze tests/conformance/assert.lua
./luau-compile tests/conformance/assert.lua
windows:
runs-on: windows-latest
@ -76,9 +77,10 @@ jobs:
- name: cmake cli
shell: bash # necessary for fail-fast
run: |
cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Debug # match config with tests to improve build time
cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Debug # match config with tests to improve build time
Debug/luau tests/conformance/assert.lua
Debug/luau-analyze tests/conformance/assert.lua
Debug/luau-compile tests/conformance/assert.lua
coverage:
runs-on: ubuntu-20.04

View file

@ -38,7 +38,7 @@ jobs:
- name: configure
run: cmake . -DCMAKE_BUILD_TYPE=Release
- name: build
run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Release -j 2
run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Release -j 2
- name: pack
if: matrix.os.name != 'windows'
run: zip luau-${{matrix.os.name}}.zip luau*

View file

@ -1,63 +0,0 @@
name: Checkout & push results
description: Checkout a given repo and push results to GitHub
inputs:
repository:
required: true
type: string
description: The benchmark results repository to check out
branch:
required: true
type: string
description: The benchmark results repository's branch to check out
token:
required: true
type: string
description: The GitHub token to use for pushing results
path:
required: true
type: string
description: The path to check out the results repository to
bench_name:
required: true
type: string
bench_tool:
required: true
type: string
bench_output_file_path:
required: true
type: string
bench_external_data_json_path:
required: true
type: string
runs:
using: "composite"
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
repository: ${{ inputs.repository }}
ref: ${{ inputs.branch }}
token: ${{ inputs.token }}
path: ${{ inputs.path }}
- name: Store results
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: ${{ inputs.bench_name }}
tool: ${{ inputs.bench_tool }}
gh-pages-branch: ${{ inputs.branch }}
output-file-path: ${{ inputs.bench_output_file_path }}
external-data-json-path: ${{ inputs.bench_external_data_json_path }}
- name: Push benchmark results
shell: bash
run: |
echo "Pushing benchmark results..."
cd gh-pages
git config user.name github-actions
git config user.email github@users.noreply.github.com
git add *.json
git commit -m "Add benchmarks results for ${{ github.sha }}"
git push
cd ..

View file

@ -22,7 +22,7 @@ jobs:
- name: configure
run: cmake . -DCMAKE_BUILD_TYPE=Release
- name: build
run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI --config Release -j 2
run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Release -j 2
- uses: actions/upload-artifact@v2
if: matrix.os.name != 'windows'
with:

1
.gitignore vendored
View file

@ -10,4 +10,5 @@
/luau
/luau-tests
/luau-analyze
/luau-compile
__pycache__

8
.idea/.gitignore generated vendored Normal file
View file

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

1
.idea/.name generated Normal file
View file

@ -0,0 +1 @@
Luau

7
.idea/codeStyles/Project.xml generated Normal file
View file

@ -0,0 +1,7 @@
<component name="ProjectCodeStyleConfiguration">
<code_scheme name="Project" version="173">
<clangFormatSettings>
<option name="ENABLED" value="true" />
</clangFormatSettings>
</code_scheme>
</component>

5
.idea/codeStyles/codeStyleConfig.xml generated Normal file
View file

@ -0,0 +1,5 @@
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="USE_PER_PROJECT_SETTINGS" value="true" />
</state>
</component>

2
.idea/luau.iml generated Normal file
View file

@ -0,0 +1,2 @@
<?xml version="1.0" encoding="UTF-8"?>
<module classpath="CMake" type="CPP_MODULE" version="4" />

4
.idea/misc.xml generated Normal file
View file

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="CMakeWorkspace" PROJECT_DIR="$PROJECT_DIR$" />
</project>

8
.idea/modules.xml generated Normal file
View file

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/luau.iml" filepath="$PROJECT_DIR$/.idea/luau.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View file

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View file

@ -4,7 +4,7 @@
#include "Luau/NotNull.h"
#include "Luau/Substitution.h"
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
#include <memory>
@ -39,4 +39,4 @@ struct Anyification : Substitution
bool ignoreChildren(TypePackId ty) override;
};
} // namespace Luau
} // namespace Luau

View file

@ -3,7 +3,7 @@
#include "Luau/Substitution.h"
#include "Luau/TxnLog.h"
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
namespace Luau
{

View file

@ -3,6 +3,7 @@
#include "Luau/Ast.h"
#include "Luau/Documentation.h"
#include "Luau/TypeFwd.h"
#include <memory>
@ -13,9 +14,6 @@ struct Binding;
struct SourceModule;
struct Module;
struct Type;
using TypeId = const Type*;
using ScopePtr = std::shared_ptr<struct Scope>;
struct ExprOrLocal
@ -64,8 +62,11 @@ private:
};
std::vector<AstNode*> findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos);
std::vector<AstNode*> findAncestryAtPositionForAutocomplete(AstStatBlock* root, Position pos);
std::vector<AstNode*> findAstAncestryOfPosition(const SourceModule& source, Position pos, bool includeTypes = false);
std::vector<AstNode*> findAstAncestryOfPosition(AstStatBlock* root, Position pos, bool includeTypes = false);
AstNode* findNodeAtPosition(const SourceModule& source, Position pos);
AstNode* findNodeAtPosition(AstStatBlock* root, Position pos);
AstExpr* findExprAtPosition(const SourceModule& source, Position pos);
ScopePtr findScopeAtPosition(const Module& module, Position pos);
std::optional<Binding> findBindingAtPosition(const Module& module, const SourceModule& source, Position pos);

View file

@ -38,6 +38,7 @@ enum class AutocompleteEntryKind
String,
Type,
Module,
GeneratedFunction,
};
enum class ParenthesesRecommendation
@ -70,6 +71,10 @@ struct AutocompleteEntry
std::optional<std::string> documentationSymbol = std::nullopt;
Tags tags;
ParenthesesRecommendation parens = ParenthesesRecommendation::None;
std::optional<std::string> insertText;
// Only meaningful if kind is Property.
bool indexedWithSelf = false;
};
using AutocompleteEntryMap = std::unordered_map<std::string, AutocompleteEntry>;
@ -94,4 +99,6 @@ using StringCompletionCallback =
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback);
constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)";
} // namespace Luau

View file

@ -1,75 +0,0 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Def.h"
#include "Luau/NotNull.h"
#include "Luau/Variant.h"
#include <string>
#include <optional>
namespace Luau
{
using NullableBreadcrumbId = const struct Breadcrumb*;
using BreadcrumbId = NotNull<const struct Breadcrumb>;
struct FieldMetadata
{
std::string prop;
};
struct SubscriptMetadata
{
BreadcrumbId key;
};
using Metadata = Variant<FieldMetadata, SubscriptMetadata>;
struct Breadcrumb
{
NullableBreadcrumbId previous;
DefId def;
std::optional<Metadata> metadata;
std::vector<BreadcrumbId> children;
};
inline Breadcrumb* asMutable(NullableBreadcrumbId breadcrumb)
{
LUAU_ASSERT(breadcrumb);
return const_cast<Breadcrumb*>(breadcrumb);
}
template<typename T>
const T* getMetadata(NullableBreadcrumbId breadcrumb)
{
if (!breadcrumb || !breadcrumb->metadata)
return nullptr;
return get_if<T>(&*breadcrumb->metadata);
}
struct BreadcrumbArena
{
TypedAllocator<Breadcrumb> allocator;
template<typename... Args>
BreadcrumbId add(NullableBreadcrumbId previous, DefId def, Args&&... args)
{
Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, std::forward<Args>(args)...});
if (previous)
asMutable(previous)->children.push_back(NotNull{bc});
return NotNull{bc};
}
template<typename T, typename... Args>
BreadcrumbId emplace(NullableBreadcrumbId previous, DefId def, Args&&... args)
{
Breadcrumb* bc = allocator.allocate(Breadcrumb{previous, def, Metadata{T{std::forward<Args>(args)...}}});
if (previous)
asMutable(previous)->children.push_back(NotNull{bc});
return NotNull{bc};
}
};
} // namespace Luau

View file

@ -14,11 +14,7 @@ struct GlobalTypes;
struct TypeChecker;
struct TypeArena;
void registerBuiltinTypes(GlobalTypes& globals);
void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals);
void registerBuiltinGlobals(Frontend& frontend);
void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false);
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types);
TypeId makeIntersection(TypeArena& arena, std::vector<TypeId>&& types);

View file

@ -0,0 +1,24 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <atomic>
namespace Luau
{
struct FrontendCancellationToken
{
void cancel()
{
cancelled.store(true);
}
bool requested()
{
return cancelled.load();
}
std::atomic<bool> cancelled;
};
} // namespace Luau

View file

@ -16,6 +16,8 @@ using SeenTypePacks = std::unordered_map<TypePackId, TypePackId>;
struct CloneState
{
NotNull<BuiltinTypes> builtinTypes;
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
@ -26,7 +28,4 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState);
TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);
TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone = false);
TypeId shallowClone(TypeId ty, NotNull<TypeArena> dest);
} // namespace Luau

View file

@ -4,8 +4,8 @@
#include "Luau/Ast.h" // Used for some of the enumerations
#include "Luau/DenseHash.h"
#include "Luau/NotNull.h"
#include "Luau/Type.h"
#include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
#include <string>
#include <memory>
@ -16,12 +16,6 @@ namespace Luau
struct Scope;
struct Type;
using TypeId = const Type*;
struct TypePackVar;
using TypePackId = const TypePackVar*;
// subType <: superType
struct SubtypeConstraint
{
@ -34,6 +28,11 @@ struct PackSubtypeConstraint
{
TypePackId subPack;
TypePackId superPack;
// HACK!! TODO clip.
// We need to know which of `PackSubtypeConstraint` are emitted from `AstStatReturn` vs any others.
// Then we force these specific `PackSubtypeConstraint` to only dispatch in the order of the `return`s.
bool returns = false;
};
// generalizedType ~ gen sourceType
@ -50,37 +49,15 @@ struct InstantiationConstraint
TypeId superType;
};
struct UnaryConstraint
{
AstExprUnary::Op op;
TypeId operandType;
TypeId resultType;
};
// let L : leftType
// let R : rightType
// in
// L op R : resultType
struct BinaryConstraint
{
AstExprBinary::Op op;
TypeId leftType;
TypeId rightType;
TypeId resultType;
// When we dispatch this constraint, we update the key at this map to record
// the overload that we selected.
const AstNode* astFragment;
DenseHashMap<const AstNode*, TypeId>* astOriginalCallTypes;
DenseHashMap<const AstNode*, TypeId>* astOverloadResolvedTypes;
};
// iteratee is iterable
// iterators is the iteration types.
struct IterableConstraint
{
TypePackId iterator;
TypePackId variables;
const AstNode* nextAstFragment;
DenseHashMap<const AstNode*, TypeId>* astForInNextTypes;
};
// name(namedType) = name
@ -105,8 +82,12 @@ struct FunctionCallConstraint
TypeId fn;
TypePackId argsPack;
TypePackId result;
class AstExprCall* callSite;
class AstExprCall* callSite = nullptr;
std::vector<std::optional<TypeId>> discriminantTypes;
// When we dispatch this constraint, we update the key at this map to record
// the overload that we selected.
DenseHashMap<const AstNode*, TypeId>* astOverloadResolvedTypes = nullptr;
};
// result ~ prim ExpectedType SomeSingletonType MultitonType
@ -139,6 +120,24 @@ struct HasPropConstraint
TypeId resultType;
TypeId subjectType;
std::string prop;
// HACK: We presently need types like true|false or string|"hello" when
// deciding whether a particular literal expression should have a singleton
// type. This boolean is set to true when extracting the property type of a
// value that may be a union of tables.
//
// For example, in the following code fragment, we want the lookup of the
// success property to yield true|false when extracting an expectedType in
// this expression:
//
// type Result<T, E> = {success:true, result: T} | {success:false, error: E}
//
// local r: Result<number, string> = {success=true, result=9}
//
// If we naively simplify the expectedType to boolean, we will erroneously
// compute the type boolean for the success property of the table literal.
// This causes type checking to fail.
bool suppressSimplification = false;
};
// result ~ setProp subjectType ["prop", "prop2", ...] propType
@ -191,11 +190,66 @@ struct UnpackConstraint
{
TypePackId resultPack;
TypePackId sourcePack;
// UnpackConstraint is sometimes used to resolve the types of assignments.
// When this is the case, any LocalTypes in resultPack can have their
// domains extended by the corresponding type from sourcePack.
bool resultIsLValue = false;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint,
BinaryConstraint, IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint,
HasPropConstraint, SetPropConstraint, SetIndexerConstraint, SingletonOrTopTypeConstraint, UnpackConstraint>;
// resultType ~ refine type mode discriminant
//
// Compute type & discriminant (or type | discriminant) as soon as possible (but
// no sooner), simplify, and bind resultType to that type.
struct RefineConstraint
{
enum
{
Intersection,
Union
} mode;
TypeId resultType;
TypeId type;
TypeId discriminant;
};
// resultType ~ T0 op T1 op ... op TN
//
// op is either union or intersection. If any of the input types are blocked,
// this constraint will block unless forced.
struct SetOpConstraint
{
enum
{
Intersection,
Union
} mode;
TypeId resultType;
std::vector<TypeId> types;
};
// ty ~ reduce ty
//
// Try to reduce ty, if it is a TypeFamilyInstanceType. Otherwise, do nothing.
struct ReduceConstraint
{
TypeId ty;
};
// tp ~ reduce tp
//
// Analogous to ReduceConstraint, but for type packs.
struct ReducePackConstraint
{
TypePackId tp;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, IterableConstraint,
NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint, HasPropConstraint, SetPropConstraint,
SetIndexerConstraint, SingletonOrTopTypeConstraint, UnpackConstraint, RefineConstraint, SetOpConstraint, ReduceConstraint, ReducePackConstraint>;
struct Constraint
{
@ -209,6 +263,8 @@ struct Constraint
ConstraintV c;
std::vector<NotNull<Constraint>> dependencies;
DenseHashSet<TypeId> getFreeTypes() const;
};
using ConstraintPtr = std::unique_ptr<Constraint>;

View file

@ -2,15 +2,20 @@
#pragma once
#include "Luau/Ast.h"
#include "Luau/Refinement.h"
#include "Luau/Constraint.h"
#include "Luau/ControlFlow.h"
#include "Luau/DataFlowGraph.h"
#include "Luau/InsertionOrderedMap.h"
#include "Luau/Module.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Normalize.h"
#include "Luau/NotNull.h"
#include "Luau/Refinement.h"
#include "Luau/Symbol.h"
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypeUtils.h"
#include "Luau/Variant.h"
#include "Luau/Normalize.h"
#include <memory>
#include <vector>
@ -52,21 +57,35 @@ struct InferencePack
}
};
struct ConstraintGraphBuilder
struct ConstraintGenerator
{
// A list of all the scopes in the module. This vector holds ownership of the
// scope pointers; the scopes themselves borrow pointers to other scopes to
// define the scope hierarchy.
std::vector<std::pair<Location, ScopePtr>> scopes;
ModuleName moduleName;
ModulePtr module;
NotNull<BuiltinTypes> builtinTypes;
const NotNull<TypeArena> arena;
// The root scope of the module we're generating constraints for.
// This is null when the CGB is initially constructed.
// This is null when the CG is initially constructed.
Scope* rootScope;
struct InferredBinding
{
Scope* scope;
Location location;
TypeIds types;
};
// Some locals have multiple type states. We wish for Scope::bindings to
// map each local name onto the union of every type that the local can have
// over its lifetime, so we use this map to accumulate the set of types it
// might have.
//
// See the functions recordInferredBinding and fillInInferredBindings.
DenseHashMap<Symbol, InferredBinding> inferredBindings{{}};
// Constraints that go straight to the solver.
std::vector<ConstraintPtr> constraints;
@ -85,18 +104,33 @@ struct ConstraintGraphBuilder
// It is pretty uncommon for constraint generation to itself produce errors, but it can happen.
std::vector<TypeError> errors;
// Needed to be able to enable error-suppression preservation for immediate refinements.
NotNull<Normalizer> normalizer;
// Needed to resolve modules to make 'require' import types properly.
NotNull<ModuleResolver> moduleResolver;
// Occasionally constraint generation needs to produce an ICE.
const NotNull<InternalErrorReporter> ice;
ScopePtr globalScope;
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
std::vector<RequireCycle> requireCycles;
DcrLogger* logger;
ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, DcrLogger* logger,
NotNull<DataFlowGraph> dfg);
ConstraintGenerator(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver,
NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope,
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, DcrLogger* logger, NotNull<DataFlowGraph> dfg,
std::vector<RequireCycle> requireCycles);
/**
* The entry point to the ConstraintGenerator. This will construct a set
* of scopes, constraints, and free types that can be solved later.
* @param block the root block to generate constraints for.
*/
void visitModuleRoot(AstStatBlock* block);
private:
/**
* Fabricates a new free type belonging to a given scope.
* @param scope the scope the free type belongs to.
@ -116,6 +150,8 @@ struct ConstraintGraphBuilder
*/
ScopePtr childScope(AstNode* node, const ScopePtr& parent);
std::optional<TypeId> lookup(Scope* scope, DefId def);
/**
* Adds a new constraint with no dependencies to a given scope.
* @param scope the scope to add the constraint to.
@ -132,38 +168,44 @@ struct ConstraintGraphBuilder
*/
NotNull<Constraint> addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
struct RefinementPartition
{
// Types that we want to intersect against the type of the expression.
std::vector<TypeId> discriminantTypes;
// Sometimes the type we're discriminating against is implicitly nil.
bool shouldAppendNilType = false;
};
using RefinementContext = InsertionOrderedMap<DefId, RefinementPartition>;
void unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, std::vector<ConstraintV>* constraints);
void computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, bool eq, std::vector<ConstraintV>* constraints);
void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement);
/**
* The entry point to the ConstraintGraphBuilder. This will construct a set
* of scopes, constraints, and free types that can be solved later.
* @param block the root block to generate constraints for.
*/
void visit(AstStatBlock* block);
ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block);
void visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block);
void visit(const ScopePtr& scope, AstStat* stat);
void visit(const ScopePtr& scope, AstStatBlock* block);
void visit(const ScopePtr& scope, AstStatLocal* local);
void visit(const ScopePtr& scope, AstStatFor* for_);
void visit(const ScopePtr& scope, AstStatForIn* forIn);
void visit(const ScopePtr& scope, AstStatWhile* while_);
void visit(const ScopePtr& scope, AstStatRepeat* repeat);
void visit(const ScopePtr& scope, AstStatLocalFunction* function);
void visit(const ScopePtr& scope, AstStatFunction* function);
void visit(const ScopePtr& scope, AstStatReturn* ret);
void visit(const ScopePtr& scope, AstStatAssign* assign);
void visit(const ScopePtr& scope, AstStatCompoundAssign* assign);
void visit(const ScopePtr& scope, AstStatIf* ifStatement);
void visit(const ScopePtr& scope, AstStatTypeAlias* alias);
void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal);
void visit(const ScopePtr& scope, AstStatDeclareClass* declareClass);
void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction);
void visit(const ScopePtr& scope, AstStatError* error);
ControlFlow visit(const ScopePtr& scope, AstStat* stat);
ControlFlow visit(const ScopePtr& scope, AstStatBlock* block);
ControlFlow visit(const ScopePtr& scope, AstStatLocal* local);
ControlFlow visit(const ScopePtr& scope, AstStatFor* for_);
ControlFlow visit(const ScopePtr& scope, AstStatForIn* forIn);
ControlFlow visit(const ScopePtr& scope, AstStatWhile* while_);
ControlFlow visit(const ScopePtr& scope, AstStatRepeat* repeat);
ControlFlow visit(const ScopePtr& scope, AstStatLocalFunction* function);
ControlFlow visit(const ScopePtr& scope, AstStatFunction* function);
ControlFlow visit(const ScopePtr& scope, AstStatReturn* ret);
ControlFlow visit(const ScopePtr& scope, AstStatAssign* assign);
ControlFlow visit(const ScopePtr& scope, AstStatCompoundAssign* assign);
ControlFlow visit(const ScopePtr& scope, AstStatIf* ifStatement);
ControlFlow visit(const ScopePtr& scope, AstStatTypeAlias* alias);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareClass* declareClass);
ControlFlow visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction);
ControlFlow visit(const ScopePtr& scope, AstStatError* error);
InferencePack checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs, const std::vector<std::optional<TypeId>>& expectedTypes = {});
InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector<std::optional<TypeId>>& expectedTypes = {});
InferencePack checkPack(
const ScopePtr& scope, AstExpr* expr, const std::vector<std::optional<TypeId>>& expectedTypes = {}, bool generalize = true);
InferencePack checkPack(const ScopePtr& scope, AstExprCall* call);
@ -173,9 +215,11 @@ struct ConstraintGraphBuilder
* @param expr the expression to check.
* @param expectedType the type of the expression that is expected from its
* surrounding context. Used to implement bidirectional type checking.
* @param generalize If true, generalize any lambdas that are encountered.
* @return the type of the expression.
*/
Inference check(const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}, bool forceSingleton = false);
Inference check(
const ScopePtr& scope, AstExpr* expr, std::optional<TypeId> expectedType = {}, bool forceSingleton = false, bool generalize = true);
Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional<TypeId> expectedType, bool forceSingleton);
Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional<TypeId> expectedType, bool forceSingleton);
@ -183,6 +227,7 @@ struct ConstraintGraphBuilder
Inference check(const ScopePtr& scope, AstExprGlobal* global);
Inference check(const ScopePtr& scope, AstExprIndexName* indexName);
Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional<TypeId> expectedType, bool generalize);
Inference check(const ScopePtr& scope, AstExprUnary* unary);
Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional<TypeId> expectedType);
@ -191,9 +236,16 @@ struct ConstraintGraphBuilder
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
std::vector<TypeId> checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);
TypeId checkLValue(const ScopePtr& scope, AstExpr* expr);
/**
* Generate constraints to assign assignedTy to the expression expr
* @returns the type of the expression. This may or may not be assignedTy itself.
*/
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExpr* expr, TypeId assignedTy);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprLocal* local, TypeId assignedTy);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprGlobal* global, TypeId assignedTy);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprIndexName* indexName, TypeId assignedTy);
std::optional<TypeId> checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr, TypeId assignedTy);
TypeId updateProperty(const ScopePtr& scope, AstExpr* expr, TypeId assignedTy);
struct FunctionSignature
{
@ -208,7 +260,8 @@ struct ConstraintGraphBuilder
ScopePtr bodyScope;
};
FunctionSignature checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn, std::optional<TypeId> expectedType = {});
FunctionSignature checkFunctionSignature(
const ScopePtr& parent, AstExprFunction* fn, std::optional<TypeId> expectedType = {}, std::optional<Location> originalName = {});
/**
* Checks the body of a function expression.
@ -277,12 +330,18 @@ struct ConstraintGraphBuilder
/** Scan the program for global definitions.
*
* ConstraintGraphBuilder needs to differentiate between globals and accesses to undefined symbols. Doing this "for
* ConstraintGenerator needs to differentiate between globals and accesses to undefined symbols. Doing this "for
* real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an
* initial scan of the AST and note what globals are defined.
*/
void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program);
// Record the fact that a particular local has a particular type in at least
// one of its states.
void recordInferredBinding(AstLocal* local, TypeId ty);
void fillInInferredBindings(const ScopePtr& globalScope, AstStatBlock* block);
/** Given a function type annotation, return a vector describing the expected types of the calls to the function
* For example, calling a function with annotation ((number) -> string & ((string) -> number))
* yields a vector of size 1, with value: [number | string]

View file

@ -3,14 +3,18 @@
#pragma once
#include "Luau/Constraint.h"
#include "Luau/DenseHash.h"
#include "Luau/Error.h"
#include "Luau/Location.h"
#include "Luau/Module.h"
#include "Luau/Normalize.h"
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeReduction.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFwd.h"
#include "Luau/Variant.h"
#include <utility>
#include <vector>
namespace Luau
@ -49,11 +53,10 @@ struct HashInstantiationSignature
struct ConstraintSolver
{
TypeArena* arena;
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter iceReporter;
NotNull<Normalizer> normalizer;
NotNull<TypeReduction> reducer;
// The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints;
NotNull<Scope> rootScope;
@ -75,6 +78,13 @@ struct ConstraintSolver
std::unordered_map<BlockedConstraintId, std::vector<NotNull<const Constraint>>, HashBlockedConstraintId> blocked;
// Memoized instantiations of type aliases.
DenseHashMap<InstantiationSignature, TypeId, HashInstantiationSignature> instantiatedAliases{{}};
// Breadcrumbs for where a free type's upper bound was expanded. We use
// these to provide more helpful error messages when a free type is solved
// as never unexpectedly.
DenseHashMap<TypeId, std::vector<std::pair<Location, TypeId>>> upperBoundContributors{nullptr};
// A mapping from free types to the number of unresolved constraints that mention them.
DenseHashMap<TypeId, size_t> unresolvedConstraints{{}};
// Recorded errors that take place within the solver.
ErrorVec errors;
@ -83,10 +93,11 @@ struct ConstraintSolver
std::vector<RequireCycle> requireCycles;
DcrLogger* logger;
TypeCheckLimits limits;
explicit ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, NotNull<TypeReduction> reducer, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles,
DcrLogger* logger);
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger,
TypeCheckLimits limits);
// Randomize the order in which to dispatch constraints
void randomize(unsigned seed);
@ -111,8 +122,6 @@ struct ConstraintSolver
bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const InstantiationConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const UnaryConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const BinaryConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint);
@ -123,6 +132,10 @@ struct ConstraintSolver
bool tryDispatch(const SetIndexerConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const RefineConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const SetOpConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const ReduceConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const ReducePackConstraint& c, NotNull<const Constraint> constraint, bool force);
// for a, ... in some_table do
// also handles __iter metamethod
@ -132,8 +145,10 @@ struct ConstraintSolver
bool tryDispatchIterableFunction(
TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(TypeId subjectType, const std::string& propName);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set<TypeId>& seen);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
TypeId subjectType, const std::string& propName, bool suppressSimplification = false);
std::pair<std::vector<TypeId>, std::optional<TypeId>> lookupTableProp(
TypeId subjectType, const std::string& propName, bool suppressSimplification, DenseHashSet<TypeId>& seen);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
/**
@ -143,21 +158,39 @@ struct ConstraintSolver
bool block(TypeId target, NotNull<const Constraint> constraint);
bool block(TypePackId target, NotNull<const Constraint> constraint);
// Traverse the type. If any blocked or pending types are found, block
// the constraint on them.
// Block on every target
template<typename T>
bool block(const T& targets, NotNull<const Constraint> constraint)
{
for (TypeId target : targets)
block(target, constraint);
return false;
}
/**
* For all constraints that are blocked on one constraint, make them block
* on a new constraint.
* @param source the constraint to copy blocks from.
* @param addition the constraint that other constraints should now block on.
*/
void inheritBlocks(NotNull<const Constraint> source, NotNull<const Constraint> addition);
// Traverse the type. If any pending types are found, block the constraint
// on them.
//
// Returns false if a type blocks the constraint.
//
// FIXME: This use of a boolean for the return result is an appalling
// interface.
bool recursiveBlock(TypeId target, NotNull<const Constraint> constraint);
bool recursiveBlock(TypePackId target, NotNull<const Constraint> constraint);
bool blockOnPendingTypes(TypeId target, NotNull<const Constraint> constraint);
bool blockOnPendingTypes(TypePackId target, NotNull<const Constraint> constraint);
void unblock(NotNull<const Constraint> progressed);
void unblock(TypeId progressed);
void unblock(TypePackId progressed);
void unblock(const std::vector<TypeId>& types);
void unblock(const std::vector<TypePackId>& packs);
void unblock(TypeId progressed, Location location);
void unblock(TypePackId progressed, Location location);
void unblock(const std::vector<TypeId>& types, Location location);
void unblock(const std::vector<TypePackId>& packs, Location location);
/**
* @returns true if the TypeId is in a blocked state.
@ -180,8 +213,9 @@ struct ConstraintSolver
* the result.
* @param subType the sub-type to unify.
* @param superType the super-type to unify.
* @returns optionally a unification too complex error if unification failed
*/
void unify(TypeId subType, TypeId superType, NotNull<Scope> scope);
std::optional<TypeError> unify(NotNull<Scope> scope, Location location, TypeId subType, TypeId superType);
/**
* Creates a new Unifier and performs a single unification operation. Commits
@ -189,7 +223,7 @@ struct ConstraintSolver
* @param subPack the sub-type pack to unify.
* @param superPack the super-type pack to unify.
*/
void unify(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope);
ErrorVec unify(NotNull<Scope> scope, Location location, TypePackId subPack, TypePackId superPack);
/** Pushes a new solver constraint to the solver.
* @param cv the body of the constraint.
@ -210,7 +244,41 @@ struct ConstraintSolver
void reportError(TypeErrorData&& data, const Location& location);
void reportError(TypeError e);
/**
* Checks the existing set of constraints to see if there exist any that contain
* the provided free type, indicating that it is not yet ready to be replaced by
* one of its bounds.
* @param ty the free type that to check for related constraints
* @returns whether or not it is unsafe to replace the free type by one of its bounds
*/
bool hasUnresolvedConstraints(TypeId ty);
private:
/** Helper used by tryDispatch(SubtypeConstraint) and
* tryDispatch(PackSubtypeConstraint)
*
* Attempts to unify subTy with superTy. If doing so would require unifying
* BlockedTypes, fail and block the constraint on those BlockedTypes.
*
* If unification fails, replace all free types with errorType.
*
* If unification succeeds, unblock every type changed by the unification.
*/
template <typename TID>
bool tryUnify(NotNull<const Constraint> constraint, TID subTy, TID superTy);
/**
* Bind a BlockedType to another type while taking care not to bind it to
* itself in the case that resultTy == blockedTy. This can happen if we
* have a tautological constraint. When it does, we must instead bind
* blockedTy to a fresh type belonging to an appropriate scope.
*
* To determine which scope is appropriate, we also accept rootTy, which is
* to be the type that contains blockedTy.
*/
void bindBlockedType(TypeId blockedTy, TypeId resultTy, TypeId rootTy, Location location);
/**
* Marks a constraint as being blocked on a type or type pack. The constraint
* solver will not attempt to dispatch blocked constraints until their
@ -231,7 +299,10 @@ private:
TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const;
TypeId unionOfTypes(TypeId a, TypeId b, NotNull<Scope> scope, bool unifyFreeTypes);
TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp);
void throwTimeLimitError();
void throwUserCancelError();
ToStringOptions opts;
};

View file

@ -0,0 +1,36 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <memory>
namespace Luau
{
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
enum class ControlFlow
{
None = 0b00001,
Returns = 0b00010,
Throws = 0b00100,
Breaks = 0b01000,
Continues = 0b10000,
};
inline ControlFlow operator&(ControlFlow a, ControlFlow b)
{
return ControlFlow(int(a) & int(b));
}
inline ControlFlow operator|(ControlFlow a, ControlFlow b)
{
return ControlFlow(int(a) | int(b));
}
inline bool matches(ControlFlow a, ControlFlow b)
{
return (a & b) != ControlFlow(0);
}
} // namespace Luau

View file

@ -3,29 +3,47 @@
// Do not include LValue. It should never be used here.
#include "Luau/Ast.h"
#include "Luau/Breadcrumb.h"
#include "Luau/ControlFlow.h"
#include "Luau/DenseHash.h"
#include "Luau/Def.h"
#include "Luau/Symbol.h"
#include "Luau/TypedAllocator.h"
#include <unordered_map>
namespace Luau
{
struct RefinementKey
{
const RefinementKey* parent = nullptr;
DefId def;
std::optional<std::string> propName;
};
struct RefinementKeyArena
{
TypedAllocator<RefinementKey> allocator;
const RefinementKey* leaf(DefId def);
const RefinementKey* node(const RefinementKey* parent, DefId def, const std::string& propName);
};
struct DataFlowGraph
{
DataFlowGraph(DataFlowGraph&&) = default;
DataFlowGraph& operator=(DataFlowGraph&&) = default;
NullableBreadcrumbId getBreadcrumb(const AstExpr* expr) const;
DefId getDef(const AstExpr* expr) const;
// Look up for the rvalue def for a compound assignment.
std::optional<DefId> getRValueDefForCompoundAssign(const AstExpr* expr) const;
BreadcrumbId getBreadcrumb(const AstLocal* local) const;
BreadcrumbId getBreadcrumb(const AstExprLocal* local) const;
BreadcrumbId getBreadcrumb(const AstExprGlobal* global) const;
DefId getDef(const AstLocal* local) const;
BreadcrumbId getBreadcrumb(const AstStatDeclareGlobal* global) const;
BreadcrumbId getBreadcrumb(const AstStatDeclareFunction* func) const;
DefId getDef(const AstStatDeclareGlobal* global) const;
DefId getDef(const AstStatDeclareFunction* func) const;
const RefinementKey* getRefinementKey(const AstExpr* expr) const;
private:
DataFlowGraph() = default;
@ -33,17 +51,23 @@ private:
DataFlowGraph(const DataFlowGraph&) = delete;
DataFlowGraph& operator=(const DataFlowGraph&) = delete;
DefArena defs;
BreadcrumbArena breadcrumbs;
DefArena defArena;
RefinementKeyArena keyArena;
DenseHashMap<const AstExpr*, NullableBreadcrumbId> astBreadcrumbs{nullptr};
DenseHashMap<const AstExpr*, const Def*> astDefs{nullptr};
// Sometimes we don't have the AstExprLocal* but we have AstLocal*, and sometimes we need to extract that DefId.
DenseHashMap<const AstLocal*, NullableBreadcrumbId> localBreadcrumbs{nullptr};
DenseHashMap<const AstLocal*, const Def*> localDefs{nullptr};
// There's no AstStatDeclaration, and it feels useless to introduce it just to enforce an invariant in one place.
// All keys in this maps are really only statements that ambiently declares a symbol.
DenseHashMap<const AstStat*, NullableBreadcrumbId> declaredBreadcrumbs{nullptr};
DenseHashMap<const AstStat*, const Def*> declaredDefs{nullptr};
// Compound assignments are in a weird situation where the local being assigned to is also being used at its
// previous type implicitly in an rvalue position. This map provides the previous binding.
DenseHashMap<const AstExpr*, const Def*> compoundAssignDefs{nullptr};
DenseHashMap<const AstExpr*, const RefinementKey*> astRefinementKeys{nullptr};
friend struct DataFlowGraphBuilder;
};
@ -51,15 +75,29 @@ private:
struct DfgScope
{
DfgScope* parent;
DenseHashMap<Symbol, NullableBreadcrumbId> bindings{Symbol{}};
DenseHashMap<const Def*, std::unordered_map<std::string, NullableBreadcrumbId>> props{nullptr};
bool isLoopScope;
NullableBreadcrumbId lookup(Symbol symbol) const;
NullableBreadcrumbId lookup(DefId def, const std::string& key) const;
using Bindings = DenseHashMap<Symbol, const Def*>;
using Props = DenseHashMap<const Def*, std::unordered_map<std::string, const Def*>>;
Bindings bindings{Symbol{}};
Props props{nullptr};
std::optional<DefId> lookup(Symbol symbol) const;
std::optional<DefId> lookup(DefId def, const std::string& key) const;
void inherit(const DfgScope* childScope);
bool canUpdateDefinition(Symbol symbol) const;
bool canUpdateDefinition(DefId def, const std::string& key) const;
};
struct DataFlowResult
{
DefId def;
const RefinementKey* parent = nullptr;
};
// Currently unsound. We do not presently track the control flow of the program.
// Additionally, we do not presently track assignments.
struct DataFlowGraphBuilder
{
static DataFlowGraph build(AstStatBlock* root, NotNull<struct InternalErrorReporter> handle);
@ -71,61 +109,69 @@ private:
DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete;
DataFlowGraph graph;
NotNull<DefArena> defs{&graph.defs};
NotNull<BreadcrumbArena> breadcrumbs{&graph.breadcrumbs};
NotNull<DefArena> defArena{&graph.defArena};
NotNull<RefinementKeyArena> keyArena{&graph.keyArena};
struct InternalErrorReporter* handle = nullptr;
DfgScope* moduleScope = nullptr;
std::vector<std::unique_ptr<DfgScope>> scopes;
DfgScope* childScope(DfgScope* scope);
DfgScope* childScope(DfgScope* scope, bool isLoopScope = false);
void visit(DfgScope* scope, AstStatBlock* b);
void visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b);
void join(DfgScope* p, DfgScope* a, DfgScope* b);
void joinBindings(DfgScope::Bindings& p, const DfgScope::Bindings& a, const DfgScope::Bindings& b);
void joinProps(DfgScope::Props& p, const DfgScope::Props& a, const DfgScope::Props& b);
void visit(DfgScope* scope, AstStat* s);
void visit(DfgScope* scope, AstStatIf* i);
void visit(DfgScope* scope, AstStatWhile* w);
void visit(DfgScope* scope, AstStatRepeat* r);
void visit(DfgScope* scope, AstStatBreak* b);
void visit(DfgScope* scope, AstStatContinue* c);
void visit(DfgScope* scope, AstStatReturn* r);
void visit(DfgScope* scope, AstStatExpr* e);
void visit(DfgScope* scope, AstStatLocal* l);
void visit(DfgScope* scope, AstStatFor* f);
void visit(DfgScope* scope, AstStatForIn* f);
void visit(DfgScope* scope, AstStatAssign* a);
void visit(DfgScope* scope, AstStatCompoundAssign* c);
void visit(DfgScope* scope, AstStatFunction* f);
void visit(DfgScope* scope, AstStatLocalFunction* l);
void visit(DfgScope* scope, AstStatTypeAlias* t);
void visit(DfgScope* scope, AstStatDeclareGlobal* d);
void visit(DfgScope* scope, AstStatDeclareFunction* d);
void visit(DfgScope* scope, AstStatDeclareClass* d);
void visit(DfgScope* scope, AstStatError* error);
DefId lookup(DfgScope* scope, Symbol symbol);
DefId lookup(DfgScope* scope, DefId def, const std::string& key);
BreadcrumbId visitExpr(DfgScope* scope, AstExpr* e);
BreadcrumbId visitExpr(DfgScope* scope, AstExprLocal* l);
BreadcrumbId visitExpr(DfgScope* scope, AstExprGlobal* g);
BreadcrumbId visitExpr(DfgScope* scope, AstExprCall* c);
BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexName* i);
BreadcrumbId visitExpr(DfgScope* scope, AstExprIndexExpr* i);
BreadcrumbId visitExpr(DfgScope* scope, AstExprFunction* f);
BreadcrumbId visitExpr(DfgScope* scope, AstExprTable* t);
BreadcrumbId visitExpr(DfgScope* scope, AstExprUnary* u);
BreadcrumbId visitExpr(DfgScope* scope, AstExprBinary* b);
BreadcrumbId visitExpr(DfgScope* scope, AstExprTypeAssertion* t);
BreadcrumbId visitExpr(DfgScope* scope, AstExprIfElse* i);
BreadcrumbId visitExpr(DfgScope* scope, AstExprInterpString* i);
BreadcrumbId visitExpr(DfgScope* scope, AstExprError* error);
ControlFlow visit(DfgScope* scope, AstStatBlock* b);
ControlFlow visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b);
void visitLValue(DfgScope* scope, AstExpr* e);
void visitLValue(DfgScope* scope, AstExprLocal* l);
void visitLValue(DfgScope* scope, AstExprGlobal* g);
void visitLValue(DfgScope* scope, AstExprIndexName* i);
void visitLValue(DfgScope* scope, AstExprIndexExpr* i);
void visitLValue(DfgScope* scope, AstExprError* e);
ControlFlow visit(DfgScope* scope, AstStat* s);
ControlFlow visit(DfgScope* scope, AstStatIf* i);
ControlFlow visit(DfgScope* scope, AstStatWhile* w);
ControlFlow visit(DfgScope* scope, AstStatRepeat* r);
ControlFlow visit(DfgScope* scope, AstStatBreak* b);
ControlFlow visit(DfgScope* scope, AstStatContinue* c);
ControlFlow visit(DfgScope* scope, AstStatReturn* r);
ControlFlow visit(DfgScope* scope, AstStatExpr* e);
ControlFlow visit(DfgScope* scope, AstStatLocal* l);
ControlFlow visit(DfgScope* scope, AstStatFor* f);
ControlFlow visit(DfgScope* scope, AstStatForIn* f);
ControlFlow visit(DfgScope* scope, AstStatAssign* a);
ControlFlow visit(DfgScope* scope, AstStatCompoundAssign* c);
ControlFlow visit(DfgScope* scope, AstStatFunction* f);
ControlFlow visit(DfgScope* scope, AstStatLocalFunction* l);
ControlFlow visit(DfgScope* scope, AstStatTypeAlias* t);
ControlFlow visit(DfgScope* scope, AstStatDeclareGlobal* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareFunction* d);
ControlFlow visit(DfgScope* scope, AstStatDeclareClass* d);
ControlFlow visit(DfgScope* scope, AstStatError* error);
DataFlowResult visitExpr(DfgScope* scope, AstExpr* e);
DataFlowResult visitExpr(DfgScope* scope, AstExprGroup* group);
DataFlowResult visitExpr(DfgScope* scope, AstExprLocal* l);
DataFlowResult visitExpr(DfgScope* scope, AstExprGlobal* g);
DataFlowResult visitExpr(DfgScope* scope, AstExprCall* c);
DataFlowResult visitExpr(DfgScope* scope, AstExprIndexName* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprIndexExpr* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprFunction* f);
DataFlowResult visitExpr(DfgScope* scope, AstExprTable* t);
DataFlowResult visitExpr(DfgScope* scope, AstExprUnary* u);
DataFlowResult visitExpr(DfgScope* scope, AstExprBinary* b);
DataFlowResult visitExpr(DfgScope* scope, AstExprTypeAssertion* t);
DataFlowResult visitExpr(DfgScope* scope, AstExprIfElse* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprInterpString* i);
DataFlowResult visitExpr(DfgScope* scope, AstExprError* error);
void visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef, bool isCompoundAssignment = false);
void visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef, bool isCompoundAssignment);
void visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef, bool isCompoundAssignment);
void visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef);
void visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef);
void visitLValue(DfgScope* scope, AstExprError* e, DefId incomingDef);
void visitType(DfgScope* scope, AstType* t);
void visitType(DfgScope* scope, AstTypeReference* r);

View file

@ -23,6 +23,7 @@ using DefId = NotNull<const Def>;
*/
struct Cell
{
bool subscripted = false;
};
/**
@ -71,13 +72,15 @@ const T* get(DefId def)
return get_if<T>(&def->v);
}
bool containsSubscriptedDefinition(DefId def);
struct DefArena
{
TypedAllocator<Def> allocator;
DefId freshCell();
// TODO: implement once we have cases where we need to merge in definitions
// DefId phi(const std::vector<DefId>& defs);
DefId freshCell(bool subscripted = false);
DefId phi(DefId a, DefId b);
DefId phi(const std::vector<DefId>& defs);
};
} // namespace Luau

View file

@ -0,0 +1,199 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/TypeFwd.h"
#include "Luau/UnifierSharedState.h"
#include <optional>
#include <string>
#include <unordered_set>
namespace Luau
{
struct DiffPathNode
{
// TODO: consider using Variants to simplify toString implementation
enum Kind
{
TableProperty,
FunctionArgument,
FunctionReturn,
Union,
Intersection,
Negation,
};
Kind kind;
// non-null when TableProperty
std::optional<Name> tableProperty;
// non-null when FunctionArgument (unless variadic arg), FunctionReturn (unless variadic arg), Union, or Intersection (i.e. anonymous fields)
std::optional<size_t> index;
/**
* Do not use for leaf nodes
*/
DiffPathNode(Kind kind)
: kind(kind)
{
}
DiffPathNode(Kind kind, std::optional<Name> tableProperty, std::optional<size_t> index)
: kind(kind)
, tableProperty(tableProperty)
, index(index)
{
}
std::string toString() const;
static DiffPathNode constructWithTableProperty(Name tableProperty);
static DiffPathNode constructWithKindAndIndex(Kind kind, size_t index);
static DiffPathNode constructWithKind(Kind kind);
};
struct DiffPathNodeLeaf
{
std::optional<TypeId> ty;
std::optional<Name> tableProperty;
std::optional<int> minLength;
bool isVariadic;
// TODO: Rename to anonymousIndex, for both union and Intersection
std::optional<size_t> unionIndex;
DiffPathNodeLeaf(
std::optional<TypeId> ty, std::optional<Name> tableProperty, std::optional<int> minLength, bool isVariadic, std::optional<size_t> unionIndex)
: ty(ty)
, tableProperty(tableProperty)
, minLength(minLength)
, isVariadic(isVariadic)
, unionIndex(unionIndex)
{
}
static DiffPathNodeLeaf detailsNormal(TypeId ty);
static DiffPathNodeLeaf detailsTableProperty(TypeId ty, Name tableProperty);
static DiffPathNodeLeaf detailsUnionIndex(TypeId ty, size_t index);
static DiffPathNodeLeaf detailsLength(int minLength, bool isVariadic);
static DiffPathNodeLeaf nullopts();
};
struct DiffPath
{
std::vector<DiffPathNode> path;
std::string toString(bool prependDot) const;
};
struct DiffError
{
enum Kind
{
Normal,
MissingTableProperty,
MissingUnionMember,
MissingIntersectionMember,
IncompatibleGeneric,
LengthMismatchInFnArgs,
LengthMismatchInFnRets,
};
Kind kind;
DiffPath diffPath;
DiffPathNodeLeaf left;
DiffPathNodeLeaf right;
std::string leftRootName;
std::string rightRootName;
DiffError(Kind kind, DiffPathNodeLeaf left, DiffPathNodeLeaf right, std::string leftRootName, std::string rightRootName)
: kind(kind)
, left(left)
, right(right)
, leftRootName(leftRootName)
, rightRootName(rightRootName)
{
checkValidInitialization(left, right);
}
DiffError(Kind kind, DiffPath diffPath, DiffPathNodeLeaf left, DiffPathNodeLeaf right, std::string leftRootName, std::string rightRootName)
: kind(kind)
, diffPath(diffPath)
, left(left)
, right(right)
, leftRootName(leftRootName)
, rightRootName(rightRootName)
{
checkValidInitialization(left, right);
}
std::string toString(bool multiLine = false) const;
private:
std::string toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf, bool multiLine) const;
void checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right);
void checkNonMissingPropertyLeavesHaveNulloptTableProperty() const;
};
struct DifferResult
{
std::optional<DiffError> diffError;
DifferResult() {}
DifferResult(DiffError diffError)
: diffError(diffError)
{
}
void wrapDiffPath(DiffPathNode node);
};
struct DifferEnvironment
{
TypeId rootLeft;
TypeId rootRight;
std::optional<std::string> externalSymbolLeft;
std::optional<std::string> externalSymbolRight;
DenseHashMap<TypeId, TypeId> genericMatchedPairs;
DenseHashMap<TypePackId, TypePackId> genericTpMatchedPairs;
DifferEnvironment(
TypeId rootLeft, TypeId rootRight, std::optional<std::string> externalSymbolLeft, std::optional<std::string> externalSymbolRight)
: rootLeft(rootLeft)
, rootRight(rootRight)
, externalSymbolLeft(externalSymbolLeft)
, externalSymbolRight(externalSymbolRight)
, genericMatchedPairs(nullptr)
, genericTpMatchedPairs(nullptr)
{
}
bool isProvenEqual(TypeId left, TypeId right) const;
bool isAssumedEqual(TypeId left, TypeId right) const;
void recordProvenEqual(TypeId left, TypeId right);
void pushVisiting(TypeId left, TypeId right);
void popVisiting();
std::vector<std::pair<TypeId, TypeId>>::const_reverse_iterator visitingBegin() const;
std::vector<std::pair<TypeId, TypeId>>::const_reverse_iterator visitingEnd() const;
std::string getDevFixFriendlyNameLeft() const;
std::string getDevFixFriendlyNameRight() const;
private:
// TODO: consider using DenseHashSet
std::unordered_set<std::pair<TypeId, TypeId>, TypeIdPairHash> provenEqual;
// Ancestors of current types
std::unordered_set<std::pair<TypeId, TypeId>, TypeIdPairHash> visiting;
std::vector<std::pair<TypeId, TypeId>> visitingStack;
};
DifferResult diff(TypeId ty1, TypeId ty2);
DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional<std::string> symbol1, std::optional<std::string> symbol2);
/**
* True if ty is a "simple" type, i.e. cannot contain types.
* string, number, boolean are simple types.
* function and table are not simple types.
*/
bool isSimple(TypeId ty);
} // namespace Luau

View file

@ -2,8 +2,12 @@
#pragma once
#include "Luau/Location.h"
#include "Luau/NotNull.h"
#include "Luau/Type.h"
#include "Luau/Variant.h"
#include "Luau/Ast.h"
#include <set>
namespace Luau
{
@ -318,6 +322,7 @@ struct TypePackMismatch
{
TypePackId wantedTp;
TypePackId givenTp;
std::string reason;
bool operator==(const TypePackMismatch& rhs) const;
};
@ -329,12 +334,51 @@ struct DynamicPropertyLookupOnClassesUnsafe
bool operator==(const DynamicPropertyLookupOnClassesUnsafe& rhs) const;
};
struct UninhabitedTypeFamily
{
TypeId ty;
bool operator==(const UninhabitedTypeFamily& rhs) const;
};
struct UninhabitedTypePackFamily
{
TypePackId tp;
bool operator==(const UninhabitedTypePackFamily& rhs) const;
};
struct WhereClauseNeeded
{
TypeId ty;
bool operator==(const WhereClauseNeeded& rhs) const;
};
struct PackWhereClauseNeeded
{
TypePackId tp;
bool operator==(const PackWhereClauseNeeded& rhs) const;
};
struct CheckedFunctionCallError
{
TypeId expected;
TypeId passed;
std::string checkedFunctionName;
// TODO: make this a vector<argumentIndices>
size_t argumentIndex;
bool operator==(const CheckedFunctionCallError& rhs) const;
};
using TypeErrorData = Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods,
DuplicateTypeDefinition, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire,
IncorrectGenericParameterCount, SyntaxError, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError,
CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning,
DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty,
TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch, DynamicPropertyLookupOnClassesUnsafe>;
TypesAreUnrelated, NormalizationTooComplex, TypePackMismatch, DynamicPropertyLookupOnClassesUnsafe, UninhabitedTypeFamily,
UninhabitedTypePackFamily, WhereClauseNeeded, PackWhereClauseNeeded, CheckedFunctionCallError>;
struct TypeErrorSummary
{
@ -403,7 +447,7 @@ std::string toString(const TypeError& error, TypeErrorToStringOptions options);
bool containsParseErrorName(const TypeError& error);
// Copy any types named in the error into destArena.
void copyErrors(ErrorVec& errors, struct TypeArena& destArena);
void copyErrors(ErrorVec& errors, struct TypeArena& destArena, NotNull<BuiltinTypes> builtinTypes);
// Internal Compiler Error
struct InternalErrorReporter

View file

@ -2,13 +2,15 @@
#pragma once
#include "Luau/Config.h"
#include "Luau/GlobalTypes.h"
#include "Luau/Module.h"
#include "Luau/ModuleResolver.h"
#include "Luau/RequireTracer.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/Variant.h"
#include <mutex>
#include <string>
#include <vector>
#include <optional>
@ -27,6 +29,8 @@ struct FileResolver;
struct ModuleResolver;
struct ParseResult;
struct HotComment;
struct BuildQueueItem;
struct FrontendCancellationToken;
struct LoadDefinitionFileResult
{
@ -36,9 +40,6 @@ struct LoadDefinitionFileResult
ModulePtr module;
};
LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view definition,
const std::string& packageName, bool captureComments);
std::optional<Mode> parseMode(const std::vector<HotComment>& hotcomments);
std::vector<std::string_view> parsePathExpr(const AstExpr& pathExpr);
@ -69,7 +70,8 @@ struct SourceNode
}
ModuleName name;
std::unordered_set<ModuleName> requireSet;
std::string humanReadableName;
DenseHashSet<ModuleName> requireSet{{}};
std::vector<std::pair<ModuleName, Location>> requireLocations;
bool dirtySourceModule = true;
bool dirtyModule = true;
@ -89,14 +91,29 @@ struct FrontendOptions
// order to get more precise type information)
bool forAutocomplete = false;
bool runLintChecks = false;
// If not empty, randomly shuffle the constraint set before attempting to
// solve. Use this value to seed the random number generator.
std::optional<unsigned> randomizeConstraintResolutionSeed;
std::optional<LintOptions> enabledLintWarnings;
std::shared_ptr<FrontendCancellationToken> cancellationToken;
// Time limit for typechecking a single module
std::optional<double> moduleTimeLimitSec;
// When true, some internal complexity limits will be scaled down for modules that miss the limit set by moduleTimeLimitSec
bool applyInternalLimitScaling = false;
};
struct CheckResult
{
std::vector<TypeError> errors;
LintResult lintResult;
std::vector<ModuleName> timeoutHits;
};
@ -109,7 +126,13 @@ struct FrontendModuleResolver : ModuleResolver
std::optional<ModuleInfo> resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override;
std::string getHumanReadableModuleName(const ModuleName& moduleName) const override;
void setModule(const ModuleName& moduleName, ModulePtr module);
void clearModules();
private:
Frontend* frontend;
mutable std::mutex moduleMutex;
std::unordered_map<ModuleName, ModulePtr> modules;
};
@ -131,10 +154,11 @@ struct Frontend
Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {});
CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess
// Parse module graph and prepare SourceNode/SourceModule data, including required dependencies without running typechecking
void parse(const ModuleName& name);
LintResult lint(const ModuleName& name, std::optional<LintOptions> enabledLintWarnings = {});
LintResult lint(const SourceModule& module, std::optional<LintOptions> enabledLintWarnings = {});
// Parse and typecheck module graph
CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess
bool isDirty(const ModuleName& name, bool forAutocomplete = false) const;
void markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty = nullptr);
@ -156,25 +180,43 @@ struct Frontend
ScopePtr addEnvironment(const std::string& environmentName);
ScopePtr getEnvironmentScope(const std::string& environmentName) const;
void registerBuiltinDefinition(const std::string& name, std::function<void(TypeChecker&, GlobalTypes&, ScopePtr)>);
void registerBuiltinDefinition(const std::string& name, std::function<void(Frontend&, GlobalTypes&, ScopePtr)>);
void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName);
LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName, bool captureComments);
LoadDefinitionFileResult loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName,
bool captureComments, bool typeCheckForAutocomplete = false);
// Batch module checking. Queue modules and check them together, retrieve results with 'getCheckResult'
// If provided, 'executeTask' function is allowed to call the 'task' function on any thread and return without waiting for 'task' to complete
void queueModuleCheck(const std::vector<ModuleName>& names);
void queueModuleCheck(const ModuleName& name);
std::vector<ModuleName> checkQueuedModules(std::optional<FrontendOptions> optionOverride = {},
std::function<void(std::function<void()> task)> executeTask = {}, std::function<void(size_t done, size_t total)> progress = {});
std::optional<CheckResult> getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false);
private:
ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector<RequireCycle> requireCycles, bool forAutocomplete = false, bool recordJsonLog = false);
ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector<RequireCycle> requireCycles, std::optional<ScopePtr> environmentScope,
bool forAutocomplete, bool recordJsonLog, TypeCheckLimits typeCheckLimits);
std::pair<SourceNode*, SourceModule*> getSourceNode(const ModuleName& name);
SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions);
bool parseGraph(std::vector<ModuleName>& buildQueue, const ModuleName& root, bool forAutocomplete);
bool parseGraph(
std::vector<ModuleName>& buildQueue, const ModuleName& root, bool forAutocomplete, std::function<bool(const ModuleName&)> canSkip = {});
void addBuildQueueItems(std::vector<BuildQueueItem>& items, std::vector<ModuleName>& buildQueue, bool cycleDetected,
DenseHashSet<Luau::ModuleName>& seen, const FrontendOptions& frontendOptions);
void checkBuildQueueItem(BuildQueueItem& item);
void checkBuildQueueItems(std::vector<BuildQueueItem>& items);
void recordItemResult(const BuildQueueItem& item);
static LintResult classifyLints(const std::vector<LintWarning>& warnings, const Config& config);
ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const;
std::unordered_map<std::string, ScopePtr> environments;
std::unordered_map<std::string, std::function<void(TypeChecker&, GlobalTypes&, ScopePtr)>> builtinDefinitions;
std::unordered_map<std::string, std::function<void(Frontend&, GlobalTypes&, ScopePtr)>> builtinDefinitions;
BuiltinTypes builtinTypes_;
@ -182,29 +224,35 @@ public:
const NotNull<BuiltinTypes> builtinTypes;
FileResolver* fileResolver;
FrontendModuleResolver moduleResolver;
FrontendModuleResolver moduleResolverForAutocomplete;
GlobalTypes globals;
GlobalTypes globalsForAutocomplete;
TypeChecker typeChecker;
TypeChecker typeCheckerForAutocomplete;
ConfigResolver* configResolver;
FrontendOptions options;
InternalErrorReporter iceHandler;
std::function<void(const ModuleName& name, const ScopePtr& scope, bool forAutocomplete)> prepareModuleScope;
std::unordered_map<ModuleName, SourceNode> sourceNodes;
std::unordered_map<ModuleName, SourceModule> sourceModules;
std::unordered_map<ModuleName, std::shared_ptr<SourceNode>> sourceNodes;
std::unordered_map<ModuleName, std::shared_ptr<SourceModule>> sourceModules;
std::unordered_map<ModuleName, RequireTraceResult> requireTrace;
Stats stats = {};
std::vector<ModuleName> moduleQueue;
};
ModulePtr check(const SourceModule& sourceModule, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes,
ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler, NotNull<ModuleResolver> moduleResolver, NotNull<FileResolver> fileResolver,
const ScopePtr& globalScope, FrontendOptions options);
const ScopePtr& globalScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, FrontendOptions options,
TypeCheckLimits limits);
ModulePtr check(const SourceModule& sourceModule, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes,
ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector<RequireCycle>& requireCycles, NotNull<BuiltinTypes> builtinTypes,
NotNull<InternalErrorReporter> iceHandler, NotNull<ModuleResolver> moduleResolver, NotNull<FileResolver> fileResolver,
const ScopePtr& globalScope, FrontendOptions options, bool recordJsonLog);
const ScopePtr& globalScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, FrontendOptions options,
TypeCheckLimits limits, bool recordJsonLog);
} // namespace Luau

View file

@ -0,0 +1,25 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Module.h"
#include "Luau/NotNull.h"
#include "Luau/Scope.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
struct GlobalTypes
{
explicit GlobalTypes(NotNull<BuiltinTypes> builtinTypes);
NotNull<BuiltinTypes> builtinTypes; // Global types are based on builtin types
TypeArena globalTypes;
SourceModule globalNames; // names for symbols entered into globalScope
ScopePtr globalScope; // shared by all modules
};
} // namespace Luau

View file

@ -0,0 +1,134 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include <unordered_map>
#include <vector>
#include <type_traits>
#include <iterator>
namespace Luau
{
template<typename K, typename V>
struct InsertionOrderedMap
{
static_assert(std::is_trivially_copyable_v<K>, "key must be trivially copyable");
private:
using vec = std::vector<std::pair<K, V>>;
public:
using iterator = typename vec::iterator;
using const_iterator = typename vec::const_iterator;
void insert(K k, V v)
{
if (indices.count(k) != 0)
return;
pairs.push_back(std::make_pair(k, std::move(v)));
indices[k] = pairs.size() - 1;
}
void clear()
{
pairs.clear();
indices.clear();
}
size_t size() const
{
LUAU_ASSERT(pairs.size() == indices.size());
return pairs.size();
}
bool contains(const K& k) const
{
return indices.count(k) > 0;
}
const V* get(const K& k) const
{
auto it = indices.find(k);
if (it == indices.end())
return nullptr;
else
return &pairs.at(it->second).second;
}
V* get(const K& k)
{
auto it = indices.find(k);
if (it == indices.end())
return nullptr;
else
return &pairs.at(it->second).second;
}
const_iterator begin() const
{
return pairs.begin();
}
const_iterator end() const
{
return pairs.end();
}
iterator begin()
{
return pairs.begin();
}
iterator end()
{
return pairs.end();
}
const_iterator find(K k) const
{
auto indicesIt = indices.find(k);
if (indicesIt == indices.end())
return end();
else
return begin() + indicesIt->second;
}
iterator find(K k)
{
auto indicesIt = indices.find(k);
if (indicesIt == indices.end())
return end();
else
return begin() + indicesIt->second;
}
void erase(iterator it)
{
if (it == pairs.end())
return;
K k = it->first;
auto indexIt = indices.find(k);
if (indexIt == indices.end())
return;
size_t removed = indexIt->second;
indices.erase(indexIt);
pairs.erase(it);
for (auto& [_, index] : indices)
{
if (index > removed)
--index;
}
}
private:
vec pairs;
std::unordered_map<K, size_t> indices;
};
} // namespace Luau

View file

@ -1,22 +1,25 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/NotNull.h"
#include "Luau/Substitution.h"
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
#include "Luau/Unifiable.h"
namespace Luau
{
struct TypeArena;
struct TxnLog;
struct TypeArena;
struct TypeCheckLimits;
// A substitution which replaces generic types in a given set by free types.
struct ReplaceGenerics : Substitution
{
ReplaceGenerics(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope, const std::vector<TypeId>& generics,
const std::vector<TypePackId>& genericPacks)
ReplaceGenerics(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope,
const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks)
: Substitution(log, arena)
, builtinTypes(builtinTypes)
, level(level)
, scope(scope)
, generics(generics)
@ -24,6 +27,8 @@ struct ReplaceGenerics : Substitution
{
}
NotNull<BuiltinTypes> builtinTypes;
TypeLevel level;
Scope* scope;
std::vector<TypeId> generics;
@ -38,13 +43,16 @@ struct ReplaceGenerics : Substitution
// A substitution which replaces generic functions by monomorphic functions
struct Instantiation : Substitution
{
Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level, Scope* scope)
Instantiation(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope)
: Substitution(log, arena)
, builtinTypes(builtinTypes)
, level(level)
, scope(scope)
{
}
NotNull<BuiltinTypes> builtinTypes;
TypeLevel level;
Scope* scope;
bool ignoreChildren(TypeId ty) override;
@ -54,4 +62,21 @@ struct Instantiation : Substitution
TypePackId clean(TypePackId tp) override;
};
/** Attempt to instantiate a type. Only used under local type inference.
*
* When given a generic function type, instantiate() will return a copy with the
* generics replaced by fresh types. Instantiation will return the same TypeId
* back if the function does not have any generics.
*
* All higher order generics are left as-is. For example, instantiation of
* <X>(<Y>(Y) -> (X, Y)) -> (X, Y) is (<Y>(Y) -> ('x, Y)) -> ('x, Y)
*
* We substitute the generic X for the free 'x, but leave the generic Y alone.
*
* Instantiation fails only when processing the type causes internal recursion
* limits to be exceeded.
*/
std::optional<TypeId> instantiate(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, TypeId ty);
} // namespace Luau

View file

@ -5,6 +5,7 @@
#include "Luau/Location.h"
#include "Luau/Type.h"
#include "Luau/Ast.h"
#include "Luau/TypePath.h"
#include <ostream>
@ -48,4 +49,14 @@ std::ostream& operator<<(std::ostream& lhs, const TypePackVar& tv);
std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted);
std::ostream& operator<<(std::ostream& lhs, TypeId ty);
std::ostream& operator<<(std::ostream& lhs, TypePackId tp);
namespace TypePath
{
std::ostream& operator<<(std::ostream& lhs, const Path& path);
}; // namespace TypePath
} // namespace Luau

View file

@ -3,6 +3,7 @@
#include "Luau/Variant.h"
#include "Luau/Symbol.h"
#include "Luau/TypeFwd.h"
#include <memory>
#include <unordered_map>
@ -10,9 +11,6 @@
namespace Luau
{
struct Type;
using TypeId = const Type*;
struct Field;
// Deprecated. Do not use in new work.

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/LinterConfig.h"
#include "Luau/Location.h"
#include <memory>
@ -15,86 +16,15 @@ class AstStat;
class AstNameTable;
struct TypeChecker;
struct Module;
struct HotComment;
using ScopePtr = std::shared_ptr<struct Scope>;
struct LintWarning
{
// Make sure any new lint codes are documented here: https://luau-lang.org/lint
// Note that in Studio, the active set of lint warnings is determined by FStringStudioLuauLints
enum Code
{
Code_Unknown = 0,
Code_UnknownGlobal = 1, // superseded by type checker
Code_DeprecatedGlobal = 2,
Code_GlobalUsedAsLocal = 3,
Code_LocalShadow = 4, // disabled in Studio
Code_SameLineStatement = 5, // disabled in Studio
Code_MultiLineStatement = 6,
Code_LocalUnused = 7, // disabled in Studio
Code_FunctionUnused = 8, // disabled in Studio
Code_ImportUnused = 9, // disabled in Studio
Code_BuiltinGlobalWrite = 10,
Code_PlaceholderRead = 11,
Code_UnreachableCode = 12,
Code_UnknownType = 13,
Code_ForRange = 14,
Code_UnbalancedAssignment = 15,
Code_ImplicitReturn = 16, // disabled in Studio, superseded by type checker in strict mode
Code_DuplicateLocal = 17,
Code_FormatString = 18,
Code_TableLiteral = 19,
Code_UninitializedLocal = 20,
Code_DuplicateFunction = 21,
Code_DeprecatedApi = 22,
Code_TableOperations = 23,
Code_DuplicateCondition = 24,
Code_MisleadingAndOr = 25,
Code_CommentDirective = 26,
Code_IntegerParsing = 27,
Code_ComparisonPrecedence = 28,
Code__Count
};
Code code;
Location location;
std::string text;
static const char* getName(Code code);
static Code parseName(const char* name);
static uint64_t parseMask(const std::vector<HotComment>& hotcomments);
};
struct LintResult
{
std::vector<LintWarning> errors;
std::vector<LintWarning> warnings;
};
struct LintOptions
{
uint64_t warningMask = 0;
void enableWarning(LintWarning::Code code)
{
warningMask |= 1ull << code;
}
void disableWarning(LintWarning::Code code)
{
warningMask &= ~(1ull << code);
}
bool isEnabled(LintWarning::Code code) const
{
return 0 != (warningMask & (1ull << code));
}
void setDefaults();
};
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module,
const std::vector<HotComment>& hotcomments, const LintOptions& options);

View file

@ -19,6 +19,7 @@ static const std::unordered_map<AstExprBinary::Op, const char*> kBinaryOpMetamet
{AstExprBinary::Op::Sub, "__sub"},
{AstExprBinary::Op::Mul, "__mul"},
{AstExprBinary::Op::Div, "__div"},
{AstExprBinary::Op::FloorDiv, "__idiv"},
{AstExprBinary::Op::Pow, "__pow"},
{AstExprBinary::Op::Mod, "__mod"},
{AstExprBinary::Op::Concat, "__concat"},

View file

@ -2,6 +2,7 @@
#pragma once
#include "Luau/Error.h"
#include "Luau/Linter.h"
#include "Luau/FileResolver.h"
#include "Luau/ParseOptions.h"
#include "Luau/ParseResult.h"
@ -27,7 +28,9 @@ class AstTypePack;
/// Root of the AST of a parsed source file
struct SourceModule
{
ModuleName name; // DataModel path if possible. Filename if not.
ModuleName name; // Module identifier or a filename
std::string humanReadableName;
SourceCode::Type type = SourceCode::None;
std::optional<std::string> environmentName;
bool cyclic = false;
@ -50,6 +53,7 @@ struct SourceModule
};
bool isWithinComment(const SourceModule& sourceModule, Position pos);
bool isWithinComment(const ParseResult& result, Position pos);
struct RequireCycle
{
@ -61,6 +65,9 @@ struct Module
{
~Module();
ModuleName name;
std::string humanReadableName;
TypeArena interfaceTypes;
TypeArena internalTypes;
@ -74,23 +81,41 @@ struct Module
DenseHashMap<const AstExpr*, TypePackId> astTypePacks{nullptr};
DenseHashMap<const AstExpr*, TypeId> astExpectedTypes{nullptr};
// For AST nodes that are function calls, this map provides the
// unspecialized type of the function that was called. If a function call
// resolves to a __call metamethod application, this map will point at that
// metamethod.
//
// This is useful for type checking and Signature Help.
DenseHashMap<const AstNode*, TypeId> astOriginalCallTypes{nullptr};
// The specialization of a function that was selected. If the function is
// generic, those generic type parameters will be replaced with the actual
// types that were passed. If the function is an overload, this map will
// point at the specific overloads that were selected.
DenseHashMap<const AstNode*, TypeId> astOverloadResolvedTypes{nullptr};
// Only used with for...in loops. The computed type of the next() function
// is kept here for type checking.
DenseHashMap<const AstNode*, TypeId> astForInNextTypes{nullptr};
DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr};
DenseHashMap<const AstType*, TypeId> astOriginalResolvedTypes{nullptr};
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
// Map AST nodes to the scope they create. Cannot be NotNull<Scope> because we need a sentinel value for the map.
DenseHashMap<const AstNode*, Scope*> astScopes{nullptr};
DenseHashMap<TypeId, std::vector<std::pair<Location, TypeId>>> upperBoundContributors{nullptr};
std::unique_ptr<struct TypeReduction> reduction;
// Map AST nodes to the scope they create. Cannot be NotNull<Scope> because
// we need a sentinel value for the map.
DenseHashMap<const AstNode*, Scope*> astScopes{nullptr};
std::unordered_map<Name, TypeId> declaredGlobals;
ErrorVec errors;
LintResult lintResult;
Mode mode;
SourceCode::Type type;
double checkDurationSec = 0.0;
bool timeout = false;
bool cancelled = false;
TypePackId returnType = nullptr;
std::unordered_map<Name, TypeFun> exportedTypeBindings;
@ -98,8 +123,9 @@ struct Module
bool hasModuleScope() const;
ScopePtr getModuleScope() const;
// Once a module has been typechecked, we clone its public interface into a separate arena.
// This helps us to force Type ownership into a DAG rather than a DCG.
// Once a module has been typechecked, we clone its public interface into a
// separate arena. This helps us to force Type ownership into a DAG rather
// than a DCG.
void clonePublicInterface(NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
};

View file

@ -0,0 +1,19 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Module.h"
#include "Luau/NotNull.h"
#include "Luau/DataFlowGraph.h"
namespace Luau
{
struct BuiltinTypes;
struct UnifierSharedState;
struct TypeCheckLimits;
void checkNonStrict(NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, const SourceModule& sourceModule, Module* module);
} // namespace Luau

View file

@ -2,10 +2,15 @@
#pragma once
#include "Luau/NotNull.h"
#include "Luau/Type.h"
#include "Luau/Set.h"
#include "Luau/TypeFwd.h"
#include "Luau/UnifierSharedState.h"
#include <initializer_list>
#include <map>
#include <memory>
#include <unordered_map>
#include <vector>
namespace Luau
{
@ -13,17 +18,18 @@ namespace Luau
struct InternalErrorReporter;
struct Module;
struct Scope;
struct BuiltinTypes;
using ModulePtr = std::shared_ptr<Module>;
bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
bool isConsistentSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter& ice);
class TypeIds
{
private:
std::unordered_set<TypeId> types;
DenseHashMap<TypeId, bool> types{nullptr};
std::vector<TypeId> order;
std::size_t hash = 0;
@ -31,10 +37,15 @@ public:
using iterator = std::vector<TypeId>::iterator;
using const_iterator = std::vector<TypeId>::const_iterator;
TypeIds(const TypeIds&) = default;
TypeIds(TypeIds&&) = default;
TypeIds() = default;
~TypeIds() = default;
TypeIds(std::initializer_list<TypeId> tys);
TypeIds(const TypeIds&) = default;
TypeIds& operator=(const TypeIds&) = default;
TypeIds(TypeIds&&) = default;
TypeIds& operator=(TypeIds&&) = default;
void insert(TypeId ty);
@ -62,6 +73,7 @@ public:
bool operator==(const TypeIds& there) const;
size_t getHash() const;
bool isNever() const;
};
} // namespace Luau
@ -189,12 +201,8 @@ struct NormalizedClassType
// this type may contain `error`.
struct NormalizedFunctionType
{
NormalizedFunctionType();
bool isTop = false;
// TODO: Remove this wrapping optional when clipping
// FFlagLuauNegatedFunctionTypes.
std::optional<TypeIds> parts;
TypeIds parts;
void resetToNever();
void resetToTop();
@ -203,18 +211,16 @@ struct NormalizedFunctionType
};
// A normalized generic/free type is a union, where each option is of the form (X & T) where
// * X is either a free type or a generic
// * X is either a free type, a generic or a blocked type.
// * T is a normalized type.
struct NormalizedType;
using NormalizedTyvars = std::unordered_map<TypeId, std::unique_ptr<NormalizedType>>;
bool isInhabited_DEPRECATED(const NormalizedType& norm);
// A normalized type is either any, unknown, or one of the form P | T | F | G where
// * P is a union of primitive types (including singletons, classes and the error type)
// * T is a union of table types
// * F is a union of an intersection of function types
// * G is a union of generic/free normalized types, intersected with a normalized type
// * G is a union of generic/free/blocked types, intersected with a normalized type
struct NormalizedType
{
// The top part of the type.
@ -228,10 +234,6 @@ struct NormalizedType
NormalizedClassType classes;
// The class part of the type.
// Each element of this set is a class, and none of the classes are subclasses of each other.
TypeIds DEPRECATED_classes;
// The error part of the type.
// This type is either never or the error type.
TypeId errors;
@ -252,6 +254,10 @@ struct NormalizedType
// This type is either never or thread.
TypeId threads;
// The buffer part of the type.
// This type is either never or buffer.
TypeId buffers;
// The (meta)table part of the type.
// Each element of this set is a (meta)table type, or the top `table` type.
// An empty set denotes never.
@ -273,22 +279,57 @@ struct NormalizedType
NormalizedType(NormalizedType&&) = default;
NormalizedType& operator=(NormalizedType&&) = default;
// IsType functions
bool isUnknown() const;
/// Returns true if the type is exactly a number. Behaves like Type::isNumber()
bool isExactlyNumber() const;
/// Returns true if the type is a subtype of string(it could be a singleton). Behaves like Type::isString()
bool isSubtypeOfString() const;
/// Returns true if this type should result in error suppressing behavior.
bool shouldSuppressErrors() const;
/// Returns true if this type contains the primitve top table type, `table`.
bool hasTopTable() const;
// Helpers that improve readability of the above (they just say if the component is present)
bool hasTops() const;
bool hasBooleans() const;
bool hasClasses() const;
bool hasErrors() const;
bool hasNils() const;
bool hasNumbers() const;
bool hasStrings() const;
bool hasThreads() const;
bool hasBuffers() const;
bool hasTables() const;
bool hasFunctions() const;
bool hasTyvars() const;
};
class Normalizer
{
std::unordered_map<TypeId, std::unique_ptr<NormalizedType>> cachedNormals;
std::unordered_map<const TypeIds*, TypeId> cachedIntersections;
std::unordered_map<const TypeIds*, TypeId> cachedUnions;
std::unordered_map<const TypeIds*, std::unique_ptr<TypeIds>> cachedTypeIds;
DenseHashMap<TypeId, bool> cachedIsInhabited{nullptr};
DenseHashMap<std::pair<TypeId, TypeId>, bool, TypeIdPairHash> cachedIsInhabitedIntersection{{nullptr, nullptr}};
bool withinResourceLimits();
public:
TypeArena* arena;
NotNull<BuiltinTypes> builtinTypes;
NotNull<UnifierSharedState> sharedState;
bool cacheInhabitance = false;
Normalizer(TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> sharedState);
Normalizer(TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> sharedState, bool cacheInhabitance = false);
Normalizer(const Normalizer&) = delete;
Normalizer(Normalizer&&) = delete;
Normalizer() = delete;
@ -323,7 +364,7 @@ public:
void unionTablesWithTable(TypeIds& heres, TypeId there);
void unionTables(TypeIds& heres, const TypeIds& theres);
bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1);
bool unionNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes, int ignoreSmallerTyvars = -1);
// ------- Negations
std::optional<NormalizedType> negateNormal(const NormalizedType& here);
@ -335,8 +376,6 @@ public:
// ------- Normalizing intersections
TypeId intersectionOfTops(TypeId here, TypeId there);
TypeId intersectionOfBools(TypeId here, TypeId there);
void DEPRECATED_intersectClasses(TypeIds& heres, const TypeIds& theres);
void DEPRECATED_intersectClassesWithClass(TypeIds& heres, TypeId there);
void intersectClasses(NormalizedClassType& heres, const NormalizedClassType& theres);
void intersectClassesWithClass(NormalizedClassType& heres, TypeId there);
void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there);
@ -347,13 +386,18 @@ public:
std::optional<TypeId> intersectionOfFunctions(TypeId here, TypeId there);
void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there);
void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress);
bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there);
bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set<TypeId>& seenSetTypes);
bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1);
bool intersectNormalWithTy(NormalizedType& here, TypeId there);
bool intersectNormalWithTy(NormalizedType& here, TypeId there, Set<TypeId>& seenSetTypes);
bool normalizeIntersections(const std::vector<TypeId>& intersections, NormalizedType& outType);
// Check for inhabitance
bool isInhabited(TypeId ty, std::unordered_set<TypeId> seen = {});
bool isInhabited(const NormalizedType* norm, std::unordered_set<TypeId> seen = {});
bool isInhabited(TypeId ty);
bool isInhabited(TypeId ty, Set<TypeId> seen);
bool isInhabited(const NormalizedType* norm, Set<TypeId> seen = {nullptr});
// Check for intersections being inhabited
bool isIntersectionInhabited(TypeId left, TypeId right);
// -------- Convert back from a normalized type to a type
TypeId typeFromNormal(const NormalizedType& norm);

View file

@ -4,15 +4,13 @@
#include "Luau/Location.h"
#include "Luau/LValue.h"
#include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
#include <vector>
namespace Luau
{
struct Type;
using TypeId = const Type*;
struct TruthyPredicate;
struct IsAPredicate;
struct TypeGuardPredicate;

View file

@ -1,7 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
#include "Luau/DenseHash.h"
#include "Luau/Unifiable.h"
#include <vector>
#include <optional>
namespace Luau
{
@ -10,6 +15,29 @@ struct TypeArena;
struct Scope;
void quantify(TypeId ty, TypeLevel level);
std::optional<TypeId> quantify(TypeArena* arena, TypeId ty, Scope* scope);
// TODO: This is eerily similar to the pattern that NormalizedClassType
// implements. We could, and perhaps should, merge them together.
template<typename K, typename V>
struct OrderedMap
{
std::vector<K> keys;
DenseHashMap<K, V> pairings{nullptr};
void push(K k, V v)
{
keys.push_back(k);
pairings[k] = v;
}
};
struct QuantifierResult
{
TypeId result;
OrderedMap<TypeId, TypeId> insertedGenerics;
OrderedMap<TypePackId, TypePackId> insertedGenericPacks;
};
std::optional<QuantifierResult> quantify(TypeArena* arena, TypeId ty, Scope* scope);
} // namespace Luau

View file

@ -4,14 +4,13 @@
#include "Luau/NotNull.h"
#include "Luau/TypedAllocator.h"
#include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
using BreadcrumbId = NotNull<const struct Breadcrumb>;
struct Type;
using TypeId = const Type*;
struct RefinementKey;
using DefId = NotNull<const struct Def>;
struct Variadic;
struct Negation;
@ -52,7 +51,7 @@ struct Equivalence
struct Proposition
{
BreadcrumbId breadcrumb;
const RefinementKey* key;
TypeId discriminantTy;
};
@ -69,7 +68,7 @@ struct RefinementArena
RefinementId conjunction(RefinementId lhs, RefinementId rhs);
RefinementId disjunction(RefinementId lhs, RefinementId rhs);
RefinementId equivalence(RefinementId lhs, RefinementId rhs);
RefinementId proposition(BreadcrumbId breadcrumb, TypeId discriminantTy);
RefinementId proposition(const RefinementKey* key, TypeId discriminantTy);
private:
TypedAllocator<Refinement> allocator;

View file

@ -2,9 +2,13 @@
#pragma once
#include "Luau/Def.h"
#include "Luau/LValue.h"
#include "Luau/Location.h"
#include "Luau/NotNull.h"
#include "Luau/Type.h"
#include "Luau/DenseHash.h"
#include "Luau/Symbol.h"
#include "Luau/Unifiable.h"
#include <unordered_map>
#include <optional>
@ -52,20 +56,32 @@ struct Scope
void addBuiltinTypeBinding(const Name& name, const TypeFun& tyFun);
std::optional<TypeId> lookup(Symbol sym) const;
std::optional<TypeId> lookupUnrefinedType(DefId def) const;
std::optional<TypeId> lookup(DefId def) const;
std::optional<std::pair<TypeId, Scope*>> lookupEx(DefId def);
std::optional<std::pair<Binding*, Scope*>> lookupEx(Symbol sym);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
std::optional<TypeFun> lookupType(const Name& name) const;
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name) const;
std::unordered_map<Name, TypePackId> privateTypePackBindings;
std::optional<TypePackId> lookupPack(const Name& name);
std::optional<TypePackId> lookupPack(const Name& name) const;
// WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2)
std::optional<Binding> linearSearchForBinding(const std::string& name, bool traverseScopeChain = true) const;
RefinementMap refinements;
DenseHashMap<const Def*, TypeId> dcrRefinements{nullptr};
// This can be viewed as the "unrefined" type of each binding.
DenseHashMap<const Def*, TypeId> lvalueTypes{nullptr};
// Luau values are routinely refined more narrowly than their actual
// inferred type through control flow statements. We retain those refined
// types here.
DenseHashMap<const Def*, TypeId> rvalueRefinements{nullptr};
void inheritAssignments(const ScopePtr& childScope);
void inheritRefinements(const ScopePtr& childScope);
// For mutually recursive type aliases, it's important that
// they use the same types for the same names.

171
Analysis/include/Luau/Set.h Normal file
View file

@ -0,0 +1,171 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
namespace Luau
{
template<typename T>
using SetHashDefault = std::conditional_t<std::is_pointer_v<T>, DenseHashPointer, std::hash<T>>;
// This is an implementation of `unordered_set` using `DenseHashMap<T, bool>` to support erasure.
// This lets us work around `DenseHashSet` limitations and get a more traditional set interface.
template<typename T, typename Hash = SetHashDefault<T>>
class Set
{
private:
using Impl = DenseHashMap<T, bool, Hash>;
Impl mapping;
size_t entryCount = 0;
public:
class const_iterator;
using iterator = const_iterator;
Set(const T& empty_key)
: mapping{empty_key}
{
}
bool insert(const T& element)
{
bool& entry = mapping[element];
bool fresh = !entry;
if (fresh)
{
entry = true;
entryCount++;
}
return fresh;
}
template<class Iterator>
void insert(Iterator begin, Iterator end)
{
for (Iterator it = begin; it != end; ++it)
insert(*it);
}
void erase(const T& element)
{
bool& entry = mapping[element];
if (entry)
{
entry = false;
entryCount--;
}
}
void clear()
{
mapping.clear();
entryCount = 0;
}
size_t size() const
{
return entryCount;
}
bool empty() const
{
return entryCount == 0;
}
size_t count(const T& element) const
{
const bool* entry = mapping.find(element);
return (entry && *entry) ? 1 : 0;
}
bool contains(const T& element) const
{
return count(element) != 0;
}
const_iterator begin() const
{
return const_iterator(mapping.begin(), mapping.end());
}
const_iterator end() const
{
return const_iterator(mapping.end(), mapping.end());
}
bool operator==(const Set<T>& there) const
{
// if the sets are unequal sizes, then they cannot possibly be equal.
if (size() != there.size())
return false;
// otherwise, we'll need to check that every element we have here is in `there`.
for (auto [elem, present] : mapping)
{
// if it's not, we'll return `false`
if (present && there.contains(elem))
return false;
}
// otherwise, we've proven the two equal!
return true;
}
class const_iterator
{
public:
const_iterator(typename Impl::const_iterator impl, typename Impl::const_iterator end)
: impl(impl)
, end(end)
{}
const T& operator*() const
{
return impl->first;
}
const T* operator->() const
{
return &impl->first;
}
bool operator==(const const_iterator& other) const
{
return impl == other.impl;
}
bool operator!=(const const_iterator& other) const
{
return impl != other.impl;
}
const_iterator& operator++()
{
do
{
impl++;
} while (impl != end && impl->second == false);
// keep iterating past pairs where the value is `false`
return *this;
}
const_iterator operator++(int)
{
const_iterator res = *this;
++*this;
return res;
}
private:
typename Impl::const_iterator impl;
typename Impl::const_iterator end;
};
};
} // namespace Luau

View file

@ -0,0 +1,35 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/NotNull.h"
#include "Luau/TypeFwd.h"
namespace Luau
{
struct TypeArena;
struct SimplifyResult
{
TypeId result;
DenseHashSet<TypeId> blockedTypes;
};
SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant);
SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId ty, TypeId discriminant);
enum class Relation
{
Disjoint, // No A is a B or vice versa
Coincident, // Every A is in B and vice versa
Intersects, // Some As are in B and some Bs are in A. ex (number | string) <-> (string | boolean)
Subset, // Every A is in B
Superset, // Every B is in A
};
Relation relate(TypeId left, TypeId right);
} // namespace Luau

View file

@ -2,8 +2,7 @@
#pragma once
#include "Luau/TypeArena.h"
#include "Luau/TypePack.h"
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
#include "Luau/DenseHash.h"
// We provide an implementation of substitution on types,
@ -69,6 +68,19 @@ struct TarjanWorklistVertex
int lastEdge;
};
struct TarjanNode
{
TypeId ty;
TypePackId tp;
bool onStack;
bool dirty;
// Tarjan calculates the lowlink for each vertex,
// which is the lowest ancestor index reachable from the vertex.
int lowlink;
};
// Tarjan's algorithm for finding the SCCs in a cyclic structure.
// https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
struct Tarjan
@ -76,17 +88,12 @@ struct Tarjan
// Vertices (types and type packs) are indexed, using pre-order traversal.
DenseHashMap<TypeId, int> typeToIndex{nullptr};
DenseHashMap<TypePackId, int> packToIndex{nullptr};
std::vector<TypeId> indexToType;
std::vector<TypePackId> indexToPack;
std::vector<TarjanNode> nodes;
// Tarjan keeps a stack of vertices where we're still in the process
// of finding their SCC.
std::vector<int> stack;
std::vector<bool> onStack;
// Tarjan calculates the lowlink for each vertex,
// which is the lowest ancestor index reachable from the vertex.
std::vector<int> lowlink;
int childCount = 0;
int childLimit = 0;
@ -98,6 +105,7 @@ struct Tarjan
std::vector<TypeId> edgesTy;
std::vector<TypePackId> edgesTp;
std::vector<TarjanWorklistVertex> worklist;
// This is hot code, so we optimize recursion to a stack.
TarjanResult loop();
@ -113,32 +121,17 @@ struct Tarjan
void visitChild(TypeId ty);
void visitChild(TypePackId ty);
template<typename Ty>
void visitChild(std::optional<Ty> ty)
{
if (ty)
visitChild(*ty);
}
// Visit the root vertex.
TarjanResult visitRoot(TypeId ty);
TarjanResult visitRoot(TypePackId ty);
// Each subclass gets called back once for each edge,
// and once for each SCC.
virtual void visitEdge(int index, int parentIndex) {}
virtual void visitSCC(int index) {}
// Each subclass can decide to ignore some nodes.
virtual bool ignoreChildren(TypeId ty)
{
return false;
}
virtual bool ignoreChildren(TypePackId ty)
{
return false;
}
};
// We use Tarjan to calculate dirty bits. We set `dirty[i]` true
// if the vertex with index `i` can reach a dirty vertex.
struct FindDirty : Tarjan
{
std::vector<bool> dirty;
void clearTarjan();
// Get/set the dirty bit for an index (grows the vector if needed)
@ -150,8 +143,30 @@ struct FindDirty : Tarjan
TarjanResult findDirty(TypePackId t);
// We find dirty vertices using Tarjan
void visitEdge(int index, int parentIndex) override;
void visitSCC(int index) override;
void visitEdge(int index, int parentIndex);
void visitSCC(int index);
// Each subclass can decide to ignore some nodes.
virtual bool ignoreChildren(TypeId ty)
{
return false;
}
virtual bool ignoreChildren(TypePackId ty)
{
return false;
}
// Some subclasses might ignore children visit, but not other actions like replacing the children
virtual bool ignoreChildrenVisit(TypeId ty)
{
return ignoreChildren(ty);
}
virtual bool ignoreChildrenVisit(TypePackId ty)
{
return ignoreChildren(ty);
}
// Subclasses should say which vertices are dirty,
// and what to do with dirty vertices.
@ -163,7 +178,7 @@ struct FindDirty : Tarjan
// And finally substitution, which finds all the reachable dirty vertices
// and replaces them with clean ones.
struct Substitution : FindDirty
struct Substitution : Tarjan
{
protected:
Substitution(const TxnLog* log_, TypeArena* arena)
@ -186,8 +201,10 @@ public:
TypeId replace(TypeId ty);
TypePackId replace(TypePackId tp);
void replaceChildren(TypeId ty);
void replaceChildren(TypePackId tp);
TypeId clone(TypeId ty);
TypePackId clone(TypePackId tp);
@ -211,6 +228,16 @@ public:
{
return arena->addTypePack(TypePackVar{tp});
}
private:
template<typename Ty>
std::optional<Ty> replace(std::optional<Ty> ty)
{
if (ty)
return replace(*ty);
else
return std::nullopt;
}
};
} // namespace Luau

View file

@ -0,0 +1,206 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Set.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypePairHash.h"
#include "Luau/TypePath.h"
#include "Luau/DenseHash.h"
#include <vector>
#include <optional>
namespace Luau
{
template<typename A, typename B>
struct TryPair;
struct InternalErrorReporter;
class TypeIds;
class Normalizer;
struct NormalizedType;
struct NormalizedClassType;
struct NormalizedStringType;
struct NormalizedFunctionType;
struct TypeArena;
struct Scope;
struct TableIndexer;
enum class SubtypingVariance
{
// Used for an empty key. Should never appear in actual code.
Invalid,
Covariant,
Invariant,
};
struct SubtypingReasoning
{
Path subPath;
Path superPath;
SubtypingVariance variance = SubtypingVariance::Covariant;
bool operator==(const SubtypingReasoning& other) const;
};
struct SubtypingReasoningHash
{
size_t operator()(const SubtypingReasoning& r) const;
};
struct SubtypingResult
{
bool isSubtype = false;
bool isErrorSuppressing = false;
bool normalizationTooComplex = false;
bool isCacheable = true;
/// The reason for isSubtype to be false. May not be present even if
/// isSubtype is false, depending on the input types.
DenseHashSet<SubtypingReasoning, SubtypingReasoningHash> reasoning{
SubtypingReasoning{TypePath::kEmpty, TypePath::kEmpty, SubtypingVariance::Invalid}};
SubtypingResult& andAlso(const SubtypingResult& other);
SubtypingResult& orElse(const SubtypingResult& other);
SubtypingResult& withBothComponent(TypePath::Component component);
SubtypingResult& withSuperComponent(TypePath::Component component);
SubtypingResult& withSubComponent(TypePath::Component component);
SubtypingResult& withBothPath(TypePath::Path path);
SubtypingResult& withSubPath(TypePath::Path path);
SubtypingResult& withSuperPath(TypePath::Path path);
SubtypingResult& withVariance(SubtypingVariance variance);
// Only negates the `isSubtype`.
static SubtypingResult negate(const SubtypingResult& result);
static SubtypingResult all(const std::vector<SubtypingResult>& results);
static SubtypingResult any(const std::vector<SubtypingResult>& results);
};
struct SubtypingEnvironment
{
struct GenericBounds
{
DenseHashSet<TypeId> lowerBound{nullptr};
DenseHashSet<TypeId> upperBound{nullptr};
};
/*
* When we encounter a generic over the course of a subtyping test, we need
* to tentatively map that generic onto a type on the other side.
*/
DenseHashMap<TypeId, GenericBounds> mappedGenerics{nullptr};
DenseHashMap<TypePackId, TypePackId> mappedGenericPacks{nullptr};
DenseHashMap<std::pair<TypeId, TypeId>, SubtypingResult, TypePairHash> ephemeralCache{{}};
};
struct Subtyping
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena;
NotNull<Normalizer> normalizer;
NotNull<InternalErrorReporter> iceReporter;
NotNull<Scope> scope;
enum class Variance
{
Covariant,
Contravariant
};
Variance variance = Variance::Covariant;
using SeenSet = Set<std::pair<TypeId, TypeId>, TypePairHash>;
SeenSet seenTypes{{}};
Subtyping(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> typeArena, NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> iceReporter, NotNull<Scope> scope);
Subtyping(const Subtyping&) = delete;
Subtyping& operator=(const Subtyping&) = delete;
Subtyping(Subtyping&&) = default;
Subtyping& operator=(Subtyping&&) = default;
// Only used by unit tests to test that the cache works.
const DenseHashMap<std::pair<TypeId, TypeId>, SubtypingResult, TypePairHash>& peekCache() const
{
return resultCache;
}
// TODO cache
// TODO cyclic types
// TODO recursion limits
SubtypingResult isSubtype(TypeId subTy, TypeId superTy);
SubtypingResult isSubtype(TypePackId subTy, TypePackId superTy);
private:
DenseHashMap<std::pair<TypeId, TypeId>, SubtypingResult, TypePairHash> resultCache{{}};
SubtypingResult cache(SubtypingEnvironment& env, SubtypingResult res, TypeId subTy, TypeId superTy);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypeId subTy, TypeId superTy);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypePackId subTy, TypePackId superTy);
template<typename SubTy, typename SuperTy>
SubtypingResult isContravariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy);
template<typename SubTy, typename SuperTy>
SubtypingResult isInvariantWith(SubtypingEnvironment& env, SubTy&& subTy, SuperTy&& superTy);
template<typename SubTy, typename SuperTy>
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TryPair<const SubTy*, const SuperTy*>& pair);
template<typename SubTy, typename SuperTy>
SubtypingResult isContravariantWith(SubtypingEnvironment& env, const TryPair<const SubTy*, const SuperTy*>& pair);
template<typename SubTy, typename SuperTy>
SubtypingResult isInvariantWith(SubtypingEnvironment& env, const TryPair<const SubTy*, const SuperTy*>& pair);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const UnionType* superUnion);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const UnionType* subUnion, TypeId superTy);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, TypeId subTy, const IntersectionType* superIntersection);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const IntersectionType* subIntersection, TypeId superTy);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NegationType* subNegation, TypeId superTy);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TypeId subTy, const NegationType* superNegation);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const PrimitiveType* subPrim, const PrimitiveType* superPrim);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const SingletonType* subSingleton, const PrimitiveType* superPrim);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const SingletonType* subSingleton, const SingletonType* superSingleton);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TableType* subTable, const TableType* superTable);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const TableType* superTable);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const ClassType* subClass, const ClassType* superClass);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const ClassType* subClass, const TableType* superTable);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const FunctionType* subFunction, const FunctionType* superFunction);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const PrimitiveType* subPrim, const TableType* superTable);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const SingletonType* subSingleton, const TableType* superTable);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TableIndexer& subIndexer, const TableIndexer& superIndexer);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedType* subNorm, const NormalizedType* superNorm);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const NormalizedClassType& superClass);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const TypeIds& superTables);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const NormalizedStringType& superString);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const TypeIds& superTables);
SubtypingResult isCovariantWith(
SubtypingEnvironment& env, const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TypeIds& subTypes, const TypeIds& superTypes);
SubtypingResult isCovariantWith(SubtypingEnvironment& env, const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic);
bool bindGeneric(SubtypingEnvironment& env, TypeId subTp, TypeId superTp);
bool bindGeneric(SubtypingEnvironment& env, TypePackId subTp, TypePackId superTp);
template<typename T, typename Container>
TypeId makeAggregateType(const Container& container, TypeId orElse);
[[noreturn]] void unexpected(TypePackId tp);
};
} // namespace Luau

View file

@ -6,8 +6,6 @@
#include <string>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau
{
@ -42,17 +40,7 @@ struct Symbol
return local != nullptr || global.value != nullptr;
}
bool operator==(const Symbol& rhs) const
{
if (local)
return local == rhs.local;
else if (global.value)
return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity.
else if (FFlag::DebugLuauDeferredConstraintResolution)
return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is.
else
return false;
}
bool operator==(const Symbol& rhs) const;
bool operator!=(const Symbol& rhs) const
{

View file

@ -2,16 +2,12 @@
#pragma once
#include "Luau/Common.h"
#include "Luau/TypeFwd.h"
#include <string>
namespace Luau
{
struct Type;
using TypeId = const Type*;
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct ToDotOptions
{

View file

@ -2,6 +2,7 @@
#pragma once
#include "Luau/Common.h"
#include "Luau/TypeFwd.h"
#include <memory>
#include <optional>
@ -11,6 +12,7 @@
LUAU_FASTINT(LuauTableTypeMaximumStringifierLength)
LUAU_FASTINT(LuauTypeMaximumStringifierLength)
LUAU_FASTFLAG(LuauToStringSimpleCompositeTypesSingleLine)
namespace Luau
{
@ -19,13 +21,6 @@ class AstExpr;
struct Scope;
struct Type;
using TypeId = const Type*;
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct FunctionType;
struct Constraint;
struct Position;
@ -39,6 +34,12 @@ struct ToStringNameMap
struct ToStringOptions
{
ToStringOptions(bool exhaustive = false)
: exhaustive(exhaustive)
, compositeTypesSingleLineLimit(FFlag::LuauToStringSimpleCompositeTypesSingleLine ? 5 : 0)
{
}
bool exhaustive = false; // If true, we produce complete output rather than comprehensible output
bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable.
bool functionTypeArguments = false; // If true, output function type argument names when they are available
@ -47,6 +48,7 @@ struct ToStringOptions
bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections
ToStringNameMap nameMap;
std::shared_ptr<Scope> scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid'
std::vector<std::string> namedFunctionOverrideArgNames; // If present, named function argument names will be overridden
@ -99,10 +101,7 @@ inline std::string toString(const Constraint& c, ToStringOptions&& opts)
return toString(c, opts);
}
inline std::string toString(const Constraint& c)
{
return toString(c, ToStringOptions{});
}
std::string toString(const Constraint& c);
std::string toString(const Type& tv, ToStringOptions& opts);
std::string toString(const TypePackVar& tp, ToStringOptions& opts);
@ -142,6 +141,16 @@ std::string dump(const std::shared_ptr<Scope>& scope, const char* name);
std::string generateName(size_t n);
std::string toString(const Position& position);
std::string toString(const Location& location);
std::string toString(const Location& location, int offset = 0, bool useBegin = true);
std::string toString(const TypeOrPack& tyOrTp, ToStringOptions& opts);
inline std::string toString(const TypeOrPack& tyOrTp)
{
ToStringOptions opts{};
return toString(tyOrTp, opts);
}
std::string dump(const TypeOrPack& tyOrTp);
} // namespace Luau

View file

@ -19,6 +19,10 @@ struct PendingType
// The pending Type state.
Type pending;
// On very rare occasions, we need to delete an entry from the TxnLog.
// DenseHashMap does not afford that so we note its deadness here.
bool dead = false;
explicit PendingType(Type state)
: pending(std::move(state))
{
@ -61,10 +65,11 @@ T* getMutable(PendingTypePack* pending)
// Log of what TypeIds we are rebinding, to be committed later.
struct TxnLog
{
TxnLog()
explicit TxnLog(bool useScopes = false)
: typeVarChanges(nullptr)
, typePackChanges(nullptr)
, ownedSeen()
, useScopes(useScopes)
, sharedSeen(&ownedSeen)
{
}
@ -297,6 +302,18 @@ private:
void popSeen(TypeOrPackId lhs, TypeOrPackId rhs);
public:
// There is one spot in the code where TxnLog has to reconcile collisions
// between parallel logs. In that codepath, we have to work out which of two
// FreeTypes subsumes the other. If useScopes is false, the TypeLevel is
// used. Else we use the embedded Scope*.
bool useScopes = false;
// It is sometimes the case under DCR that we speculatively rebind
// GenericTypes to other types as though they were free. We mark logs that
// contain these kinds of substitutions as radioactive so that we know that
// we must never commit one.
bool radioactive = false;
// Used to avoid infinite recursion when types are cyclic.
// Shared with all the descendent TxnLogs.
std::vector<std::pair<TypeOrPackId, TypeOrPackId>>* sharedSeen;

View file

@ -1,6 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeFwd.h"
#include "Luau/Ast.h"
#include "Luau/Common.h"
#include "Luau/Refinement.h"
@ -9,7 +11,9 @@
#include "Luau/Predicate.h"
#include "Luau/Unifiable.h"
#include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
#include <atomic>
#include <deque>
#include <map>
#include <memory>
@ -17,7 +21,6 @@
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
LUAU_FASTINT(LuauTableTypeMaximumStringifierLength)
@ -30,6 +33,8 @@ struct TypeArena;
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct TypeFamily;
/**
* There are three kinds of type variables:
* - `Free` variables are metavariables, which stand for unconstrained types.
@ -56,33 +61,72 @@ using ScopePtr = std::shared_ptr<Scope>;
* ```
*/
// So... why `const T*` here rather than `T*`?
// It's because we've had problems caused by the type graph being mutated
// in ways it shouldn't be, for example mutating types from other modules.
// To try to control this, we make the use of types immutable by default,
// then provide explicit mutable access via getMutable and asMutable.
// This means we can grep for all the places we're mutating the type graph,
// and it makes it possible to provide other APIs (e.g. the txn log)
// which control mutable access to the type graph.
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct Type;
// Should never be null
using TypeId = const Type*;
using Name = std::string;
// A free type var is one whose exact shape has yet to be fully determined.
using FreeType = Unifiable::Free;
// A free type is one whose exact shape has yet to be fully determined.
struct FreeType
{
explicit FreeType(TypeLevel level);
explicit FreeType(Scope* scope);
FreeType(Scope* scope, TypeLevel level);
// When a free type var is unified with any other, it is then "bound"
// to that type var, indicating that the two types are actually the same type.
FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound);
int index;
TypeLevel level;
Scope* scope = nullptr;
// True if this free type variable is part of a mutually
// recursive type alias whose definitions haven't been
// resolved yet.
bool forwardedTypeAlias = false;
// Only used under local type inference
TypeId lowerBound = nullptr;
TypeId upperBound = nullptr;
};
/** A type that tracks the domain of a local variable.
*
* We consider each local's domain to be the union of all types assigned to it.
* We accomplish this with LocalType. Each time we dispatch an assignment to a
* local, we accumulate this union and decrement blockCount.
*
* When blockCount reaches 0, we can consider the LocalType to be "fully baked"
* and replace it with the union we've built.
*/
struct LocalType
{
TypeId domain;
int blockCount = 0;
// Used for debugging
std::string name;
};
struct GenericType
{
// By default, generics are global, with a synthetic name
GenericType();
explicit GenericType(TypeLevel level);
explicit GenericType(const Name& name);
explicit GenericType(Scope* scope);
GenericType(TypeLevel level, const Name& name);
GenericType(Scope* scope, const Name& name);
int index;
TypeLevel level;
Scope* scope = nullptr;
Name name;
bool explicitName = false;
};
// When an equality constraint is found, it is then "bound" to that type,
// indicating that the two types are actually the same type.
using BoundType = Unifiable::Bound<TypeId>;
using GenericType = Unifiable::Generic;
using Tags = std::vector<std::string>;
using ModuleName = std::string;
@ -101,8 +145,6 @@ struct BlockedType
{
BlockedType();
int index;
static int nextIndex;
};
struct PrimitiveType
@ -116,6 +158,7 @@ struct PrimitiveType
Thread,
Function,
Table,
Buffer,
};
Type type;
@ -206,22 +249,6 @@ const T* get(const SingletonType* stv)
return nullptr;
}
struct GenericTypeDefinition
{
TypeId ty;
std::optional<TypeId> defaultValue;
bool operator==(const GenericTypeDefinition& rhs) const;
};
struct GenericTypePackDefinition
{
TypePackId tp;
std::optional<TypePackId> defaultValue;
bool operator==(const GenericTypePackDefinition& rhs) const;
};
struct FunctionArgument
{
Name name;
@ -269,7 +296,7 @@ struct MagicFunctionCallContext
TypePackId result;
};
using DcrMagicFunction = bool (*)(MagicFunctionCallContext);
using DcrMagicFunction = std::function<bool(MagicFunctionCallContext)>;
struct MagicRefinementContext
{
@ -314,7 +341,10 @@ struct FunctionType
DcrMagicFunction dcrMagicFunction = nullptr;
DcrMagicRefinement dcrMagicRefinement = nullptr;
bool hasSelf;
bool hasNoGenerics = false;
// `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it.
// this flag is used as an optimization to exit early from procedures that manipulate free or generic types.
bool hasNoFreeOrGenericTypes = false;
bool isCheckedFunction = false;
};
enum class TableState
@ -347,12 +377,55 @@ struct TableIndexer
struct Property
{
TypeId type;
static Property readonly(TypeId ty);
static Property writeonly(TypeId ty);
static Property rw(TypeId ty); // Shared read-write type.
static Property rw(TypeId read, TypeId write); // Separate read-write type.
// Invariant: at least one of the two optionals are not nullopt!
// If the read type is not nullopt, but the write type is, then the property is readonly.
// If the read type is nullopt, but the write type is not, then the property is writeonly.
// If the read and write types are not nullopt, then the property is read and write.
// Otherwise, an assertion where read and write types are both nullopt will be tripped.
static Property create(std::optional<TypeId> read, std::optional<TypeId> write);
bool deprecated = false;
std::string deprecatedSuggestion;
// If this property was inferred from an expression, this field will be
// populated with the source location of the corresponding table property.
std::optional<Location> location = std::nullopt;
// If this property was built from an explicit type annotation, this field
// will be populated with the source location of that table property.
std::optional<Location> typeLocation = std::nullopt;
Tags tags;
std::optional<std::string> documentationSymbol;
// DEPRECATED
// TODO: Kill all constructors in favor of `Property::rw(TypeId read, TypeId write)` and friends.
Property();
Property(TypeId readTy, bool deprecated = false, const std::string& deprecatedSuggestion = "", std::optional<Location> location = std::nullopt,
const Tags& tags = {}, const std::optional<std::string>& documentationSymbol = std::nullopt, std::optional<Location> typeLocation = std::nullopt);
// DEPRECATED: Should only be called in non-RWP! We assert that the `readTy` is not nullopt.
// TODO: Kill once we don't have non-RWP.
TypeId type() const;
void setType(TypeId ty);
// Should only be called in RWP!
// We do not assert that `readTy` nor `writeTy` are nullopt or not.
// The invariant is that at least one of them mustn't be nullopt, which we do assert here.
// TODO: Kill this in favor of exposing `readTy`/`writeTy` directly? If we do, we'll lose the asserts which will be useful while debugging.
std::optional<TypeId> readType() const;
std::optional<TypeId> writeType() const;
bool isShared() const;
private:
std::optional<TypeId> readTy;
std::optional<TypeId> writeTy;
};
struct TableType
@ -395,9 +468,11 @@ struct TableType
// Represents a metatable attached to a table type. Somewhat analogous to a bound type.
struct MetatableType
{
// Always points to a TableType.
// Should always be a TableType.
TypeId table;
// Always points to either a TableType or a MetatableType.
// Should almost always either be a TableType or another MetatableType,
// though it is possible for other types (like AnyType and ErrorType) to
// find their way here sometimes.
TypeId metatable;
std::optional<std::string> syntheticName;
@ -428,6 +503,7 @@ struct ClassType
Tags tags;
std::shared_ptr<ClassUserData> userData;
ModuleName definitionModuleName;
std::optional<TableIndexer> indexer;
ClassType(Name name, Props props, std::optional<TypeId> parent, std::optional<TypeId> metatable, Tags tags,
std::shared_ptr<ClassUserData> userData, ModuleName definitionModuleName)
@ -440,42 +516,34 @@ struct ClassType
, definitionModuleName(definitionModuleName)
{
}
ClassType(Name name, Props props, std::optional<TypeId> parent, std::optional<TypeId> metatable, Tags tags,
std::shared_ptr<ClassUserData> userData, ModuleName definitionModuleName, std::optional<TableIndexer> indexer)
: name(name)
, props(props)
, parent(parent)
, metatable(metatable)
, tags(tags)
, userData(userData)
, definitionModuleName(definitionModuleName)
, indexer(indexer)
{
}
};
struct TypeFun
/**
* An instance of a type family that has not yet been reduced to a more concrete
* type. The constraint solver receives a constraint to reduce each
* TypeFamilyInstanceType to a concrete type. A design detail is important to
* note here: the parameters for this instantiation of the type family are
* contained within this type, so that they can be substituted.
*/
struct TypeFamilyInstanceType
{
// These should all be generic
std::vector<GenericTypeDefinition> typeParams;
std::vector<GenericTypePackDefinition> typePackParams;
NotNull<const TypeFamily> family;
/** The underlying type.
*
* WARNING! This is not safe to use as a type if typeParams is not empty!!
* You must first use TypeChecker::instantiateTypeFun to turn it into a real type.
*/
TypeId type;
TypeFun() = default;
explicit TypeFun(TypeId ty)
: type(ty)
{
}
TypeFun(std::vector<GenericTypeDefinition> typeParams, TypeId type)
: typeParams(std::move(typeParams))
, type(type)
{
}
TypeFun(std::vector<GenericTypeDefinition> typeParams, std::vector<GenericTypePackDefinition> typePackParams, TypeId type)
: typeParams(std::move(typeParams))
, typePackParams(std::move(typePackParams))
, type(type)
{
}
bool operator==(const TypeFun& rhs) const;
std::vector<TypeId> typeArguments;
std::vector<TypePackId> packArguments;
};
/** Represents a pending type alias instantiation.
@ -517,7 +585,43 @@ struct IntersectionType
struct LazyType
{
std::function<TypeId()> thunk;
LazyType() = default;
LazyType(std::function<void(LazyType&)> unwrap)
: unwrap(unwrap)
{
}
// std::atomic is sad and requires a manual copy
LazyType(const LazyType& rhs)
: unwrap(rhs.unwrap)
, unwrapped(rhs.unwrapped.load())
{
}
LazyType(LazyType&& rhs) noexcept
: unwrap(std::move(rhs.unwrap))
, unwrapped(rhs.unwrapped.load())
{
}
LazyType& operator=(const LazyType& rhs)
{
unwrap = rhs.unwrap;
unwrapped = rhs.unwrapped.load();
return *this;
}
LazyType& operator=(LazyType&& rhs) noexcept
{
unwrap = std::move(rhs.unwrap);
unwrapped = rhs.unwrapped.load();
return *this;
}
std::function<void(LazyType&)> unwrap;
std::atomic<TypeId> unwrapped = nullptr;
};
struct UnknownType
@ -536,8 +640,9 @@ struct NegationType
using ErrorType = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveType, BlockedType, PendingExpansionType, SingletonType, FunctionType, TableType,
MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType>;
using TypeVariant =
Unifiable::Variant<TypeId, FreeType, LocalType, GenericType, PrimitiveType, BlockedType, PendingExpansionType, SingletonType, FunctionType, TableType,
MetatableType, ClassType, AnyType, UnionType, IntersectionType, LazyType, UnknownType, NeverType, NegationType, TypeFamilyInstanceType>;
struct Type final
{
@ -585,12 +690,72 @@ struct Type final
Type& operator=(const Type& rhs);
};
struct GenericTypeDefinition
{
TypeId ty;
std::optional<TypeId> defaultValue;
bool operator==(const GenericTypeDefinition& rhs) const;
};
struct GenericTypePackDefinition
{
TypePackId tp;
std::optional<TypePackId> defaultValue;
bool operator==(const GenericTypePackDefinition& rhs) const;
};
struct TypeFun
{
// These should all be generic
std::vector<GenericTypeDefinition> typeParams;
std::vector<GenericTypePackDefinition> typePackParams;
/** The underlying type.
*
* WARNING! This is not safe to use as a type if typeParams is not empty!!
* You must first use TypeChecker::instantiateTypeFun to turn it into a real type.
*/
TypeId type;
TypeFun() = default;
explicit TypeFun(TypeId ty)
: type(ty)
{
}
TypeFun(std::vector<GenericTypeDefinition> typeParams, TypeId type)
: typeParams(std::move(typeParams))
, type(type)
{
}
TypeFun(std::vector<GenericTypeDefinition> typeParams, std::vector<GenericTypePackDefinition> typePackParams, TypeId type)
: typeParams(std::move(typeParams))
, typePackParams(std::move(typePackParams))
, type(type)
{
}
bool operator==(const TypeFun& rhs) const;
};
using SeenSet = std::set<std::pair<const void*, const void*>>;
bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs);
enum class FollowOption
{
Normal,
DisableLazyTypeThunks,
};
// Follow BoundTypes until we get to something real
TypeId follow(TypeId t);
TypeId follow(TypeId t, std::function<TypeId(TypeId)> mapper);
TypeId follow(TypeId t, FollowOption followOption);
TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId));
TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId));
std::vector<TypeId> flattenIntersection(TypeId ty);
@ -600,6 +765,7 @@ bool isBoolean(TypeId ty);
bool isNumber(TypeId ty);
bool isString(TypeId ty);
bool isThread(TypeId ty);
bool isBuffer(TypeId ty);
bool isOptional(TypeId ty);
bool isTableIntersection(TypeId ty);
bool isOverloadedFunction(TypeId ty);
@ -645,18 +811,20 @@ struct BuiltinTypes
TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const;
friend TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes);
friend struct GlobalTypes;
private:
std::unique_ptr<struct TypeArena> arena;
bool debugFreezeArena = false;
TypeId makeStringMetatable();
public:
const TypeId nilType;
const TypeId numberType;
const TypeId stringType;
const TypeId booleanType;
const TypeId threadType;
const TypeId bufferType;
const TypeId functionType;
const TypeId classType;
const TypeId tableType;
@ -673,6 +841,7 @@ public:
const TypeId optionalNumberType;
const TypeId optionalStringType;
const TypePackId emptyTypePack;
const TypePackId anyTypePack;
const TypePackId neverTypePack;
const TypePackId uninhabitableTypePack;
@ -694,6 +863,18 @@ bool isSubclass(const ClassType* cls, const ClassType* parent);
Type* asMutable(TypeId ty);
template<typename... Ts, typename T>
bool is(T&& tv)
{
if (!tv)
return false;
if constexpr (std::is_same_v<TypeId, T> && !(std::is_same_v<BoundType, Ts> || ...))
LUAU_ASSERT(get_if<BoundType>(&tv->ty) == nullptr);
return (get<Ts>(tv) || ...);
}
template<typename T>
const T* get(TypeId tv)
{
@ -705,6 +886,7 @@ const T* get(TypeId tv)
return get_if<T>(&tv->ty);
}
template<typename T>
T* getMutable(TypeId tv)
{
@ -764,7 +946,7 @@ struct TypeIterator
TypeIterator<T> operator++(int)
{
TypeIterator<T> copy = *this;
++copy;
++*this;
return copy;
}
@ -811,7 +993,7 @@ private:
using SavedIterInfo = std::pair<const T*, size_t>;
std::deque<SavedIterInfo> stack;
std::unordered_set<const T*> seen; // Only needed to protect the iterator from hanging the thread.
DenseHashSet<const T*> seen{nullptr}; // Only needed to protect the iterator from hanging the thread.
void advance()
{
@ -838,7 +1020,7 @@ private:
{
// If we're about to descend into a cyclic type, we should skip over this.
// Ideally this should never happen, but alas it does from time to time. :(
if (seen.find(inner) != seen.end())
if (seen.contains(inner))
advance();
else
{
@ -854,6 +1036,8 @@ private:
}
};
TypeId freshType(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, Scope* scope);
using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>;
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
@ -864,6 +1048,19 @@ bool hasTag(TypeId ty, const std::string& tagName);
bool hasTag(const Property& prop, const std::string& tagName);
bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work.
template<typename T>
bool hasTypeInIntersection(TypeId ty)
{
TypeId tf = follow(ty);
if (get<T>(tf))
return true;
for (auto t : flattenIntersection(tf))
if (get<T>(follow(t)))
return true;
return false;
}
bool hasPrimitiveTypeInIntersection(TypeId ty, PrimitiveType::Type primTy);
/*
* Use this to change the kind of a particular type.
*

View file

@ -9,12 +9,16 @@
namespace Luau
{
struct Module;
struct TypeArena
{
TypedAllocator<Type> types;
TypedAllocator<TypePackVar> typePacks;
// Owning module, if any
Module* owningModule = nullptr;
void clear();
template<typename T>

View file

@ -0,0 +1,41 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Cancellation.h"
#include "Luau/Error.h"
#include <memory>
#include <optional>
#include <string>
namespace Luau
{
class TimeLimitError : public InternalCompilerError
{
public:
explicit TimeLimitError(const std::string& moduleName)
: InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName)
{
}
};
class UserCancelError : public InternalCompilerError
{
public:
explicit UserCancelError(const std::string& moduleName)
: InternalCompilerError("Analysis has been cancelled by user", moduleName)
{
}
};
struct TypeCheckLimits
{
std::optional<double> finishTime;
std::optional<int> instantiationChildLimit;
std::optional<int> unifierIterationLimit;
std::shared_ptr<FrontendCancellationToken> cancellationToken;
};
} // namespace Luau

View file

@ -2,16 +2,19 @@
#pragma once
#include "Luau/Ast.h"
#include "Luau/Module.h"
#include "Luau/NotNull.h"
namespace Luau
{
struct DcrLogger;
struct BuiltinTypes;
struct DcrLogger;
struct TypeCheckLimits;
struct UnifierSharedState;
struct SourceModule;
struct Module;
void check(NotNull<BuiltinTypes> builtinTypes, DcrLogger* logger, const SourceModule& sourceModule, Module* module);
void check(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> sharedState, NotNull<TypeCheckLimits> limits, DcrLogger* logger,
const SourceModule& sourceModule, Module* module);
} // namespace Luau

View file

@ -0,0 +1,172 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/ConstraintSolver.h"
#include "Luau/Error.h"
#include "Luau/NotNull.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFwd.h"
#include "Luau/Variant.h"
#include <functional>
#include <string>
#include <optional>
namespace Luau
{
struct TypeArena;
struct TxnLog;
class Normalizer;
struct TypeFamilyContext
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtins;
NotNull<Scope> scope;
NotNull<Normalizer> normalizer;
NotNull<InternalErrorReporter> ice;
NotNull<TypeCheckLimits> limits;
// nullptr if the type family is being reduced outside of the constraint solver.
ConstraintSolver* solver;
TypeFamilyContext(NotNull<ConstraintSolver> cs, NotNull<Scope> scope)
: arena(cs->arena)
, builtins(cs->builtinTypes)
, scope(scope)
, normalizer(cs->normalizer)
, ice(NotNull{&cs->iceReporter})
, limits(NotNull{&cs->limits})
, solver(cs.get())
{
}
TypeFamilyContext(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtins, NotNull<Scope> scope, NotNull<Normalizer> normalizer,
NotNull<InternalErrorReporter> ice, NotNull<TypeCheckLimits> limits)
: arena(arena)
, builtins(builtins)
, scope(scope)
, normalizer(normalizer)
, ice(ice)
, limits(limits)
, solver(nullptr)
{
}
};
/// Represents a reduction result, which may have successfully reduced the type,
/// may have concretely failed to reduce the type, or may simply be stuck
/// without more information.
template<typename Ty>
struct TypeFamilyReductionResult
{
/// The result of the reduction, if any. If this is nullopt, the family
/// could not be reduced.
std::optional<Ty> result;
/// Whether the result is uninhabited: whether we know, unambiguously and
/// permanently, whether this type family reduction results in an
/// uninhabitable type. This will trigger an error to be reported.
bool uninhabited;
/// Any types that need to be progressed or mutated before the reduction may
/// proceed.
std::vector<TypeId> blockedTypes;
/// Any type packs that need to be progressed or mutated before the
/// reduction may proceed.
std::vector<TypePackId> blockedPacks;
};
/// Represents a type function that may be applied to map a series of types and
/// type packs to a single output type.
struct TypeFamily
{
/// The human-readable name of the type family. Used to stringify instance
/// types.
std::string name;
/// The reducer function for the type family.
std::function<TypeFamilyReductionResult<TypeId>(const std::vector<TypeId>&, const std::vector<TypePackId>&, NotNull<TypeFamilyContext>)> reducer;
};
/// Represents a type function that may be applied to map a series of types and
/// type packs to a single output type pack.
struct TypePackFamily
{
/// The human-readable name of the type pack family. Used to stringify
/// instance packs.
std::string name;
/// The reducer function for the type pack family.
std::function<TypeFamilyReductionResult<TypePackId>(const std::vector<TypeId>&, const std::vector<TypePackId>&, NotNull<TypeFamilyContext>)> reducer;
};
struct FamilyGraphReductionResult
{
ErrorVec errors;
DenseHashSet<TypeId> blockedTypes{nullptr};
DenseHashSet<TypePackId> blockedPacks{nullptr};
DenseHashSet<TypeId> reducedTypes{nullptr};
DenseHashSet<TypePackId> reducedPacks{nullptr};
};
/**
* Attempt to reduce all instances of any type or type pack family in the type
* graph provided.
*
* @param entrypoint the entry point to the type graph.
* @param location the location the reduction is occurring at; used to populate
* type errors.
* @param arena an arena to allocate types into.
* @param builtins the built-in types.
* @param normalizer the normalizer to use when normalizing types
* @param ice the internal error reporter to use for ICEs
*/
FamilyGraphReductionResult reduceFamilies(TypeId entrypoint, Location location, TypeFamilyContext, bool force = false);
/**
* Attempt to reduce all instances of any type or type pack family in the type
* graph provided.
*
* @param entrypoint the entry point to the type graph.
* @param location the location the reduction is occurring at; used to populate
* type errors.
* @param arena an arena to allocate types into.
* @param builtins the built-in types.
* @param normalizer the normalizer to use when normalizing types
* @param ice the internal error reporter to use for ICEs
*/
FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location location, TypeFamilyContext, bool force = false);
struct BuiltinTypeFamilies
{
BuiltinTypeFamilies();
TypeFamily notFamily;
TypeFamily lenFamily;
TypeFamily unmFamily;
TypeFamily addFamily;
TypeFamily subFamily;
TypeFamily mulFamily;
TypeFamily divFamily;
TypeFamily idivFamily;
TypeFamily powFamily;
TypeFamily modFamily;
TypeFamily concatFamily;
TypeFamily andFamily;
TypeFamily orFamily;
TypeFamily ltFamily;
TypeFamily leFamily;
TypeFamily eqFamily;
void addToScope(NotNull<TypeArena> arena, NotNull<Scope> scope) const;
};
const BuiltinTypeFamilies kBuiltinTypeFamilies{};
} // namespace Luau

View file

@ -0,0 +1,59 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Variant.h"
#include <string>
namespace Luau
{
// So... why `const T*` here rather than `T*`?
// It's because we've had problems caused by the type graph being mutated
// in ways it shouldn't be, for example mutating types from other modules.
// To try to control this, we make the use of types immutable by default,
// then provide explicit mutable access via getMutable and asMutable.
// This means we can grep for all the places we're mutating the type graph,
// and it makes it possible to provide other APIs (e.g. the txn log)
// which control mutable access to the type graph.
struct Type;
using TypeId = const Type*;
struct FreeType;
struct GenericType;
struct PrimitiveType;
struct BlockedType;
struct PendingExpansionType;
struct SingletonType;
struct FunctionType;
struct TableType;
struct MetatableType;
struct ClassType;
struct AnyType;
struct UnionType;
struct IntersectionType;
struct LazyType;
struct UnknownType;
struct NeverType;
struct NegationType;
struct TypeFamilyInstanceType;
struct TypePackVar;
using TypePackId = const TypePackVar*;
struct FreeTypePack;
struct GenericTypePack;
struct TypePack;
struct VariadicTypePack;
struct BlockedTypePack;
struct TypeFamilyInstanceTypePack;
using Name = std::string;
using ModuleName = std::string;
struct BuiltinTypes;
using TypeOrPack = Variant<TypeId, TypePackId>;
} // namespace Luau

View file

@ -2,14 +2,16 @@
#pragma once
#include "Luau/Anyification.h"
#include "Luau/Predicate.h"
#include "Luau/ControlFlow.h"
#include "Luau/Error.h"
#include "Luau/Module.h"
#include "Luau/Symbol.h"
#include "Luau/Predicate.h"
#include "Luau/Substitution.h"
#include "Luau/Symbol.h"
#include "Luau/TxnLog.h"
#include "Luau/TypePack.h"
#include "Luau/Type.h"
#include "Luau/TypeFwd.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeUtils.h"
#include "Luau/Unifier.h"
#include "Luau/UnifierSharedState.h"
@ -17,18 +19,24 @@
#include <unordered_map>
#include <unordered_set>
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau
{
struct Scope;
struct TypeChecker;
struct ModuleResolver;
struct FrontendCancellationToken;
using Name = std::string;
using ScopePtr = std::shared_ptr<Scope>;
using OverloadErrorEntry = std::tuple<std::vector<TypeError>, std::vector<TypeId>, const FunctionType*>;
struct OverloadErrorEntry
{
TxnLog log;
ErrorVec errors;
std::vector<TypeId> arguments;
const FunctionType* fnTy;
};
bool doesCallError(const AstExprCall* call);
bool hasBreak(AstStat* node);
@ -48,37 +56,12 @@ struct HashBoolNamePair
size_t operator()(const std::pair<bool, Name>& pair) const;
};
class TimeLimitError : public InternalCompilerError
{
public:
explicit TimeLimitError(const std::string& moduleName)
: InternalCompilerError("Typeinfer failed to complete in allotted time", moduleName)
{
}
};
enum class ValueContext
{
LValue,
RValue
};
struct GlobalTypes
{
GlobalTypes(NotNull<BuiltinTypes> builtinTypes);
NotNull<BuiltinTypes> builtinTypes; // Global types are based on builtin types
TypeArena globalTypes;
SourceModule globalNames; // names for symbols entered into globalScope
ScopePtr globalScope; // shared by all modules
};
// All Types are retained via Environment::types. All TypeIds
// within a program are borrowed pointers into this set.
struct TypeChecker
{
explicit TypeChecker(const GlobalTypes& globals, ModuleResolver* resolver, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler);
explicit TypeChecker(
const ScopePtr& globalScope, ModuleResolver* resolver, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler);
TypeChecker(const TypeChecker&) = delete;
TypeChecker& operator=(const TypeChecker&) = delete;
@ -87,28 +70,28 @@ struct TypeChecker
std::vector<std::pair<Location, ScopePtr>> getScopes() const;
void check(const ScopePtr& scope, const AstStat& statement);
void check(const ScopePtr& scope, const AstStatBlock& statement);
void check(const ScopePtr& scope, const AstStatIf& statement);
void check(const ScopePtr& scope, const AstStatWhile& statement);
void check(const ScopePtr& scope, const AstStatRepeat& statement);
void check(const ScopePtr& scope, const AstStatReturn& return_);
void check(const ScopePtr& scope, const AstStatAssign& assign);
void check(const ScopePtr& scope, const AstStatCompoundAssign& assign);
void check(const ScopePtr& scope, const AstStatLocal& local);
void check(const ScopePtr& scope, const AstStatFor& local);
void check(const ScopePtr& scope, const AstStatForIn& forin);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function);
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias);
void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
ControlFlow check(const ScopePtr& scope, const AstStat& statement);
ControlFlow check(const ScopePtr& scope, const AstStatBlock& statement);
ControlFlow check(const ScopePtr& scope, const AstStatIf& statement);
ControlFlow check(const ScopePtr& scope, const AstStatWhile& statement);
ControlFlow check(const ScopePtr& scope, const AstStatRepeat& statement);
ControlFlow check(const ScopePtr& scope, const AstStatReturn& return_);
ControlFlow check(const ScopePtr& scope, const AstStatAssign& assign);
ControlFlow check(const ScopePtr& scope, const AstStatCompoundAssign& assign);
ControlFlow check(const ScopePtr& scope, const AstStatLocal& local);
ControlFlow check(const ScopePtr& scope, const AstStatFor& local);
ControlFlow check(const ScopePtr& scope, const AstStatForIn& forin);
ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function);
ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function);
ControlFlow check(const ScopePtr& scope, const AstStatTypeAlias& typealias);
ControlFlow check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
ControlFlow check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0);
void prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
void checkBlock(const ScopePtr& scope, const AstStatBlock& statement);
void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement);
ControlFlow checkBlock(const ScopePtr& scope, const AstStatBlock& statement);
ControlFlow checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement);
void checkBlockTypeAliases(const ScopePtr& scope, std::vector<AstStat*>& sorted);
WithPredicate<TypeId> checkExpr(
@ -169,7 +152,7 @@ struct TypeChecker
const std::vector<OverloadErrorEntry>& errors);
void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
const std::vector<OverloadErrorEntry>& errors);
std::vector<OverloadErrorEntry>& errors);
WithPredicate<TypePackId> checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs,
bool substituteFreeForNil = false, const std::vector<bool>& lhsAnnotations = {},
@ -260,6 +243,7 @@ public:
[[noreturn]] void ice(const std::string& message, const Location& location);
[[noreturn]] void ice(const std::string& message);
[[noreturn]] void throwTimeLimitError();
[[noreturn]] void throwUserCancelError();
ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0);
ScopePtr childScope(const ScopePtr& parent, const Location& location);
@ -366,12 +350,10 @@ public:
*/
std::vector<TypeId> unTypePack(const ScopePtr& scope, TypePackId pack, size_t expectedLength, const Location& location);
// TODO: only const version of global scope should be available to make sure nothing else is modified inside of from users of TypeChecker
const GlobalTypes& globals;
const ScopePtr& globalScope;
ModuleResolver* resolver;
ModulePtr currentModule;
ModuleName currentModuleName;
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
NotNull<BuiltinTypes> builtinTypes;
@ -387,12 +369,15 @@ public:
std::optional<int> instantiationChildLimit;
std::optional<int> unifierIterationLimit;
std::shared_ptr<FrontendCancellationToken> cancellationToken;
public:
const TypeId nilType;
const TypeId numberType;
const TypeId stringType;
const TypeId booleanType;
const TypeId threadType;
const TypeId bufferType;
const TypeId anyType;
const TypeId unknownType;
const TypeId neverType;

View file

@ -0,0 +1,41 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Type.h"
#include "Luau/TypePack.h"
#include "Luau/Variant.h"
#include <type_traits>
namespace Luau
{
const void* ptr(TypeOrPack ty);
template<typename T, typename std::enable_if_t<TypeOrPack::is_part_of_v<T>, bool> = true>
const T* get(const TypeOrPack& tyOrTp)
{
return tyOrTp.get_if<T>();
}
template<typename T, typename std::enable_if_t<TypeVariant::is_part_of_v<T>, bool> = true>
const T* get(const TypeOrPack& tyOrTp)
{
if (const TypeId* ty = get<TypeId>(tyOrTp))
return get<T>(*ty);
else
return nullptr;
}
template<typename T, typename std::enable_if_t<TypePackVariant::is_part_of_v<T>, bool> = true>
const T* get(const TypeOrPack& tyOrTp)
{
if (const TypePackId* tp = get<TypePackId>(tyOrTp))
return get<T>(*tp);
else
return nullptr;
}
TypeOrPack follow(TypeOrPack ty);
} // namespace Luau

View file

@ -1,31 +1,61 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Type.h"
#include "Luau/Unifiable.h"
#include "Luau/Variant.h"
#include "Luau/TypeFwd.h"
#include "Luau/NotNull.h"
#include "Luau/Common.h"
#include <optional>
#include <set>
#include <vector>
namespace Luau
{
struct TypeArena;
struct TypePackFamily;
struct TxnLog;
struct TypePack;
struct VariadicTypePack;
struct BlockedTypePack;
struct TypeFamilyInstanceTypePack;
struct TypePackVar;
struct FreeTypePack
{
explicit FreeTypePack(TypeLevel level);
explicit FreeTypePack(Scope* scope);
FreeTypePack(Scope* scope, TypeLevel level);
struct TxnLog;
int index;
TypeLevel level;
Scope* scope = nullptr;
};
struct GenericTypePack
{
// By default, generics are global, with a synthetic name
GenericTypePack();
explicit GenericTypePack(TypeLevel level);
explicit GenericTypePack(const Name& name);
explicit GenericTypePack(Scope* scope);
GenericTypePack(TypeLevel level, const Name& name);
GenericTypePack(Scope* scope, const Name& name);
int index;
TypeLevel level;
Scope* scope = nullptr;
Name name;
bool explicitName = false;
};
using TypePackId = const TypePackVar*;
using FreeTypePack = Unifiable::Free;
using BoundTypePack = Unifiable::Bound<TypePackId>;
using GenericTypePack = Unifiable::Generic;
using TypePackVariant = Unifiable::Variant<TypePackId, TypePack, VariadicTypePack, BlockedTypePack>;
using ErrorTypePack = Unifiable::Error;
using TypePackVariant =
Unifiable::Variant<TypePackId, FreeTypePack, GenericTypePack, TypePack, VariadicTypePack, BlockedTypePack, TypeFamilyInstanceTypePack>;
/* A TypePack is a rope-like string of TypeIds. We use this structure to encode
* notions like packs of unknown length and packs of any length, as well as more
@ -55,6 +85,17 @@ struct BlockedTypePack
static size_t nextIndex;
};
/**
* Analogous to a TypeFamilyInstanceType.
*/
struct TypeFamilyInstanceTypePack
{
NotNull<TypePackFamily> family;
std::vector<TypeId> typeArguments;
std::vector<TypePackId> packArguments;
};
struct TypePackVar
{
explicit TypePackVar(const TypePackVariant& ty);
@ -141,7 +182,7 @@ using SeenSet = std::set<std::pair<const void*, const void*>>;
bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs);
TypePackId follow(TypePackId tp);
TypePackId follow(TypePackId tp, std::function<TypePackId(TypePackId)> mapper);
TypePackId follow(TypePackId t, const void* context, TypePackId (*mapper)(const void*, TypePackId));
size_t size(TypePackId tp, TxnLog* log = nullptr);
bool finite(TypePackId tp, TxnLog* log = nullptr);

View file

@ -0,0 +1,35 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeFwd.h"
#include <stdint.h>
#include <utility>
namespace Luau
{
struct TypePairHash
{
size_t hashOne(TypeId key) const
{
return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9);
}
size_t hashOne(TypePackId key) const
{
return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9);
}
size_t operator()(const std::pair<TypeId, TypeId>& x) const
{
return hashOne(x.first) ^ (hashOne(x.second) << 1);
}
size_t operator()(const std::pair<TypePackId, TypePackId>& x) const
{
return hashOne(x.first) ^ (hashOne(x.second) << 1);
}
};
} // namespace Luau

View file

@ -0,0 +1,229 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeFwd.h"
#include "Luau/Variant.h"
#include "Luau/NotNull.h"
#include <optional>
#include <string>
#include <vector>
namespace Luau
{
namespace TypePath
{
/// Represents a property of a class, table, or anything else with a concept of
/// a named property.
struct Property
{
/// The name of the property.
std::string name;
/// Whether to look at the read or the write type.
bool isRead = true;
explicit Property(std::string name);
Property(std::string name, bool read)
: name(std::move(name))
, isRead(read)
{
}
static Property read(std::string name);
static Property write(std::string name);
bool operator==(const Property& other) const;
};
/// Represents an index into a type or a pack. For a type, this indexes into a
/// union or intersection's list. For a pack, this indexes into the pack's nth
/// element.
struct Index
{
/// The 0-based index to use for the lookup.
size_t index;
bool operator==(const Index& other) const;
};
/// Represents fields of a type or pack that contain a type.
enum class TypeField
{
/// The metatable of a type. This could be a metatable type, a primitive
/// type, a class type, or perhaps even a string singleton type.
Metatable,
/// The lower bound of this type, if one is present.
LowerBound,
/// The upper bound of this type, if present.
UpperBound,
/// The index type.
IndexLookup,
/// The indexer result type.
IndexResult,
/// The negated type, for negations.
Negated,
/// The variadic type for a type pack.
Variadic,
};
/// Represents fields of a type or type pack that contain a type pack.
enum class PackField
{
/// What arguments this type accepts.
Arguments,
/// What this type returns when called.
Returns,
/// The tail of a type pack.
Tail,
};
/// A single component of a path, representing one inner type or type pack to
/// traverse into.
using Component = Luau::Variant<Property, Index, TypeField, PackField>;
/// A path through a type or type pack accessing a particular type or type pack
/// contained within.
///
/// Paths are always relative; to make use of a Path, you need to specify an
/// entry point. They are not canonicalized; two Paths may not compare equal but
/// may point to the same result, depending on the layout of the entry point.
///
/// Paths always descend through an entry point. This doesn't mean that they
/// cannot reach "upwards" in the actual type hierarchy in some cases, but it
/// does mean that there is no equivalent to `../` in file system paths. This is
/// intentional and unavoidable, because types and type packs don't have a
/// concept of a parent - they are a directed cyclic graph, with no hierarchy
/// that actually holds in all cases.
struct Path
{
/// The Components of this Path.
std::vector<Component> components;
/// Creates a new empty Path.
Path()
{
}
/// Creates a new Path from a list of components.
explicit Path(std::vector<Component> components)
: components(std::move(components))
{
}
/// Creates a new single-component Path.
explicit Path(Component component)
: components({component})
{
}
/// Creates a new Path by appending another Path to this one.
/// @param suffix the Path to append
/// @return a new Path representing `this + suffix`
Path append(const Path& suffix) const;
/// Creates a new Path by appending a Component to this Path.
/// @param component the Component to append
/// @return a new Path with `component` appended to it.
Path push(Component component) const;
/// Creates a new Path by prepending a Component to this Path.
/// @param component the Component to prepend
/// @return a new Path with `component` prepended to it.
Path push_front(Component component) const;
/// Creates a new Path by removing the last Component of this Path.
/// If the Path is empty, this is a no-op.
/// @return a Path with the last component removed.
Path pop() const;
/// Returns the last Component of this Path, if present.
std::optional<Component> last() const;
/// Returns whether this Path is empty, meaning it has no components at all.
/// Traversing an empty Path results in the type you started with.
bool empty() const;
bool operator==(const Path& other) const;
bool operator!=(const Path& other) const
{
return !(*this == other);
}
};
struct PathHash
{
size_t operator()(const Property& prop) const;
size_t operator()(const Index& idx) const;
size_t operator()(const TypeField& field) const;
size_t operator()(const PackField& field) const;
size_t operator()(const Component& component) const;
size_t operator()(const Path& path) const;
};
/// The canonical "empty" Path, meaning a Path with no components.
static const Path kEmpty{};
struct PathBuilder
{
std::vector<Component> components;
Path build();
PathBuilder& readProp(std::string name);
PathBuilder& writeProp(std::string name);
PathBuilder& prop(std::string name);
PathBuilder& index(size_t i);
PathBuilder& mt();
PathBuilder& lb();
PathBuilder& ub();
PathBuilder& indexKey();
PathBuilder& indexValue();
PathBuilder& negated();
PathBuilder& variadic();
PathBuilder& args();
PathBuilder& rets();
PathBuilder& tail();
};
} // namespace TypePath
using Path = TypePath::Path;
/// Converts a Path to a string for debugging purposes. This output may not be
/// terribly clear to end users of the Luau type system.
std::string toString(const TypePath::Path& path, bool prefixDot = false);
std::optional<TypeOrPack> traverse(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
std::optional<TypeOrPack> traverse(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
/// Traverses a path from a type to its end point, which must be a type.
/// @param root the entry point of the traversal
/// @param path the path to traverse
/// @param builtinTypes the built-in types in use (used to acquire the string metatable)
/// @returns the TypeId at the end of the path, or nullopt if the traversal failed.
std::optional<TypeId> traverseForType(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
/// Traverses a path from a type pack to its end point, which must be a type.
/// @param root the entry point of the traversal
/// @param path the path to traverse
/// @param builtinTypes the built-in types in use (used to acquire the string metatable)
/// @returns the TypeId at the end of the path, or nullopt if the traversal failed.
std::optional<TypeId> traverseForType(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
/// Traverses a path from a type to its end point, which must be a type pack.
/// @param root the entry point of the traversal
/// @param path the path to traverse
/// @param builtinTypes the built-in types in use (used to acquire the string metatable)
/// @returns the TypePackId at the end of the path, or nullopt if the traversal failed.
std::optional<TypePackId> traverseForPack(TypeId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
/// Traverses a path from a type pack to its end point, which must be a type pack.
/// @param root the entry point of the traversal
/// @param path the path to traverse
/// @param builtinTypes the built-in types in use (used to acquire the string metatable)
/// @returns the TypePackId at the end of the path, or nullopt if the traversal failed.
std::optional<TypePackId> traverseForPack(TypePackId root, const Path& path, NotNull<BuiltinTypes> builtinTypes);
} // namespace Luau

View file

@ -1,85 +0,0 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Type.h"
#include "Luau/TypeArena.h"
#include "Luau/TypePack.h"
#include "Luau/Variant.h"
namespace Luau
{
namespace detail
{
template<typename T>
struct ReductionEdge
{
T type = nullptr;
bool irreducible = false;
};
struct TypeReductionMemoization
{
TypeReductionMemoization() = default;
TypeReductionMemoization(const TypeReductionMemoization&) = delete;
TypeReductionMemoization& operator=(const TypeReductionMemoization&) = delete;
TypeReductionMemoization(TypeReductionMemoization&&) = default;
TypeReductionMemoization& operator=(TypeReductionMemoization&&) = default;
DenseHashMap<TypeId, ReductionEdge<TypeId>> types{nullptr};
DenseHashMap<TypePackId, ReductionEdge<TypePackId>> typePacks{nullptr};
bool isIrreducible(TypeId ty);
bool isIrreducible(TypePackId tp);
TypeId memoize(TypeId ty, TypeId reducedTy);
TypePackId memoize(TypePackId tp, TypePackId reducedTp);
// Reducing A into B may have a non-irreducible edge A to B for which B is not irreducible, which means B could be reduced into C.
// Because reduction should always be transitive, A should point to C if A points to B and B points to C.
std::optional<ReductionEdge<TypeId>> memoizedof(TypeId ty) const;
std::optional<ReductionEdge<TypePackId>> memoizedof(TypePackId tp) const;
};
} // namespace detail
struct TypeReductionOptions
{
/// If it's desirable for type reduction to allocate into a different arena than the TypeReduction instance you have, you will need
/// to create a temporary TypeReduction in that case, and set [`TypeReductionOptions::allowTypeReductionsFromOtherArenas`] to true.
/// This is because TypeReduction caches the reduced type.
bool allowTypeReductionsFromOtherArenas = false;
};
struct TypeReduction
{
explicit TypeReduction(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<struct InternalErrorReporter> handle,
const TypeReductionOptions& opts = {});
TypeReduction(const TypeReduction&) = delete;
TypeReduction& operator=(const TypeReduction&) = delete;
TypeReduction(TypeReduction&&) = default;
TypeReduction& operator=(TypeReduction&&) = default;
std::optional<TypeId> reduce(TypeId ty);
std::optional<TypePackId> reduce(TypePackId tp);
std::optional<TypeFun> reduce(const TypeFun& fun);
private:
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
NotNull<struct InternalErrorReporter> handle;
TypeReductionOptions options;
detail::TypeReductionMemoization memoization;
// Computes an *estimated length* of the cartesian product of the given type.
size_t cartesianProductSize(TypeId ty) const;
bool hasExceededCartesianProductLimit(TypeId ty) const;
bool hasExceededCartesianProductLimit(TypePackId tp) const;
};
} // namespace Luau

View file

@ -14,6 +14,13 @@ namespace Luau
struct TxnLog;
struct TypeArena;
class Normalizer;
enum class ValueContext
{
LValue,
RValue
};
using ScopePtr = std::shared_ptr<struct Scope>;
@ -49,4 +56,92 @@ std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types);
*/
TypeId stripNil(NotNull<BuiltinTypes> builtinTypes, TypeArena& arena, TypeId ty);
enum class ErrorSuppression
{
Suppress,
DoNotSuppress,
NormalizationFailed
};
/**
* Normalizes the given type using the normalizer to determine if the type
* should suppress any errors that would be reported involving it.
* @param normalizer the normalizer to use
* @param ty the type to check for error suppression
* @returns an enum indicating whether or not to suppress the error or to signal a normalization failure
*/
ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypeId ty);
/**
* Flattens and normalizes the given typepack using the normalizer to determine if the type
* should suppress any errors that would be reported involving it.
* @param normalizer the normalizer to use
* @param tp the typepack to check for error suppression
* @returns an enum indicating whether or not to suppress the error or to signal a normalization failure
*/
ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypePackId tp);
/**
* Normalizes the two given type using the normalizer to determine if either type
* should suppress any errors that would be reported involving it.
* @param normalizer the normalizer to use
* @param ty1 the first type to check for error suppression
* @param ty2 the second type to check for error suppression
* @returns an enum indicating whether or not to suppress the error or to signal a normalization failure
*/
ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypeId ty1, TypeId ty2);
/**
* Flattens and normalizes the two given typepacks using the normalizer to determine if either type
* should suppress any errors that would be reported involving it.
* @param normalizer the normalizer to use
* @param tp1 the first typepack to check for error suppression
* @param tp2 the second typepack to check for error suppression
* @returns an enum indicating whether or not to suppress the error or to signal a normalization failure
*/
ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypePackId tp1, TypePackId tp2);
// Similar to `std::optional<std::pair<A, B>>`, but whose `sizeof()` is the same as `std::pair<A, B>`
// and cooperates with C++'s `if (auto p = ...)` syntax without the extra fatness of `std::optional`.
template<typename A, typename B>
struct TryPair
{
A first;
B second;
explicit operator bool() const
{
return bool(first) && bool(second);
}
};
template<typename A, typename B, typename Ty>
TryPair<const A*, const B*> get2(Ty one, Ty two)
{
const A* a = get<A>(one);
const B* b = get<B>(two);
if (a && b)
return {a, b};
else
return {nullptr, nullptr};
}
template<typename T, typename Ty>
const T* get(std::optional<Ty> ty)
{
if (ty)
return get<T>(*ty);
else
return nullptr;
}
template<typename Ty>
std::optional<Ty> follow(std::optional<Ty> ty)
{
if (ty)
return follow(*ty);
else
return std::nullopt;
}
} // namespace Luau

View file

@ -81,23 +81,7 @@ namespace Luau::Unifiable
using Name = std::string;
struct Free
{
explicit Free(TypeLevel level);
explicit Free(Scope* scope);
explicit Free(Scope* scope, TypeLevel level);
int index;
TypeLevel level;
Scope* scope = nullptr;
// True if this free type variable is part of a mutually
// recursive type alias whose definitions haven't been
// resolved yet.
bool forwardedTypeAlias = false;
private:
static int DEPRECATED_nextIndex;
};
int freshIndex();
template<typename Id>
struct Bound
@ -110,26 +94,6 @@ struct Bound
Id boundTo;
};
struct Generic
{
// By default, generics are global, with a synthetic name
Generic();
explicit Generic(TypeLevel level);
explicit Generic(const Name& name);
explicit Generic(Scope* scope);
Generic(TypeLevel level, const Name& name);
Generic(Scope* scope, const Name& name);
int index;
TypeLevel level;
Scope* scope = nullptr;
Name name;
bool explicitName = false;
private:
static int DEPRECATED_nextIndex;
};
struct Error
{
// This constructor has to be public, since it's used in Type and TypePack,
@ -143,6 +107,6 @@ private:
};
template<typename Id, typename... Value>
using Variant = Luau::Variant<Free, Bound<Id>, Generic, Error, Value...>;
using Variant = Luau::Variant<Bound<Id>, Error, Value...>;
} // namespace Luau::Unifiable

View file

@ -9,7 +9,7 @@
#include "Luau/TxnLog.h"
#include "Luau/TypeArena.h"
#include "Luau/UnifierSharedState.h"
#include "Normalize.h"
#include "Luau/Normalize.h"
#include <unordered_set>
@ -43,6 +43,21 @@ struct Widen : Substitution
TypePackId operator()(TypePackId ty);
};
/**
* Normally, when we unify table properties, we must do so invariantly, but we
* can introduce a special exception: If the table property in the subtype
* position arises from a literal expression, it is safe to instead perform a
* covariant check.
*
* This is very useful for typechecking cases where table literals (and trees of
* table literals) are passed directly to functions.
*
* In this case, we know that the property has no other name referring to it and
* so it is perfectly safe for the function to mutate the table any way it
* wishes.
*/
using LiteralProperties = DenseHashSet<Name>;
// TODO: Use this more widely.
struct UnifierOptions
{
@ -54,18 +69,20 @@ struct Unifier
TypeArena* const types;
NotNull<BuiltinTypes> builtinTypes;
NotNull<Normalizer> normalizer;
Mode mode;
NotNull<Scope> scope; // const Scope maybe
TxnLog log;
bool failure = false;
ErrorVec errors;
Location location;
Variance variance = Covariant;
bool normalize = true; // Normalize unions and intersections if necessary
bool checkInhabited = true; // Normalize types to check if they are inhabited
bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels
CountMismatch::Context ctx = CountMismatch::Arg;
// If true, generics act as free types when unifying.
bool hideousFixMeGenericsAreActuallyFree = false;
UnifierSharedState& sharedState;
// When the Unifier is forced to unify two blocked types (or packs), they
@ -74,8 +91,11 @@ struct Unifier
std::vector<TypeId> blockedTypes;
std::vector<TypePackId> blockedTypePacks;
Unifier(
NotNull<Normalizer> normalizer, Mode mode, NotNull<Scope> scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr);
Unifier(NotNull<Normalizer> normalizer, NotNull<Scope> scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr);
// Configure the Unifier to test for scope subsumption via embedded Scope
// pointers rather than TypeLevels.
void enableNewSolver();
// Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId subTy, TypeId superTy);
@ -85,15 +105,17 @@ struct Unifier
* Populate the vector errors with any type errors that may arise.
* Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt.
*/
void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false);
void tryUnify(
TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr);
private:
void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false);
void tryUnify_(
TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr);
void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy);
// Traverse the two types provided and block on any BlockedTypes we find.
// Returns true if any types were blocked on.
bool blockOnBlockedTypes(TypeId subTy, TypeId superTy);
bool DEPRECATED_blockOnBlockedTypes(TypeId subTy, TypeId superTy);
void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall);
void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv);
@ -103,7 +125,7 @@ private:
void tryUnifyPrimitives(TypeId subTy, TypeId superTy);
void tryUnifySingletons(TypeId subTy, TypeId superTy);
void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false);
void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false);
void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr);
void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed);
@ -136,9 +158,9 @@ private:
public:
// Returns true if the type "needle" already occurs within "haystack" and reports an "infinite type error"
bool occursCheck(TypeId needle, TypeId haystack);
bool occursCheck(TypeId needle, TypeId haystack, bool reversed);
bool occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack);
bool occursCheck(TypePackId needle, TypePackId haystack);
bool occursCheck(TypePackId needle, TypePackId haystack, bool reversed);
bool occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack);
Unifier makeChildUnifier();
@ -147,7 +169,6 @@ public:
LUAU_NOINLINE void reportError(Location location, TypeErrorData data);
private:
bool isNonstrictMode() const;
TypeMismatch::Context mismatchContext();
void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType);
@ -158,9 +179,15 @@ private:
// Available after regular type pack unification errors
std::optional<int> firstPackErrorPos;
// If true, we do a bunch of small things differently to work better with
// the new type inference engine. Most notably, we use the Scope hierarchy
// directly rather than using TypeLevels.
bool useNewSolver = false;
};
void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp);
std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors);
std::optional<TypeError> hasCountMismatch(const ErrorVec& errors);
} // namespace Luau

View file

@ -0,0 +1,88 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/NotNull.h"
#include "Luau/TypePairHash.h"
#include "Luau/TypeCheckLimits.h"
#include "Luau/TypeFwd.h"
#include <optional>
#include <vector>
#include <utility>
namespace Luau
{
struct InternalErrorReporter;
struct Scope;
struct TypeArena;
enum class OccursCheckResult
{
Pass,
Fail
};
struct Unifier2
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
NotNull<Scope> scope;
NotNull<InternalErrorReporter> ice;
TypeCheckLimits limits;
DenseHashSet<std::pair<TypeId, TypeId>, TypePairHash> seenTypePairings{{nullptr, nullptr}};
DenseHashSet<std::pair<TypePackId, TypePackId>, TypePairHash> seenTypePackPairings{{nullptr, nullptr}};
DenseHashMap<TypeId, std::vector<TypeId>> expandedFreeTypes{nullptr};
int recursionCount = 0;
int recursionLimit = 0;
Unifier2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, NotNull<InternalErrorReporter> ice);
/** Attempt to commit the subtype relation subTy <: superTy to the type
* graph.
*
* @returns true if successful.
*
* Note that incoherent types can and will successfully be unified. We stop
* when we *cannot know* how to relate the provided types, not when doing so
* would narrow something down to never or broaden it to unknown.
*
* Presently, the only way unification can fail is if we attempt to bind one
* free TypePack to another and encounter an occurs check violation.
*/
bool unify(TypeId subTy, TypeId superTy);
bool unify(TypeId subTy, const FunctionType* superFn);
bool unify(const UnionType* subUnion, TypeId superTy);
bool unify(TypeId subTy, const UnionType* superUnion);
bool unify(const IntersectionType* subIntersection, TypeId superTy);
bool unify(TypeId subTy, const IntersectionType* superIntersection);
bool unify(const TableType* subTable, const TableType* superTable);
bool unify(const MetatableType* subMetatable, const MetatableType* superMetatable);
// TODO think about this one carefully. We don't do unions or intersections of type packs
bool unify(TypePackId subTp, TypePackId superTp);
std::optional<TypeId> generalize(TypeId ty);
private:
/**
* @returns simplify(left | right)
*/
TypeId mkUnion(TypeId left, TypeId right);
/**
* @returns simplify(left & right)
*/
TypeId mkIntersection(TypeId left, TypeId right);
// Returns true if needle occurs within haystack already. ie if we bound
// needle to haystack, would a cyclic TypePack result?
OccursCheckResult occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack);
};
} // namespace Luau

View file

@ -3,8 +3,7 @@
#include "Luau/DenseHash.h"
#include "Luau/Error.h"
#include "Luau/Type.h"
#include "Luau/TypePack.h"
#include "Luau/TypeFwd.h"
#include <utility>

View file

@ -1,13 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include <initializer_list>
#include <new>
#include <type_traits>
#include <initializer_list>
#include <stddef.h>
#include <utility>
#include <stddef.h>
namespace Luau
{
@ -44,6 +44,9 @@ private:
public:
using first_alternative = typename First<Ts...>::type;
template<typename T>
static constexpr bool is_part_of_v = std::disjunction_v<typename std::is_same<std::decay_t<Ts>, T>...>;
Variant()
{
static_assert(std::is_default_constructible_v<first_alternative>, "first alternative type must be default constructible");

View file

@ -9,6 +9,9 @@
#include "Luau/Type.h"
LUAU_FASTINT(LuauVisitRecursionLimit)
LUAU_FASTFLAG(LuauBoundLazyTypes2)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(DebugLuauReadWriteProperties)
namespace Luau
{
@ -94,6 +97,10 @@ struct GenericTypeVisitor
{
return visit(ty);
}
virtual bool visit(TypeId ty, const LocalType& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const GenericType& gtv)
{
return visit(ty);
@ -158,6 +165,10 @@ struct GenericTypeVisitor
{
return visit(ty);
}
virtual bool visit(TypeId ty, const TypeFamilyInstanceType& tfit)
{
return visit(ty);
}
virtual bool visit(TypePackId tp)
{
@ -191,6 +202,10 @@ struct GenericTypeVisitor
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const TypeFamilyInstanceTypePack& tfitp)
{
return visit(tp);
}
void traverse(TypeId ty)
{
@ -210,7 +225,31 @@ struct GenericTypeVisitor
traverse(btv->boundTo);
}
else if (auto ftv = get<FreeType>(ty))
visit(ty, *ftv);
{
if (FFlag::DebugLuauDeferredConstraintResolution)
{
if (visit(ty, *ftv))
{
// TODO: Replace these if statements with assert()s when we
// delete FFlag::DebugLuauDeferredConstraintResolution.
//
// When the old solver is used, these pointers are always
// unused. When the new solver is used, they are never null.
if (ftv->lowerBound)
traverse(ftv->lowerBound);
if (ftv->upperBound)
traverse(ftv->upperBound);
}
}
else
visit(ty, *ftv);
}
else if (auto lt = get<LocalType>(ty))
{
if (visit(ty, *lt))
traverse(lt->domain);
}
else if (auto gtv = get<GenericType>(ty))
visit(ty, *gtv);
else if (auto etv = get<ErrorType>(ty))
@ -241,7 +280,18 @@ struct GenericTypeVisitor
else
{
for (auto& [_name, prop] : ttv->props)
traverse(prop.type);
{
if (FFlag::DebugLuauReadWriteProperties)
{
if (auto ty = prop.readType())
traverse(*ty);
if (auto ty = prop.writeType())
traverse(*ty);
}
else
traverse(prop.type());
}
if (ttv->indexer)
{
@ -264,13 +314,30 @@ struct GenericTypeVisitor
if (visit(ty, *ctv))
{
for (const auto& [name, prop] : ctv->props)
traverse(prop.type);
{
if (FFlag::DebugLuauReadWriteProperties)
{
if (auto ty = prop.readType())
traverse(*ty);
if (auto ty = prop.writeType())
traverse(*ty);
}
else
traverse(prop.type());
}
if (ctv->parent)
traverse(*ctv->parent);
if (ctv->metatable)
traverse(*ctv->metatable);
if (ctv->indexer)
{
traverse(ctv->indexer->indexType);
traverse(ctv->indexer->indexResultType);
}
}
}
else if (auto atv = get<AnyType>(ty))
@ -291,9 +358,12 @@ struct GenericTypeVisitor
traverse(partTy);
}
}
else if (get<LazyType>(ty))
else if (auto ltv = get<LazyType>(ty))
{
// Visiting into LazyType may necessarily cause infinite expansion, so we don't do that on purpose.
if (TypeId unwrapped = ltv->unwrapped)
traverse(unwrapped);
// Visiting into LazyType that hasn't been unwrapped may necessarily cause infinite expansion, so we don't do that on purpose.
// Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassType
// that doesn't need to be expanded.
}
@ -321,6 +391,17 @@ struct GenericTypeVisitor
if (visit(ty, *ntv))
traverse(ntv->ty);
}
else if (auto tfit = get<TypeFamilyInstanceType>(ty))
{
if (visit(ty, *tfit))
{
for (TypeId p : tfit->typeArguments)
traverse(p);
for (TypePackId p : tfit->packArguments)
traverse(p);
}
}
else
LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypeId) is not exhaustive!");
@ -341,10 +422,10 @@ struct GenericTypeVisitor
traverse(btv->boundTo);
}
else if (auto ftv = get<Unifiable::Free>(tp))
else if (auto ftv = get<FreeTypePack>(tp))
visit(tp, *ftv);
else if (auto gtv = get<Unifiable::Generic>(tp))
else if (auto gtv = get<GenericTypePack>(tp))
visit(tp, *gtv);
else if (auto etv = get<Unifiable::Error>(tp))
@ -370,6 +451,17 @@ struct GenericTypeVisitor
}
else if (auto btp = get<BlockedTypePack>(tp))
visit(tp, *btp);
else if (auto tfitp = get<TypeFamilyInstanceTypePack>(tp))
{
if (visit(tp, *tfitp))
{
for (TypeId t : tfitp->typeArguments)
traverse(t);
for (TypePackId t : tfitp->packArguments)
traverse(t);
}
}
else
LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypePackId) is not exhaustive!");

View file

@ -6,8 +6,6 @@
#include "Luau/Normalize.h"
#include "Luau/TxnLog.h"
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau
{
@ -78,7 +76,7 @@ TypePackId Anyification::clean(TypePackId tp)
bool Anyification::ignoreChildren(TypeId ty)
{
if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassType>(ty))
if (get<ClassType>(ty))
return true;
return ty->persistent;

View file

@ -2,8 +2,6 @@
#include "Luau/ApplyTypeFunction.h"
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau
{
@ -33,7 +31,7 @@ bool ApplyTypeFunction::ignoreChildren(TypeId ty)
{
if (get<GenericType>(ty))
return true;
else if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassType>(ty))
else if (get<ClassType>(ty))
return true;
else
return false;

View file

@ -8,6 +8,8 @@
#include <math.h>
LUAU_FASTFLAG(LuauClipExtraHasEndProps);
namespace Luau
{
@ -374,7 +376,8 @@ struct AstJsonEncoder : public AstVisitor
PROP(body);
PROP(functionDepth);
PROP(debugname);
PROP(hasEnd);
if (!FFlag::LuauClipExtraHasEndProps)
write("hasEnd", node->DEPRECATED_hasEnd);
});
}
@ -514,6 +517,8 @@ struct AstJsonEncoder : public AstVisitor
return writeString("Mul");
case AstExprBinary::Div:
return writeString("Div");
case AstExprBinary::FloorDiv:
return writeString("FloorDiv");
case AstExprBinary::Mod:
return writeString("Mod");
case AstExprBinary::Pow:
@ -536,6 +541,8 @@ struct AstJsonEncoder : public AstVisitor
return writeString("And");
case AstExprBinary::Or:
return writeString("Or");
default:
LUAU_ASSERT(!"Unknown Op");
}
}
@ -567,6 +574,11 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstStatBlock* node)
{
writeNode(node, "AstStatBlock", [&]() {
if (FFlag::LuauClipExtraHasEndProps)
{
writeRaw(",\"hasEnd\":");
write(node->hasEnd);
}
writeRaw(",\"body\":[");
bool comma = false;
for (AstStat* stat : node->body)
@ -590,7 +602,8 @@ struct AstJsonEncoder : public AstVisitor
if (node->elsebody)
PROP(elsebody);
write("hasThen", node->thenLocation.has_value());
PROP(hasEnd);
if (!FFlag::LuauClipExtraHasEndProps)
write("hasEnd", node->DEPRECATED_hasEnd);
});
}
@ -600,7 +613,8 @@ struct AstJsonEncoder : public AstVisitor
PROP(condition);
PROP(body);
PROP(hasDo);
PROP(hasEnd);
if (!FFlag::LuauClipExtraHasEndProps)
write("hasEnd", node->DEPRECATED_hasEnd);
});
}
@ -609,7 +623,8 @@ struct AstJsonEncoder : public AstVisitor
writeNode(node, "AstStatRepeat", [&]() {
PROP(condition);
PROP(body);
PROP(hasUntil);
if (!FFlag::LuauClipExtraHasEndProps)
write("hasUntil", node->DEPRECATED_hasUntil);
});
}
@ -655,7 +670,8 @@ struct AstJsonEncoder : public AstVisitor
PROP(step);
PROP(body);
PROP(hasDo);
PROP(hasEnd);
if (!FFlag::LuauClipExtraHasEndProps)
write("hasEnd", node->DEPRECATED_hasEnd);
});
}
@ -667,7 +683,8 @@ struct AstJsonEncoder : public AstVisitor
PROP(body);
PROP(hasIn);
PROP(hasDo);
PROP(hasEnd);
if (!FFlag::LuauClipExtraHasEndProps)
write("hasEnd", node->DEPRECATED_hasEnd);
});
}
@ -752,6 +769,7 @@ struct AstJsonEncoder : public AstVisitor
if (node->superName)
write("superName", *node->superName);
PROP(props);
PROP(indexer);
});
}
@ -776,7 +794,10 @@ struct AstJsonEncoder : public AstVisitor
writeNode(node, "AstTypeReference", [&]() {
if (node->prefix)
PROP(prefix);
if (node->prefixLocation)
write("prefixLocation", *node->prefixLocation);
PROP(name);
PROP(nameLocation);
PROP(parameters);
});
}

View file

@ -11,7 +11,7 @@
#include <algorithm>
LUAU_FASTFLAG(LuauCompleteTableKeysBetter);
LUAU_FASTFLAG(DebugLuauReadWriteProperties)
namespace Luau
{
@ -31,24 +31,12 @@ struct AutocompleteNodeFinder : public AstVisitor
bool visit(AstExpr* expr) override
{
if (FFlag::LuauCompleteTableKeysBetter)
if (expr->location.begin <= pos && pos <= expr->location.end)
{
if (expr->location.begin <= pos && pos <= expr->location.end)
{
ancestry.push_back(expr);
return true;
}
return false;
}
else
{
if (expr->location.begin < pos && pos <= expr->location.end)
{
ancestry.push_back(expr);
return true;
}
return false;
ancestry.push_back(expr);
return true;
}
return false;
}
bool visit(AstStat* stat) override
@ -160,6 +148,16 @@ struct FindNode : public AstVisitor
return false;
}
bool visit(AstStatFunction* node) override
{
visit(static_cast<AstNode*>(node));
if (node->name->location.contains(pos))
node->name->visit(this);
else if (node->func->location.contains(pos))
node->func->visit(this);
return false;
}
bool visit(AstStatBlock* block) override
{
visit(static_cast<AstNode*>(block));
@ -200,6 +198,16 @@ struct FindFullAncestry final : public AstVisitor
return false;
}
bool visit(AstStatFunction* node) override
{
visit(static_cast<AstNode*>(node));
if (node->name->location.contains(pos))
node->name->visit(this);
else if (node->func->location.contains(pos))
node->func->visit(this);
return false;
}
bool visit(AstNode* node) override
{
if (node->location.contains(pos))
@ -225,33 +233,48 @@ struct FindFullAncestry final : public AstVisitor
std::vector<AstNode*> findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos)
{
AutocompleteNodeFinder finder{pos, source.root};
source.root->visit(&finder);
return findAncestryAtPositionForAutocomplete(source.root, pos);
}
std::vector<AstNode*> findAncestryAtPositionForAutocomplete(AstStatBlock* root, Position pos)
{
AutocompleteNodeFinder finder{pos, root};
root->visit(&finder);
return finder.ancestry;
}
std::vector<AstNode*> findAstAncestryOfPosition(const SourceModule& source, Position pos, bool includeTypes)
{
const Position end = source.root->location.end;
return findAstAncestryOfPosition(source.root, pos, includeTypes);
}
std::vector<AstNode*> findAstAncestryOfPosition(AstStatBlock* root, Position pos, bool includeTypes)
{
const Position end = root->location.end;
if (pos > end)
pos = end;
FindFullAncestry finder(pos, end, includeTypes);
source.root->visit(&finder);
root->visit(&finder);
return finder.nodes;
}
AstNode* findNodeAtPosition(const SourceModule& source, Position pos)
{
const Position end = source.root->location.end;
if (pos < source.root->location.begin)
return source.root;
return findNodeAtPosition(source.root, pos);
}
AstNode* findNodeAtPosition(AstStatBlock* root, Position pos)
{
const Position end = root->location.end;
if (pos < root->location.begin)
return root;
if (pos > end)
pos = end;
FindNode findNode{pos, end};
findNode.visit(source.root);
findNode.visit(root);
return findNode.best;
}
@ -500,12 +523,28 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
if (const TableType* ttv = get<TableType>(parentTy))
{
if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end())
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
{
if (FFlag::DebugLuauReadWriteProperties)
{
if (auto ty = propIt->second.readType())
return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol);
}
else
return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol);
}
}
else if (const ClassType* ctv = get<ClassType>(parentTy))
{
if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end())
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
{
if (FFlag::DebugLuauReadWriteProperties)
{
if (auto ty = propIt->second.readType())
return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol);
}
else
return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol);
}
}
}
}

View file

@ -5,16 +5,19 @@
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Frontend.h"
#include "Luau/ToString.h"
#include "Luau/Subtyping.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeReduction.h"
#include <algorithm>
#include <unordered_set>
#include <utility>
LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false);
LUAU_FASTFLAGVARIABLE(LuauAutocompleteSkipNormalization, false);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(DebugLuauReadWriteProperties);
LUAU_FASTFLAG(LuauClipExtraHasEndProps);
LUAU_FASTFLAGVARIABLE(LuauAutocompleteDoEnd, false);
LUAU_FASTFLAGVARIABLE(LuauAutocompleteStringLiteralBounds, false);
static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -142,16 +145,24 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, T
InternalErrorReporter iceReporter;
UnifierSharedState unifierState(&iceReporter);
Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}};
Unifier unifier(NotNull<Normalizer>{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant);
if (FFlag::LuauAutocompleteSkipNormalization)
if (FFlag::DebugLuauDeferredConstraintResolution)
{
Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&iceReporter}, scope};
return subtyping.isSubtype(subTy, superTy).isSubtype;
}
else
{
Unifier unifier(NotNull<Normalizer>{&normalizer}, scope, Location(), Variance::Covariant);
// Cost of normalization can be too high for autocomplete response time requirements
unifier.normalize = false;
unifier.checkInhabited = false;
return unifier.canUnify(subTy, superTy).empty();
}
return unifier.canUnify(subTy, superTy).empty();
}
static TypeCorrectKind checkTypeCorrectKind(
@ -266,25 +277,27 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul
// already populated, it takes precedence over the property we found just now.
if (result.count(name) == 0 && name != kParseNameError)
{
Luau::TypeId type = Luau::follow(prop.type);
Luau::TypeId type;
if (FFlag::DebugLuauReadWriteProperties)
{
if (auto ty = prop.readType())
type = follow(*ty);
else
continue;
}
else
type = follow(prop.type());
TypeCorrectKind typeCorrect = indexType == PropIndexType::Key
? TypeCorrectKind::Correct
: checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type);
ParenthesesRecommendation parens =
indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect);
result[name] = AutocompleteEntry{
AutocompleteEntryKind::Property,
type,
prop.deprecated,
isWrongIndexer(type),
typeCorrect,
containingClass,
&prop,
prop.documentationSymbol,
{},
parens,
};
result[name] = AutocompleteEntry{AutocompleteEntryKind::Property, type, prop.deprecated, isWrongIndexer(type), typeCorrect,
containingClass, &prop, prop.documentationSymbol, {}, parens, {}, indexType == PropIndexType::Colon};
}
}
};
@ -293,7 +306,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul
auto indexIt = mtable->props.find("__index");
if (indexIt != mtable->props.end())
{
TypeId followed = follow(indexIt->second.type);
TypeId followed = follow(indexIt->second.type());
if (get<TableType>(followed) || get<MetatableType>(followed))
{
autocompleteProps(module, typeArena, builtinTypes, rootTy, followed, indexType, nodes, result, seen);
@ -454,8 +467,19 @@ AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position posi
return result;
}
static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteEntryMap& result)
static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result)
{
if (FFlag::LuauAutocompleteStringLiteralBounds)
{
if (position == node->location.begin || position == node->location.end)
{
if (auto str = node->as<AstExprConstantString>(); str && str->quoteStyle == AstExprConstantString::Quoted)
return;
else if (node->is<AstExprInterpString>())
return;
}
}
auto formatKey = [addQuotes](const std::string& key) {
if (addQuotes)
return "\"" + escape(key) + "\"";
@ -584,14 +608,13 @@ std::optional<TypeId> getLocalTypeInScopeAt(const Module& module, Position posit
return {};
}
static std::optional<Name> tryGetTypeNameInScope(ScopePtr scope, TypeId ty)
template<typename T>
static std::optional<std::string> tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments)
{
if (!canSuggestInferredType(scope, ty))
return std::nullopt;
ToStringOptions opts;
opts.useLineBreaks = false;
opts.hideTableKind = true;
opts.functionTypeArguments = functionTypeArguments;
opts.scope = scope;
ToStringResult name = toStringDetailed(ty, opts);
@ -601,6 +624,14 @@ static std::optional<Name> tryGetTypeNameInScope(ScopePtr scope, TypeId ty)
return name.name;
}
static std::optional<Name> tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool functionTypeArguments = false)
{
if (!canSuggestInferredType(scope, ty))
return std::nullopt;
return tryToStringDetailed(scope, ty, functionTypeArguments);
}
static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position)
{
std::optional<TypeId> ty;
@ -981,25 +1012,14 @@ T* extractStat(const std::vector<AstNode*>& ancestry)
AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr;
AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr;
if (FFlag::LuauCompleteTableKeysBetter)
{
if (!grandParent)
return nullptr;
if (!grandParent)
return nullptr;
if (T* t = parent->as<T>(); t && grandParent->is<AstStatBlock>())
return t;
if (T* t = parent->as<T>(); t && grandParent->is<AstStatBlock>())
return t;
if (!greatGrandParent)
return nullptr;
}
else
{
if (T* t = parent->as<T>(); t && parent->is<AstStatBlock>())
return t;
if (!grandParent || !greatGrandParent)
return nullptr;
}
if (!greatGrandParent)
return nullptr;
if (T* t = greatGrandParent->as<T>(); t && grandParent->is<AstStatBlock>() && parent->is<AstStatError>() && isIdentifier(node))
return t;
@ -1049,18 +1069,57 @@ static AutocompleteEntryMap autocompleteStatement(
for (const auto& kw : kStatementStartingKeywords)
result.emplace(kw, AutocompleteEntry{AutocompleteEntryKind::Keyword});
for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it)
if (FFlag::LuauClipExtraHasEndProps)
{
if (AstStatForIn* statForIn = (*it)->as<AstStatForIn>(); statForIn && !statForIn->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (AstStatFor* statFor = (*it)->as<AstStatFor>(); statFor && !statFor->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (AstStatIf* statIf = (*it)->as<AstStatIf>(); statIf && !statIf->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (AstStatWhile* statWhile = (*it)->as<AstStatWhile>(); statWhile && !statWhile->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (AstExprFunction* exprFunction = (*it)->as<AstExprFunction>(); exprFunction && !exprFunction->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it)
{
if (AstStatForIn* statForIn = (*it)->as<AstStatForIn>(); statForIn && !statForIn->body->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
else if (AstStatFor* statFor = (*it)->as<AstStatFor>(); statFor && !statFor->body->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
else if (AstStatIf* statIf = (*it)->as<AstStatIf>())
{
bool hasEnd = statIf->thenbody->hasEnd;
if (statIf->elsebody)
{
if (AstStatBlock* elseBlock = statIf->elsebody->as<AstStatBlock>())
hasEnd = elseBlock->hasEnd;
}
if (!hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
}
else if (AstStatWhile* statWhile = (*it)->as<AstStatWhile>(); statWhile && !statWhile->body->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
else if (AstExprFunction* exprFunction = (*it)->as<AstExprFunction>(); exprFunction && !exprFunction->body->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (FFlag::LuauAutocompleteDoEnd)
{
if (AstStatBlock* exprBlock = (*it)->as<AstStatBlock>(); exprBlock && !exprBlock->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
}
}
}
else
{
for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it)
{
if (AstStatForIn* statForIn = (*it)->as<AstStatForIn>(); statForIn && !statForIn->DEPRECATED_hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
else if (AstStatFor* statFor = (*it)->as<AstStatFor>(); statFor && !statFor->DEPRECATED_hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
else if (AstStatIf* statIf = (*it)->as<AstStatIf>(); statIf && !statIf->DEPRECATED_hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
else if (AstStatWhile* statWhile = (*it)->as<AstStatWhile>(); statWhile && !statWhile->DEPRECATED_hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
else if (AstExprFunction* exprFunction = (*it)->as<AstExprFunction>(); exprFunction && !exprFunction->DEPRECATED_hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (FFlag::LuauAutocompleteDoEnd)
{
if (AstStatBlock* exprBlock = (*it)->as<AstStatBlock>(); exprBlock && !exprBlock->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
}
}
}
if (ancestry.size() >= 2)
@ -1075,8 +1134,16 @@ static AutocompleteEntryMap autocompleteStatement(
}
}
if (AstStatRepeat* statRepeat = parent->as<AstStatRepeat>(); statRepeat && !statRepeat->hasUntil)
result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (FFlag::LuauClipExtraHasEndProps)
{
if (AstStatRepeat* statRepeat = parent->as<AstStatRepeat>(); statRepeat && !statRepeat->body->hasEnd)
result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword});
}
else
{
if (AstStatRepeat* statRepeat = parent->as<AstStatRepeat>(); statRepeat && !statRepeat->DEPRECATED_hasUntil)
result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword});
}
}
if (ancestry.size() >= 4)
@ -1090,8 +1157,16 @@ static AutocompleteEntryMap autocompleteStatement(
}
}
if (AstStatRepeat* statRepeat = extractStat<AstStatRepeat>(ancestry); statRepeat && !statRepeat->hasUntil)
result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (FFlag::LuauClipExtraHasEndProps)
{
if (AstStatRepeat* statRepeat = extractStat<AstStatRepeat>(ancestry); statRepeat && !statRepeat->body->hasEnd)
result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword});
}
else
{
if (AstStatRepeat* statRepeat = extractStat<AstStatRepeat>(ancestry); statRepeat && !statRepeat->DEPRECATED_hasUntil)
result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword});
}
return result;
}
@ -1195,7 +1270,7 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu
result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction};
if (auto ty = findExpectedTypeAt(module, node, position))
autocompleteStringSingleton(*ty, true, result);
autocompleteStringSingleton(*ty, true, node, position, result);
}
return AutocompleteContext::Expression;
@ -1301,6 +1376,14 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(const Source
return std::nullopt;
}
if (!nodes.back()->is<AstExprError>())
{
if (nodes.back()->location.end == position || nodes.back()->location.begin == position)
{
return std::nullopt;
}
}
AstExprCall* candidate = nodes.at(nodes.size() - 2)->as<AstExprCall>();
if (!candidate)
{
@ -1365,6 +1448,139 @@ static AutocompleteResult autocompleteWhileLoopKeywords(std::vector<AstNode*> an
return {std::move(ret), std::move(ancestry), AutocompleteContext::Keyword};
}
static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy)
{
std::string result = "function(";
auto [args, tail] = Luau::flatten(funcTy.argTypes);
bool first = true;
// Skip the implicit 'self' argument if call is indexed with ':'
for (size_t argIdx = 0; argIdx < args.size(); ++argIdx)
{
if (!first)
result += ", ";
else
first = false;
std::string name;
if (argIdx < funcTy.argNames.size() && funcTy.argNames[argIdx])
name = funcTy.argNames[argIdx]->name;
else
name = "a" + std::to_string(argIdx);
if (std::optional<Name> type = tryGetTypeNameInScope(scope, args[argIdx], true))
result += name + ": " + *type;
else
result += name;
}
if (tail && (Luau::isVariadic(*tail) || Luau::get<Luau::FreeTypePack>(Luau::follow(*tail))))
{
if (!first)
result += ", ";
std::optional<std::string> varArgType;
if (const VariadicTypePack* pack = get<VariadicTypePack>(follow(*tail)))
{
if (std::optional<std::string> res = tryToStringDetailed(scope, pack->ty, true))
varArgType = std::move(res);
}
if (varArgType)
result += "...: " + *varArgType;
else
result += "...";
}
result += ")";
auto [rets, retTail] = Luau::flatten(funcTy.retTypes);
if (const size_t totalRetSize = rets.size() + (retTail ? 1 : 0); totalRetSize > 0)
{
if (std::optional<std::string> returnTypes = tryToStringDetailed(scope, funcTy.retTypes, true))
{
result += ": ";
bool wrap = totalRetSize != 1;
if (wrap)
result += "(";
result += *returnTypes;
if (wrap)
result += ")";
}
}
result += " end";
return result;
}
static std::optional<AutocompleteEntry> makeAnonymousAutofilled(
const ModulePtr& module, Position position, const AstNode* node, const std::vector<AstNode*>& ancestry)
{
const AstExprCall* call = node->as<AstExprCall>();
if (!call && ancestry.size() > 1)
call = ancestry[ancestry.size() - 2]->as<AstExprCall>();
if (!call)
return std::nullopt;
if (!call->location.containsClosed(position) || call->func->location.containsClosed(position))
return std::nullopt;
TypeId* typeIter = module->astTypes.find(call->func);
if (!typeIter)
return std::nullopt;
const FunctionType* outerFunction = get<FunctionType>(follow(*typeIter));
if (!outerFunction)
return std::nullopt;
size_t argument = 0;
for (size_t i = 0; i < call->args.size; ++i)
{
if (call->args.data[i]->location.containsClosed(position))
{
argument = i;
break;
}
}
if (call->self)
argument++;
std::optional<TypeId> argType;
auto [args, tail] = flatten(outerFunction->argTypes);
if (argument < args.size())
argType = args[argument];
if (!argType)
return std::nullopt;
TypeId followed = follow(*argType);
const FunctionType* type = get<FunctionType>(followed);
if (!type)
{
if (const UnionType* unionType = get<UnionType>(followed))
{
if (std::optional<const FunctionType*> nonnullFunction = returnFirstNonnullOptionOfType<FunctionType>(unionType))
type = *nonnullFunction;
}
}
if (!type)
return std::nullopt;
const ScopePtr scope = findScopeAtPosition(*module, position);
if (!scope)
return std::nullopt;
AutocompleteEntry entry;
entry.kind = AutocompleteEntryKind::GeneratedFunction;
entry.typeCorrect = TypeCorrectKind::Correct;
entry.type = argType;
entry.insertText = makeAnonymous(scope, *type);
return std::make_optional(std::move(entry));
}
static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, NotNull<BuiltinTypes> builtinTypes,
TypeArena* typeArena, Scope* globalScope, Position position, StringCompletionCallback callback)
{
@ -1533,23 +1749,20 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
{
auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry);
if (FFlag::LuauCompleteTableKeysBetter)
{
if (auto nodeIt = module->astExpectedTypes.find(node->asExpr()))
autocompleteStringSingleton(*nodeIt, !node->is<AstExprConstantString>(), result);
if (auto nodeIt = module->astExpectedTypes.find(node->asExpr()))
autocompleteStringSingleton(*nodeIt, !node->is<AstExprConstantString>(), node, position, result);
if (!key)
if (!key)
{
// If there is "no key," it may be that the user
// intends for the current token to be the key, but
// has yet to type the `=` sign.
//
// If the key type is a union of singleton strings,
// suggest those too.
if (auto ttv = get<TableType>(follow(*it)); ttv && ttv->indexer)
{
// If there is "no key," it may be that the user
// intends for the current token to be the key, but
// has yet to type the `=` sign.
//
// If the key type is a union of singleton strings,
// suggest those too.
if (auto ttv = get<TableType>(follow(*it)); ttv && ttv->indexer)
{
autocompleteStringSingleton(ttv->indexer->indexType, false, result);
}
autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result);
}
}
@ -1586,7 +1799,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
AutocompleteEntryMap result;
if (auto it = module->astExpectedTypes.find(node->asExpr()))
autocompleteStringSingleton(*it, false, result);
autocompleteStringSingleton(*it, false, node, position, result);
if (ancestry.size() >= 2)
{
@ -1600,7 +1813,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe)
{
if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left))
autocompleteStringSingleton(*it, false, result);
autocompleteStringSingleton(*it, false, node, position, result);
}
}
}
@ -1619,7 +1832,12 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
return {};
if (node->asExpr())
return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position);
{
AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position);
if (std::optional<AutocompleteEntry> generated = makeAnonymousAutofilled(module, position, node, ancestry))
ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated);
return ret;
}
else if (node->asStat())
return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement};
@ -1628,12 +1846,6 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback)
{
// FIXME: We can improve performance here by parsing without checking.
// The old type graph is probably fine. (famous last words!)
FrontendOptions opts;
opts.forAutocomplete = true;
frontend.check(moduleName, opts);
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule)
return {};

View file

@ -7,22 +7,24 @@
#include "Luau/Common.h"
#include "Luau/ToString.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintGenerator.h"
#include "Luau/NotNull.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeFamily.h"
#include "Luau/TypePack.h"
#include "Luau/Type.h"
#include "Luau/TypeUtils.h"
#include <algorithm>
LUAU_FASTFLAGVARIABLE(LuauDeprecateTableGetnForeach, false)
/** FIXME: Many of these type definitions are not quite completely accurate.
*
* Some of them require richer generics than we have. For instance, we do not yet have a way to talk
* about a function that takes any number of values, but where each value must have some specific type.
*/
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau
{
@ -54,6 +56,7 @@ TypeId makeIntersection(TypeArena& arena, std::vector<TypeId>&& types)
TypeId makeOption(NotNull<BuiltinTypes> builtinTypes, TypeArena& arena, TypeId t)
{
LUAU_ASSERT(t);
return makeUnion(arena, {builtinTypes->nilType, t});
}
@ -202,19 +205,7 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string&
}
}
void registerBuiltinTypes(GlobalTypes& globals)
{
globals.globalScope->addBuiltinTypeBinding("any", TypeFun{{}, globals.builtinTypes->anyType});
globals.globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, globals.builtinTypes->nilType});
globals.globalScope->addBuiltinTypeBinding("number", TypeFun{{}, globals.builtinTypes->numberType});
globals.globalScope->addBuiltinTypeBinding("string", TypeFun{{}, globals.builtinTypes->stringType});
globals.globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, globals.builtinTypes->booleanType});
globals.globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, globals.builtinTypes->threadType});
globals.globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, globals.builtinTypes->unknownType});
globals.globalScope->addBuiltinTypeBinding("never", TypeFun{{}, globals.builtinTypes->neverType});
}
void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals)
void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete)
{
LUAU_ASSERT(!globals.globalTypes.types.isFrozen());
LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen());
@ -222,8 +213,11 @@ void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals)
TypeArena& arena = globals.globalTypes;
NotNull<BuiltinTypes> builtinTypes = globals.builtinTypes;
LoadDefinitionFileResult loadResult =
Luau::loadDefinitionFile(typeChecker, globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false);
if (FFlag::DebugLuauDeferredConstraintResolution)
kBuiltinTypeFamilies.addToScope(NotNull{&arena}, NotNull{globals.globalScope.get()});
LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(
globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false, typeCheckForAutocomplete);
LUAU_ASSERT(loadResult.success);
TypeId genericK = arena.addType(GenericType{"K"});
@ -238,7 +232,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals)
auto it = stringMetatableTable->props.find("__index");
LUAU_ASSERT(it != stringMetatableTable->props.end());
addGlobalBinding(globals, "string", it->second.type, "@luau");
addGlobalBinding(globals, "string", it->second.type(), "@luau");
// next<K, V>(t: Table<K, V>, i: K?) -> (K?, V)
TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}});
@ -298,122 +292,531 @@ void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals)
ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze");
ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone");
if (FFlag::LuauDeprecateTableGetnForeach)
{
ttv->props["getn"].deprecated = true;
ttv->props["getn"].deprecatedSuggestion = "#";
ttv->props["foreach"].deprecated = true;
ttv->props["foreachi"].deprecated = true;
}
ttv->props["getn"].deprecated = true;
ttv->props["getn"].deprecatedSuggestion = "#";
ttv->props["foreach"].deprecated = true;
ttv->props["foreachi"].deprecated = true;
attachMagicFunction(ttv->props["pack"].type, magicFunctionPack);
attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack);
attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack);
attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack);
}
attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire);
attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire);
}
void registerBuiltinGlobals(Frontend& frontend)
static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
{
GlobalTypes& globals = frontend.globals;
const char* options = "cdiouxXeEfgGqs*";
LUAU_ASSERT(!globals.globalTypes.types.isFrozen());
LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen());
std::vector<TypeId> result;
registerBuiltinTypes(globals);
for (size_t i = 0; i < size; ++i)
{
if (data[i] == '%')
{
i++;
TypeArena& arena = globals.globalTypes;
NotNull<BuiltinTypes> builtinTypes = globals.builtinTypes;
if (i < size && data[i] == '%')
continue;
LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau", /* captureComments */ false);
LUAU_ASSERT(loadResult.success);
// we just ignore all characters (including flags/precision) up until first alphabetic character
while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*')))
i++;
TypeId genericK = arena.addType(GenericType{"K"});
TypeId genericV = arena.addType(GenericType{"V"});
TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), globals.globalScope->level, TableState::Generic});
if (i == size)
break;
std::optional<TypeId> stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes);
LUAU_ASSERT(stringMetatableTy);
const TableType* stringMetatableTable = get<TableType>(follow(*stringMetatableTy));
LUAU_ASSERT(stringMetatableTable);
if (data[i] == 'q' || data[i] == 's')
result.push_back(builtinTypes->stringType);
else if (data[i] == '*')
result.push_back(builtinTypes->unknownType);
else if (strchr(options, data[i]))
result.push_back(builtinTypes->numberType);
else
result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType));
}
}
auto it = stringMetatableTable->props.find("__index");
LUAU_ASSERT(it != stringMetatableTable->props.end());
return result;
}
addGlobalBinding(globals, "string", it->second.type, "@luau");
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
// next<K, V>(t: Table<K, V>, i: K?) -> (K?, V)
TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}});
TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}});
addGlobalBinding(globals, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau");
TypeArena& arena = typechecker.currentModule->internalTypes;
TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV});
AstExprConstantString* fmt = nullptr;
if (auto index = expr.func->as<AstExprIndexName>(); index && expr.self)
{
if (auto group = index->expr->as<AstExprGroup>())
fmt = group->expr->as<AstExprConstantString>();
else
fmt = index->expr->as<AstExprConstantString>();
}
TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack});
TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, builtinTypes->nilType}});
if (!expr.self && expr.args.size > 0)
fmt = expr.args.data[0]->as<AstExprConstantString>();
// pairs<K, V>(t: Table<K, V>) -> ((Table<K, V>, K?) -> (K?, V), Table<K, V>, nil)
addGlobalBinding(globals, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau");
if (!fmt)
return std::nullopt;
TypeId genericMT = arena.addType(GenericType{"MT"});
std::vector<TypeId> expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(paramPack);
TableType tab{TableState::Generic, globals.globalScope->level};
TypeId tabTy = arena.addType(tab);
size_t paramOffset = 1;
size_t dataOffset = expr.self ? 0 : 1;
TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT});
// unify the prefix one argument at a time
for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i)
{
Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location;
addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau");
typechecker.unify(params[i + paramOffset], expected[i], scope, location);
}
// clang-format off
// setmetatable<T: {}, MT>(T, MT) -> { @metatable MT, T }
addGlobalBinding(globals, "setmetatable",
arena.addType(
FunctionType{
{genericMT},
{},
arena.addTypePack(TypePack{{tabTy, genericMT}}),
arena.addTypePack(TypePack{{tableMetaMT}})
// if we know the argument count or if we have too many arguments for sure, we can issue an error
size_t numActualParams = params.size();
size_t numExpectedParams = expected.size() + 1; // + 1 for the format string
if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams))
typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}});
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
}
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
{
TypeArena* arena = context.solver->arena;
AstExprConstantString* fmt = nullptr;
if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self)
{
if (auto group = index->expr->as<AstExprGroup>())
fmt = group->expr->as<AstExprConstantString>();
else
fmt = index->expr->as<AstExprConstantString>();
}
if (!context.callSite->self && context.callSite->args.size > 0)
fmt = context.callSite->args.data[0]->as<AstExprConstantString>();
if (!fmt)
return false;
std::vector<TypeId> expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(context.arguments);
size_t paramOffset = 1;
// unify the prefix one argument at a time
for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i)
{
context.solver->unify(context.solver->rootScope, context.callSite->location, params[i + paramOffset], expected[i]);
}
// if we know the argument count or if we have too many arguments for sure, we can issue an error
size_t numActualParams = params.size();
size_t numExpectedParams = expected.size() + 1; // + 1 for the format string
if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams))
context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}});
TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType});
asMutable(context.result)->ty.emplace<BoundTypePack>(resultPack);
return true;
}
static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
{
std::vector<TypeId> result;
int depth = 0;
bool parsingSet = false;
for (size_t i = 0; i < size; ++i)
{
if (data[i] == '%')
{
++i;
if (!parsingSet && i < size && data[i] == 'b')
i += 2;
}
else if (!parsingSet && data[i] == '[')
{
parsingSet = true;
if (i + 1 < size && data[i + 1] == ']')
i += 1;
}
else if (parsingSet && data[i] == ']')
{
parsingSet = false;
}
else if (data[i] == '(')
{
if (parsingSet)
continue;
if (i + 1 < size && data[i + 1] == ')')
{
i++;
result.push_back(builtinTypes->optionalNumberType);
continue;
}
), "@luau"
);
// clang-format on
for (const auto& pair : globals.globalScope->bindings)
{
persist(pair.second.typeId);
if (TableType* ttv = getMutable<TableType>(pair.second.typeId))
++depth;
result.push_back(builtinTypes->optionalStringType);
}
else if (data[i] == ')')
{
if (!ttv->name)
ttv->name = "typeof(" + toString(pair.first) + ")";
if (parsingSet)
continue;
--depth;
if (depth < 0)
break;
}
}
attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert);
attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable);
attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect);
attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect);
if (depth != 0 || parsingSet)
return std::vector<TypeId>();
if (TableType* ttv = getMutable<TableType>(getGlobalBinding(globals, "table")))
if (result.empty())
result.push_back(builtinTypes->optionalStringType);
return result;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() != 2)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t index = expr.self ? 0 : 1;
if (expr.args.size > index)
pattern = expr.args.data[index]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypePackId emptyPack = arena.addTypePack({});
const TypePackId returnList = arena.addTypePack(returnTypes);
const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList});
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
}
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() != 2)
return false;
TypeArena* arena = context.solver->arena;
AstExprConstantString* pattern = nullptr;
size_t index = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > index)
pattern = context.callSite->args.data[index]->as<AstExprConstantString>();
if (!pattern)
return false;
std::vector<TypeId> returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType);
const TypePackId emptyPack = arena->addTypePack({});
const TypePackId returnList = arena->addTypePack(returnTypes);
const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList});
const TypePackId resTypePack = arena->addTypePack({iteratorType});
asMutable(context.result)->ty.emplace<BoundTypePack>(resTypePack);
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 3)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() == 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() < 2 || params.size() > 3)
return false;
TypeArena* arena = context.solver->arena;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > patternIndex)
pattern = context.callSite->args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return false;
std::vector<TypeId> returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType);
const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}});
size_t initIndex = context.callSite->self ? 1 : 2;
if (params.size() == 3 && context.callSite->args.size > initIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber);
const TypePackId returnList = arena->addTypePack(returnTypes);
asMutable(context.result)->ty.emplace<BoundTypePack>(returnList);
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 4)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
bool plain = false;
size_t plainIndex = expr.self ? 2 : 3;
if (expr.args.size > plainIndex)
{
// tabTy is a generic table type which we can't express via declaration syntax yet
ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze");
ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone");
if (FFlag::LuauDeprecateTableGetnForeach)
{
ttv->props["getn"].deprecated = true;
ttv->props["getn"].deprecatedSuggestion = "#";
ttv->props["foreach"].deprecated = true;
ttv->props["foreachi"].deprecated = true;
}
attachMagicFunction(ttv->props["pack"].type, magicFunctionPack);
AstExprConstantBool* p = expr.args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire);
attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire);
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
}
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}});
const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() >= 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
if (params.size() == 4 && expr.args.size > plainIndex)
typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionFind(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() < 2 || params.size() > 4)
return false;
TypeArena* arena = context.solver->arena;
NotNull<BuiltinTypes> builtinTypes = context.solver->builtinTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > patternIndex)
pattern = context.callSite->args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return false;
bool plain = false;
size_t plainIndex = context.callSite->self ? 2 : 3;
if (context.callSite->args.size > plainIndex)
{
AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
}
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], builtinTypes->stringType);
const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}});
const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}});
size_t initIndex = context.callSite->self ? 1 : 2;
if (params.size() >= 3 && context.callSite->args.size > initIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber);
if (params.size() == 4 && context.callSite->args.size > plainIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[3], optionalBoolean);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena->addTypePack(returnTypes);
asMutable(context.result)->ty.emplace<BoundTypePack>(returnList);
return true;
}
TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
{
NotNull<TypeArena> arena{builtinTypes->arena.get()};
const TypeId nilType = builtinTypes->nilType;
const TypeId numberType = builtinTypes->numberType;
const TypeId booleanType = builtinTypes->booleanType;
const TypeId stringType = builtinTypes->stringType;
const TypeId anyType = builtinTypes->anyType;
const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}});
const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}});
const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}});
const TypePackId oneStringPack = arena->addTypePack({stringType});
const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true});
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
const TypePackId emptyPack = arena->addTypePack({});
const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}});
const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}});
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType});
const TypeId replArgType =
arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}});
const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType});
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})});
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
const TypeId matchFunc = arena->addType(
FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})});
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})});
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
TableType::Props stringLib = {
{"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}},
{"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, anyTypePack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"unpack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
anyTypePack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
if (TableType* ttv = getMutable<TableType>(tableType))
ttv->name = "typeof(string)";
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
@ -713,7 +1116,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
if (!checkRequirePath(typechecker, expr.args.data[0]))
return std::nullopt;
if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr))
if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModule->name, expr))
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})};
return std::nullopt;

View file

@ -1,22 +1,486 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Clone.h"
#include "Luau/NotNull.h"
#include "Luau/RecursionCounter.h"
#include "Luau/TxnLog.h"
#include "Luau/Type.h"
#include "Luau/TypePack.h"
#include "Luau/Unifiable.h"
LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing)
LUAU_FASTFLAG(LuauClonePublicInterfaceLess)
LUAU_FASTFLAG(DebugLuauReadWriteProperties)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300)
LUAU_FASTFLAGVARIABLE(LuauStacklessTypeClone3, false)
LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000)
namespace Luau
{
namespace
{
using Kind = Variant<TypeId, TypePackId>;
template<typename T>
const T* get(const Kind& kind)
{
return get_if<T>(&kind);
}
class TypeCloner2
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
// A queue of kinds where we cloned it, but whose interior types hasn't
// been updated to point to new clones. Once all of its interior types
// has been updated, it gets removed from the queue.
std::vector<Kind> queue;
NotNull<SeenTypes> types;
NotNull<SeenTypePacks> packs;
int steps = 0;
public:
TypeCloner2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<SeenTypes> types, NotNull<SeenTypePacks> packs)
: arena(arena)
, builtinTypes(builtinTypes)
, types(types)
, packs(packs)
{
}
TypeId clone(TypeId ty)
{
shallowClone(ty);
run();
if (hasExceededIterationLimit())
{
TypeId error = builtinTypes->errorRecoveryType();
(*types)[ty] = error;
return error;
}
return find(ty).value_or(builtinTypes->errorRecoveryType());
}
TypePackId clone(TypePackId tp)
{
shallowClone(tp);
run();
if (hasExceededIterationLimit())
{
TypePackId error = builtinTypes->errorRecoveryTypePack();
(*packs)[tp] = error;
return error;
}
return find(tp).value_or(builtinTypes->errorRecoveryTypePack());
}
private:
bool hasExceededIterationLimit() const
{
if (FInt::LuauTypeCloneIterationLimit == 0)
return false;
return steps + queue.size() >= size_t(FInt::LuauTypeCloneIterationLimit);
}
void run()
{
while (!queue.empty())
{
++steps;
if (hasExceededIterationLimit())
break;
Kind kind = queue.back();
queue.pop_back();
if (find(kind))
continue;
cloneChildren(kind);
}
}
std::optional<TypeId> find(TypeId ty) const
{
ty = follow(ty, FollowOption::DisableLazyTypeThunks);
if (auto it = types->find(ty); it != types->end())
return it->second;
else if (ty->persistent)
return ty;
return std::nullopt;
}
std::optional<TypePackId> find(TypePackId tp) const
{
tp = follow(tp);
if (auto it = packs->find(tp); it != packs->end())
return it->second;
else if (tp->persistent)
return tp;
return std::nullopt;
}
std::optional<Kind> find(Kind kind) const
{
if (auto ty = get<TypeId>(kind))
return find(*ty);
else if (auto tp = get<TypePackId>(kind))
return find(*tp);
else
{
LUAU_ASSERT(!"Unknown kind?");
return std::nullopt;
}
}
private:
TypeId shallowClone(TypeId ty)
{
// We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s.
ty = follow(ty, FollowOption::DisableLazyTypeThunks);
if (auto clone = find(ty))
return *clone;
else if (ty->persistent)
return ty;
TypeId target = arena->addType(ty->ty);
asMutable(target)->documentationSymbol = ty->documentationSymbol;
(*types)[ty] = target;
queue.push_back(target);
return target;
}
TypePackId shallowClone(TypePackId tp)
{
tp = follow(tp);
if (auto clone = find(tp))
return *clone;
else if (tp->persistent)
return tp;
TypePackId target = arena->addTypePack(tp->ty);
(*packs)[tp] = target;
queue.push_back(target);
return target;
}
Property shallowClone(const Property& p)
{
if (FFlag::DebugLuauReadWriteProperties)
{
std::optional<TypeId> cloneReadTy;
if (auto ty = p.readType())
cloneReadTy = shallowClone(*ty);
std::optional<TypeId> cloneWriteTy;
if (auto ty = p.writeType())
cloneWriteTy = shallowClone(*ty);
std::optional<Property> cloned = Property::create(cloneReadTy, cloneWriteTy);
LUAU_ASSERT(cloned);
cloned->deprecated = p.deprecated;
cloned->deprecatedSuggestion = p.deprecatedSuggestion;
cloned->location = p.location;
cloned->tags = p.tags;
cloned->documentationSymbol = p.documentationSymbol;
return *cloned;
}
else
{
return Property{
shallowClone(p.type()),
p.deprecated,
p.deprecatedSuggestion,
p.location,
p.tags,
p.documentationSymbol,
};
}
}
void cloneChildren(TypeId ty)
{
return visit(
[&](auto&& t) {
return cloneChildren(&t);
},
asMutable(ty)->ty);
}
void cloneChildren(TypePackId tp)
{
return visit(
[&](auto&& t) {
return cloneChildren(&t);
},
asMutable(tp)->ty);
}
void cloneChildren(Kind kind)
{
if (auto ty = get<TypeId>(kind))
return cloneChildren(*ty);
else if (auto tp = get<TypePackId>(kind))
return cloneChildren(*tp);
else
LUAU_ASSERT(!"Item holds neither TypeId nor TypePackId when enqueuing its children?");
}
// ErrorType and ErrorTypePack is an alias to this type.
void cloneChildren(Unifiable::Error* t)
{
// noop.
}
void cloneChildren(BoundType* t)
{
t->boundTo = shallowClone(t->boundTo);
}
void cloneChildren(FreeType* t)
{
if (t->lowerBound)
t->lowerBound = shallowClone(t->lowerBound);
if (t->upperBound)
t->upperBound = shallowClone(t->upperBound);
}
void cloneChildren(LocalType* t)
{
t->domain = shallowClone(t->domain);
}
void cloneChildren(GenericType* t)
{
// TOOD: clone upper bounds.
}
void cloneChildren(PrimitiveType* t)
{
// noop.
}
void cloneChildren(BlockedType* t)
{
// TODO: In the new solver, we should ice.
}
void cloneChildren(PendingExpansionType* t)
{
// TODO: In the new solver, we should ice.
}
void cloneChildren(SingletonType* t)
{
// noop.
}
void cloneChildren(FunctionType* t)
{
for (TypeId& g : t->generics)
g = shallowClone(g);
for (TypePackId& gp : t->genericPacks)
gp = shallowClone(gp);
t->argTypes = shallowClone(t->argTypes);
t->retTypes = shallowClone(t->retTypes);
}
void cloneChildren(TableType* t)
{
if (t->indexer)
{
t->indexer->indexType = shallowClone(t->indexer->indexType);
t->indexer->indexResultType = shallowClone(t->indexer->indexResultType);
}
for (auto& [_, p] : t->props)
p = shallowClone(p);
for (TypeId& ty : t->instantiatedTypeParams)
ty = shallowClone(ty);
for (TypePackId& tp : t->instantiatedTypePackParams)
tp = shallowClone(tp);
}
void cloneChildren(MetatableType* t)
{
t->table = shallowClone(t->table);
t->metatable = shallowClone(t->metatable);
}
void cloneChildren(ClassType* t)
{
for (auto& [_, p] : t->props)
p = shallowClone(p);
if (t->parent)
t->parent = shallowClone(*t->parent);
if (t->metatable)
t->metatable = shallowClone(*t->metatable);
if (t->indexer)
{
t->indexer->indexType = shallowClone(t->indexer->indexType);
t->indexer->indexResultType = shallowClone(t->indexer->indexResultType);
}
}
void cloneChildren(AnyType* t)
{
// noop.
}
void cloneChildren(UnionType* t)
{
for (TypeId& ty : t->options)
ty = shallowClone(ty);
}
void cloneChildren(IntersectionType* t)
{
for (TypeId& ty : t->parts)
ty = shallowClone(ty);
}
void cloneChildren(LazyType* t)
{
if (auto unwrapped = t->unwrapped.load())
t->unwrapped.store(shallowClone(unwrapped));
}
void cloneChildren(UnknownType* t)
{
// noop.
}
void cloneChildren(NeverType* t)
{
// noop.
}
void cloneChildren(NegationType* t)
{
t->ty = shallowClone(t->ty);
}
void cloneChildren(TypeFamilyInstanceType* t)
{
for (TypeId& ty : t->typeArguments)
ty = shallowClone(ty);
for (TypePackId& tp : t->packArguments)
tp = shallowClone(tp);
}
void cloneChildren(FreeTypePack* t)
{
// TODO: clone lower and upper bounds.
// TODO: In the new solver, we should ice.
}
void cloneChildren(GenericTypePack* t)
{
// TOOD: clone upper bounds.
}
void cloneChildren(BlockedTypePack* t)
{
// TODO: In the new solver, we should ice.
}
void cloneChildren(BoundTypePack* t)
{
t->boundTo = shallowClone(t->boundTo);
}
void cloneChildren(VariadicTypePack* t)
{
t->ty = shallowClone(t->ty);
}
void cloneChildren(TypePack* t)
{
for (TypeId& ty : t->head)
ty = shallowClone(ty);
if (t->tail)
t->tail = shallowClone(*t->tail);
}
void cloneChildren(TypeFamilyInstanceTypePack* t)
{
for (TypeId& ty : t->typeArguments)
ty = shallowClone(ty);
for (TypePackId& tp : t->packArguments)
tp = shallowClone(tp);
}
};
} // namespace
namespace
{
Property clone(const Property& prop, TypeArena& dest, CloneState& cloneState)
{
if (FFlag::DebugLuauReadWriteProperties)
{
std::optional<TypeId> cloneReadTy;
if (auto ty = prop.readType())
cloneReadTy = clone(*ty, dest, cloneState);
std::optional<TypeId> cloneWriteTy;
if (auto ty = prop.writeType())
cloneWriteTy = clone(*ty, dest, cloneState);
std::optional<Property> cloned = Property::create(cloneReadTy, cloneWriteTy);
LUAU_ASSERT(cloned);
cloned->deprecated = prop.deprecated;
cloned->deprecatedSuggestion = prop.deprecatedSuggestion;
cloned->location = prop.location;
cloned->tags = prop.tags;
cloned->documentationSymbol = prop.documentationSymbol;
return *cloned;
}
else
{
return Property{
clone(prop.type(), dest, cloneState),
prop.deprecated,
prop.deprecatedSuggestion,
prop.location,
prop.tags,
prop.documentationSymbol,
};
}
}
static TableIndexer clone(const TableIndexer& indexer, TypeArena& dest, CloneState& cloneState)
{
return TableIndexer{clone(indexer.indexType, dest, cloneState), clone(indexer.indexResultType, dest, cloneState)};
}
struct TypePackCloner;
/*
@ -44,10 +508,11 @@ struct TypeCloner
template<typename T>
void defaultClone(const T& t);
void operator()(const Unifiable::Free& t);
void operator()(const Unifiable::Generic& t);
void operator()(const Unifiable::Bound<TypeId>& t);
void operator()(const Unifiable::Error& t);
void operator()(const FreeType& t);
void operator()(const LocalType& t);
void operator()(const GenericType& t);
void operator()(const BoundType& t);
void operator()(const ErrorType& t);
void operator()(const BlockedType& t);
void operator()(const PendingExpansionType& t);
void operator()(const PrimitiveType& t);
@ -63,6 +528,7 @@ struct TypeCloner
void operator()(const UnknownType& t);
void operator()(const NeverType& t);
void operator()(const NegationType& t);
void operator()(const TypeFamilyInstanceType& t);
};
struct TypePackCloner
@ -89,15 +555,15 @@ struct TypePackCloner
seenTypePacks[typePackId] = cloned;
}
void operator()(const Unifiable::Free& t)
void operator()(const FreeTypePack& t)
{
defaultClone(t);
}
void operator()(const Unifiable::Generic& t)
void operator()(const GenericTypePack& t)
{
defaultClone(t);
}
void operator()(const Unifiable::Error& t)
void operator()(const ErrorTypePack& t)
{
defaultClone(t);
}
@ -112,8 +578,6 @@ struct TypePackCloner
void operator()(const Unifiable::Bound<TypePackId>& t)
{
TypePackId cloned = clone(t.boundTo, dest, cloneState);
if (FFlag::DebugLuauCopyBeforeNormalizing)
cloned = dest.addTypePack(TypePackVar{BoundTypePack{cloned}});
seenTypePacks[typePackId] = cloned;
}
@ -136,6 +600,22 @@ struct TypePackCloner
if (t.tail)
destTp->tail = clone(*t.tail, dest, cloneState);
}
void operator()(const TypeFamilyInstanceTypePack& t)
{
TypePackId cloned = dest.addTypePack(TypeFamilyInstanceTypePack{t.family, {}, {}});
TypeFamilyInstanceTypePack* destTp = getMutable<TypeFamilyInstanceTypePack>(cloned);
LUAU_ASSERT(destTp);
seenTypePacks[typePackId] = cloned;
destTp->typeArguments.reserve(t.typeArguments.size());
for (TypeId ty : t.typeArguments)
destTp->typeArguments.push_back(clone(ty, dest, cloneState));
destTp->packArguments.reserve(t.packArguments.size());
for (TypePackId tp : t.packArguments)
destTp->packArguments.push_back(clone(tp, dest, cloneState));
}
};
template<typename T>
@ -145,12 +625,24 @@ void TypeCloner::defaultClone(const T& t)
seenTypes[typeId] = cloned;
}
void TypeCloner::operator()(const Unifiable::Free& t)
void TypeCloner::operator()(const FreeType& t)
{
if (FFlag::DebugLuauDeferredConstraintResolution)
{
FreeType ft{t.scope, clone(t.lowerBound, dest, cloneState), clone(t.upperBound, dest, cloneState)};
TypeId res = dest.addType(ft);
seenTypes[typeId] = res;
}
else
defaultClone(t);
}
void TypeCloner::operator()(const LocalType& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const Unifiable::Generic& t)
void TypeCloner::operator()(const GenericType& t)
{
defaultClone(t);
}
@ -158,8 +650,6 @@ void TypeCloner::operator()(const Unifiable::Generic& t)
void TypeCloner::operator()(const Unifiable::Bound<TypeId>& t)
{
TypeId boundTo = clone(t.boundTo, dest, cloneState);
if (FFlag::DebugLuauCopyBeforeNormalizing)
boundTo = dest.addType(BoundType{boundTo});
seenTypes[typeId] = boundTo;
}
@ -224,13 +714,14 @@ void TypeCloner::operator()(const FunctionType& t)
ftv->argTypes = clone(t.argTypes, dest, cloneState);
ftv->argNames = t.argNames;
ftv->retTypes = clone(t.retTypes, dest, cloneState);
ftv->hasNoGenerics = t.hasNoGenerics;
ftv->hasNoFreeOrGenericTypes = t.hasNoFreeOrGenericTypes;
ftv->isCheckedFunction = t.isCheckedFunction;
}
void TypeCloner::operator()(const TableType& t)
{
// If table is now bound to another one, we ignore the content of the original
if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo)
if (t.boundTo)
{
TypeId boundTo = clone(*t.boundTo, dest, cloneState);
seenTypes[typeId] = boundTo;
@ -247,14 +738,11 @@ void TypeCloner::operator()(const TableType& t)
ttv->level = TypeLevel{0, 0};
if (FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, cloneState);
for (const auto& [name, prop] : t.props)
ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags};
ttv->props[name] = clone(prop, dest, cloneState);
if (t.indexer)
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)};
ttv->indexer = clone(*t.indexer, dest, cloneState);
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = clone(arg, dest, cloneState);
@ -285,13 +773,16 @@ void TypeCloner::operator()(const ClassType& t)
seenTypes[typeId] = result;
for (const auto& [name, prop] : t.props)
ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags};
ctv->props[name] = clone(prop, dest, cloneState);
if (t.parent)
ctv->parent = clone(*t.parent, dest, cloneState);
if (t.metatable)
ctv->metatable = clone(*t.metatable, dest, cloneState);
if (t.indexer)
ctv->indexer = clone(*t.indexer, dest, cloneState);
}
void TypeCloner::operator()(const AnyType& t)
@ -301,14 +792,19 @@ void TypeCloner::operator()(const AnyType& t)
void TypeCloner::operator()(const UnionType& t)
{
// We're just using this FreeType as a placeholder until we've finished
// cloning the parts of this union so it is okay that its bounds are
// nullptr. We'll never indirect them.
TypeId result = dest.addType(FreeType{nullptr, /*lowerBound*/ nullptr, /*upperBound*/ nullptr});
seenTypes[typeId] = result;
std::vector<TypeId> options;
options.reserve(t.options.size());
for (TypeId ty : t.options)
options.push_back(clone(ty, dest, cloneState));
TypeId result = dest.addType(UnionType{std::move(options)});
seenTypes[typeId] = result;
asMutable(result)->ty.emplace<UnionType>(std::move(options));
}
void TypeCloner::operator()(const IntersectionType& t)
@ -325,7 +821,14 @@ void TypeCloner::operator()(const IntersectionType& t)
void TypeCloner::operator()(const LazyType& t)
{
defaultClone(t);
if (TypeId unwrapped = t.unwrapped.load())
{
seenTypes[typeId] = clone(unwrapped, dest, cloneState);
}
else
{
defaultClone(t);
}
}
void TypeCloner::operator()(const UnknownType& t)
@ -347,6 +850,28 @@ void TypeCloner::operator()(const NegationType& t)
asMutable(result)->ty = NegationType{ty};
}
void TypeCloner::operator()(const TypeFamilyInstanceType& t)
{
TypeId result = dest.addType(TypeFamilyInstanceType{
t.family,
{},
{},
});
seenTypes[typeId] = result;
TypeFamilyInstanceType* tfit = getMutable<TypeFamilyInstanceType>(result);
LUAU_ASSERT(tfit != nullptr);
tfit->typeArguments.reserve(t.typeArguments.size());
for (TypeId p : t.typeArguments)
tfit->typeArguments.push_back(clone(p, dest, cloneState));
tfit->packArguments.reserve(t.packArguments.size());
for (TypePackId p : t.packArguments)
tfit->packArguments.push_back(clone(p, dest, cloneState));
}
} // anonymous namespace
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
@ -354,17 +879,25 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
if (tp->persistent)
return tp;
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
TypePackId& res = cloneState.seenTypePacks[tp];
if (res == nullptr)
if (FFlag::LuauStacklessTypeClone3)
{
TypePackCloner cloner{dest, tp, cloneState};
Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into.
TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
return cloner.clone(tp);
}
else
{
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
return res;
TypePackId& res = cloneState.seenTypePacks[tp];
if (res == nullptr)
{
TypePackCloner cloner{dest, tp, cloneState};
Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into.
}
return res;
}
}
TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
@ -372,136 +905,91 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
if (typeId->persistent)
return typeId;
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
TypeId& res = cloneState.seenTypes[typeId];
if (res == nullptr)
if (FFlag::LuauStacklessTypeClone3)
{
TypeCloner cloner{dest, typeId, cloneState};
Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into.
// Persistent types are not being cloned and we get the original type back which might be read-only
if (!res->persistent)
{
asMutable(res)->documentationSymbol = typeId->documentationSymbol;
}
TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
return cloner.clone(typeId);
}
else
{
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
return res;
TypeId& res = cloneState.seenTypes[typeId];
if (res == nullptr)
{
TypeCloner cloner{dest, typeId, cloneState};
Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into.
// Persistent types are not being cloned and we get the original type back which might be read-only
if (!res->persistent)
{
asMutable(res)->documentationSymbol = typeId->documentationSymbol;
}
}
return res;
}
}
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
{
TypeFun result;
for (auto param : typeFun.typeParams)
if (FFlag::LuauStacklessTypeClone3)
{
TypeId ty = clone(param.ty, dest, cloneState);
std::optional<TypeId> defaultValue;
TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, cloneState);
TypeFun copy = typeFun;
result.typeParams.push_back({ty, defaultValue});
}
for (auto& param : copy.typeParams)
{
param.ty = cloner.clone(param.ty);
for (auto param : typeFun.typePackParams)
{
TypePackId tp = clone(param.tp, dest, cloneState);
std::optional<TypePackId> defaultValue;
if (param.defaultValue)
param.defaultValue = cloner.clone(*param.defaultValue);
}
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, cloneState);
for (auto& param : copy.typePackParams)
{
param.tp = cloner.clone(param.tp);
result.typePackParams.push_back({tp, defaultValue});
}
if (param.defaultValue)
param.defaultValue = cloner.clone(*param.defaultValue);
}
result.type = clone(typeFun.type, dest, cloneState);
copy.type = cloner.clone(copy.type);
return result;
}
TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysClone)
{
ty = log->follow(ty);
TypeId result = ty;
if (auto pty = log->pending(ty))
ty = &pty->pending;
if (const FunctionType* ftv = get<FunctionType>(ty))
{
FunctionType clone = FunctionType{ftv->level, ftv->scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf};
clone.generics = ftv->generics;
clone.genericPacks = ftv->genericPacks;
clone.magicFunction = ftv->magicFunction;
clone.dcrMagicFunction = ftv->dcrMagicFunction;
clone.dcrMagicRefinement = ftv->dcrMagicRefinement;
clone.tags = ftv->tags;
clone.argNames = ftv->argNames;
result = dest.addType(std::move(clone));
}
else if (const TableType* ttv = get<TableType>(ty))
{
LUAU_ASSERT(!ttv->boundTo);
TableType clone = TableType{ttv->props, ttv->indexer, ttv->level, ttv->scope, ttv->state};
clone.definitionModuleName = ttv->definitionModuleName;
clone.definitionLocation = ttv->definitionLocation;
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.instantiatedTypeParams = ttv->instantiatedTypeParams;
clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams;
clone.tags = ttv->tags;
result = dest.addType(std::move(clone));
}
else if (const MetatableType* mtv = get<MetatableType>(ty))
{
MetatableType clone = MetatableType{mtv->table, mtv->metatable};
clone.syntheticName = mtv->syntheticName;
result = dest.addType(std::move(clone));
}
else if (const UnionType* utv = get<UnionType>(ty))
{
UnionType clone;
clone.options = utv->options;
result = dest.addType(std::move(clone));
}
else if (const IntersectionType* itv = get<IntersectionType>(ty))
{
IntersectionType clone;
clone.parts = itv->parts;
result = dest.addType(std::move(clone));
}
else if (const PendingExpansionType* petv = get<PendingExpansionType>(ty))
{
PendingExpansionType clone{petv->prefix, petv->name, petv->typeArguments, petv->packArguments};
result = dest.addType(std::move(clone));
}
else if (const ClassType* ctv = get<ClassType>(ty); FFlag::LuauClonePublicInterfaceLess && ctv && alwaysClone)
{
ClassType clone{ctv->name, ctv->props, ctv->parent, ctv->metatable, ctv->tags, ctv->userData, ctv->definitionModuleName};
result = dest.addType(std::move(clone));
}
else if (FFlag::LuauClonePublicInterfaceLess && alwaysClone)
{
result = dest.addType(*ty);
}
else if (const NegationType* ntv = get<NegationType>(ty))
{
result = dest.addType(NegationType{ntv->ty});
return copy;
}
else
{
TypeFun result;
for (auto param : typeFun.typeParams)
{
TypeId ty = clone(param.ty, dest, cloneState);
std::optional<TypeId> defaultValue;
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, cloneState);
result.typeParams.push_back({ty, defaultValue});
}
for (auto param : typeFun.typePackParams)
{
TypePackId tp = clone(param.tp, dest, cloneState);
std::optional<TypePackId> defaultValue;
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, cloneState);
result.typePackParams.push_back({tp, defaultValue});
}
result.type = clone(typeFun.type, dest, cloneState);
return result;
asMutable(result)->documentationSymbol = ty->documentationSymbol;
return result;
}
TypeId shallowClone(TypeId ty, NotNull<TypeArena> dest)
{
return shallowClone(ty, *dest, TxnLog::empty());
}
}
} // namespace Luau

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Constraint.h"
#include "Luau/VisitType.h"
namespace Luau
{
@ -12,4 +13,40 @@ Constraint::Constraint(NotNull<Scope> scope, const Location& location, Constrain
{
}
struct FreeTypeCollector : TypeOnceVisitor
{
DenseHashSet<TypeId>* result;
FreeTypeCollector(DenseHashSet<TypeId>* result)
: result(result)
{
}
bool visit(TypeId ty, const FreeType&) override
{
result->insert(ty);
return false;
}
};
DenseHashSet<TypeId> Constraint::getFreeTypes() const
{
DenseHashSet<TypeId> types{{}};
FreeTypeCollector ftc{&types};
if (auto sc = get<SubtypeConstraint>(*this))
{
ftc.traverse(sc->subType);
ftc.traverse(sc->superType);
}
else if (auto psc = get<PackSubtypeConstraint>(*this))
{
ftc.traverse(psc->subPack);
ftc.traverse(psc->superPack);
}
return types;
}
} // namespace Luau

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,12 +1,58 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Def.h"
#include "Luau/Common.h"
#include <algorithm>
#include <deque>
namespace Luau
{
DefId DefArena::freshCell()
bool containsSubscriptedDefinition(DefId def)
{
return NotNull{allocator.allocate(Def{Cell{}})};
if (auto cell = get<Cell>(def))
return cell->subscripted;
else if (auto phi = get<Phi>(def))
return std::any_of(phi->operands.begin(), phi->operands.end(), containsSubscriptedDefinition);
else
return false;
}
DefId DefArena::freshCell(bool subscripted)
{
return NotNull{allocator.allocate(Def{Cell{subscripted}})};
}
static void collectOperands(DefId def, std::vector<DefId>& operands)
{
if (std::find(operands.begin(), operands.end(), def) != operands.end())
return;
else if (get<Cell>(def))
operands.push_back(def);
else if (auto phi = get<Phi>(def))
{
for (const Def* operand : phi->operands)
collectOperands(NotNull{operand}, operands);
}
}
DefId DefArena::phi(DefId a, DefId b)
{
return phi({a, b});
}
DefId DefArena::phi(const std::vector<DefId>& defs)
{
std::vector<DefId> operands;
for (DefId operand : defs)
collectOperands(operand, operands);
// There's no need to allocate a Phi node for a singleton set.
if (operands.size() == 1)
return operands[0];
else
return NotNull{allocator.allocate(Def{Phi{std::move(operands)}})};
}
} // namespace Luau

948
Analysis/src/Differ.cpp Normal file
View file

@ -0,0 +1,948 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Differ.h"
#include "Luau/Common.h"
#include "Luau/Error.h"
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypePack.h"
#include "Luau/Unifiable.h"
#include <optional>
#include <string>
#include <unordered_set>
#include <vector>
namespace Luau
{
std::string DiffPathNode::toString() const
{
switch (kind)
{
case DiffPathNode::Kind::TableProperty:
{
if (!tableProperty.has_value())
throw InternalCompilerError{"DiffPathNode has kind TableProperty but tableProperty is nullopt"};
return *tableProperty;
break;
}
case DiffPathNode::Kind::FunctionArgument:
{
if (!index.has_value())
return "Arg[Variadic]";
// Add 1 because Lua is 1-indexed
return "Arg[" + std::to_string(*index + 1) + "]";
}
case DiffPathNode::Kind::FunctionReturn:
{
if (!index.has_value())
return "Ret[Variadic]";
// Add 1 because Lua is 1-indexed
return "Ret[" + std::to_string(*index + 1) + "]";
}
case DiffPathNode::Kind::Negation:
{
return "Negation";
}
default:
{
throw InternalCompilerError{"DiffPathNode::toString is not exhaustive"};
}
}
}
DiffPathNode DiffPathNode::constructWithTableProperty(Name tableProperty)
{
return DiffPathNode{DiffPathNode::Kind::TableProperty, tableProperty, std::nullopt};
}
DiffPathNode DiffPathNode::constructWithKindAndIndex(Kind kind, size_t index)
{
return DiffPathNode{kind, std::nullopt, index};
}
DiffPathNode DiffPathNode::constructWithKind(Kind kind)
{
return DiffPathNode{kind, std::nullopt, std::nullopt};
}
DiffPathNodeLeaf DiffPathNodeLeaf::detailsNormal(TypeId ty)
{
return DiffPathNodeLeaf{ty, std::nullopt, std::nullopt, false, std::nullopt};
}
DiffPathNodeLeaf DiffPathNodeLeaf::detailsTableProperty(TypeId ty, Name tableProperty)
{
return DiffPathNodeLeaf{ty, tableProperty, std::nullopt, false, std::nullopt};
}
DiffPathNodeLeaf DiffPathNodeLeaf::detailsUnionIndex(TypeId ty, size_t index)
{
return DiffPathNodeLeaf{ty, std::nullopt, std::nullopt, false, index};
}
DiffPathNodeLeaf DiffPathNodeLeaf::detailsLength(int minLength, bool isVariadic)
{
return DiffPathNodeLeaf{std::nullopt, std::nullopt, minLength, isVariadic, std::nullopt};
}
DiffPathNodeLeaf DiffPathNodeLeaf::nullopts()
{
return DiffPathNodeLeaf{std::nullopt, std::nullopt, std::nullopt, false, std::nullopt};
}
std::string DiffPath::toString(bool prependDot) const
{
std::string pathStr;
bool isFirstInForLoop = !prependDot;
for (auto node = path.rbegin(); node != path.rend(); node++)
{
if (isFirstInForLoop)
{
isFirstInForLoop = false;
}
else
{
pathStr += ".";
}
pathStr += node->toString();
}
return pathStr;
}
std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf, bool multiLine) const
{
std::string conditionalNewline = multiLine ? "\n" : " ";
std::string conditionalIndent = multiLine ? " " : "";
std::string pathStr{rootName + diffPath.toString(true)};
switch (kind)
{
case DiffError::Kind::Normal:
{
checkNonMissingPropertyLeavesHaveNulloptTableProperty();
return pathStr + conditionalNewline + "has type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty);
}
case DiffError::Kind::MissingTableProperty:
{
if (leaf.ty.has_value())
{
if (!leaf.tableProperty.has_value())
throw InternalCompilerError{"leaf.tableProperty is nullopt"};
return pathStr + "." + *leaf.tableProperty + conditionalNewline + "has type" + conditionalNewline + conditionalIndent +
Luau::toString(*leaf.ty);
}
else if (otherLeaf.ty.has_value())
{
if (!otherLeaf.tableProperty.has_value())
throw InternalCompilerError{"otherLeaf.tableProperty is nullopt"};
return pathStr + conditionalNewline + "is missing the property" + conditionalNewline + conditionalIndent + *otherLeaf.tableProperty;
}
throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"};
}
case DiffError::Kind::MissingUnionMember:
{
// TODO: do normal case
if (leaf.ty.has_value())
{
if (!leaf.unionIndex.has_value())
throw InternalCompilerError{"leaf.unionIndex is nullopt"};
return pathStr + conditionalNewline + "is a union containing type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty);
}
else if (otherLeaf.ty.has_value())
{
return pathStr + conditionalNewline + "is a union missing type" + conditionalNewline + conditionalIndent + Luau::toString(*otherLeaf.ty);
}
throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"};
}
case DiffError::Kind::MissingIntersectionMember:
{
// TODO: better message for intersections
// An intersection of just functions is always an "overloaded function"
// An intersection of just tables is always a "joined table"
if (leaf.ty.has_value())
{
if (!leaf.unionIndex.has_value())
throw InternalCompilerError{"leaf.unionIndex is nullopt"};
return pathStr + conditionalNewline + "is an intersection containing type" + conditionalNewline + conditionalIndent +
Luau::toString(*leaf.ty);
}
else if (otherLeaf.ty.has_value())
{
return pathStr + conditionalNewline + "is an intersection missing type" + conditionalNewline + conditionalIndent +
Luau::toString(*otherLeaf.ty);
}
throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"};
}
case DiffError::Kind::LengthMismatchInFnArgs:
{
if (!leaf.minLength.has_value())
throw InternalCompilerError{"leaf.minLength is nullopt"};
return pathStr + conditionalNewline + "takes " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " arguments";
}
case DiffError::Kind::LengthMismatchInFnRets:
{
if (!leaf.minLength.has_value())
throw InternalCompilerError{"leaf.minLength is nullopt"};
return pathStr + conditionalNewline + "returns " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " values";
}
default:
{
throw InternalCompilerError{"DiffPath::toStringALeaf is not exhaustive"};
}
}
}
void DiffError::checkNonMissingPropertyLeavesHaveNulloptTableProperty() const
{
if (left.tableProperty.has_value() || right.tableProperty.has_value())
throw InternalCompilerError{"Non-MissingProperty DiffError should have nullopt tableProperty in both leaves"};
}
std::string getDevFixFriendlyName(const std::optional<std::string>& maybeSymbol, TypeId ty)
{
if (maybeSymbol.has_value())
return *maybeSymbol;
if (auto table = get<TableType>(ty))
{
if (table->name.has_value())
return *table->name;
else if (table->syntheticName.has_value())
return *table->syntheticName;
}
if (auto metatable = get<MetatableType>(ty))
{
if (metatable->syntheticName.has_value())
{
return *metatable->syntheticName;
}
}
return "<unlabeled-symbol>";
}
std::string DifferEnvironment::getDevFixFriendlyNameLeft() const
{
return getDevFixFriendlyName(externalSymbolLeft, rootLeft);
}
std::string DifferEnvironment::getDevFixFriendlyNameRight() const
{
return getDevFixFriendlyName(externalSymbolRight, rootRight);
}
std::string DiffError::toString(bool multiLine) const
{
std::string conditionalNewline = multiLine ? "\n" : " ";
std::string conditionalIndent = multiLine ? " " : "";
switch (kind)
{
case DiffError::Kind::IncompatibleGeneric:
{
std::string diffPathStr{diffPath.toString(true)};
return "DiffError: these two types are not equal because the left generic at" + conditionalNewline + conditionalIndent + leftRootName +
diffPathStr + conditionalNewline + "cannot be the same type parameter as the right generic at" + conditionalNewline +
conditionalIndent + rightRootName + diffPathStr;
}
default:
{
return "DiffError: these two types are not equal because the left type at" + conditionalNewline + conditionalIndent +
toStringALeaf(leftRootName, left, right, multiLine) + "," + conditionalNewline + "while the right type at" + conditionalNewline +
conditionalIndent + toStringALeaf(rightRootName, right, left, multiLine);
}
}
}
void DiffError::checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right)
{
if (!left.ty.has_value() || !right.ty.has_value())
{
// TODO: think about whether this should be always thrown!
// For example, Kind::Primitive doesn't make too much sense to have a TypeId
// throw InternalCompilerError{"Left and Right fields are leaf nodes and must have a TypeId"};
}
}
void DifferResult::wrapDiffPath(DiffPathNode node)
{
if (!diffError.has_value())
{
throw InternalCompilerError{"Cannot wrap diffPath because there is no diffError"};
}
diffError->diffPath.path.push_back(node);
}
static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffMetatable(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffNegation(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right);
struct FindSeteqCounterexampleResult
{
// nullopt if no counterexample found
std::optional<size_t> mismatchIdx;
// true if counterexample is in the left, false if cex is in the right
bool inLeft;
};
static FindSeteqCounterexampleResult findSeteqCounterexample(
DifferEnvironment& env, const std::vector<TypeId>& left, const std::vector<TypeId>& right);
static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right);
static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId right);
/**
* The last argument gives context info on which complex type contained the TypePack.
*/
static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right);
static DifferResult diffCanonicalTpShape(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& left, const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& right);
static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right);
static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypePackId right);
static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right)
{
const TableType* leftTable = get<TableType>(left);
const TableType* rightTable = get<TableType>(right);
LUAU_ASSERT(leftTable);
LUAU_ASSERT(rightTable);
for (auto const& [field, value] : leftTable->props)
{
if (rightTable->props.find(field) == rightTable->props.end())
{
// left has a field the right doesn't
return DifferResult{DiffError{
DiffError::Kind::MissingTableProperty,
DiffPathNodeLeaf::detailsTableProperty(value.type(), field),
DiffPathNodeLeaf::nullopts(),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
}
for (auto const& [field, value] : rightTable->props)
{
if (leftTable->props.find(field) == leftTable->props.end())
{
// right has a field the left doesn't
return DifferResult{DiffError{DiffError::Kind::MissingTableProperty, DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::detailsTableProperty(value.type(), field), env.getDevFixFriendlyNameLeft(), env.getDevFixFriendlyNameRight()}};
}
}
// left and right have the same set of keys
for (auto const& [field, leftValue] : leftTable->props)
{
auto const& rightValue = rightTable->props.at(field);
DifferResult differResult = diffUsingEnv(env, leftValue.type(), rightValue.type());
if (differResult.diffError.has_value())
{
differResult.wrapDiffPath(DiffPathNode::constructWithTableProperty(field));
return differResult;
}
}
return DifferResult{};
}
static DifferResult diffMetatable(DifferEnvironment& env, TypeId left, TypeId right)
{
const MetatableType* leftMetatable = get<MetatableType>(left);
const MetatableType* rightMetatable = get<MetatableType>(right);
LUAU_ASSERT(leftMetatable);
LUAU_ASSERT(rightMetatable);
DifferResult diffRes = diffUsingEnv(env, leftMetatable->table, rightMetatable->table);
if (diffRes.diffError.has_value())
{
return diffRes;
}
diffRes = diffUsingEnv(env, leftMetatable->metatable, rightMetatable->metatable);
if (diffRes.diffError.has_value())
{
diffRes.wrapDiffPath(DiffPathNode::constructWithTableProperty("__metatable"));
return diffRes;
}
return DifferResult{};
}
static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId right)
{
const PrimitiveType* leftPrimitive = get<PrimitiveType>(left);
const PrimitiveType* rightPrimitive = get<PrimitiveType>(right);
LUAU_ASSERT(leftPrimitive);
LUAU_ASSERT(rightPrimitive);
if (leftPrimitive->type != rightPrimitive->type)
{
return DifferResult{DiffError{
DiffError::Kind::Normal,
DiffPathNodeLeaf::detailsNormal(left),
DiffPathNodeLeaf::detailsNormal(right),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
return DifferResult{};
}
static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId right)
{
const SingletonType* leftSingleton = get<SingletonType>(left);
const SingletonType* rightSingleton = get<SingletonType>(right);
LUAU_ASSERT(leftSingleton);
LUAU_ASSERT(rightSingleton);
if (*leftSingleton != *rightSingleton)
{
return DifferResult{DiffError{
DiffError::Kind::Normal,
DiffPathNodeLeaf::detailsNormal(left),
DiffPathNodeLeaf::detailsNormal(right),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
return DifferResult{};
}
static DifferResult diffFunction(DifferEnvironment& env, TypeId left, TypeId right)
{
const FunctionType* leftFunction = get<FunctionType>(left);
const FunctionType* rightFunction = get<FunctionType>(right);
LUAU_ASSERT(leftFunction);
LUAU_ASSERT(rightFunction);
DifferResult differResult = diffTpi(env, DiffError::Kind::LengthMismatchInFnArgs, leftFunction->argTypes, rightFunction->argTypes);
if (differResult.diffError.has_value())
return differResult;
return diffTpi(env, DiffError::Kind::LengthMismatchInFnRets, leftFunction->retTypes, rightFunction->retTypes);
}
static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId right)
{
LUAU_ASSERT(get<GenericType>(left));
LUAU_ASSERT(get<GenericType>(right));
// Try to pair up the generics
bool isLeftFree = !env.genericMatchedPairs.contains(left);
bool isRightFree = !env.genericMatchedPairs.contains(right);
if (isLeftFree && isRightFree)
{
env.genericMatchedPairs[left] = right;
env.genericMatchedPairs[right] = left;
return DifferResult{};
}
else if (isLeftFree || isRightFree)
{
return DifferResult{DiffError{
DiffError::Kind::IncompatibleGeneric,
DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::nullopts(),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
// Both generics are already paired up
if (*env.genericMatchedPairs.find(left) == right)
return DifferResult{};
return DifferResult{DiffError{
DiffError::Kind::IncompatibleGeneric,
DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::nullopts(),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
static DifferResult diffNegation(DifferEnvironment& env, TypeId left, TypeId right)
{
const NegationType* leftNegation = get<NegationType>(left);
const NegationType* rightNegation = get<NegationType>(right);
LUAU_ASSERT(leftNegation);
LUAU_ASSERT(rightNegation);
DifferResult differResult = diffUsingEnv(env, leftNegation->ty, rightNegation->ty);
if (!differResult.diffError.has_value())
return DifferResult{};
differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::Negation));
return differResult;
}
static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right)
{
const ClassType* leftClass = get<ClassType>(left);
const ClassType* rightClass = get<ClassType>(right);
LUAU_ASSERT(leftClass);
LUAU_ASSERT(rightClass);
if (leftClass == rightClass)
{
return DifferResult{};
}
return DifferResult{DiffError{
DiffError::Kind::Normal,
DiffPathNodeLeaf::detailsNormal(left),
DiffPathNodeLeaf::detailsNormal(right),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
static FindSeteqCounterexampleResult findSeteqCounterexample(
DifferEnvironment& env, const std::vector<TypeId>& left, const std::vector<TypeId>& right)
{
std::unordered_set<size_t> unmatchedRightIdxes;
for (size_t i = 0; i < right.size(); i++)
unmatchedRightIdxes.insert(i);
for (size_t leftIdx = 0; leftIdx < left.size(); leftIdx++)
{
bool leftIdxIsMatched = false;
auto unmatchedRightIdxIt = unmatchedRightIdxes.begin();
while (unmatchedRightIdxIt != unmatchedRightIdxes.end())
{
DifferResult differResult = diffUsingEnv(env, left[leftIdx], right[*unmatchedRightIdxIt]);
if (differResult.diffError.has_value())
{
unmatchedRightIdxIt++;
continue;
}
// unmatchedRightIdxIt is matched with current leftIdx
env.recordProvenEqual(left[leftIdx], right[*unmatchedRightIdxIt]);
leftIdxIsMatched = true;
unmatchedRightIdxIt = unmatchedRightIdxes.erase(unmatchedRightIdxIt);
}
if (!leftIdxIsMatched)
{
return FindSeteqCounterexampleResult{leftIdx, true};
}
}
if (unmatchedRightIdxes.empty())
return FindSeteqCounterexampleResult{std::nullopt, false};
return FindSeteqCounterexampleResult{*unmatchedRightIdxes.begin(), false};
}
static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right)
{
const UnionType* leftUnion = get<UnionType>(left);
const UnionType* rightUnion = get<UnionType>(right);
LUAU_ASSERT(leftUnion);
LUAU_ASSERT(rightUnion);
FindSeteqCounterexampleResult findSeteqCexResult = findSeteqCounterexample(env, leftUnion->options, rightUnion->options);
if (findSeteqCexResult.mismatchIdx.has_value())
{
if (findSeteqCexResult.inLeft)
return DifferResult{DiffError{
DiffError::Kind::MissingUnionMember,
DiffPathNodeLeaf::detailsUnionIndex(leftUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx),
DiffPathNodeLeaf::nullopts(),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
else
return DifferResult{DiffError{
DiffError::Kind::MissingUnionMember,
DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::detailsUnionIndex(rightUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
// TODO: somehow detect mismatch index, likely using heuristics
return DifferResult{};
}
static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId right)
{
const IntersectionType* leftIntersection = get<IntersectionType>(left);
const IntersectionType* rightIntersection = get<IntersectionType>(right);
LUAU_ASSERT(leftIntersection);
LUAU_ASSERT(rightIntersection);
FindSeteqCounterexampleResult findSeteqCexResult = findSeteqCounterexample(env, leftIntersection->parts, rightIntersection->parts);
if (findSeteqCexResult.mismatchIdx.has_value())
{
if (findSeteqCexResult.inLeft)
return DifferResult{DiffError{
DiffError::Kind::MissingIntersectionMember,
DiffPathNodeLeaf::detailsUnionIndex(leftIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx),
DiffPathNodeLeaf::nullopts(),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
else
return DifferResult{DiffError{
DiffError::Kind::MissingIntersectionMember,
DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::detailsUnionIndex(rightIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
// TODO: somehow detect mismatch index, likely using heuristics
return DifferResult{};
}
static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId right)
{
left = follow(left);
right = follow(right);
if (left->ty.index() != right->ty.index())
{
return DifferResult{DiffError{
DiffError::Kind::Normal,
DiffPathNodeLeaf::detailsNormal(left),
DiffPathNodeLeaf::detailsNormal(right),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
// Both left and right are the same variant
// Check cycles & caches
if (env.isAssumedEqual(left, right) || env.isProvenEqual(left, right))
return DifferResult{};
if (isSimple(left))
{
if (auto lp = get<PrimitiveType>(left))
return diffPrimitive(env, left, right);
else if (auto ls = get<SingletonType>(left))
{
return diffSingleton(env, left, right);
}
else if (auto la = get<AnyType>(left))
{
// Both left and right must be Any if either is Any for them to be equal!
return DifferResult{};
}
else if (auto lu = get<UnknownType>(left))
{
return DifferResult{};
}
else if (auto ln = get<NeverType>(left))
{
return DifferResult{};
}
else if (auto ln = get<NegationType>(left))
{
return diffNegation(env, left, right);
}
else if (auto lc = get<ClassType>(left))
{
return diffClass(env, left, right);
}
throw InternalCompilerError{"Unimplemented Simple TypeId variant for diffing"};
}
// Both left and right are the same non-Simple
// Non-simple types must record visits in the DifferEnvironment
env.pushVisiting(left, right);
if (auto lt = get<TableType>(left))
{
DifferResult diffRes = diffTable(env, left, right);
if (!diffRes.diffError.has_value())
{
env.recordProvenEqual(left, right);
}
env.popVisiting();
return diffRes;
}
if (auto lm = get<MetatableType>(left))
{
env.popVisiting();
return diffMetatable(env, left, right);
}
if (auto lf = get<FunctionType>(left))
{
DifferResult diffRes = diffFunction(env, left, right);
if (!diffRes.diffError.has_value())
{
env.recordProvenEqual(left, right);
}
env.popVisiting();
return diffRes;
}
if (auto lg = get<GenericType>(left))
{
DifferResult diffRes = diffGeneric(env, left, right);
if (!diffRes.diffError.has_value())
{
env.recordProvenEqual(left, right);
}
env.popVisiting();
return diffRes;
}
if (auto lu = get<UnionType>(left))
{
DifferResult diffRes = diffUnion(env, left, right);
if (!diffRes.diffError.has_value())
{
env.recordProvenEqual(left, right);
}
env.popVisiting();
return diffRes;
}
if (auto li = get<IntersectionType>(left))
{
DifferResult diffRes = diffIntersection(env, left, right);
if (!diffRes.diffError.has_value())
{
env.recordProvenEqual(left, right);
}
env.popVisiting();
return diffRes;
}
if (auto le = get<Luau::Unifiable::Error>(left))
{
// TODO: return debug-friendly result state
env.popVisiting();
return DifferResult{};
}
throw InternalCompilerError{"Unimplemented non-simple TypeId variant for diffing"};
}
static DifferResult diffTpi(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right)
{
left = follow(left);
right = follow(right);
// Canonicalize
std::pair<std::vector<TypeId>, std::optional<TypePackId>> leftFlatTpi = flatten(left);
std::pair<std::vector<TypeId>, std::optional<TypePackId>> rightFlatTpi = flatten(right);
// Check for shape equality
DifferResult diffResult = diffCanonicalTpShape(env, possibleNonNormalErrorKind, leftFlatTpi, rightFlatTpi);
if (diffResult.diffError.has_value())
{
return diffResult;
}
// Left and Right have the same shape
for (size_t i = 0; i < leftFlatTpi.first.size(); i++)
{
DifferResult differResult = diffUsingEnv(env, leftFlatTpi.first[i], rightFlatTpi.first[i]);
if (!differResult.diffError.has_value())
continue;
switch (possibleNonNormalErrorKind)
{
case DiffError::Kind::LengthMismatchInFnArgs:
{
differResult.wrapDiffPath(DiffPathNode::constructWithKindAndIndex(DiffPathNode::Kind::FunctionArgument, i));
return differResult;
}
case DiffError::Kind::LengthMismatchInFnRets:
{
differResult.wrapDiffPath(DiffPathNode::constructWithKindAndIndex(DiffPathNode::Kind::FunctionReturn, i));
return differResult;
}
default:
{
throw InternalCompilerError{"Unhandled Tpi diffing case with same shape"};
}
}
}
if (!leftFlatTpi.second.has_value())
return DifferResult{};
return diffHandleFlattenedTail(env, possibleNonNormalErrorKind, *leftFlatTpi.second, *rightFlatTpi.second);
}
static DifferResult diffCanonicalTpShape(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind,
const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& left, const std::pair<std::vector<TypeId>, std::optional<TypePackId>>& right)
{
if (left.first.size() == right.first.size() && left.second.has_value() == right.second.has_value())
return DifferResult{};
return DifferResult{DiffError{
possibleNonNormalErrorKind,
DiffPathNodeLeaf::detailsLength(int(left.first.size()), left.second.has_value()),
DiffPathNodeLeaf::detailsLength(int(right.first.size()), right.second.has_value()),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, TypePackId left, TypePackId right)
{
left = follow(left);
right = follow(right);
if (left->ty.index() != right->ty.index())
{
return DifferResult{DiffError{
DiffError::Kind::Normal,
DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->first),
DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->second),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
// Both left and right are the same variant
if (auto lv = get<VariadicTypePack>(left))
{
auto rv = get<VariadicTypePack>(right);
DifferResult differResult = diffUsingEnv(env, lv->ty, rv->ty);
if (!differResult.diffError.has_value())
return DifferResult{};
switch (possibleNonNormalErrorKind)
{
case DiffError::Kind::LengthMismatchInFnArgs:
{
differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionArgument));
return differResult;
}
case DiffError::Kind::LengthMismatchInFnRets:
{
differResult.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionReturn));
return differResult;
}
default:
{
throw InternalCompilerError{"Unhandled flattened tail case for VariadicTypePack"};
}
}
}
if (auto lg = get<GenericTypePack>(left))
{
DifferResult diffRes = diffGenericTp(env, left, right);
if (!diffRes.diffError.has_value())
return DifferResult{};
switch (possibleNonNormalErrorKind)
{
case DiffError::Kind::LengthMismatchInFnArgs:
{
diffRes.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionArgument));
return diffRes;
}
case DiffError::Kind::LengthMismatchInFnRets:
{
diffRes.wrapDiffPath(DiffPathNode::constructWithKind(DiffPathNode::Kind::FunctionReturn));
return diffRes;
}
default:
{
throw InternalCompilerError{"Unhandled flattened tail case for GenericTypePack"};
}
}
}
throw InternalCompilerError{"Unhandled tail type pack variant for flattened tails"};
}
static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypePackId right)
{
LUAU_ASSERT(get<GenericTypePack>(left));
LUAU_ASSERT(get<GenericTypePack>(right));
// Try to pair up the generics
bool isLeftFree = !env.genericTpMatchedPairs.contains(left);
bool isRightFree = !env.genericTpMatchedPairs.contains(right);
if (isLeftFree && isRightFree)
{
env.genericTpMatchedPairs[left] = right;
env.genericTpMatchedPairs[right] = left;
return DifferResult{};
}
else if (isLeftFree || isRightFree)
{
return DifferResult{DiffError{
DiffError::Kind::IncompatibleGeneric,
DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::nullopts(),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
// Both generics are already paired up
if (*env.genericTpMatchedPairs.find(left) == right)
return DifferResult{};
return DifferResult{DiffError{
DiffError::Kind::IncompatibleGeneric,
DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::nullopts(),
env.getDevFixFriendlyNameLeft(),
env.getDevFixFriendlyNameRight(),
}};
}
bool DifferEnvironment::isProvenEqual(TypeId left, TypeId right) const
{
return provenEqual.find({left, right}) != provenEqual.end();
}
bool DifferEnvironment::isAssumedEqual(TypeId left, TypeId right) const
{
return visiting.find({left, right}) != visiting.end();
}
void DifferEnvironment::recordProvenEqual(TypeId left, TypeId right)
{
provenEqual.insert({left, right});
provenEqual.insert({right, left});
}
void DifferEnvironment::pushVisiting(TypeId left, TypeId right)
{
LUAU_ASSERT(visiting.find({left, right}) == visiting.end());
LUAU_ASSERT(visiting.find({right, left}) == visiting.end());
visitingStack.push_back({left, right});
visiting.insert({left, right});
visiting.insert({right, left});
}
void DifferEnvironment::popVisiting()
{
auto tyPair = visitingStack.back();
visiting.erase({tyPair.first, tyPair.second});
visiting.erase({tyPair.second, tyPair.first});
visitingStack.pop_back();
}
std::vector<std::pair<TypeId, TypeId>>::const_reverse_iterator DifferEnvironment::visitingBegin() const
{
return visitingStack.crbegin();
}
std::vector<std::pair<TypeId, TypeId>>::const_reverse_iterator DifferEnvironment::visitingEnd() const
{
return visitingStack.crend();
}
DifferResult diff(TypeId ty1, TypeId ty2)
{
DifferEnvironment differEnv{ty1, ty2, std::nullopt, std::nullopt};
return diffUsingEnv(differEnv, ty1, ty2);
}
DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional<std::string> symbol1, std::optional<std::string> symbol2)
{
DifferEnvironment differEnv{ty1, ty2, symbol1, symbol2};
return diffUsingEnv(differEnv, ty1, ty2);
}
bool isSimple(TypeId ty)
{
ty = follow(ty);
// TODO: think about GenericType, etc.
return get<PrimitiveType>(ty) || get<SingletonType>(ty) || get<AnyType>(ty) || get<NegationType>(ty) || get<ClassType>(ty) ||
get<UnknownType>(ty) || get<NeverType>(ty);
}
} // namespace Luau

View file

@ -1,9 +1,76 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAGVARIABLE(LuauBufferDefinitions, false)
LUAU_FASTFLAGVARIABLE(LuauBufferTypeck, false)
namespace Luau
{
static const std::string kBuiltinDefinitionBufferSrc_DEPRECATED = R"BUILTIN_SRC(
-- TODO: this will be replaced with a built-in primitive type
declare class buffer end
declare buffer: {
create: (size: number) -> buffer,
fromstring: (str: string) -> buffer,
tostring: () -> string,
len: (b: buffer) -> number,
copy: (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (),
fill: (b: buffer, offset: number, value: number, count: number?) -> (),
readi8: (b: buffer, offset: number) -> number,
readu8: (b: buffer, offset: number) -> number,
readi16: (b: buffer, offset: number) -> number,
readu16: (b: buffer, offset: number) -> number,
readi32: (b: buffer, offset: number) -> number,
readu32: (b: buffer, offset: number) -> number,
readf32: (b: buffer, offset: number) -> number,
readf64: (b: buffer, offset: number) -> number,
writei8: (b: buffer, offset: number, value: number) -> (),
writeu8: (b: buffer, offset: number, value: number) -> (),
writei16: (b: buffer, offset: number, value: number) -> (),
writeu16: (b: buffer, offset: number, value: number) -> (),
writei32: (b: buffer, offset: number, value: number) -> (),
writeu32: (b: buffer, offset: number, value: number) -> (),
writef32: (b: buffer, offset: number, value: number) -> (),
writef64: (b: buffer, offset: number, value: number) -> (),
readstring: (b: buffer, offset: number, count: number) -> string,
writestring: (b: buffer, offset: number, value: string, count: number?) -> (),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionBufferSrc = R"BUILTIN_SRC(
declare buffer: {
create: (size: number) -> buffer,
fromstring: (str: string) -> buffer,
tostring: (b: buffer) -> string,
len: (b: buffer) -> number,
copy: (target: buffer, targetOffset: number, source: buffer, sourceOffset: number?, count: number?) -> (),
fill: (b: buffer, offset: number, value: number, count: number?) -> (),
readi8: (b: buffer, offset: number) -> number,
readu8: (b: buffer, offset: number) -> number,
readi16: (b: buffer, offset: number) -> number,
readu16: (b: buffer, offset: number) -> number,
readi32: (b: buffer, offset: number) -> number,
readu32: (b: buffer, offset: number) -> number,
readf32: (b: buffer, offset: number) -> number,
readf64: (b: buffer, offset: number) -> number,
writei8: (b: buffer, offset: number, value: number) -> (),
writeu8: (b: buffer, offset: number, value: number) -> (),
writei16: (b: buffer, offset: number, value: number) -> (),
writeu16: (b: buffer, offset: number, value: number) -> (),
writei32: (b: buffer, offset: number, value: number) -> (),
writeu32: (b: buffer, offset: number, value: number) -> (),
writef32: (b: buffer, offset: number, value: number) -> (),
writef64: (b: buffer, offset: number, value: number) -> (),
readstring: (b: buffer, offset: number, count: number) -> string,
writestring: (b: buffer, offset: number, value: string, count: number?) -> (),
}
)BUILTIN_SRC";
static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC(
declare bit32: {
@ -21,6 +88,7 @@ declare bit32: {
replace: (n: number, v: number, field: number, width: number?) -> number,
countlz: (n: number) -> number,
countrz: (n: number) -> number,
byteswap: (n: number) -> number,
}
declare math: {
@ -92,7 +160,7 @@ type DateTypeResult = {
declare os: {
time: (time: DateTypeArg?) -> number,
date: (formatString: string?, time: number?) -> DateTypeResult | string,
date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string),
difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number,
clock: () -> number,
}
@ -148,8 +216,7 @@ declare coroutine: {
resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
running: () -> thread,
status: (co: thread) -> "dead" | "running" | "normal" | "suspended",
-- FIXME: This technically returns a function, but we can't represent this yet.
wrap: <A..., R...>(f: (A...) -> R...) -> any,
wrap: <A..., R...>(f: (A...) -> R...) -> ((A...) -> R...),
yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean,
close: (co: thread) -> (boolean, any)
@ -199,6 +266,12 @@ declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
std::string getBuiltinDefinitionSource()
{
std::string result = kBuiltinDefinitionLuaSrc;
if (FFlag::LuauBufferTypeck)
result = kBuiltinDefinitionBufferSrc + result;
else if (FFlag::LuauBufferDefinitions)
result = kBuiltinDefinitionBufferSrc_DEPRECATED + result;
return result;
}

View file

@ -4,13 +4,15 @@
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/FileResolver.h"
#include "Luau/NotNull.h"
#include "Luau/StringUtils.h"
#include "Luau/ToString.h"
#include <optional>
#include <stdexcept>
#include <type_traits>
LUAU_FASTFLAGVARIABLE(LuauTypeMismatchInvarianceInError, false)
LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10)
static std::string wrongNumberOfArgsString(
size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
@ -67,6 +69,20 @@ struct ErrorConverter
std::string result;
auto quote = [&](std::string s) {
return "'" + s + "'";
};
auto constructErrorMessage = [&](std::string givenType, std::string wantedType, std::optional<std::string> givenModule,
std::optional<std::string> wantedModule) -> std::string {
std::string given = givenModule ? quote(givenType) + " from " + quote(*givenModule) : quote(givenType);
std::string wanted = wantedModule ? quote(wantedType) + " from " + quote(*wantedModule) : quote(wantedType);
size_t luauIndentTypeMismatchMaxTypeLength = size_t(FInt::LuauIndentTypeMismatchMaxTypeLength);
if (givenType.length() <= luauIndentTypeMismatchMaxTypeLength || wantedType.length() <= luauIndentTypeMismatchMaxTypeLength)
return "Type " + given + " could not be converted into " + wanted;
return "Type\n " + given + "\ncould not be converted into\n " + wanted;
};
if (givenTypeName == wantedTypeName)
{
if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType))
@ -77,20 +93,18 @@ struct ErrorConverter
{
std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule);
std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule);
result = "Type '" + givenTypeName + "' from '" + givenModuleName + "' could not be converted into '" + wantedTypeName +
"' from '" + wantedModuleName + "'";
result = constructErrorMessage(givenTypeName, wantedTypeName, givenModuleName, wantedModuleName);
}
else
{
result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName +
"' from '" + *wantedDefinitionModule + "'";
result = constructErrorMessage(givenTypeName, wantedTypeName, *givenDefinitionModule, *wantedDefinitionModule);
}
}
}
}
if (result.empty())
result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'";
result = constructErrorMessage(givenTypeName, wantedTypeName, std::nullopt, std::nullopt);
if (tm.error)
@ -98,7 +112,7 @@ struct ErrorConverter
result += "\ncaused by:\n ";
if (!tm.reason.empty())
result += tm.reason + " ";
result += tm.reason + "\n";
result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver});
}
@ -106,7 +120,7 @@ struct ErrorConverter
{
result += "; " + tm.reason;
}
else if (FFlag::LuauTypeMismatchInvarianceInError && tm.context == TypeMismatch::InvariantContext)
else if (tm.context == TypeMismatch::InvariantContext)
{
result += " in an invariant context";
}
@ -349,7 +363,10 @@ struct ErrorConverter
else
s += " -> ";
s += name;
if (fileResolver != nullptr)
s += fileResolver->getHumanReadableModuleName(name);
else
s += name;
}
return s;
@ -473,13 +490,49 @@ struct ErrorConverter
std::string operator()(const TypePackMismatch& e) const
{
return "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'";
std::string ss = "Type pack '" + toString(e.givenTp) + "' could not be converted into '" + toString(e.wantedTp) + "'";
if (!e.reason.empty())
ss += "; " + e.reason;
return ss;
}
std::string operator()(const DynamicPropertyLookupOnClassesUnsafe& e) const
{
return "Attempting a dynamic property access on type '" + Luau::toString(e.ty) + "' is unsafe and may cause exceptions at runtime";
}
std::string operator()(const UninhabitedTypeFamily& e) const
{
return "Type family instance " + Luau::toString(e.ty) + " is uninhabited";
}
std::string operator()(const UninhabitedTypePackFamily& e) const
{
return "Type pack family instance " + Luau::toString(e.tp) + " is uninhabited";
}
std::string operator()(const WhereClauseNeeded& e) const
{
return "Type family instance " + Luau::toString(e.ty) +
" depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this "
"time";
}
std::string operator()(const PackWhereClauseNeeded& e) const
{
return "Type pack family instance " + Luau::toString(e.tp) +
" depends on generic function parameters but does not appear in the function signature; this construct cannot be type-checked at this "
"time";
}
std::string operator()(const CheckedFunctionCallError& e) const
{
// TODO: What happens if checkedFunctionName cannot be found??
return "Function '" + e.checkedFunctionName + "' expects '" + toString(e.expected) + "' at argument #" + std::to_string(e.argumentIndex) +
", but got '" + Luau::toString(e.passed) + "'";
}
};
struct InvalidNameChecker
@ -782,6 +835,32 @@ bool DynamicPropertyLookupOnClassesUnsafe::operator==(const DynamicPropertyLooku
return ty == rhs.ty;
}
bool UninhabitedTypeFamily::operator==(const UninhabitedTypeFamily& rhs) const
{
return ty == rhs.ty;
}
bool UninhabitedTypePackFamily::operator==(const UninhabitedTypePackFamily& rhs) const
{
return tp == rhs.tp;
}
bool WhereClauseNeeded::operator==(const WhereClauseNeeded& rhs) const
{
return ty == rhs.ty;
}
bool PackWhereClauseNeeded::operator==(const PackWhereClauseNeeded& rhs) const
{
return tp == rhs.tp;
}
bool CheckedFunctionCallError::operator==(const CheckedFunctionCallError& rhs) const
{
return *expected == *rhs.expected && *passed == *rhs.passed && checkedFunctionName == rhs.checkedFunctionName &&
argumentIndex == rhs.argumentIndex;
}
std::string toString(const TypeError& error)
{
return toString(error, TypeErrorToStringOptions{});
@ -799,7 +878,7 @@ bool containsParseErrorName(const TypeError& error)
}
template<typename T>
void copyError(T& e, TypeArena& destArena, CloneState cloneState)
void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
{
auto clone = [&](auto&& ty) {
return ::Luau::clone(ty, destArena, cloneState);
@ -940,13 +1019,26 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState)
}
else if constexpr (std::is_same_v<T, DynamicPropertyLookupOnClassesUnsafe>)
e.ty = clone(e.ty);
else if constexpr (std::is_same_v<T, UninhabitedTypeFamily>)
e.ty = clone(e.ty);
else if constexpr (std::is_same_v<T, UninhabitedTypePackFamily>)
e.tp = clone(e.tp);
else if constexpr (std::is_same_v<T, WhereClauseNeeded>)
e.ty = clone(e.ty);
else if constexpr (std::is_same_v<T, PackWhereClauseNeeded>)
e.tp = clone(e.tp);
else if constexpr (std::is_same_v<T, CheckedFunctionCallError>)
{
e.expected = clone(e.expected);
e.passed = clone(e.passed);
}
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}
void copyErrors(ErrorVec& errors, TypeArena& destArena)
void copyErrors(ErrorVec& errors, TypeArena& destArena, NotNull<BuiltinTypes> builtinTypes)
{
CloneState cloneState;
CloneState cloneState{builtinTypes};
auto visitErrorData = [&](auto&& e) {
copyError(e, destArena, cloneState);

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,37 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/GlobalTypes.h"
LUAU_FASTFLAG(LuauInitializeStringMetatableInGlobalTypes)
LUAU_FASTFLAG(LuauBufferTypeck)
namespace Luau
{
GlobalTypes::GlobalTypes(NotNull<BuiltinTypes> builtinTypes)
: builtinTypes(builtinTypes)
{
globalScope = std::make_shared<Scope>(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}));
globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType});
globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType});
globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType});
globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType});
globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType});
globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType});
if (FFlag::LuauBufferTypeck)
globalScope->addBuiltinTypeBinding("buffer", TypeFun{{}, builtinTypes->bufferType});
globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType});
globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType});
if (FFlag::LuauInitializeStringMetatableInGlobalTypes)
{
unfreeze(*builtinTypes->arena);
TypeId stringMetatableTy = makeStringMetatable(builtinTypes);
asMutable(builtinTypes->stringType)->ty.emplace<PrimitiveType>(PrimitiveType::String, stringMetatableTy);
persist(stringMetatableTy);
freeze(*builtinTypes->arena);
}
}
} // namespace Luau

View file

@ -1,10 +1,15 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Common.h"
#include "Luau/Instantiation.h"
#include "Luau/Common.h"
#include "Luau/ToString.h"
#include "Luau/TxnLog.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeCheckLimits.h"
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
#include <algorithm>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
namespace Luau
{
@ -13,7 +18,7 @@ bool Instantiation::isDirty(TypeId ty)
{
if (const FunctionType* ftv = log->getMutable<FunctionType>(ty))
{
if (ftv->hasNoGenerics)
if (ftv->hasNoFreeOrGenericTypes)
return false;
return true;
@ -33,7 +38,7 @@ bool Instantiation::ignoreChildren(TypeId ty)
{
if (log->getMutable<FunctionType>(ty))
return true;
else if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassType>(ty))
else if (get<ClassType>(ty))
return true;
else
return false;
@ -54,7 +59,7 @@ TypeId Instantiation::clean(TypeId ty)
// Annoyingly, we have to do this even if there are no generics,
// to replace any generic tables.
ReplaceGenerics replaceGenerics{log, arena, level, scope, ftv->generics, ftv->genericPacks};
ReplaceGenerics replaceGenerics{log, arena, builtinTypes, level, scope, ftv->generics, ftv->genericPacks};
// TODO: What to do if this returns nullopt?
// We don't have access to the error-reporting machinery
@ -74,7 +79,7 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty)
{
if (const FunctionType* ftv = log->getMutable<FunctionType>(ty))
{
if (ftv->hasNoGenerics)
if (ftv->hasNoFreeOrGenericTypes)
return true;
// We aren't recursing in the case of a generic function which
@ -84,7 +89,7 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty)
// whenever we quantify, so the vectors overlap if and only if they are equal.
return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks);
}
else if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassType>(ty))
else if (get<ClassType>(ty))
return true;
else
{
@ -120,14 +125,94 @@ TypeId ReplaceGenerics::clean(TypeId ty)
clone.definitionLocation = ttv->definitionLocation;
return addType(std::move(clone));
}
else if (FFlag::DebugLuauDeferredConstraintResolution)
{
TypeId res = freshType(NotNull{arena}, builtinTypes, scope);
getMutable<FreeType>(res)->level = level;
return res;
}
else
{
return addType(FreeType{scope, level});
}
}
TypePackId ReplaceGenerics::clean(TypePackId tp)
{
LUAU_ASSERT(isDirty(tp));
return addTypePack(TypePackVar(FreeTypePack{level}));
return addTypePack(TypePackVar(FreeTypePack{scope, level}));
}
struct Replacer : Substitution
{
DenseHashMap<TypeId, TypeId> replacements;
DenseHashMap<TypePackId, TypePackId> replacementPacks;
Replacer(NotNull<TypeArena> arena, DenseHashMap<TypeId, TypeId> replacements, DenseHashMap<TypePackId, TypePackId> replacementPacks)
: Substitution(TxnLog::empty(), arena)
, replacements(std::move(replacements))
, replacementPacks(std::move(replacementPacks))
{
}
bool isDirty(TypeId ty) override
{
return replacements.find(ty) != nullptr;
}
bool isDirty(TypePackId tp) override
{
return replacementPacks.find(tp) != nullptr;
}
TypeId clean(TypeId ty) override
{
return replacements[ty];
}
TypePackId clean(TypePackId tp) override
{
return replacementPacks[tp];
}
};
std::optional<TypeId> instantiate(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, TypeId ty)
{
ty = follow(ty);
const FunctionType* ft = get<FunctionType>(ty);
if (!ft)
return ty;
if (ft->generics.empty() && ft->genericPacks.empty())
return ty;
DenseHashMap<TypeId, TypeId> replacements{nullptr};
DenseHashMap<TypePackId, TypePackId> replacementPacks{nullptr};
for (TypeId g : ft->generics)
replacements[g] = freshType(arena, builtinTypes, scope);
for (TypePackId g : ft->genericPacks)
replacementPacks[g] = arena->freshTypePack(scope);
Replacer r{arena, std::move(replacements), std::move(replacementPacks)};
if (limits->instantiationChildLimit)
r.childLimit = *limits->instantiationChildLimit;
std::optional<TypeId> res = r.substitute(ty);
if (!res)
return res;
FunctionType* ft2 = getMutable<FunctionType>(*res);
LUAU_ASSERT(ft != ft2);
ft2->generics.clear();
ft2->genericPacks.clear();
return res;
}
} // namespace Luau

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/IostreamHelpers.h"
#include "Luau/ToString.h"
#include "Luau/TypePath.h"
namespace Luau
{
@ -192,6 +193,17 @@ static void errorToString(std::ostream& stream, const T& err)
stream << "TypePackMismatch { wanted = '" + toString(err.wantedTp) + "', given = '" + toString(err.givenTp) + "' }";
else if constexpr (std::is_same_v<T, DynamicPropertyLookupOnClassesUnsafe>)
stream << "DynamicPropertyLookupOnClassesUnsafe { " << toString(err.ty) << " }";
else if constexpr (std::is_same_v<T, UninhabitedTypeFamily>)
stream << "UninhabitedTypeFamily { " << toString(err.ty) << " }";
else if constexpr (std::is_same_v<T, UninhabitedTypePackFamily>)
stream << "UninhabitedTypePackFamily { " << toString(err.tp) << " }";
else if constexpr (std::is_same_v<T, WhereClauseNeeded>)
stream << "WhereClauseNeeded { " << toString(err.ty) << " }";
else if constexpr (std::is_same_v<T, PackWhereClauseNeeded>)
stream << "PackWhereClauseNeeded { " << toString(err.tp) << " }";
else if constexpr (std::is_same_v<T, CheckedFunctionCallError>)
stream << "CheckedFunctionCallError { expected = '" << toString(err.expected) << "', passed = '" << toString(err.passed)
<< "', checkedFunctionName = " << err.checkedFunctionName << ", argumentIndex = " << std::to_string(err.argumentIndex) << " }";
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}
@ -225,4 +237,34 @@ std::ostream& operator<<(std::ostream& stream, const TypePackVar& tv)
return stream << toString(tv);
}
std::ostream& operator<<(std::ostream& stream, TypeId ty)
{
// we commonly use a null pointer when a type may not be present; we need to
// account for that here.
if (!ty)
return stream << "<nullptr>";
return stream << toString(ty);
}
std::ostream& operator<<(std::ostream& stream, TypePackId tp)
{
// we commonly use a null pointer when a type may not be present; we need to
// account for that here.
if (!tp)
return stream << "<nullptr>";
return stream << toString(tp);
}
namespace TypePath
{
std::ostream& operator<<(std::ostream& stream, const Path& path)
{
return stream << toString(path);
}
} // namespace TypePath
} // namespace Luau

View file

@ -14,48 +14,11 @@
LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4)
LUAU_FASTFLAGVARIABLE(LuauImproveDeprecatedApiLint, false)
LUAU_FASTFLAG(LuauBufferTypeck)
namespace Luau
{
// clang-format off
static const char* kWarningNames[] = {
"Unknown",
"UnknownGlobal",
"DeprecatedGlobal",
"GlobalUsedAsLocal",
"LocalShadow",
"SameLineStatement",
"MultiLineStatement",
"LocalUnused",
"FunctionUnused",
"ImportUnused",
"BuiltinGlobalWrite",
"PlaceholderRead",
"UnreachableCode",
"UnknownType",
"ForRange",
"UnbalancedAssignment",
"ImplicitReturn",
"DuplicateLocal",
"FormatString",
"TableLiteral",
"UninitializedLocal",
"DuplicateFunction",
"DeprecatedApi",
"TableOperations",
"DuplicateCondition",
"MisleadingAndOr",
"CommentDirective",
"IntegerParsing",
"ComparisonPrecedence",
};
// clang-format on
static_assert(std::size(kWarningNames) == unsigned(LintWarning::Code__Count), "did you forget to add warning to the list?");
struct LintContext
{
struct Global
@ -1144,7 +1107,7 @@ private:
TypeKind getTypeKind(const std::string& name)
{
if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" ||
name == "function" || name == "thread")
name == "function" || name == "thread" || (FFlag::LuauBufferTypeck && name == "buffer"))
return Kind_Primitive;
if (name == "vector")
@ -2102,9 +2065,6 @@ class LintDeprecatedApi : AstVisitor
public:
LUAU_NOINLINE static void process(LintContext& context)
{
if (!FFlag::LuauImproveDeprecatedApiLint && !context.module)
return;
LintDeprecatedApi pass{&context};
context.root->visit(&pass);
}
@ -2122,8 +2082,33 @@ private:
if (std::optional<TypeId> ty = context->getType(node->expr))
check(node, follow(*ty));
else if (AstExprGlobal* global = node->expr->as<AstExprGlobal>())
if (FFlag::LuauImproveDeprecatedApiLint)
check(node->location, global->name, node->index);
check(node->location, global->name, node->index);
return true;
}
bool visit(AstExprCall* node) override
{
// getfenv/setfenv are deprecated, however they are still used in some test frameworks and don't have a great general replacement
// for now we warn about the deprecation only when they are used with a numeric first argument; this produces fewer warnings and makes use
// of getfenv/setfenv a little more localized
if (!node->self && node->args.size >= 1)
{
if (AstExprGlobal* fenv = node->func->as<AstExprGlobal>(); fenv && (fenv->name == "getfenv" || fenv->name == "setfenv"))
{
AstExpr* level = node->args.data[0];
std::optional<TypeId> ty = context->getType(level);
if ((ty && isNumber(*ty)) || level->is<AstExprConstantNumber>())
{
// some common uses of getfenv(n) can be replaced by debug.info if the goal is to get the caller's identity
const char* suggestion = (fenv->name == "getfenv") ? "; consider using 'debug.info' instead" : "";
emitWarning(
*context, LintWarning::Code_DeprecatedApi, node->location, "Function '%s' is deprecated%s", fenv->name.value, suggestion);
}
}
}
return true;
}
@ -2144,7 +2129,7 @@ private:
if (prop != tty->props.end() && prop->second.deprecated)
{
// strip synthetic typeof() for builtin tables
if (FFlag::LuauImproveDeprecatedApiLint && tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')')
if (tty->name && tty->name->compare(0, 7, "typeof(") == 0 && tty->name->back() == ')')
report(node->location, prop->second, tty->name->substr(7, tty->name->length() - 8).c_str(), node->index.value);
else
report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value);
@ -2197,16 +2182,49 @@ private:
{
}
bool visit(AstExprUnary* node) override
{
if (node->op == AstExprUnary::Len)
checkIndexer(node, node->expr, "#");
return true;
}
bool visit(AstExprCall* node) override
{
AstExprIndexName* func = node->func->as<AstExprIndexName>();
if (!func)
return true;
if (AstExprGlobal* func = node->func->as<AstExprGlobal>())
{
if (func->name == "ipairs" && node->args.size == 1)
checkIndexer(node, node->args.data[0], "ipairs");
}
else if (AstExprIndexName* func = node->func->as<AstExprIndexName>())
{
if (AstExprGlobal* tablib = func->expr->as<AstExprGlobal>(); tablib && tablib->name == "table")
checkTableCall(node, func);
}
AstExprGlobal* tablib = func->expr->as<AstExprGlobal>();
if (!tablib || tablib->name != "table")
return true;
return true;
}
void checkIndexer(AstExpr* node, AstExpr* expr, const char* op)
{
std::optional<Luau::TypeId> ty = context->getType(expr);
if (!ty)
return;
const TableType* tty = get<TableType>(follow(*ty));
if (!tty)
return;
if (!tty->indexer && !tty->props.empty() && tty->state != TableState::Generic)
emitWarning(
*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table without an array part is likely a bug", op);
else if (tty->indexer && isString(tty->indexer->indexType)) // note: to avoid complexity of subtype tests we just check if the key is a string
emitWarning(*context, LintWarning::Code_TableOperations, node->location, "Using '%s' on a table with string keys is likely a bug", op);
}
void checkTableCall(AstExprCall* node, AstExprIndexName* func)
{
AstExpr** args = node->args.data;
if (func->index == "insert" && node->args.size == 2)
@ -2288,8 +2306,6 @@ private:
emitWarning(*context, LintWarning::Code_TableOperations, as->expr->location,
"table.create with a table literal will reuse the same object for all elements; consider using a for loop instead");
}
return true;
}
bool isConstant(AstExpr* expr, double value)
@ -2635,13 +2651,17 @@ private:
case ConstantNumberParseResult::Ok:
case ConstantNumberParseResult::Malformed:
break;
case ConstantNumberParseResult::Imprecise:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Number literal exceeded available precision and was truncated to closest representable number");
break;
case ConstantNumberParseResult::BinOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Binary number literal exceeded available precision and has been truncated to 2^64");
"Binary number literal exceeded available precision and was truncated to 2^64");
break;
case ConstantNumberParseResult::HexOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Hexadecimal number literal exceeded available precision and has been truncated to 2^64");
"Hexadecimal number literal exceeded available precision and was truncated to 2^64");
break;
}
@ -2831,6 +2851,12 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
"optimize directive uses unknown optimization level '%s', 0..2 expected", level);
}
}
else if (first == "native")
{
if (space != std::string::npos)
emitWarning(
context, LintWarning::Code_CommentDirective, hc.location, "native directive has extra symbols at the end of the line");
}
else
{
static const char* kHotComments[] = {
@ -2839,6 +2865,7 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
"nonstrict",
"strict",
"optimize",
"native",
};
if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments)))
@ -2852,12 +2879,6 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
}
}
void LintOptions::setDefaults()
{
// By default, we enable all warnings
warningMask = ~0ull;
}
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module,
const std::vector<HotComment>& hotcomments, const LintOptions& options)
{
@ -2949,54 +2970,6 @@ std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const Sc
return context.result;
}
const char* LintWarning::getName(Code code)
{
LUAU_ASSERT(unsigned(code) < Code__Count);
return kWarningNames[code];
}
LintWarning::Code LintWarning::parseName(const char* name)
{
for (int code = Code_Unknown; code < Code__Count; ++code)
if (strcmp(name, getName(Code(code))) == 0)
return Code(code);
return Code_Unknown;
}
uint64_t LintWarning::parseMask(const std::vector<HotComment>& hotcomments)
{
uint64_t result = 0;
for (const HotComment& hc : hotcomments)
{
if (!hc.header)
continue;
if (hc.content.compare(0, 6, "nolint") != 0)
continue;
size_t name = hc.content.find_first_not_of(" \t", 6);
// --!nolint disables everything
if (name == std::string::npos)
return ~0ull;
// --!nolint needs to be followed by a whitespace character
if (name == 6)
continue;
// --!nolint name disables the specific lint
LintWarning::Code code = LintWarning::parseName(hc.content.c_str() + name);
if (code != LintWarning::Code_Unknown)
result |= 1ull << int(code);
}
return result;
}
std::vector<AstName> getDeprecatedGlobals(const AstNameTable& names)
{
LintContext context;

View file

@ -3,23 +3,18 @@
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintGenerator.h"
#include "Luau/Normalize.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/Type.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeReduction.h"
#include "Luau/VisitType.h"
#include <algorithm>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false);
LUAU_FASTFLAG(LuauSubstitutionReentrant);
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution);
LUAU_FASTFLAG(LuauSubstitutionFixMissingFields);
namespace Luau
{
@ -37,14 +32,14 @@ static bool contains(Position pos, Comment comment)
return false;
}
bool isWithinComment(const SourceModule& sourceModule, Position pos)
static bool isWithinComment(const std::vector<Comment>& commentLocations, Position pos)
{
auto iter = std::lower_bound(sourceModule.commentLocations.begin(), sourceModule.commentLocations.end(),
Comment{Lexeme::Comment, Location{pos, pos}}, [](const Comment& a, const Comment& b) {
auto iter = std::lower_bound(
commentLocations.begin(), commentLocations.end(), Comment{Lexeme::Comment, Location{pos, pos}}, [](const Comment& a, const Comment& b) {
return a.location.end < b.location.end;
});
if (iter == sourceModule.commentLocations.end())
if (iter == commentLocations.end())
return false;
if (contains(pos, *iter))
@ -53,12 +48,22 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos)
// Due to the nature of std::lower_bound, it is possible that iter points at a comment that ends
// at pos. We'll try the next comment, if it exists.
++iter;
if (iter == sourceModule.commentLocations.end())
if (iter == commentLocations.end())
return false;
return contains(pos, *iter);
}
bool isWithinComment(const SourceModule& sourceModule, Position pos)
{
return isWithinComment(sourceModule.commentLocations, pos);
}
bool isWithinComment(const ParseResult& result, Position pos)
{
return isWithinComment(result.commentLocations, pos);
}
struct ClonePublicInterface : Substitution
{
NotNull<BuiltinTypes> builtinTypes;
@ -89,6 +94,22 @@ struct ClonePublicInterface : Substitution
return tp->owningArena == &module->internalTypes;
}
bool ignoreChildrenVisit(TypeId ty) override
{
if (ty->owningArena != &module->internalTypes)
return true;
return false;
}
bool ignoreChildrenVisit(TypePackId tp) override
{
if (tp->owningArena != &module->internalTypes)
return true;
return false;
}
TypeId clean(TypeId ty) override
{
TypeId result = clone(ty);
@ -108,8 +129,6 @@ struct ClonePublicInterface : Substitution
TypeId cloneType(TypeId ty)
{
LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields);
std::optional<TypeId> result = substitute(ty);
if (result)
{
@ -124,8 +143,6 @@ struct ClonePublicInterface : Substitution
TypePackId cloneTypePack(TypePackId tp)
{
LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields);
std::optional<TypePackId> result = substitute(tp);
if (result)
{
@ -140,8 +157,6 @@ struct ClonePublicInterface : Substitution
TypeFun cloneTypeFun(const TypeFun& tf)
{
LUAU_ASSERT(FFlag::LuauSubstitutionReentrant && FFlag::LuauSubstitutionFixMissingFields);
std::vector<GenericTypeDefinition> typeParams;
std::vector<GenericTypePackDefinition> typePackParams;
@ -184,7 +199,7 @@ void Module::clonePublicInterface(NotNull<BuiltinTypes> builtinTypes, InternalEr
LUAU_ASSERT(interfaceTypes.types.empty());
LUAU_ASSERT(interfaceTypes.typePacks.empty());
CloneState cloneState;
CloneState cloneState{builtinTypes};
ScopePtr moduleScope = getModuleScope();
@ -194,43 +209,28 @@ void Module::clonePublicInterface(NotNull<BuiltinTypes> builtinTypes, InternalEr
TxnLog log;
ClonePublicInterface clonePublicInterface{&log, builtinTypes, this};
if (FFlag::LuauClonePublicInterfaceLess)
returnType = clonePublicInterface.cloneTypePack(returnType);
else
returnType = clone(returnType, interfaceTypes, cloneState);
returnType = clonePublicInterface.cloneTypePack(returnType);
moduleScope->returnType = returnType;
if (varargPack)
{
if (FFlag::LuauClonePublicInterfaceLess)
varargPack = clonePublicInterface.cloneTypePack(*varargPack);
else
varargPack = clone(*varargPack, interfaceTypes, cloneState);
varargPack = clonePublicInterface.cloneTypePack(*varargPack);
moduleScope->varargPack = varargPack;
}
for (auto& [name, tf] : moduleScope->exportedTypeBindings)
{
if (FFlag::LuauClonePublicInterfaceLess)
tf = clonePublicInterface.cloneTypeFun(tf);
else
tf = clone(tf, interfaceTypes, cloneState);
tf = clonePublicInterface.cloneTypeFun(tf);
}
for (auto& [name, ty] : declaredGlobals)
{
if (FFlag::LuauClonePublicInterfaceLess)
ty = clonePublicInterface.cloneType(ty);
else
ty = clone(ty, interfaceTypes, cloneState);
ty = clonePublicInterface.cloneType(ty);
}
// Copy external stuff over to Module itself
this->returnType = moduleScope->returnType;
if (FFlag::DebugLuauDeferredConstraintResolution)
this->exportedTypeBindings = moduleScope->exportedTypeBindings;
else
this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings);
this->exportedTypeBindings = moduleScope->exportedTypeBindings;
}
bool Module::hasModuleScope() const

View file

@ -0,0 +1,619 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/NonStrictTypeChecker.h"
#include "Luau/Ast.h"
#include "Luau/Common.h"
#include "Luau/Simplify.h"
#include "Luau/Type.h"
#include "Luau/Simplify.h"
#include "Luau/Subtyping.h"
#include "Luau/Normalize.h"
#include "Luau/Error.h"
#include "Luau/TypeArena.h"
#include "Luau/TypeFamily.h"
#include "Luau/Def.h"
#include <iostream>
namespace Luau
{
/* Push a scope onto the end of a stack for the lifetime of the StackPusher instance.
* NonStrictTypeChecker uses this to maintain knowledge about which scope encloses every
* given AstNode.
*/
struct StackPusher
{
std::vector<NotNull<Scope>>* stack;
NotNull<Scope> scope;
explicit StackPusher(std::vector<NotNull<Scope>>& stack, Scope* scope)
: stack(&stack)
, scope(scope)
{
stack.push_back(NotNull{scope});
}
~StackPusher()
{
if (stack)
{
LUAU_ASSERT(stack->back() == scope);
stack->pop_back();
}
}
StackPusher(const StackPusher&) = delete;
StackPusher&& operator=(const StackPusher&) = delete;
StackPusher(StackPusher&& other)
: stack(std::exchange(other.stack, nullptr))
, scope(other.scope)
{
}
};
struct NonStrictContext
{
std::unordered_map<const Def*, TypeId> context;
NonStrictContext() = default;
NonStrictContext(const NonStrictContext&) = delete;
NonStrictContext& operator=(const NonStrictContext&) = delete;
NonStrictContext(NonStrictContext&&) = default;
NonStrictContext& operator=(NonStrictContext&&) = default;
static NonStrictContext disjunction(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right)
{
// disjunction implements union over the domain of keys
// if the default value for a defId not in the map is `never`
// then never | T is T
NonStrictContext disj{};
for (auto [def, leftTy] : left.context)
{
if (std::optional<TypeId> rightTy = right.find(def))
disj.context[def] = simplifyUnion(builtinTypes, arena, leftTy, *rightTy).result;
else
disj.context[def] = leftTy;
}
for (auto [def, rightTy] : right.context)
{
if (!left.find(def).has_value())
disj.context[def] = rightTy;
}
return disj;
}
static NonStrictContext conjunction(
NotNull<BuiltinTypes> builtins, NotNull<TypeArena> arena, const NonStrictContext& left, const NonStrictContext& right)
{
NonStrictContext conj{};
for (auto [def, leftTy] : left.context)
{
if (std::optional<TypeId> rightTy = right.find(def))
conj.context[def] = simplifyIntersection(builtins, arena, leftTy, *rightTy).result;
}
return conj;
}
void removeFromContext(const std::vector<DefId>& defs)
{
// TODO: unimplemented
}
std::optional<TypeId> find(const DefId& def) const
{
const Def* d = def.get();
return find(d);
}
private:
std::optional<TypeId> find(const Def* d) const
{
auto it = context.find(d);
if (it != context.end())
return {it->second};
return {};
}
};
struct NonStrictTypeChecker
{
NotNull<BuiltinTypes> builtinTypes;
const NotNull<InternalErrorReporter> ice;
TypeArena arena;
Module* module;
Normalizer normalizer;
Subtyping subtyping;
NotNull<const DataFlowGraph> dfg;
DenseHashSet<TypeId> noTypeFamilyErrors{nullptr};
std::vector<NotNull<Scope>> stack;
const NotNull<TypeCheckLimits> limits;
NonStrictTypeChecker(NotNull<BuiltinTypes> builtinTypes, const NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, Module* module)
: builtinTypes(builtinTypes)
, ice(ice)
, module(module)
, normalizer{&arena, builtinTypes, unifierState, /* cache inhabitance */ true}
, subtyping{builtinTypes, NotNull{&arena}, NotNull(&normalizer), ice, NotNull{module->getModuleScope().get()}}
, dfg(dfg)
, limits(limits)
{
}
std::optional<StackPusher> pushStack(AstNode* node)
{
if (Scope** scope = module->astScopes.find(node))
return StackPusher{stack, *scope};
else
return std::nullopt;
}
TypeId flattenPack(TypePackId pack)
{
pack = follow(pack);
if (auto fst = first(pack, /*ignoreHiddenVariadics*/ false))
return *fst;
else if (auto ftp = get<FreeTypePack>(pack))
{
TypeId result = arena.addType(FreeType{ftp->scope});
TypePackId freeTail = arena.addTypePack(FreeTypePack{ftp->scope});
TypePack& resultPack = asMutable(pack)->ty.emplace<TypePack>();
resultPack.head.assign(1, result);
resultPack.tail = freeTail;
return result;
}
else if (get<Unifiable::Error>(pack))
return builtinTypes->errorRecoveryType();
else if (finite(pack) && size(pack) == 0)
return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil`
else
ice->ice("flattenPack got a weird pack!");
}
TypeId checkForFamilyInhabitance(TypeId instance, Location location)
{
if (noTypeFamilyErrors.find(instance))
return instance;
ErrorVec errors = reduceFamilies(
instance, location, TypeFamilyContext{NotNull{&arena}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true)
.errors;
if (errors.empty())
noTypeFamilyErrors.insert(instance);
// TODO??
// if (!isErrorSuppressing(location, instance))
// reportErrors(std::move(errors));
return instance;
}
TypeId lookupType(AstExpr* expr)
{
TypeId* ty = module->astTypes.find(expr);
if (ty)
return checkForFamilyInhabitance(follow(*ty), expr->location);
TypePackId* tp = module->astTypePacks.find(expr);
if (tp)
return checkForFamilyInhabitance(flattenPack(*tp), expr->location);
return builtinTypes->anyType;
}
NonStrictContext visit(AstStat* stat)
{
auto pusher = pushStack(stat);
if (auto s = stat->as<AstStatBlock>())
return visit(s);
else if (auto s = stat->as<AstStatIf>())
return visit(s);
else if (auto s = stat->as<AstStatWhile>())
return visit(s);
else if (auto s = stat->as<AstStatRepeat>())
return visit(s);
else if (auto s = stat->as<AstStatBreak>())
return visit(s);
else if (auto s = stat->as<AstStatContinue>())
return visit(s);
else if (auto s = stat->as<AstStatReturn>())
return visit(s);
else if (auto s = stat->as<AstStatExpr>())
return visit(s);
else if (auto s = stat->as<AstStatLocal>())
return visit(s);
else if (auto s = stat->as<AstStatFor>())
return visit(s);
else if (auto s = stat->as<AstStatForIn>())
return visit(s);
else if (auto s = stat->as<AstStatAssign>())
return visit(s);
else if (auto s = stat->as<AstStatCompoundAssign>())
return visit(s);
else if (auto s = stat->as<AstStatFunction>())
return visit(s);
else if (auto s = stat->as<AstStatLocalFunction>())
return visit(s);
else if (auto s = stat->as<AstStatTypeAlias>())
return visit(s);
else if (auto s = stat->as<AstStatDeclareFunction>())
return visit(s);
else if (auto s = stat->as<AstStatDeclareGlobal>())
return visit(s);
else if (auto s = stat->as<AstStatDeclareClass>())
return visit(s);
else if (auto s = stat->as<AstStatError>())
return visit(s);
else
{
LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown statement type");
ice->ice("NonStrictTypeChecker encountered an unknown statement type");
}
}
NonStrictContext visit(AstStatBlock* block)
{
auto StackPusher = pushStack(block);
NonStrictContext ctx;
for (AstStat* statement : block->body)
ctx = NonStrictContext::disjunction(builtinTypes, NotNull{&arena}, ctx, visit(statement));
return ctx;
}
NonStrictContext visit(AstStatIf* ifStatement)
{
NonStrictContext condB = visit(ifStatement->condition);
NonStrictContext branchContext;
// If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error
if (ifStatement->elsebody)
{
NonStrictContext thenBody = visit(ifStatement->thenbody);
NonStrictContext elseBody = visit(ifStatement->elsebody);
branchContext = NonStrictContext::conjunction(builtinTypes, NotNull{&arena}, thenBody, elseBody);
}
return NonStrictContext::disjunction(builtinTypes, NotNull{&arena}, condB, branchContext);
}
NonStrictContext visit(AstStatWhile* whileStatement)
{
return {};
}
NonStrictContext visit(AstStatRepeat* repeatStatement)
{
return {};
}
NonStrictContext visit(AstStatBreak* breakStatement)
{
return {};
}
NonStrictContext visit(AstStatContinue* continueStatement)
{
return {};
}
NonStrictContext visit(AstStatReturn* returnStatement)
{
return {};
}
NonStrictContext visit(AstStatExpr* expr)
{
return visit(expr->expr);
}
NonStrictContext visit(AstStatLocal* local)
{
for (AstExpr* rhs : local->values)
visit(rhs);
return {};
}
NonStrictContext visit(AstStatFor* forStatement)
{
return {};
}
NonStrictContext visit(AstStatForIn* forInStatement)
{
return {};
}
NonStrictContext visit(AstStatAssign* assign)
{
return {};
}
NonStrictContext visit(AstStatCompoundAssign* compoundAssign)
{
return {};
}
NonStrictContext visit(AstStatFunction* statFn)
{
return visit(statFn->func);
}
NonStrictContext visit(AstStatLocalFunction* localFn)
{
return visit(localFn->func);
}
NonStrictContext visit(AstStatTypeAlias* typeAlias)
{
return {};
}
NonStrictContext visit(AstStatDeclareFunction* declFn)
{
return {};
}
NonStrictContext visit(AstStatDeclareGlobal* declGlobal)
{
return {};
}
NonStrictContext visit(AstStatDeclareClass* declClass)
{
return {};
}
NonStrictContext visit(AstStatError* error)
{
return {};
}
NonStrictContext visit(AstExpr* expr)
{
auto pusher = pushStack(expr);
if (auto e = expr->as<AstExprGroup>())
return visit(e);
else if (auto e = expr->as<AstExprConstantNil>())
return visit(e);
else if (auto e = expr->as<AstExprConstantBool>())
return visit(e);
else if (auto e = expr->as<AstExprConstantNumber>())
return visit(e);
else if (auto e = expr->as<AstExprConstantString>())
return visit(e);
else if (auto e = expr->as<AstExprLocal>())
return visit(e);
else if (auto e = expr->as<AstExprGlobal>())
return visit(e);
else if (auto e = expr->as<AstExprVarargs>())
return visit(e);
else if (auto e = expr->as<AstExprCall>())
return visit(e);
else if (auto e = expr->as<AstExprIndexName>())
return visit(e);
else if (auto e = expr->as<AstExprIndexExpr>())
return visit(e);
else if (auto e = expr->as<AstExprFunction>())
return visit(e);
else if (auto e = expr->as<AstExprTable>())
return visit(e);
else if (auto e = expr->as<AstExprUnary>())
return visit(e);
else if (auto e = expr->as<AstExprBinary>())
return visit(e);
else if (auto e = expr->as<AstExprTypeAssertion>())
return visit(e);
else if (auto e = expr->as<AstExprIfElse>())
return visit(e);
else if (auto e = expr->as<AstExprInterpString>())
return visit(e);
else if (auto e = expr->as<AstExprError>())
return visit(e);
else
{
LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown expression type");
ice->ice("NonStrictTypeChecker encountered an unknown expression type");
}
}
NonStrictContext visit(AstExprGroup* group)
{
return {};
}
NonStrictContext visit(AstExprConstantNil* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantBool* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantNumber* expr)
{
return {};
}
NonStrictContext visit(AstExprConstantString* expr)
{
return {};
}
NonStrictContext visit(AstExprLocal* local)
{
return {};
}
NonStrictContext visit(AstExprGlobal* global)
{
return {};
}
NonStrictContext visit(AstExprVarargs* global)
{
return {};
}
NonStrictContext visit(AstExprCall* call)
{
NonStrictContext fresh{};
TypeId* originalCallTy = module->astOriginalCallTypes.find(call);
if (!originalCallTy)
return fresh;
TypeId fnTy = *originalCallTy;
if (auto fn = get<FunctionType>(follow(fnTy)))
{
if (fn->isCheckedFunction)
{
// We know fn is a checked function, which means it looks like:
// (S1, ... SN) -> T &
// (~S1, unknown^N-1) -> error &
// (unknown, ~S2, unknown^N-2) -> error
// ...
// ...
// (unknown^N-1, ~S_N) -> error
std::vector<TypeId> argTypes;
for (TypeId ty : fn->argTypes)
argTypes.push_back(ty);
// For a checked function, these gotta be the same size
LUAU_ASSERT(call->args.size == argTypes.size());
for (size_t i = 0; i < call->args.size; i++)
{
// For example, if the arg is "hi"
// The actual arg type is string
// The expected arg type is number
// The type of the argument in the overload is ~number
// We will compare arg and ~number
AstExpr* arg = call->args.data[i];
TypeId expectedArgType = argTypes[i];
DefId def = dfg->getDef(arg);
// TODO: Cache negations created here!!!
// See Jira Ticket: https://roblox.atlassian.net/browse/CLI-87539
TypeId runTimeErrorTy = arena.addType(NegationType{expectedArgType});
fresh.context[def.get()] = runTimeErrorTy;
}
// Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types
AstName name = getIdentifier(call->func);
for (size_t i = 0; i < call->args.size; i++)
{
AstExpr* arg = call->args.data[i];
if (auto runTimeFailureType = willRunTimeError(arg, fresh))
reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, name.value, i}, arg->location);
}
}
}
return fresh;
}
NonStrictContext visit(AstExprIndexName* indexName)
{
return {};
}
NonStrictContext visit(AstExprIndexExpr* indexExpr)
{
return {};
}
NonStrictContext visit(AstExprFunction* exprFn)
{
auto pusher = pushStack(exprFn);
return visit(exprFn->body);
}
NonStrictContext visit(AstExprTable* table)
{
return {};
}
NonStrictContext visit(AstExprUnary* unary)
{
return {};
}
NonStrictContext visit(AstExprBinary* binary)
{
return {};
}
NonStrictContext visit(AstExprTypeAssertion* typeAssertion)
{
return {};
}
NonStrictContext visit(AstExprIfElse* ifElse)
{
NonStrictContext condB = visit(ifElse->condition);
NonStrictContext thenB = visit(ifElse->trueExpr);
NonStrictContext elseB = visit(ifElse->falseExpr);
return NonStrictContext::disjunction(
builtinTypes, NotNull{&arena}, condB, NonStrictContext::conjunction(builtinTypes, NotNull{&arena}, thenB, elseB));
}
NonStrictContext visit(AstExprInterpString* interpString)
{
return {};
}
NonStrictContext visit(AstExprError* error)
{
return {};
}
void reportError(TypeErrorData data, const Location& location)
{
module->errors.emplace_back(location, module->name, std::move(data));
// TODO: weave in logger here?
}
// If this fragment of the ast will run time error, return the type that causes this
std::optional<TypeId> willRunTimeError(AstExpr* fragment, const NonStrictContext& context)
{
DefId def = dfg->getDef(fragment);
if (std::optional<TypeId> contextTy = context.find(def))
{
TypeId actualType = lookupType(fragment);
SubtypingResult r = subtyping.isSubtype(actualType, *contextTy);
if (r.normalizationTooComplex)
reportError(NormalizationTooComplex{}, fragment->location);
if (r.isSubtype)
return {actualType};
}
return {};
}
};
void checkNonStrict(NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, NotNull<UnifierSharedState> unifierState,
NotNull<const DataFlowGraph> dfg, NotNull<TypeCheckLimits> limits, const SourceModule& sourceModule, Module* module)
{
// TODO: unimplemented
NonStrictTypeChecker typeChecker{builtinTypes, ice, unifierState, dfg, limits, module};
typeChecker.visit(sourceModule.root);
unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes, builtinTypes);
freeze(module->interfaceTypes);
}
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -9,8 +9,6 @@
#include "Luau/VisitType.h"
LUAU_FASTFLAG(DebugLuauSharedSelf)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau
{
@ -27,7 +25,6 @@ struct Quantifier final : TypeOnceVisitor
explicit Quantifier(TypeLevel level)
: level(level)
{
LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution);
}
/// @return true if outer encloses inner
@ -117,7 +114,7 @@ void quantify(TypeId ty, TypeLevel level)
for (const auto& [_, prop] : ttv->props)
{
auto ftv = getMutable<FunctionType>(follow(prop.type));
auto ftv = getMutable<FunctionType>(follow(prop.type()));
if (!ftv || !ftv->hasSelf)
continue;
@ -137,7 +134,7 @@ void quantify(TypeId ty, TypeLevel level)
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType)
ftv->hasNoGenerics = true;
ftv->hasNoFreeOrGenericTypes = true;
}
}
else
@ -155,8 +152,8 @@ void quantify(TypeId ty, TypeLevel level)
struct PureQuantifier : Substitution
{
Scope* scope;
std::vector<TypeId> insertedGenerics;
std::vector<TypePackId> insertedGenericPacks;
OrderedMap<TypeId, TypeId> insertedGenerics;
OrderedMap<TypePackId, TypePackId> insertedGenericPacks;
bool seenMutableType = false;
bool seenGenericType = false;
@ -204,7 +201,7 @@ struct PureQuantifier : Substitution
if (auto ftv = get<FreeType>(ty))
{
TypeId result = arena->addType(GenericType{scope});
insertedGenerics.push_back(result);
insertedGenerics.push(ty, result);
return result;
}
else if (auto ttv = get<TableType>(ty))
@ -218,7 +215,10 @@ struct PureQuantifier : Substitution
resultTable->scope = scope;
if (ttv->state == TableState::Free)
{
resultTable->state = TableState::Generic;
insertedGenerics.push(ty, result);
}
else if (ttv->state == TableState::Unsealed)
resultTable->state = TableState::Sealed;
@ -232,8 +232,8 @@ struct PureQuantifier : Substitution
{
if (auto ftp = get<FreeTypePack>(tp))
{
TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{}});
insertedGenericPacks.push_back(result);
TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{scope}});
insertedGenericPacks.push(tp, result);
return result;
}
@ -242,7 +242,7 @@ struct PureQuantifier : Substitution
bool ignoreChildren(TypeId ty) override
{
if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassType>(ty))
if (get<ClassType>(ty))
return true;
return ty->persistent;
@ -253,7 +253,7 @@ struct PureQuantifier : Substitution
}
};
std::optional<TypeId> quantify(TypeArena* arena, TypeId ty, Scope* scope)
std::optional<QuantifierResult> quantify(TypeArena* arena, TypeId ty, Scope* scope)
{
PureQuantifier quantifier{arena, scope};
std::optional<TypeId> result = quantifier.substitute(ty);
@ -263,11 +263,20 @@ std::optional<TypeId> quantify(TypeArena* arena, TypeId ty, Scope* scope)
FunctionType* ftv = getMutable<FunctionType>(*result);
LUAU_ASSERT(ftv);
ftv->scope = scope;
ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end());
ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType;
return *result;
for (auto k : quantifier.insertedGenerics.keys)
{
TypeId g = quantifier.insertedGenerics.pairings[k];
if (get<GenericType>(g))
ftv->generics.push_back(g);
}
for (auto k : quantifier.insertedGenericPacks.keys)
ftv->genericPacks.push_back(quantifier.insertedGenericPacks.pairings[k]);
ftv->hasNoFreeOrGenericTypes = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType;
return std::optional<QuantifierResult>({*result, std::move(quantifier.insertedGenerics), std::move(quantifier.insertedGenericPacks)});
}
} // namespace Luau

View file

@ -1,37 +1,60 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Refinement.h"
#include <algorithm>
namespace Luau
{
RefinementId RefinementArena::variadic(const std::vector<RefinementId>& refis)
{
bool hasRefinements = false;
for (RefinementId r : refis)
hasRefinements |= bool(r);
if (!hasRefinements)
return nullptr;
return NotNull{allocator.allocate(Variadic{refis})};
}
RefinementId RefinementArena::negation(RefinementId refinement)
{
if (!refinement)
return nullptr;
return NotNull{allocator.allocate(Negation{refinement})};
}
RefinementId RefinementArena::conjunction(RefinementId lhs, RefinementId rhs)
{
if (!lhs && !rhs)
return nullptr;
return NotNull{allocator.allocate(Conjunction{lhs, rhs})};
}
RefinementId RefinementArena::disjunction(RefinementId lhs, RefinementId rhs)
{
if (!lhs && !rhs)
return nullptr;
return NotNull{allocator.allocate(Disjunction{lhs, rhs})};
}
RefinementId RefinementArena::equivalence(RefinementId lhs, RefinementId rhs)
{
if (!lhs && !rhs)
return nullptr;
return NotNull{allocator.allocate(Equivalence{lhs, rhs})};
}
RefinementId RefinementArena::proposition(BreadcrumbId breadcrumb, TypeId discriminantTy)
RefinementId RefinementArena::proposition(const RefinementKey* key, TypeId discriminantTy)
{
return NotNull{allocator.allocate(Proposition{breadcrumb, discriminantTy})};
if (!key)
return nullptr;
return NotNull{allocator.allocate(Proposition{key, discriminantTy})};
}
} // namespace Luau

Some files were not shown because too many files have changed in this diff Show more