Merge branch 'master' into rfc-shape-types

This commit is contained in:
ajeffrey@roblox.com 2022-06-30 13:42:31 -07:00
commit 0f61c85537
241 changed files with 23664 additions and 6556 deletions

270
.github/workflows/benchmark.yml vendored Normal file
View file

@ -0,0 +1,270 @@
name: benchmark
on:
push:
branches:
- master
paths-ignore:
- "docs/**"
- "papers/**"
- "rfcs/**"
- "*.md"
- "prototyping/**"
jobs:
windows:
name: windows-${{matrix.arch}}
strategy:
fail-fast: false
matrix:
os: [windows-latest]
arch: [Win32, x64]
bench:
- {
script: "run-benchmarks",
timeout: 12,
title: "Luau Benchmarks",
cachegrindTitle: "Performance",
cachegrindIterCount: 20,
}
benchResultsRepo:
- { name: "luau-lang/benchmark-data", branch: "main" }
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Luau repository
uses: actions/checkout@v3
- name: Build Luau
shell: bash # necessary for fail-fast
run: |
mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
cmake --build . --target Luau.Repl.CLI --config Release
cmake --build . --target Luau.Analyze.CLI --config Release
- name: Move build files to root
run: |
move build/Release/* .
- uses: actions/setup-python@v3
with:
python-version: "3.9"
architecture: "x64"
- name: Install python dependencies
run: |
python -m pip install requests
python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose
- name: Run benchmark
run: |
python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt
- name: Checkout Benchmark Results repository
uses: actions/checkout@v3
with:
repository: ${{ matrix.benchResultsRepo.name }}
ref: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
- name: Store ${{ matrix.bench.title }} result
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: ${{ matrix.bench.title }} (Windows ${{matrix.arch}})
tool: "benchmarkluau"
output-file-path: ./${{ matrix.bench.script }}-output.txt
external-data-json-path: ./gh-pages/dev/bench/data.json
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Push benchmark results
if: github.event_name == 'push'
run: |
echo "Pushing benchmark results..."
cd gh-pages
git config user.name github-actions
git config user.email github@users.noreply.github.com
git add ./dev/bench/data.json
git commit -m "Add benchmarks results for ${{ github.sha }}"
git push
cd ..
unix:
name: ${{matrix.os}}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
bench:
- {
script: "run-benchmarks",
timeout: 12,
title: "Luau Benchmarks",
cachegrindTitle: "Performance",
cachegrindIterCount: 20,
}
benchResultsRepo:
- { name: "luau-lang/benchmark-data", branch: "main" }
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Luau repository
uses: actions/checkout@v3
- name: Build Luau
run: make config=release luau luau-analyze
- uses: actions/setup-python@v3
with:
python-version: "3.9"
architecture: "x64"
- name: Install python dependencies
run: |
python -m pip install requests
python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose
- name: Run benchmark
run: |
python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt
- name: Install valgrind
if: matrix.os == 'ubuntu-latest'
run: |
sudo apt-get install valgrind
- name: Run ${{ matrix.bench.title }} (Cold Cachegrind)
if: matrix.os == 'ubuntu-latest'
run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 | tee -a ${{ matrix.bench.script }}-output.txt
- name: Run ${{ matrix.bench.title }} (Warm Cachegrind)
if: matrix.os == 'ubuntu-latest'
run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle }}" ${{ matrix.bench.cachegrindIterCount }} | tee -a ${{ matrix.bench.script }}-output.txt
- name: Checkout Benchmark Results repository
uses: actions/checkout@v3
with:
repository: ${{ matrix.benchResultsRepo.name }}
ref: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
- name: Store ${{ matrix.bench.title }} result
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: ${{ matrix.bench.title }}
tool: "benchmarkluau"
output-file-path: ./${{ matrix.bench.script }}-output.txt
external-data-json-path: ./gh-pages/dev/bench/data.json
github-token: ${{ secrets.BENCH_GITHUB_TOKEN }}
- name: Store ${{ matrix.bench.title }} result (CacheGrind)
if: matrix.os == 'ubuntu-latest'
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: ${{ matrix.bench.title }} (CacheGrind)
tool: "roblox"
output-file-path: ./${{ matrix.bench.script }}-output.txt
external-data-json-path: ./gh-pages/dev/bench/data.json
github-token: ${{ secrets.BENCH_GITHUB_TOKEN }}
- name: Push benchmark results
if: github.event_name == 'push'
run: |
echo "Pushing benchmark results..."
cd gh-pages
git config user.name github-actions
git config user.email github@users.noreply.github.com
git add ./dev/bench/data.json
git commit -m "Add benchmarks results for ${{ github.sha }}"
git push
cd ..
static-analysis:
name: luau-analyze
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
bench:
- {
script: "run-analyze",
timeout: 12,
title: "Luau Analyze",
cachegrindTitle: "Performance",
cachegrindIterCount: 20,
}
benchResultsRepo:
- { name: "luau-lang/benchmark-data", branch: "main" }
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
with:
token: "${{ secrets.BENCH_GITHUB_TOKEN }}"
- name: Build Luau
run: make config=release luau luau-analyze
- uses: actions/setup-python@v4
with:
python-version: "3.9"
architecture: "x64"
- name: Install python dependencies
run: |
sudo pip install requests numpy scipy matplotlib ipython jupyter pandas sympy nose
- name: Install valgrind
run: |
sudo apt-get install valgrind
- name: Run Luau Analyze on static file
run: sudo python ./bench/measure_time.py ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee ${{ matrix.bench.script }}-output.txt
- name: Run ${{ matrix.bench.title }} (Cold Cachegrind)
run: sudo ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt
- name: Run ${{ matrix.bench.title }} (Warm Cachegrind)
run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt
- name: Checkout Benchmark Results repository
uses: actions/checkout@v3
with:
repository: ${{ matrix.benchResultsRepo.name }}
ref: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
- name: Store ${{ matrix.bench.title }} result
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: ${{ matrix.bench.title }}
tool: "benchmarkluau"
gh-pages-branch: "main"
output-file-path: ./${{ matrix.bench.script }}-output.txt
external-data-json-path: ./gh-pages/dev/bench/data.json
github-token: ${{ secrets.BENCH_GITHUB_TOKEN }}
- name: Store ${{ matrix.bench.title }} result (CacheGrind)
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: ${{ matrix.bench.title }}
tool: "roblox"
gh-pages-branch: "main"
output-file-path: ./${{ matrix.bench.script }}-output.txt
external-data-json-path: ./gh-pages/dev/bench/data.json
github-token: ${{ secrets.BENCH_GITHUB_TOKEN }}
- name: Push benchmark results
if: github.event_name == 'push'
run: |
echo "Pushing benchmark results..."
cd gh-pages
git config user.name github-actions
git config user.email github@users.noreply.github.com
git add ./dev/bench/data.json
git commit -m "Add benchmarks results for ${{ github.sha }}"
git push
cd ..

View file

@ -10,7 +10,9 @@ jobs:
linux:
strategy:
matrix:
agda: [2.6.2.1]
agda: [2.6.2.2]
hackageDate: ["2022-04-07"]
hackageTime: ["23:06:28"]
name: prototyping
runs-on: ubuntu-latest
steps:
@ -18,7 +20,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ~/.cabal/store
key: prototyping-${{ runner.os }}-${{ matrix.agda }}
key: "prototyping-${{ runner.os }}-${{ matrix.agda }}-${{ matrix.hackageDate }}-${{ matrix.hackageTime }}"
- uses: actions/cache@v2
id: luau-ast-cache
with:
@ -28,12 +30,12 @@ jobs:
run: sudo apt-get install -y cabal-install
- name: cabal update
working-directory: prototyping
run: cabal update
run: cabal v2-update "hackage.haskell.org,${{ matrix.hackageDate }}T${{ matrix.hackageTime }}Z"
- name: cabal install
working-directory: prototyping
run: |
cabal install Agda-${{ matrix.agda }}
cabal install --lib scientific vector aeson --package-env .
cabal install --allow-newer Agda-${{ matrix.agda }}
- name: check targets
working-directory: prototyping
run: |

View file

@ -0,0 +1,30 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeArena.h"
#include "Luau/TypeVar.h"
#include <unordered_map>
namespace Luau
{
// Only exposed so they can be unit tested.
using SeenTypes = std::unordered_map<TypeId, TypeId>;
using SeenTypePacks = std::unordered_map<TypePackId, TypePackId>;
struct CloneState
{
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
int recursionCount = 0;
};
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState);
TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState);
TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log);
} // namespace Luau

View file

@ -0,0 +1,88 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/NotNull.h"
#include "Luau/Variant.h"
#include <string>
#include <memory>
#include <vector>
namespace Luau
{
struct Scope2;
struct TypeVar;
using TypeId = const TypeVar*;
struct TypePackVar;
using TypePackId = const TypePackVar*;
// subType <: superType
struct SubtypeConstraint
{
TypeId subType;
TypeId superType;
};
// subPack <: superPack
struct PackSubtypeConstraint
{
TypePackId subPack;
TypePackId superPack;
};
// subType ~ gen superType
struct GeneralizationConstraint
{
TypeId generalizedType;
TypeId sourceType;
Scope2* scope;
};
// subType ~ inst superType
struct InstantiationConstraint
{
TypeId subType;
TypeId superType;
};
// name(namedType) = name
struct NameConstraint
{
TypeId namedType;
std::string name;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, NameConstraint>;
using ConstraintPtr = std::unique_ptr<struct Constraint>;
struct Constraint
{
explicit Constraint(ConstraintV&& c);
Constraint(const Constraint&) = delete;
Constraint& operator=(const Constraint&) = delete;
ConstraintV c;
std::vector<NotNull<Constraint>> dependencies;
};
inline Constraint& asMutable(const Constraint& c)
{
return const_cast<Constraint&>(c);
}
template<typename T>
T* getMutable(Constraint& c)
{
return ::Luau::get_if<T>(&c.c);
}
template<typename T>
const T* get(const Constraint& c)
{
return getMutable<T>(asMutable(c));
}
} // namespace Luau

View file

@ -0,0 +1,150 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <memory>
#include <vector>
#include <unordered_map>
#include "Luau/Ast.h"
#include "Luau/Constraint.h"
#include "Luau/Module.h"
#include "Luau/NotNull.h"
#include "Luau/Symbol.h"
#include "Luau/TypeVar.h"
#include "Luau/Variant.h"
namespace Luau
{
struct Scope2;
struct ConstraintGraphBuilder
{
// A list of all the scopes in the module. This vector holds ownership of the
// scope pointers; the scopes themselves borrow pointers to other scopes to
// define the scope hierarchy.
std::vector<std::pair<Location, std::unique_ptr<Scope2>>> scopes;
SingletonTypes& singletonTypes;
TypeArena* const arena;
// The root scope of the module we're generating constraints for.
Scope2* rootScope;
// A mapping of AST node to TypeId.
DenseHashMap<const AstExpr*, TypeId> astTypes{nullptr};
// A mapping of AST node to TypePackId.
DenseHashMap<const AstExpr*, TypePackId> astTypePacks{nullptr};
DenseHashMap<const AstExpr*, TypeId> astOriginalCallTypes{nullptr};
// Types resolved from type annotations. Analogous to astTypes.
DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr};
// Type packs resolved from type annotations. Analogous to astTypePacks.
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
explicit ConstraintGraphBuilder(TypeArena* arena);
/**
* Fabricates a new free type belonging to a given scope.
* @param scope the scope the free type belongs to. Must not be null.
*/
TypeId freshType(Scope2* scope);
/**
* Fabricates a new free type pack belonging to a given scope.
* @param scope the scope the free type pack belongs to. Must not be null.
*/
TypePackId freshTypePack(Scope2* scope);
/**
* Fabricates a scope that is a child of another scope.
* @param location the lexical extent of the scope in the source code.
* @param parent the parent scope of the new scope. Must not be null.
*/
Scope2* childScope(Location location, Scope2* parent);
/**
* Adds a new constraint with no dependencies to a given scope.
* @param scope the scope to add the constraint to. Must not be null.
* @param cv the constraint variant to add.
*/
void addConstraint(Scope2* scope, ConstraintV cv);
/**
* Adds a constraint to a given scope.
* @param scope the scope to add the constraint to. Must not be null.
* @param c the constraint to add.
*/
void addConstraint(Scope2* scope, std::unique_ptr<Constraint> c);
/**
* The entry point to the ConstraintGraphBuilder. This will construct a set
* of scopes, constraints, and free types that can be solved later.
* @param block the root block to generate constraints for.
*/
void visit(AstStatBlock* block);
void visit(Scope2* scope, AstStat* stat);
void visit(Scope2* scope, AstStatBlock* block);
void visit(Scope2* scope, AstStatLocal* local);
void visit(Scope2* scope, AstStatLocalFunction* function);
void visit(Scope2* scope, AstStatFunction* function);
void visit(Scope2* scope, AstStatReturn* ret);
void visit(Scope2* scope, AstStatAssign* assign);
void visit(Scope2* scope, AstStatIf* ifStatement);
void visit(Scope2* scope, AstStatTypeAlias* alias);
TypePackId checkExprList(Scope2* scope, const AstArray<AstExpr*>& exprs);
TypePackId checkPack(Scope2* scope, AstArray<AstExpr*> exprs);
TypePackId checkPack(Scope2* scope, AstExpr* expr);
/**
* Checks an expression that is expected to evaluate to one type.
* @param scope the scope the expression is contained within.
* @param expr the expression to check.
* @return the type of the expression.
*/
TypeId check(Scope2* scope, AstExpr* expr);
TypeId checkExprTable(Scope2* scope, AstExprTable* expr);
TypeId check(Scope2* scope, AstExprIndexName* indexName);
std::pair<TypeId, Scope2*> checkFunctionSignature(Scope2* parent, AstExprFunction* fn);
/**
* Checks the body of a function expression.
* @param scope the interior scope of the body of the function.
* @param fn the function expression to check.
*/
void checkFunctionBody(Scope2* scope, AstExprFunction* fn);
/**
* Resolves a type from its AST annotation.
* @param scope the scope that the type annotation appears within.
* @param ty the AST annotation to resolve.
* @return the type of the AST annotation.
**/
TypeId resolveType(Scope2* scope, AstType* ty);
/**
* Resolves a type pack from its AST annotation.
* @param scope the scope that the type annotation appears within.
* @param tp the AST annotation to resolve.
* @return the type pack of the AST annotation.
**/
TypePackId resolveTypePack(Scope2* scope, AstTypePack* tp);
TypePackId resolveTypePack(Scope2* scope, const AstTypeList& list);
};
/**
* Collects a vector of borrowed constraints from the scope and all its child
* scopes. It is important to only call this function when you're done adding
* constraints to the scope or its descendants, lest the borrowed pointers
* become invalid due to a container reallocation.
* @param rootScope the root scope of the scope graph to collect constraints
* from.
* @return a list of pointers to constraints contained within the scope graph.
* None of these pointers should be null.
*/
std::vector<NotNull<Constraint>> collectConstraints(Scope2* rootScope);
} // namespace Luau

View file

@ -0,0 +1,120 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Error.h"
#include "Luau/Variant.h"
#include "Luau/Constraint.h"
#include "Luau/ConstraintSolverLogger.h"
#include "Luau/TypeVar.h"
#include <vector>
namespace Luau
{
// TypeId, TypePackId, or Constraint*. It is impossible to know which, but we
// never dereference this pointer.
using BlockedConstraintId = const void*;
struct ConstraintSolver
{
TypeArena* arena;
InternalErrorReporter iceReporter;
// The entire set of constraints that the solver is trying to resolve. It
// is important to not add elements to this vector, lest the underlying
// storage that we retain pointers to be mutated underneath us.
const std::vector<NotNull<Constraint>> constraints;
Scope2* rootScope;
// This includes every constraint that has not been fully solved.
// A constraint can be both blocked and unsolved, for instance.
std::vector<NotNull<const Constraint>> unsolvedConstraints;
// A mapping of constraint pointer to how many things the constraint is
// blocked on. Can be empty or 0 for constraints that are not blocked on
// anything.
std::unordered_map<NotNull<const Constraint>, size_t> blockedConstraints;
// A mapping of type/pack pointers to the constraints they block.
std::unordered_map<BlockedConstraintId, std::vector<NotNull<const Constraint>>> blocked;
ConstraintSolverLogger logger;
explicit ConstraintSolver(TypeArena* arena, Scope2* rootScope);
/**
* Attempts to dispatch all pending constraints and reach a type solution
* that satisfies all of the constraints.
**/
void run();
bool done();
bool tryDispatch(NotNull<const Constraint> c, bool force);
bool tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const InstantiationConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
/**
* Block a constraint on the resolution of a TypeVar.
* @returns false always. This is just to allow tryDispatch to return the result of block()
*/
bool block(TypeId target, NotNull<const Constraint> constraint);
bool block(TypePackId target, NotNull<const Constraint> constraint);
void unblock(NotNull<const Constraint> progressed);
void unblock(TypeId progressed);
void unblock(TypePackId progressed);
/**
* @returns true if the TypeId is in a blocked state.
*/
bool isBlocked(TypeId ty);
/**
* Returns whether the constraint is blocked on anything.
* @param constraint the constraint to check.
*/
bool isBlocked(NotNull<const Constraint> constraint);
/**
* Creates a new Unifier and performs a single unification operation. Commits
* the result.
* @param subType the sub-type to unify.
* @param superType the super-type to unify.
*/
void unify(TypeId subType, TypeId superType);
/**
* Creates a new Unifier and performs a single unification operation. Commits
* the result.
* @param subPack the sub-type pack to unify.
* @param superPack the super-type pack to unify.
*/
void unify(TypePackId subPack, TypePackId superPack);
private:
/**
* Marks a constraint as being blocked on a type or type pack. The constraint
* solver will not attempt to dispatch blocked constraints until their
* dependencies have made progress.
* @param target the type or type pack pointer that the constraint is blocked on.
* @param constraint the constraint to block.
**/
void block_(BlockedConstraintId target, NotNull<const Constraint> constraint);
/**
* Informs the solver that progress has been made on a type or type pack. The
* solver will wake up all constraints that are blocked on the type or type pack,
* and will resume attempting to dispatch them.
* @param progressed the type or type pack pointer that has progressed.
**/
void unblock_(BlockedConstraintId progressed);
};
void dump(Scope2* rootScope, struct ToStringOptions& opts);
} // namespace Luau

View file

@ -0,0 +1,28 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Constraint.h"
#include "Luau/NotNull.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include <optional>
#include <string>
#include <vector>
namespace Luau
{
struct ConstraintSolverLogger
{
std::string compileOutput();
void captureBoundarySnapshot(const Scope2* rootScope, std::vector<NotNull<const Constraint>>& unsolvedConstraints);
void prepareStepSnapshot(const Scope2* rootScope, NotNull<const Constraint> current, std::vector<NotNull<const Constraint>>& unsolvedConstraints);
void commitPreparedStepSnapshot();
private:
std::vector<std::string> snapshots;
std::optional<std::string> preparedSnapshot;
ToStringOptions opts;
};
} // namespace Luau

View file

@ -5,6 +5,7 @@
#include "Luau/Location.h"
#include "Luau/TypeVar.h"
#include "Luau/Variant.h"
#include "Luau/TypeArena.h"
namespace Luau
{
@ -108,9 +109,6 @@ struct FunctionDoesNotTakeSelf
struct FunctionRequiresSelf
{
// TODO: Delete with LuauAnyInIsOptionalIsOptional
int requiredExtraNils = 0;
bool operator==(const FunctionRequiresSelf& rhs) const;
};
@ -171,6 +169,13 @@ struct GenericError
bool operator==(const GenericError& rhs) const;
};
struct InternalError
{
std::string message;
bool operator==(const InternalError& rhs) const;
};
struct CannotCallNonFunction
{
TypeId ty;
@ -287,12 +292,20 @@ struct TypesAreUnrelated
bool operator==(const TypesAreUnrelated& rhs) const;
};
using TypeErrorData =
Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods, DuplicateTypeDefinition,
CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire, IncorrectGenericParameterCount, SyntaxError,
CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed,
ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning, DuplicateGenericParameter, CannotInferBinaryOperation,
MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, TypesAreUnrelated>;
struct NormalizationTooComplex
{
bool operator==(const NormalizationTooComplex&) const
{
return true;
}
};
using TypeErrorData = Variant<TypeMismatch, UnknownSymbol, UnknownProperty, NotATable, CannotExtendTable, OnlyTablesCanHaveMethods,
DuplicateTypeDefinition, CountMismatch, FunctionDoesNotTakeSelf, FunctionRequiresSelf, OccursCheckFailed, UnknownRequire,
IncorrectGenericParameterCount, SyntaxError, CodeTooComplex, UnificationTooComplex, UnknownPropButFoundLikeProp, GenericError, InternalError,
CannotCallNonFunction, ExtraInformation, DeprecatedApiUsed, ModuleHasCyclicDependency, IllegalRequire, FunctionExitsWithoutReturning,
DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty,
TypesAreUnrelated, NormalizationTooComplex>;
struct TypeError
{
@ -333,7 +346,13 @@ T* get(TypeError& e)
using ErrorVec = std::vector<TypeError>;
struct TypeErrorToStringOptions
{
FileResolver* fileResolver = nullptr;
};
std::string toString(const TypeError& error);
std::string toString(const TypeError& error, TypeErrorToStringOptions options);
bool containsParseErrorName(const TypeError& error);
@ -350,4 +369,24 @@ struct InternalErrorReporter
[[noreturn]] void ice(const std::string& message);
};
class InternalCompilerError : public std::exception {
public:
explicit InternalCompilerError(const std::string& message, const std::string& moduleName)
: message(message)
, moduleName(moduleName)
{
}
explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location)
: message(message)
, moduleName(moduleName)
, location(location)
{
}
virtual const char* what() const throw();
const std::string message;
const std::string moduleName;
const std::optional<Location> location;
};
} // namespace Luau

View file

@ -55,10 +55,23 @@ std::optional<ModuleName> pathExprToModuleName(const ModuleName& currentModuleNa
struct SourceNode
{
bool hasDirtySourceModule() const
{
return dirtySourceModule;
}
bool hasDirtyModule(bool forAutocomplete) const
{
return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule;
}
ModuleName name;
std::unordered_set<ModuleName> requires;
std::unordered_set<ModuleName> requireSet;
std::vector<std::pair<ModuleName, Location>> requireLocations;
bool dirty = true;
bool dirtySourceModule = true;
bool dirtyModule = true;
bool dirtyModuleForAutocomplete = true;
double autocompleteLimitsMult = 1.0;
};
struct FrontendOptions
@ -69,14 +82,14 @@ struct FrontendOptions
// is complete.
bool retainFullTypeGraphs = false;
// When true, we run typechecking twice, once in the regular mode, and once in strict mode
// in order to get more precise type information (e.g. for autocomplete).
bool typecheckTwice = false;
// Run typechecking only in mode required for autocomplete (strict mode in order to get more precise type information)
bool forAutocomplete = false;
};
struct CheckResult
{
std::vector<TypeError> errors;
std::vector<ModuleName> timeoutHits;
};
struct FrontendModuleResolver : ModuleResolver
@ -120,10 +133,9 @@ struct Frontend
*/
std::pair<SourceModule, LintResult> lintFragment(std::string_view source, std::optional<LintOptions> enabledLintWarnings = {});
CheckResult check(const SourceModule& module); // OLD. TODO KILL
LintResult lint(const SourceModule& module, std::optional<LintOptions> enabledLintWarnings = {});
bool isDirty(const ModuleName& name) const;
bool isDirty(const ModuleName& name, bool forAutocomplete = false) const;
void markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty = nullptr);
/** Borrow a pointer into the SourceModule cache.
@ -147,10 +159,12 @@ struct Frontend
void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName);
private:
ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope);
std::pair<SourceNode*, SourceModule*> getSourceNode(CheckResult& checkResult, const ModuleName& name);
SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions);
bool parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& checkResult, const ModuleName& root);
bool parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete);
static LintResult classifyLints(const std::vector<LintWarning>& warnings, const Config& config);
@ -172,7 +186,7 @@ public:
std::unordered_map<ModuleName, SourceNode> sourceNodes;
std::unordered_map<ModuleName, SourceModule> sourceModules;
std::unordered_map<ModuleName, RequireTraceResult> requires;
std::unordered_map<ModuleName, RequireTraceResult> requireTrace;
Stats stats = {};
};

View file

@ -0,0 +1,53 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Substitution.h"
#include "Luau/TypeVar.h"
#include "Luau/Unifiable.h"
namespace Luau
{
struct TypeArena;
struct TxnLog;
// A substitution which replaces generic types in a given set by free types.
struct ReplaceGenerics : Substitution
{
ReplaceGenerics(
const TxnLog* log, TypeArena* arena, TypeLevel level, const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks)
: Substitution(log, arena)
, level(level)
, generics(generics)
, genericPacks(genericPacks)
{
}
TypeLevel level;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
};
// A substitution which replaces generic functions by monomorphic functions
struct Instantiation : Substitution
{
Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level)
: Substitution(log, arena)
, level(level)
{
}
TypeLevel level;
bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
};
} // namespace Luau

View file

@ -30,6 +30,7 @@ std::ostream& operator<<(std::ostream& lhs, const OccursCheckFailed& error);
std::ostream& operator<<(std::ostream& lhs, const UnknownRequire& error);
std::ostream& operator<<(std::ostream& lhs, const UnknownPropButFoundLikeProp& e);
std::ostream& operator<<(std::ostream& lhs, const GenericError& error);
std::ostream& operator<<(std::ostream& lhs, const InternalError& error);
std::ostream& operator<<(std::ostream& lhs, const FunctionExitsWithoutReturning& error);
std::ostream& operator<<(std::ostream& lhs, const MissingProperties& error);
std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error);

View file

@ -34,8 +34,8 @@ const LValue* baseof(const LValue& lvalue);
std::optional<LValue> tryGetLValue(const class AstExpr& expr);
// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys.
std::pair<Symbol, std::vector<std::string>> getFullName(const LValue& lvalue);
// Utility function: breaks down an LValue to get at the Symbol
Symbol getBaseSymbol(const LValue& lvalue);
template<typename T>
const T* get(const LValue& lvalue)

View file

@ -1,12 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/FileResolver.h"
#include "Luau/TypePack.h"
#include "Luau/TypedAllocator.h"
#include "Luau/ParseOptions.h"
#include "Luau/Error.h"
#include "Luau/FileResolver.h"
#include "Luau/ParseOptions.h"
#include "Luau/ParseResult.h"
#include "Luau/Scope.h"
#include "Luau/TypeArena.h"
#include <memory>
#include <vector>
@ -21,6 +21,9 @@ struct Module;
using ScopePtr = std::shared_ptr<struct Scope>;
using ModulePtr = std::shared_ptr<Module>;
class AstType;
class AstTypePack;
/// Root of the AST of a parsed source file
struct SourceModule
{
@ -29,8 +32,8 @@ struct SourceModule
std::optional<std::string> environmentName;
bool cyclic = false;
std::unique_ptr<Allocator> allocator;
std::unique_ptr<AstNameTable> names;
std::shared_ptr<Allocator> allocator;
std::shared_ptr<AstNameTable> names;
std::vector<ParseError> parseErrors;
AstStatBlock* root = nullptr;
@ -48,49 +51,12 @@ struct SourceModule
bool isWithinComment(const SourceModule& sourceModule, Position pos);
struct TypeArena
struct RequireCycle
{
TypedAllocator<TypeVar> typeVars;
TypedAllocator<TypePackVar> typePacks;
void clear();
template<typename T>
TypeId addType(T tv)
{
if constexpr (std::is_same_v<T, UnionTypeVar>)
LUAU_ASSERT(tv.options.size() >= 2);
return addTV(TypeVar(std::move(tv)));
}
TypeId addTV(TypeVar&& tv);
TypeId freshType(TypeLevel level);
TypePackId addTypePack(std::initializer_list<TypeId> types);
TypePackId addTypePack(std::vector<TypeId> types);
TypePackId addTypePack(TypePack pack);
TypePackId addTypePack(TypePackVar pack);
Location location;
std::vector<ModuleName> path; // one of the paths for a require() to go all the way back to the originating module
};
void freeze(TypeArena& arena);
void unfreeze(TypeArena& arena);
// Only exposed so they can be unit tested.
using SeenTypes = std::unordered_map<TypeId, TypeId>;
using SeenTypePacks = std::unordered_map<TypePackId, TypePackId>;
struct CloneState
{
int recursionCount = 0;
bool encounteredFreeType = false;
};
TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState);
TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState);
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState);
struct Module
{
~Module();
@ -98,25 +64,33 @@ struct Module
TypeArena interfaceTypes;
TypeArena internalTypes;
// Scopes and AST types refer to parse data, so we need to keep that alive
std::shared_ptr<Allocator> allocator;
std::shared_ptr<AstNameTable> names;
std::vector<std::pair<Location, ScopePtr>> scopes; // never empty
std::vector<std::pair<Location, std::unique_ptr<Scope2>>> scope2s; // never empty
DenseHashMap<const AstExpr*, TypeId> astTypes{nullptr};
DenseHashMap<const AstExpr*, TypePackId> astTypePacks{nullptr};
DenseHashMap<const AstExpr*, TypeId> astExpectedTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astOriginalCallTypes{nullptr};
DenseHashMap<const AstExpr*, TypeId> astOverloadResolvedTypes{nullptr};
DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr};
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
std::unordered_map<Name, TypeId> declaredGlobals;
ErrorVec errors;
Mode mode;
SourceCode::Type type;
bool timeout = false;
ScopePtr getModuleScope() const;
Scope2* getModuleScope2() const;
// Once a module has been typechecked, we clone its public interface into a separate arena.
// This helps us to force TypeVar ownership into a DAG rather than a DCG.
// Returns true if there were any free types encountered in the public interface. This
// indicates a bug in the type checker that we want to surface.
bool clonePublicInterface();
void clonePublicInterface(InternalErrorReporter& ice);
};
} // namespace Luau

View file

@ -0,0 +1,20 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Substitution.h"
#include "Luau/TypeVar.h"
#include "Luau/Module.h"
namespace Luau
{
struct InternalErrorReporter;
bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice);
bool isSubtype(TypePackId subTy, TypePackId superTy, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, TypeArena& arena, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, const ModulePtr& module, InternalErrorReporter& ice);
} // namespace Luau

View file

@ -0,0 +1,88 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include <functional>
namespace Luau
{
/** A non-owning, non-null pointer to a T.
*
* A NotNull<T> is notionally identical to a T* with the added restriction that
* it can never store nullptr.
*
* The sole conversion rule from T* to NotNull<T> is the single-argument
* constructor, which is intentionally marked explicit. This constructor
* performs a runtime test to verify that the passed pointer is never nullptr.
*
* Pointer arithmetic, increment, decrement, and array indexing are all
* forbidden.
*
* An implicit coersion from NotNull<T> to T* is afforded, as are the pointer
* indirection and member access operators. (*p and p->prop)
*
* The explicit delete statement is permitted (but not recommended) on a
* NotNull<T> through this implicit conversion.
*/
template <typename T>
struct NotNull
{
explicit NotNull(T* t)
: ptr(t)
{
LUAU_ASSERT(t);
}
explicit NotNull(std::nullptr_t) = delete;
void operator=(std::nullptr_t) = delete;
template <typename U>
NotNull(NotNull<U> other)
: ptr(other.get())
{}
operator T*() const noexcept
{
return ptr;
}
T& operator*() const noexcept
{
return *ptr;
}
T* operator->() const noexcept
{
return ptr;
}
T& operator[](int) = delete;
T& operator+(int) = delete;
T& operator-(int) = delete;
T* get() const noexcept
{
return ptr;
}
private:
T* ptr;
};
}
namespace std
{
template <typename T> struct hash<Luau::NotNull<T>>
{
size_t operator()(const Luau::NotNull<T>& p) const
{
return std::hash<T*>()(p.get());
}
};
}

View file

@ -6,6 +6,10 @@
namespace Luau
{
struct TypeArena;
struct Scope2;
void quantify(TypeId ty, TypeLevel level);
TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope);
} // namespace Luau

View file

@ -4,10 +4,19 @@
#include "Luau/Common.h"
#include <stdexcept>
#include <exception>
namespace Luau
{
struct RecursionLimitException : public std::exception
{
const char* what() const noexcept
{
return "Internal recursion counter limit exceeded";
}
};
struct RecursionCounter
{
RecursionCounter(int* count)
@ -32,7 +41,9 @@ struct RecursionLimiter : RecursionCounter
: RecursionCounter(count)
{
if (limit > 0 && *count > limit)
throw std::runtime_error("Internal recursion counter limit exceeded");
{
throw RecursionLimitException();
}
}
};

View file

@ -19,7 +19,7 @@ struct RequireTraceResult
{
DenseHashMap<const AstExpr*, ModuleInfo> exprs{nullptr};
std::vector<std::pair<ModuleName, Location>> requires;
std::vector<std::pair<ModuleName, Location>> requireList;
};
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName);

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Constraint.h"
#include "Luau/Location.h"
#include "Luau/TypeVar.h"
@ -64,4 +65,21 @@ struct Scope
std::unordered_map<Name, TypePackId> typeAliasTypePackParameters;
};
struct Scope2
{
// The parent scope of this scope. Null if there is no parent (i.e. this
// is the module-level scope).
Scope2* parent = nullptr;
// All the children of this scope.
std::vector<Scope2*> children;
std::unordered_map<Symbol, TypeId> bindings; // TODO: I think this can be a DenseHashMap
std::unordered_map<Name, TypeId> typeBindings;
TypePackId returnType;
// All constraints belonging to this scope.
std::vector<ConstraintPtr> constraints;
std::optional<TypeId> lookup(Symbol sym);
std::optional<TypeId> lookupTypeBinding(const Name& name);
};
} // namespace Luau

View file

@ -1,8 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Module.h"
#include "Luau/ModuleResolver.h"
#include "Luau/TypeArena.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
#include "Luau/DenseHash.h"
@ -90,6 +89,7 @@ struct Tarjan
std::vector<int> lowlink;
int childCount = 0;
int childLimit = 0;
// This should never be null; ensure you initialize it before calling
// substitution methods.

View file

@ -30,6 +30,9 @@ struct Symbol
{
}
template<typename T>
Symbol(const T&) = delete;
AstLocal* local;
AstName global;

View file

@ -3,6 +3,7 @@
#include "Luau/Common.h"
#include "Luau/TypeVar.h"
#include "Luau/ConstraintGraphBuilder.h"
#include <unordered_map>
#include <optional>
@ -28,6 +29,8 @@ struct ToStringOptions
bool functionTypeArguments = false; // If true, output function type argument names when they are available
bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}'
bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level.
bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self
bool indent = false;
size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars
size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength);
std::optional<ToStringNameMap> nameMap;
@ -51,6 +54,7 @@ ToStringResult toStringDetailed(TypePackId ty, const ToStringOptions& opts = {})
std::string toString(TypeId ty, const ToStringOptions& opts);
std::string toString(TypePackId ty, const ToStringOptions& opts);
std::string toString(const Constraint& c, ToStringOptions& opts);
// These are offered as overloads rather than a default parameter so that they can be easily invoked from within the MSVC debugger.
// You can use them in watch expressions!
@ -62,6 +66,11 @@ inline std::string toString(TypePackId ty)
{
return toString(ty, ToStringOptions{});
}
inline std::string toString(const Constraint& c)
{
ToStringOptions opts;
return toString(c, opts);
}
std::string toString(const TypeVar& tv, const ToStringOptions& opts = {});
std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {});
@ -72,6 +81,9 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp
// These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression
std::string dump(TypeId ty);
std::string dump(TypePackId ty);
std::string dump(const Constraint& c);
std::string dump(const std::shared_ptr<Scope>& scope, const char* name);
std::string generateName(size_t n);

View file

@ -7,8 +7,6 @@
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
LUAU_FASTFLAG(LuauShareTxnSeen);
namespace Luau
{
@ -64,13 +62,17 @@ T* getMutable(PendingTypePack* pending)
struct TxnLog
{
TxnLog()
: ownedSeen()
: typeVarChanges(nullptr)
, typePackChanges(nullptr)
, ownedSeen()
, sharedSeen(&ownedSeen)
{
}
explicit TxnLog(TxnLog* parent)
: parent(parent)
: typeVarChanges(nullptr)
, typePackChanges(nullptr)
, parent(parent)
{
if (parent)
{
@ -83,12 +85,8 @@ struct TxnLog
}
explicit TxnLog(std::vector<std::pair<TypeOrPackId, TypeOrPackId>>* sharedSeen)
: sharedSeen(sharedSeen)
{
}
TxnLog(TxnLog* parent, std::vector<std::pair<TypeOrPackId, TypeOrPackId>>* sharedSeen)
: parent(parent)
: typeVarChanges(nullptr)
, typePackChanges(nullptr)
, sharedSeen(sharedSeen)
{
}
@ -243,6 +241,12 @@ struct TxnLog
return Luau::getMutable<T>(ty);
}
template<typename T, typename TID>
const T* get(TID ty) const
{
return this->getMutable<T>(ty);
}
// Returns whether a given type or type pack is a given state, respecting the
// log's pending state.
//
@ -263,11 +267,8 @@ private:
// unique_ptr is used to give us stable pointers across insertions into the
// map. Otherwise, it would be really easy to accidentally invalidate the
// pointers returned from queue/pending.
//
// We can't use a DenseHashMap here because we need a non-const iterator
// over the map when we concatenate.
std::unordered_map<TypeId, std::unique_ptr<PendingType>, DenseHashPointer> typeVarChanges;
std::unordered_map<TypePackId, std::unique_ptr<PendingTypePack>, DenseHashPointer> typePackChanges;
DenseHashMap<TypeId, std::unique_ptr<PendingType>> typeVarChanges;
DenseHashMap<TypePackId, std::unique_ptr<PendingTypePack>> typePackChanges;
TxnLog* parent = nullptr;

View file

@ -0,0 +1,42 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypedAllocator.h"
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
#include <vector>
namespace Luau
{
struct TypeArena
{
TypedAllocator<TypeVar> typeVars;
TypedAllocator<TypePackVar> typePacks;
void clear();
template<typename T>
TypeId addType(T tv)
{
if constexpr (std::is_same_v<T, UnionTypeVar>)
LUAU_ASSERT(tv.options.size() >= 2);
return addTV(TypeVar(std::move(tv)));
}
TypeId addTV(TypeVar&& tv);
TypeId freshType(TypeLevel level);
TypePackId addTypePack(std::initializer_list<TypeId> types);
TypePackId addTypePack(std::vector<TypeId> types);
TypePackId addTypePack(TypePack pack);
TypePackId addTypePack(TypePackVar pack);
};
void freeze(TypeArena& arena);
void unfreeze(TypeArena& arena);
} // namespace Luau

View file

@ -0,0 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
#include "Luau/Module.h"
namespace Luau
{
void check(const SourceModule& sourceModule, Module* module);
} // namespace Luau

View file

@ -34,61 +34,35 @@ const AstStat* getFallthrough(const AstStat* node);
struct UnifierOptions;
struct Unifier;
// A substitution which replaces generic types in a given set by free types.
struct ReplaceGenerics : Substitution
{
ReplaceGenerics(
const TxnLog* log, TypeArena* arena, TypeLevel level, const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks)
: Substitution(log, arena)
, level(level)
, generics(generics)
, genericPacks(genericPacks)
{
}
TypeLevel level;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
};
// A substitution which replaces generic functions by monomorphic functions
struct Instantiation : Substitution
{
Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level)
: Substitution(log, arena)
, level(level)
{
}
TypeLevel level;
bool ignoreChildren(TypeId ty) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
};
// A substitution which replaces free types by any
struct Anyification : Substitution
{
Anyification(TypeArena* arena, TypeId anyType, TypePackId anyTypePack)
Anyification(TypeArena* arena, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack)
: Substitution(TxnLog::empty(), arena)
, iceHandler(iceHandler)
, anyType(anyType)
, anyTypePack(anyTypePack)
{
}
InternalErrorReporter* iceHandler;
TypeId anyType;
TypePackId anyTypePack;
bool normalizationTooComplex = false;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
bool ignoreChildren(TypeId ty) override
{
return ty->persistent;
}
bool ignoreChildren(TypePackId ty) override
{
return ty->persistent;
}
};
// A substitution which replaces the type parameters of a type function by arguments
@ -124,6 +98,12 @@ struct HashBoolNamePair
size_t operator()(const std::pair<bool, Name>& pair) const;
};
class TimeLimitError : public std::exception
{
public:
virtual const char* what() const throw();
};
// All TypeVars are retained via Environment::typeVars. All TypeIds
// within a program are borrowed pointers into this set.
struct TypeChecker
@ -133,6 +113,7 @@ struct TypeChecker
TypeChecker& operator=(const TypeChecker&) = delete;
ModulePtr check(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope = std::nullopt);
ModulePtr checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional<ScopePtr> environmentScope = std::nullopt);
std::vector<std::pair<Location, ScopePtr>> getScopes() const;
@ -154,27 +135,28 @@ struct TypeChecker
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
void checkBlock(const ScopePtr& scope, const AstStatBlock& statement);
void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement);
void checkBlockTypeAliases(const ScopePtr& scope, std::vector<AstStat*>& sorted);
ExprResult<TypeId> checkExpr(
WithPredicate<TypeId> checkExpr(
const ScopePtr& scope, const AstExpr& expr, std::optional<TypeId> expectedType = std::nullopt, bool forceSingleton = false);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprLocal& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprGlobal& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprVarargs& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprCall& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprIndexName& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional<TypeId> expectedType = std::nullopt);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional<TypeId> expectedType = std::nullopt);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprUnary& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprLocal& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprGlobal& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprVarargs& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprCall& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprIndexName& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprUnary& expr);
TypeId checkRelationalOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {});
TypeId checkBinaryOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {});
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprBinary& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprError& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprBinary& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprError& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType = std::nullopt);
TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector<std::pair<TypeId, TypeId>>& fieldTypes,
std::optional<TypeId> expectedType);
@ -197,11 +179,11 @@ struct TypeChecker
void checkArgumentList(
const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector<Location>& argLocations);
ExprResult<TypePackId> checkExprPack(const ScopePtr& scope, const AstExpr& expr);
ExprResult<TypePackId> checkExprPack(const ScopePtr& scope, const AstExprCall& expr);
WithPredicate<TypePackId> checkExprPack(const ScopePtr& scope, const AstExpr& expr);
WithPredicate<TypePackId> checkExprPack(const ScopePtr& scope, const AstExprCall& expr);
std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall);
std::optional<ExprResult<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const ExprResult<TypePackId>& argListResult,
std::optional<WithPredicate<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors);
bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations,
const std::vector<OverloadErrorEntry>& errors);
@ -209,7 +191,7 @@ struct TypeChecker
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
const std::vector<OverloadErrorEntry>& errors);
ExprResult<TypePackId> checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs,
WithPredicate<TypePackId> checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs,
bool substituteFreeForNil = false, const std::vector<bool>& lhsAnnotations = {},
const std::vector<std::optional<TypeId>>& expectedTypes = {});
@ -252,6 +234,8 @@ struct TypeChecker
ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location);
ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location);
void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location);
std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location);
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location);
@ -371,7 +355,7 @@ private:
const AstArray<AstGenericType>& genericNames, const AstArray<AstGenericTypePack>& genericPackNames, bool useCache = false);
public:
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);
void resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);
private:
void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate);
@ -379,16 +363,17 @@ private:
std::optional<TypeId> resolveLValue(const ScopePtr& scope, const LValue& lvalue);
std::optional<TypeId> resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue);
void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false);
void resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr);
void resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr);
void resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const PredicateVec& predicates, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false);
void resolve(const Predicate& predicate, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr);
void resolve(const TruthyPredicate& truthyP, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr);
void resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const OrPredicate& orP, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const TypeGuardPredicate& typeguardP, RefinementMap& refis, const ScopePtr& scope, bool sense);
void resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense);
bool isNonstrictMode() const;
bool useConstrainedIntersections() const;
public:
/** Extract the types in a type pack, given the assumption that the pack must have some exact length.
@ -413,6 +398,13 @@ public:
UnifierSharedState unifierState;
std::vector<RequireCycle> requireCycles;
// Type inference limits
std::optional<double> finishTime;
std::optional<int> instantiationChildLimit;
std::optional<int> unifierIterationLimit;
public:
const TypeId nilType;
const TypeId numberType;
@ -420,7 +412,6 @@ public:
const TypeId booleanType;
const TypeId threadType;
const TypeId anyType;
const TypeId optionalNumberType;
const TypePackId anyTypePack;

View file

@ -40,6 +40,7 @@ struct TypePack
struct VariadicTypePack
{
TypeId ty;
bool hidden = false; // if true, we don't display this when toString()ing a pack with this variadic as its tail.
};
struct TypePackVar
@ -47,13 +48,24 @@ struct TypePackVar
explicit TypePackVar(const TypePackVariant& ty);
explicit TypePackVar(TypePackVariant&& ty);
TypePackVar(TypePackVariant&& ty, bool persistent);
bool operator==(const TypePackVar& rhs) const;
TypePackVar& operator=(TypePackVariant&& tp);
TypePackVar& operator=(const TypePackVar& rhs);
// Re-assignes the content of the pack, but doesn't change the owning arena and can't make pack persistent.
void reassign(const TypePackVar& rhs)
{
ty = rhs.ty;
}
TypePackVariant ty;
bool persistent = false;
// Pointer to the type arena that allocated this type.
// Pointer to the type arena that allocated this pack.
TypeArena* owningArena = nullptr;
};
@ -109,10 +121,10 @@ private:
};
TypePackIterator begin(TypePackId tp);
TypePackIterator begin(TypePackId tp, TxnLog* log);
TypePackIterator begin(TypePackId tp, const TxnLog* log);
TypePackIterator end(TypePackId tp);
using SeenSet = std::set<std::pair<void*, void*>>;
using SeenSet = std::set<std::pair<const void*, const void*>>;
bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs);
@ -122,7 +134,7 @@ TypePackId follow(TypePackId tp, std::function<TypePackId(TypePackId)> mapper);
size_t size(TypePackId tp, TxnLog* log = nullptr);
bool finite(TypePackId tp, TxnLog* log = nullptr);
size_t size(const TypePack& tp, TxnLog* log = nullptr);
std::optional<TypeId> first(TypePackId tp);
std::optional<TypeId> first(TypePackId tp, bool ignoreHiddenVariadics = true);
TypePackVar* asMutable(TypePackId tp);
TypePack* asMutable(const TypePack* tp);
@ -154,5 +166,12 @@ bool isEmpty(TypePackId tp);
/// Flattens out a type pack. Also returns a valid TypePackId tail if the type pack's full size is not known
std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp);
std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp, const TxnLog& log);
/// Returs true if the type pack arose from a function that is declared to be variadic.
/// Returns *false* for function argument packs that are inferred to be safe to oversaturate!
bool isVariadic(TypePackId tp);
bool isVariadic(TypePackId tp, const TxnLog& log);
} // namespace Luau

View file

@ -24,6 +24,7 @@ namespace Luau
{
struct TypeArena;
struct Scope2;
/**
* There are three kinds of type variables:
@ -83,6 +84,24 @@ using Tags = std::vector<std::string>;
using ModuleName = std::string;
/** A TypeVar that cannot be computed.
*
* BlockedTypeVars essentially serve as a way to encode partial ordering on the
* constraint graph. Until a BlockedTypeVar is unblocked by its owning
* constraint, nothing at all can be said about it. Constraints that need to
* process a BlockedTypeVar cannot be dispatched.
*
* Whenever a BlockedTypeVar is added to the graph, we also record a constraint
* that will eventually unblock it.
*/
struct BlockedTypeVar
{
BlockedTypeVar();
int index;
static int nextIndex;
};
struct PrimitiveTypeVar
{
enum Type
@ -109,6 +128,24 @@ struct PrimitiveTypeVar
}
};
struct ConstrainedTypeVar
{
explicit ConstrainedTypeVar(TypeLevel level)
: level(level)
{
}
explicit ConstrainedTypeVar(TypeLevel level, const std::vector<TypeId>& parts)
: parts(parts)
, level(level)
{
}
std::vector<TypeId> parts;
TypeLevel level;
Scope2* scope = nullptr;
};
// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
// Types for true and false
struct BooleanSingleton
@ -212,42 +249,44 @@ struct FunctionDefinition
// TODO: Do we actually need this? We'll find out later if we can delete this.
// Does not exactly belong in TypeVar.h, but this is the only way to appease the compiler.
template<typename T>
struct ExprResult
struct WithPredicate
{
T type;
PredicateVec predicates;
};
using MagicFunction = std::function<std::optional<ExprResult<TypePackId>>(
struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, ExprResult<TypePackId>)>;
using MagicFunction = std::function<std::optional<WithPredicate<TypePackId>>(
struct TypeChecker&, const std::shared_ptr<struct Scope>&, const class AstExprCall&, WithPredicate<TypePackId>)>;
struct FunctionTypeVar
{
// Global monomorphic function
FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
// Global polymorphic function
FunctionTypeVar(std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retType,
FunctionTypeVar(std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retTypes,
std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
// Local monomorphic function
FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
// Local polymorphic function
FunctionTypeVar(TypeLevel level, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retType,
FunctionTypeVar(TypeLevel level, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retTypes,
std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
TypeLevel level;
Scope2* scope = nullptr;
/// These should all be generic
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
TypePackId argTypes;
std::vector<std::optional<FunctionArgument>> argNames;
TypePackId retType;
TypePackId retTypes;
std::optional<FunctionDefinition> definition;
MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr.
bool hasSelf;
Tags tags;
bool hasNoGenerics = false;
};
enum class TableState
@ -305,13 +344,13 @@ struct TableTypeVar
TableState state = TableState::Unsealed;
TypeLevel level;
Scope2* scope = nullptr;
std::optional<std::string> name;
// Sometimes we throw a type on a name to make for nicer error messages, but without creating any entry in the type namespace
// We need to know which is which when we stringify types.
std::optional<std::string> syntheticName;
std::map<Name, Location> methodDefinitionLocations;
std::vector<TypeId> instantiatedTypeParams;
std::vector<TypePackId> instantiatedTypePackParams;
ModuleName definitionModuleName;
@ -355,15 +394,17 @@ struct ClassTypeVar
std::optional<TypeId> metatable; // metaclass?
Tags tags;
std::shared_ptr<ClassUserData> userData;
ModuleName definitionModuleName;
ClassTypeVar(
Name name, Props props, std::optional<TypeId> parent, std::optional<TypeId> metatable, Tags tags, std::shared_ptr<ClassUserData> userData)
ClassTypeVar(Name name, Props props, std::optional<TypeId> parent, std::optional<TypeId> metatable, Tags tags,
std::shared_ptr<ClassUserData> userData, ModuleName definitionModuleName)
: name(name)
, props(props)
, parent(parent)
, metatable(metatable)
, tags(tags)
, userData(userData)
, definitionModuleName(definitionModuleName)
{
}
};
@ -418,8 +459,8 @@ struct LazyTypeVar
using ErrorTypeVar = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar, MetatableTypeVar, ClassTypeVar,
AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar>;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, ConstrainedTypeVar, BlockedTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar,
MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar>;
struct TypeVar final
{
@ -436,9 +477,18 @@ struct TypeVar final
TypeVar(const TypeVariant& ty, bool persistent)
: ty(ty)
, persistent(persistent)
, normal(persistent) // We assume that all persistent types are irreducable.
{
}
// Re-assignes the content of the type, but doesn't change the owning arena and can't make type persistent.
void reassign(const TypeVar& rhs)
{
ty = rhs.ty;
normal = rhs.normal;
documentationSymbol = rhs.documentationSymbol;
}
TypeVariant ty;
// Kludge: A persistent TypeVar is one that belongs to the global scope.
@ -446,6 +496,10 @@ struct TypeVar final
// Persistent TypeVars do not get cloned.
bool persistent = false;
// Normalization sets this for types that are fully normalized.
// This implies that they are transitively immutable.
bool normal = false;
std::optional<std::string> documentationSymbol;
// Pointer to the type arena that allocated this type.
@ -456,9 +510,11 @@ struct TypeVar final
TypeVar& operator=(const TypeVariant& rhs);
TypeVar& operator=(TypeVariant&& rhs);
TypeVar& operator=(const TypeVar& rhs);
};
using SeenSet = std::set<std::pair<void*, void*>>;
using SeenSet = std::set<std::pair<const void*, const void*>>;
bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs);
// Follow BoundTypeVars until we get to something real
@ -513,8 +569,9 @@ struct SingletonTypes
const TypeId stringType;
const TypeId booleanType;
const TypeId threadType;
const TypeId trueType;
const TypeId falseType;
const TypeId anyType;
const TypeId optionalNumberType;
const TypePackId anyTypePack;
@ -543,6 +600,8 @@ void persist(TypePackId tp);
const TypeLevel* getLevel(TypeId ty);
TypeLevel* getMutableLevel(TypeId ty);
std::optional<TypeLevel> getLevel(TypePackId tp);
const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name);
bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent);

View file

@ -23,6 +23,12 @@ public:
currentBlockSize = kBlockSize;
}
TypedAllocator(const TypedAllocator&) = delete;
TypedAllocator& operator=(const TypedAllocator&) = delete;
TypedAllocator(TypedAllocator&&) = default;
TypedAllocator& operator=(TypedAllocator&&) = default;
~TypedAllocator()
{
if (frozen)

View file

@ -8,6 +8,8 @@
namespace Luau
{
struct Scope2;
/**
* The 'level' of a TypeVar is an indirect way to talk about the scope that it 'belongs' too.
* To start, read http://okmij.org/ftp/ML/generalization.html
@ -56,6 +58,14 @@ struct TypeLevel
}
};
inline TypeLevel max(const TypeLevel& a, const TypeLevel& b)
{
if (a.subsumes(b))
return b;
else
return a;
}
inline TypeLevel min(const TypeLevel& a, const TypeLevel& b)
{
if (a.subsumes(b))
@ -64,7 +74,9 @@ inline TypeLevel min(const TypeLevel& a, const TypeLevel& b)
return b;
}
namespace Unifiable
} // namespace Luau
namespace Luau::Unifiable
{
using Name = std::string;
@ -72,9 +84,11 @@ using Name = std::string;
struct Free
{
explicit Free(TypeLevel level);
explicit Free(Scope2* scope);
int index;
TypeLevel level;
Scope2* scope = nullptr;
// True if this free type variable is part of a mutually
// recursive type alias whose definitions haven't been
// resolved yet.
@ -101,12 +115,15 @@ struct Generic
Generic();
explicit Generic(TypeLevel level);
explicit Generic(const Name& name);
explicit Generic(Scope2* scope);
Generic(TypeLevel level, const Name& name);
Generic(Scope2* scope, const Name& name);
int index;
TypeLevel level;
Scope2* scope = nullptr;
Name name;
bool explicitName;
bool explicitName = false;
private:
static int nextIndex;
@ -125,7 +142,6 @@ private:
};
template<typename Id, typename... Value>
using Variant = Variant<Free, Bound<Id>, Generic, Error, Value...>;
using Variant = Luau::Variant<Free, Bound<Id>, Generic, Error, Value...>;
} // namespace Unifiable
} // namespace Luau
} // namespace Luau::Unifiable

View file

@ -5,7 +5,7 @@
#include "Luau/Location.h"
#include "Luau/TxnLog.h"
#include "Luau/TypeInfer.h"
#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header.
#include "Luau/TypeArena.h"
#include "Luau/UnifierSharedState.h"
#include <unordered_set>
@ -32,6 +32,9 @@ struct Widen : Substitution
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId ty) override;
bool ignoreChildren(TypeId ty) override;
TypeId operator()(TypeId ty);
TypePackId operator()(TypePackId ty);
};
// TODO: Use this more widely.
@ -49,14 +52,12 @@ struct Unifier
ErrorVec errors;
Location location;
Variance variance = Covariant;
bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once.
CountMismatch::Context ctx = CountMismatch::Arg;
UnifierSharedState& sharedState;
Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState,
TxnLog* parentLog = nullptr);
Unifier(TypeArena* types, Mode mode, std::vector<std::pair<TypeOrPackId, TypeOrPackId>>* sharedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr);
Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr);
// Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId subTy, TypeId superTy);
@ -78,12 +79,8 @@ private:
void tryUnifySingletons(TypeId subTy, TypeId superTy);
void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false);
void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false);
void DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false);
void tryUnifyFreeTable(TypeId subTy, TypeId superTy);
void tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection);
void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer);
TypeId widen(TypeId ty);
TypePackId widen(TypePackId tp);
@ -92,7 +89,6 @@ private:
bool canCacheResult(TypeId subTy, TypeId superTy);
void cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount);
void cacheResult_DEPRECATED(TypeId subTy, TypeId superTy);
public:
void tryUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false);
@ -106,7 +102,12 @@ private:
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name);
void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy);
void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy);
public:
void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel);
// Report an "infinite type error" if the type "needle" already occurs within "haystack"
void occursCheck(TypeId needle, TypeId haystack);
void occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack);
@ -115,12 +116,7 @@ public:
Unifier makeChildUnifier();
// A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error.
void reportError(TypeError error)
{
errors.push_back(error);
}
void reportError(TypeError err);
private:
bool isNonstrictMode() const;
@ -135,4 +131,6 @@ private:
std::optional<int> firstPackErrorPos;
};
void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, TypePackId tp);
} // namespace Luau

View file

@ -28,7 +28,9 @@ struct TypeIdPairHash
struct UnifierCounters
{
int recursionCount = 0;
int recursionLimit = 0;
int iterationCount = 0;
int iterationLimit = 0;
};
struct UnifierSharedState
@ -40,7 +42,6 @@ struct UnifierSharedState
InternalErrorReporter* iceHandler;
DenseHashSet<void*> seenAny{nullptr};
DenseHashMap<TypeId, bool> skipCacheForType{nullptr};
DenseHashSet<std::pair<TypeId, TypeId>, TypeIdPairHash> cachedUnify{{nullptr, nullptr}};
DenseHashMap<std::pair<TypeId, TypeId>, TypeErrorData, TypeIdPairHash> cachedUnifyError{{nullptr, nullptr}};

View file

@ -2,45 +2,15 @@
#pragma once
#include "Luau/Common.h"
#ifndef LUAU_USE_STD_VARIANT
#define LUAU_USE_STD_VARIANT 0
#endif
#if LUAU_USE_STD_VARIANT
#include <variant>
#else
#include <new>
#include <type_traits>
#include <initializer_list>
#include <stddef.h>
#endif
#include <utility>
namespace Luau
{
#if LUAU_USE_STD_VARIANT
template<typename... Ts>
using Variant = std::variant<Ts...>;
template<class Visitor, class Variant>
auto visit(Visitor&& vis, Variant&& var)
{
// This change resolves the ABI issues with std::variant on libc++; std::visit normally throws bad_variant_access
// but it requires an update to libc++.dylib which ships with macOS 10.14. To work around this, we assert on valueless
// variants since we will never generate them and call into a libc++ function that doesn't throw.
LUAU_ASSERT(!var.valueless_by_exception());
#ifdef __APPLE__
// See https://stackoverflow.com/a/53868971/503215
return std::__variant_detail::__visitation::__variant::__visit_value(vis, var);
#else
return std::visit(vis, var);
#endif
}
using std::get_if;
#else
template<typename... Ts>
class Variant
{
@ -126,6 +96,20 @@ public:
return *this;
}
template<typename T, typename... Args>
T& emplace(Args&&... args)
{
using TT = std::decay_t<T>;
constexpr int tid = getTypeId<T>();
static_assert(tid >= 0, "unsupported T");
tableDtor[typeId](&storage);
typeId = tid;
new (&storage) TT(std::forward<Args>(args)...);
return *reinterpret_cast<T*>(&storage);
}
template<typename T>
const T* get_if() const
{
@ -248,6 +232,8 @@ static void fnVisitV(Visitor& vis, std::conditional_t<std::is_const_v<T>, const
template<class Visitor, typename... Ts>
auto visit(Visitor&& vis, const Variant<Ts...>& var)
{
static_assert(std::conjunction_v<std::is_invocable<Visitor, Ts>...>, "visitor must accept every alternative as an argument");
using Result = std::invoke_result_t<Visitor, typename Variant<Ts...>::first_alternative>;
static_assert(std::conjunction_v<std::is_same<Result, std::invoke_result_t<Visitor, Ts>>...>,
"visitor result type must be consistent between alternatives");
@ -273,6 +259,8 @@ auto visit(Visitor&& vis, const Variant<Ts...>& var)
template<class Visitor, typename... Ts>
auto visit(Visitor&& vis, Variant<Ts...>& var)
{
static_assert(std::conjunction_v<std::is_invocable<Visitor, Ts&>...>, "visitor must accept every alternative as an argument");
using Result = std::invoke_result_t<Visitor, typename Variant<Ts...>::first_alternative&>;
static_assert(std::conjunction_v<std::is_same<Result, std::invoke_result_t<Visitor, Ts&>>...>,
"visitor result type must be consistent between alternatives");
@ -294,7 +282,6 @@ auto visit(Visitor&& vis, Variant<Ts...>& var)
return res;
}
}
#endif
template<class>
inline constexpr bool always_false_v = false;

View file

@ -1,9 +1,15 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <unordered_set>
#include "Luau/DenseHash.h"
#include "Luau/TypeVar.h"
#include "Luau/RecursionCounter.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
LUAU_FASTINT(LuauVisitRecursionLimit)
LUAU_FASTFLAG(LuauNormalizeFlagIsConservative)
namespace Luau
{
@ -52,182 +58,296 @@ inline void unsee(std::unordered_set<void*>& seen, const void* tv)
inline void unsee(DenseHashSet<void*>& seen, const void* tv)
{
// When DenseHashSet is used for 'visitOnce', where don't forget visited elements
}
template<typename F, typename Set>
void visit(TypePackId tp, F& f, Set& seen);
template<typename F, typename Set>
void visit(TypeId ty, F& f, Set& seen)
{
if (visit_detail::hasSeen(seen, ty))
{
f.cycle(ty);
return;
}
if (auto btv = get<BoundTypeVar>(ty))
{
if (apply(ty, *btv, seen, f))
visit(btv->boundTo, f, seen);
}
else if (auto ftv = get<FreeTypeVar>(ty))
apply(ty, *ftv, seen, f);
else if (auto gtv = get<GenericTypeVar>(ty))
apply(ty, *gtv, seen, f);
else if (auto etv = get<ErrorTypeVar>(ty))
apply(ty, *etv, seen, f);
else if (auto ptv = get<PrimitiveTypeVar>(ty))
apply(ty, *ptv, seen, f);
else if (auto ftv = get<FunctionTypeVar>(ty))
{
if (apply(ty, *ftv, seen, f))
{
visit(ftv->argTypes, f, seen);
visit(ftv->retType, f, seen);
}
}
else if (auto ttv = get<TableTypeVar>(ty))
{
// Some visitors want to see bound tables, that's why we visit the original type
if (apply(ty, *ttv, seen, f))
{
if (ttv->boundTo)
{
visit(*ttv->boundTo, f, seen);
}
else
{
for (auto& [_name, prop] : ttv->props)
visit(prop.type, f, seen);
if (ttv->indexer)
{
visit(ttv->indexer->indexType, f, seen);
visit(ttv->indexer->indexResultType, f, seen);
}
}
}
}
else if (auto mtv = get<MetatableTypeVar>(ty))
{
if (apply(ty, *mtv, seen, f))
{
visit(mtv->table, f, seen);
visit(mtv->metatable, f, seen);
}
}
else if (auto ctv = get<ClassTypeVar>(ty))
{
if (apply(ty, *ctv, seen, f))
{
for (const auto& [name, prop] : ctv->props)
visit(prop.type, f, seen);
if (ctv->parent)
visit(*ctv->parent, f, seen);
if (ctv->metatable)
visit(*ctv->metatable, f, seen);
}
}
else if (auto atv = get<AnyTypeVar>(ty))
apply(ty, *atv, seen, f);
else if (auto utv = get<UnionTypeVar>(ty))
{
if (apply(ty, *utv, seen, f))
{
for (TypeId optTy : utv->options)
visit(optTy, f, seen);
}
}
else if (auto itv = get<IntersectionTypeVar>(ty))
{
if (apply(ty, *itv, seen, f))
{
for (TypeId partTy : itv->parts)
visit(partTy, f, seen);
}
}
visit_detail::unsee(seen, ty);
}
template<typename F, typename Set>
void visit(TypePackId tp, F& f, Set& seen)
{
if (visit_detail::hasSeen(seen, tp))
{
f.cycle(tp);
return;
}
if (auto btv = get<BoundTypePack>(tp))
{
if (apply(tp, *btv, seen, f))
visit(btv->boundTo, f, seen);
}
else if (auto ftv = get<Unifiable::Free>(tp))
apply(tp, *ftv, seen, f);
else if (auto gtv = get<Unifiable::Generic>(tp))
apply(tp, *gtv, seen, f);
else if (auto etv = get<Unifiable::Error>(tp))
apply(tp, *etv, seen, f);
else if (auto pack = get<TypePack>(tp))
{
apply(tp, *pack, seen, f);
for (TypeId ty : pack->head)
visit(ty, f, seen);
if (pack->tail)
visit(*pack->tail, f, seen);
}
else if (auto pack = get<VariadicTypePack>(tp))
{
apply(tp, *pack, seen, f);
visit(pack->ty, f, seen);
}
visit_detail::unsee(seen, tp);
// When DenseHashSet is used for 'visitTypeVarOnce', where don't forget visited elements
}
} // namespace visit_detail
template<typename TID, typename F>
void visitTypeVar(TID ty, F& f, std::unordered_set<void*>& seen)
template<typename S>
struct GenericTypeVarVisitor
{
visit_detail::visit(ty, f, seen);
}
using Set = S;
template<typename TID, typename F>
void visitTypeVar(TID ty, F& f)
{
std::unordered_set<void*> seen;
visit_detail::visit(ty, f, seen);
}
Set seen;
int recursionCounter = 0;
template<typename TID, typename F>
void visitTypeVarOnce(TID ty, F& f, DenseHashSet<void*>& seen)
GenericTypeVarVisitor() = default;
explicit GenericTypeVarVisitor(Set seen)
: seen(std::move(seen))
{
}
virtual void cycle(TypeId) {}
virtual void cycle(TypePackId) {}
virtual bool visit(TypeId ty)
{
return true;
}
virtual bool visit(TypeId ty, const BoundTypeVar& btv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const FreeTypeVar& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const GenericTypeVar& gtv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const ErrorTypeVar& etv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const FunctionTypeVar& ftv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const TableTypeVar& ttv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const MetatableTypeVar& mtv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const ClassTypeVar& ctv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const AnyTypeVar& atv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const UnionTypeVar& utv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const IntersectionTypeVar& itv)
{
return visit(ty);
}
virtual bool visit(TypePackId tp)
{
return true;
}
virtual bool visit(TypePackId tp, const BoundTypePack& btp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const FreeTypePack& ftp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const GenericTypePack& gtp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const Unifiable::Error& etp)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const TypePack& pack)
{
return visit(tp);
}
virtual bool visit(TypePackId tp, const VariadicTypePack& vtp)
{
return visit(tp);
}
void traverse(TypeId ty)
{
RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit};
if (visit_detail::hasSeen(seen, ty))
{
cycle(ty);
return;
}
if (auto btv = get<BoundTypeVar>(ty))
{
if (visit(ty, *btv))
traverse(btv->boundTo);
}
else if (auto ftv = get<FreeTypeVar>(ty))
visit(ty, *ftv);
else if (auto gtv = get<GenericTypeVar>(ty))
visit(ty, *gtv);
else if (auto etv = get<ErrorTypeVar>(ty))
visit(ty, *etv);
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
if (visit(ty, *ctv))
{
for (TypeId part : ctv->parts)
traverse(part);
}
}
else if (auto ptv = get<PrimitiveTypeVar>(ty))
visit(ty, *ptv);
else if (auto ftv = get<FunctionTypeVar>(ty))
{
if (visit(ty, *ftv))
{
traverse(ftv->argTypes);
traverse(ftv->retTypes);
}
}
else if (auto ttv = get<TableTypeVar>(ty))
{
// Some visitors want to see bound tables, that's why we traverse the original type
if (visit(ty, *ttv))
{
if (ttv->boundTo)
{
traverse(*ttv->boundTo);
}
else
{
for (auto& [_name, prop] : ttv->props)
traverse(prop.type);
if (ttv->indexer)
{
traverse(ttv->indexer->indexType);
traverse(ttv->indexer->indexResultType);
}
}
}
}
else if (auto mtv = get<MetatableTypeVar>(ty))
{
if (visit(ty, *mtv))
{
traverse(mtv->table);
traverse(mtv->metatable);
}
}
else if (auto ctv = get<ClassTypeVar>(ty))
{
if (visit(ty, *ctv))
{
for (const auto& [name, prop] : ctv->props)
traverse(prop.type);
if (ctv->parent)
traverse(*ctv->parent);
if (ctv->metatable)
traverse(*ctv->metatable);
}
}
else if (auto atv = get<AnyTypeVar>(ty))
visit(ty, *atv);
else if (auto utv = get<UnionTypeVar>(ty))
{
if (visit(ty, *utv))
{
for (TypeId optTy : utv->options)
traverse(optTy);
}
}
else if (auto itv = get<IntersectionTypeVar>(ty))
{
if (visit(ty, *itv))
{
for (TypeId partTy : itv->parts)
traverse(partTy);
}
}
visit_detail::unsee(seen, ty);
}
void traverse(TypePackId tp)
{
if (visit_detail::hasSeen(seen, tp))
{
cycle(tp);
return;
}
if (auto btv = get<BoundTypePack>(tp))
{
if (visit(tp, *btv))
traverse(btv->boundTo);
}
else if (auto ftv = get<Unifiable::Free>(tp))
visit(tp, *ftv);
else if (auto gtv = get<Unifiable::Generic>(tp))
visit(tp, *gtv);
else if (auto etv = get<Unifiable::Error>(tp))
visit(tp, *etv);
else if (auto pack = get<TypePack>(tp))
{
bool res = visit(tp, *pack);
if (!FFlag::LuauNormalizeFlagIsConservative || res)
{
for (TypeId ty : pack->head)
traverse(ty);
if (pack->tail)
traverse(*pack->tail);
}
}
else if (auto pack = get<VariadicTypePack>(tp))
{
bool res = visit(tp, *pack);
if (!FFlag::LuauNormalizeFlagIsConservative || res)
traverse(pack->ty);
}
else
LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!");
visit_detail::unsee(seen, tp);
}
};
/** Visit each type under a given type. Skips over cycles and keeps recursion depth under control.
*
* The same type may be visited multiple times if there are multiple distinct paths to it. If this is undesirable, use
* TypeVarOnceVisitor.
*/
struct TypeVarVisitor : GenericTypeVarVisitor<std::unordered_set<void*>>
{
seen.clear();
visit_detail::visit(ty, f, seen);
}
};
/// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it.
struct TypeVarOnceVisitor : GenericTypeVarVisitor<DenseHashSet<void*>>
{
TypeVarOnceVisitor()
: GenericTypeVarVisitor{DenseHashSet<void*>{nullptr}}
{
}
};
} // namespace Luau

View file

@ -71,9 +71,11 @@ struct FindFullAncestry final : public AstVisitor
{
std::vector<AstNode*> nodes;
Position pos;
Position documentEnd;
explicit FindFullAncestry(Position pos)
explicit FindFullAncestry(Position pos, Position documentEnd)
: pos(pos)
, documentEnd(documentEnd)
{
}
@ -84,6 +86,16 @@ struct FindFullAncestry final : public AstVisitor
nodes.push_back(node);
return true;
}
// Edge case: If we ask for the node at the position that is the very end of the document
// return the innermost AST element that ends at that position.
if (node->location.end == documentEnd && pos >= documentEnd)
{
nodes.push_back(node);
return true;
}
return false;
}
};
@ -92,7 +104,11 @@ struct FindFullAncestry final : public AstVisitor
std::vector<AstNode*> findAstAncestryOfPosition(const SourceModule& source, Position pos)
{
FindFullAncestry finder(pos);
const Position end = source.root->location.end;
if (pos > end)
pos = end;
FindFullAncestry finder(pos, end);
source.root->visit(&finder);
return std::move(finder.nodes);
}

View file

@ -13,8 +13,7 @@
#include <unordered_set>
#include <utility>
LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false);
LUAU_FASTFLAG(LuauSelfCallAutocompleteFix)
LUAU_FASTFLAG(LuauSelfCallAutocompleteFix2)
static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -150,8 +149,12 @@ static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionTyp
auto idxExpr = nodes.back()->as<AstExprIndexName>();
bool hasImplicitSelf = idxExpr && idxExpr->op == ':';
auto args = Luau::flatten(func->argTypes);
bool noArgFunction = (args.first.empty() || (hasImplicitSelf && args.first.size() == 1)) && !args.second.has_value();
auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes);
if (argVariadicPack.has_value() && isVariadic(*argVariadicPack))
return ParenthesesRecommendation::CursorInside;
bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1);
return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside;
}
@ -243,7 +246,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
ty = follow(ty);
auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) {
LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix);
LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2);
InternalErrorReporter iceReporter;
UnifierSharedState unifierState(&iceReporter);
@ -262,16 +265,16 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
TypeId expectedType = follow(*typeAtPosition);
auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) {
if (FFlag::LuauSelfCallAutocompleteFix)
if (FFlag::LuauSelfCallAutocompleteFix2)
{
if (std::optional<TypeId> firstRetTy = first(ftv->retType))
if (std::optional<TypeId> firstRetTy = first(ftv->retTypes))
return checkTypeMatch(typeArena, *firstRetTy, expectedType);
return false;
}
else
{
auto [retHead, retTail] = flatten(ftv->retType);
auto [retHead, retTail] = flatten(ftv->retTypes);
if (!retHead.empty() && canUnify(retHead.front(), expectedType))
return true;
@ -303,7 +306,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
}
}
if (FFlag::LuauSelfCallAutocompleteFix)
if (FFlag::LuauSelfCallAutocompleteFix2)
return checkTypeMatch(typeArena, ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
else
return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
@ -320,7 +323,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
const std::vector<AstNode*>& nodes, AutocompleteEntryMap& result, std::unordered_set<TypeId>& seen,
std::optional<const ClassTypeVar*> containingClass = std::nullopt)
{
if (FFlag::LuauSelfCallAutocompleteFix)
if (FFlag::LuauSelfCallAutocompleteFix2)
rootTy = follow(rootTy);
ty = follow(ty);
@ -330,7 +333,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
seen.insert(ty);
auto isWrongIndexer_DEPRECATED = [indexType, useStrictFunctionIndexers = !!get<ClassTypeVar>(ty)](Luau::TypeId type) {
LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix);
LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2);
if (indexType == PropIndexType::Key)
return false;
@ -363,7 +366,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
}
};
auto isWrongIndexer = [typeArena, rootTy, indexType](Luau::TypeId type) {
LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix);
LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix2);
if (indexType == PropIndexType::Key)
return false;
@ -377,10 +380,15 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
return calledWithSelf == ftv->hasSelf;
}
if (std::optional<TypeId> firstArgTy = first(ftv->argTypes))
// If a call is made with ':', it is invalid if a function has incompatible first argument or no arguments at all
// If a call is made with '.', but it was declared with 'self', it is considered invalid if first argument is compatible
if (calledWithSelf || ftv->hasSelf)
{
if (checkTypeMatch(typeArena, rootTy, *firstArgTy))
return calledWithSelf;
if (std::optional<TypeId> firstArgTy = first(ftv->argTypes))
{
if (checkTypeMatch(typeArena, rootTy, *firstArgTy))
return calledWithSelf;
}
}
return !calledWithSelf;
@ -422,7 +430,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
AutocompleteEntryKind::Property,
type,
prop.deprecated,
FFlag::LuauSelfCallAutocompleteFix ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type),
FFlag::LuauSelfCallAutocompleteFix2 ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type),
typeCorrect,
containingClass,
&prop,
@ -445,7 +453,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
}
else if (auto indexFunction = get<FunctionTypeVar>(followed))
{
std::optional<TypeId> indexFunctionResult = first(indexFunction->retType);
std::optional<TypeId> indexFunctionResult = first(indexFunction->retTypes);
if (indexFunctionResult)
autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen);
}
@ -457,7 +465,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
containingClass = containingClass.value_or(cls);
fillProps(cls->props);
if (cls->parent)
autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, cls);
autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass);
}
else if (auto tbl = get<TableTypeVar>(ty))
fillProps(tbl->props);
@ -465,7 +473,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
{
autocompleteProps(module, typeArena, rootTy, mt->table, indexType, nodes, result, seen);
if (FFlag::LuauSelfCallAutocompleteFix)
if (FFlag::LuauSelfCallAutocompleteFix2)
{
if (auto mtable = get<TableTypeVar>(mt->metatable))
fillMetatableProps(mtable);
@ -484,7 +492,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
autocompleteProps(module, typeArena, rootTy, followed, indexType, nodes, result, seen);
else if (auto indexFunction = get<FunctionTypeVar>(followed))
{
std::optional<TypeId> indexFunctionResult = first(indexFunction->retType);
std::optional<TypeId> indexFunctionResult = first(indexFunction->retTypes);
if (indexFunctionResult)
autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen);
}
@ -531,7 +539,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
AutocompleteEntryMap inner;
std::unordered_set<TypeId> innerSeen;
if (!FFlag::LuauSelfCallAutocompleteFix)
if (!FFlag::LuauSelfCallAutocompleteFix2)
innerSeen = seen;
if (isNil(*iter))
@ -557,7 +565,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
++iter;
}
}
else if (auto pt = get<PrimitiveTypeVar>(ty); pt && FFlag::LuauSelfCallAutocompleteFix)
else if (auto pt = get<PrimitiveTypeVar>(ty); pt && FFlag::LuauSelfCallAutocompleteFix2)
{
if (pt->metatable)
{
@ -565,7 +573,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
fillMetatableProps(mtable);
}
}
else if (FFlag::LuauSelfCallAutocompleteFix && get<StringSingleton>(get<SingletonTypeVar>(ty)))
else if (FFlag::LuauSelfCallAutocompleteFix2 && get<StringSingleton>(get<SingletonTypeVar>(ty)))
{
autocompleteProps(module, typeArena, rootTy, getSingletonTypes().stringType, indexType, nodes, result, seen);
}
@ -625,6 +633,31 @@ AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position posi
return result;
}
static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteEntryMap& result)
{
auto formatKey = [addQuotes](const std::string& key) {
if (addQuotes)
return "\"" + escape(key) + "\"";
return escape(key);
};
ty = follow(ty);
if (auto ss = get<StringSingleton>(get<SingletonTypeVar>(ty)))
{
result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct};
}
else if (auto uty = get<UnionTypeVar>(ty))
{
for (auto el : uty)
{
if (auto ss = get<StringSingleton>(get<SingletonTypeVar>(el)))
result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct};
}
}
};
static bool canSuggestInferredType(ScopePtr scope, TypeId ty)
{
ty = follow(ty);
@ -708,7 +741,7 @@ static std::optional<TypeId> findTypeElementAt(AstType* astType, TypeId ty, Posi
if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position))
return element;
if (auto element = findTypeElementAt(type->returnTypes, ftv->retType, position))
if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position))
return element;
}
@ -924,7 +957,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(follow(*it)))
{
if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos))
if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos))
inferredType = *ty;
}
}
@ -1016,7 +1049,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
{
if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node))
{
if (auto ty = tryGetTypePackTypeAt(ftv->retType, i))
if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i))
tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position);
}
@ -1033,7 +1066,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
{
if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node))
{
if (auto ty = tryGetTypePackTypeAt(ftv->retType, ~0u))
if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u))
tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position);
}
}
@ -1232,7 +1265,7 @@ static bool autocompleteIfElseExpression(
if (!parent)
return false;
if (FFlag::LuauIfElseExprFixCompletionIssue && node->is<AstExprIfElse>())
if (node->is<AstExprIfElse>())
{
// Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else
// expression.
@ -1310,16 +1343,20 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul
}
TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType);
TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType);
TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().trueType);
TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().falseType);
TypeCorrectKind correctForFunction =
functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false};
result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean};
result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean};
result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForTrue};
result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForFalse};
result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil};
result["not"] = {AutocompleteEntryKind::Keyword};
result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction};
if (auto ty = findExpectedTypeAt(module, node, position))
autocompleteStringSingleton(*ty, true, result);
}
}
@ -1466,7 +1503,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
TypeId ty = follow(*it);
PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point;
if (!FFlag::LuauSelfCallAutocompleteFix && isString(ty))
if (!FFlag::LuauSelfCallAutocompleteFix2 && isString(ty))
return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry),
finder.ancestry};
else
@ -1625,17 +1662,29 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
}
else if (node->is<AstExprConstantString>())
{
AutocompleteEntryMap result;
if (auto it = module->astExpectedTypes.find(node->asExpr()))
autocompleteStringSingleton(*it, false, result);
if (finder.ancestry.size() >= 2)
{
if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as<AstExprIndexExpr>())
{
if (auto it = module->astTypes.find(idxExpr->expr))
autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result);
}
else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as<AstExprBinary>())
{
if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe)
{
return {autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry), finder.ancestry};
if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left))
autocompleteStringSingleton(*it, false, result);
}
}
}
return {};
return {result, finder.ancestry};
}
if (node->is<AstExprConstantNumber>())
@ -1655,16 +1704,16 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName
{
// FIXME: We can improve performance here by parsing without checking.
// The old type graph is probably fine. (famous last words!)
// FIXME: We don't need to typecheck for script analysis here, just for autocomplete.
frontend.check(moduleName);
FrontendOptions opts;
opts.forAutocomplete = true;
frontend.check(moduleName, opts);
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule)
return {};
TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName)
: frontend.moduleResolver.getModule(moduleName));
TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete;
ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName);
if (!module)
return {};
@ -1692,8 +1741,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view
sourceModule->mode = Mode::Strict;
sourceModule->commentLocations = std::move(result.commentLocations);
TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker);
TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete;
ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict);
OwningAutocompleteResult autocompleteResult = {

View file

@ -8,8 +8,6 @@
#include <algorithm>
LUAU_FASTFLAG(LuauAssertStripsFalsyTypes)
LUAU_FASTFLAGVARIABLE(LuauTableCloneType, false)
LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false)
/** FIXME: Many of these type definitions are not quite completely accurate.
@ -21,16 +19,16 @@ LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false)
namespace Luau
{
static std::optional<ExprResult<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult);
static std::optional<ExprResult<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult);
static std::optional<ExprResult<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult);
static std::optional<ExprResult<TypePackId>> magicFunctionPack(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult);
static std::optional<ExprResult<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult);
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types)
{
@ -181,44 +179,13 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
LUAU_ASSERT(!typeChecker.globalTypes.typeVars.isFrozen());
LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen());
TypeId numberType = typeChecker.numberType;
TypeId booleanType = typeChecker.booleanType;
TypeId nilType = typeChecker.nilType;
TypeArena& arena = typeChecker.globalTypes;
TypePackId oneNumberPack = arena.addTypePack({numberType});
TypePackId oneBooleanPack = arena.addTypePack({booleanType});
TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}});
TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList});
TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{
listOfAtLeastOneNumber,
oneNumberPack,
});
TypeId listOfAtLeastZeroNumbersToNumberType = arena.addType(FunctionTypeVar{numberVariadicList, oneNumberPack});
LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau");
LUAU_ASSERT(loadResult.success);
TypeId mathLibType = getGlobalBinding(typeChecker, "math");
if (TableTypeVar* ttv = getMutable<TableTypeVar>(mathLibType))
{
ttv->props["min"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.min");
ttv->props["max"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.max");
}
TypeId bit32LibType = getGlobalBinding(typeChecker, "bit32");
if (TableTypeVar* ttv = getMutable<TableTypeVar>(bit32LibType))
{
ttv->props["band"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.band");
ttv->props["bor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bor");
ttv->props["bxor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bxor");
ttv->props["btest"] = makeProperty(arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneBooleanPack}), "@luau/global/bit32.btest");
}
TypeId genericK = arena.addType(GenericTypeVar{"K"});
TypeId genericV = arena.addType(GenericTypeVar{"V"});
TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic});
@ -233,7 +200,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
addGlobalBinding(typeChecker, "string", it->second.type, "@luau");
// next<K, V>(t: Table<K, V>, i: K | nil) -> (K, V)
// next<K, V>(t: Table<K, V>, i: K?) -> (K, V)
TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}});
addGlobalBinding(typeChecker, "next",
arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau");
@ -243,8 +210,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})});
TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}});
// NOTE we are missing 'i: K | nil' argument in the first return types' argument.
// pairs<K, V>(t: Table<K, V>) -> ((Table<K, V>) -> (K, V), Table<K, V>, nil)
// pairs<K, V>(t: Table<K, V>) -> ((Table<K, V>, K?) -> (K, V), Table<K, V>, nil)
addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau");
TypeId genericMT = arena.addType(GenericTypeVar{"MT"});
@ -289,9 +255,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
{
// tabTy is a generic table type which we can't express via declaration syntax yet
ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze");
if (FFlag::LuauTableCloneType)
ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone");
ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone");
attachMagicFunction(ttv->props["pack"].type, magicFunctionPack);
}
@ -299,10 +263,10 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire);
}
static std::optional<ExprResult<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult)
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = exprResult;
auto [paramPack, _predicates] = withPredicate;
(void)scope;
@ -323,10 +287,10 @@ static std::optional<ExprResult<TypePackId>> magicFunctionSelect(
if (size_t(offset) < v.size())
{
std::vector<TypeId> result(v.begin() + offset, v.end());
return ExprResult<TypePackId>{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})};
return WithPredicate<TypePackId>{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})};
}
else if (tail)
return ExprResult<TypePackId>{*tail};
return WithPredicate<TypePackId>{*tail};
}
typechecker.reportError(TypeError{arg1->location, GenericError{"bad argument #1 to select (index out of range)"}});
@ -334,16 +298,16 @@ static std::optional<ExprResult<TypePackId>> magicFunctionSelect(
else if (AstExprConstantString* str = arg1->as<AstExprConstantString>())
{
if (str->value.size == 1 && str->value.data[0] == '#')
return ExprResult<TypePackId>{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})};
return WithPredicate<TypePackId>{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})};
}
return std::nullopt;
}
static std::optional<ExprResult<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult)
static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = exprResult;
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
@ -379,7 +343,7 @@ static std::optional<ExprResult<TypePackId>> magicFunctionSetMetaTable(
if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1)
{
return ExprResult<TypePackId>{};
return WithPredicate<TypePackId>{};
}
if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self)
@ -392,7 +356,7 @@ static std::optional<ExprResult<TypePackId>> magicFunctionSetMetaTable(
}
}
return ExprResult<TypePackId>{arena.addTypePack({mtTy})};
return WithPredicate<TypePackId>{arena.addTypePack({mtTy})};
}
}
else if (get<AnyTypeVar>(target) || get<ErrorTypeVar>(target) || isTableIntersection(target))
@ -403,55 +367,43 @@ static std::optional<ExprResult<TypePackId>> magicFunctionSetMetaTable(
typechecker.reportError(TypeError{expr.location, GenericError{"setmetatable should take a table"}});
}
return ExprResult<TypePackId>{arena.addTypePack({target})};
return WithPredicate<TypePackId>{arena.addTypePack({target})};
}
static std::optional<ExprResult<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult)
static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, predicates] = exprResult;
auto [paramPack, predicates] = withPredicate;
if (FFlag::LuauAssertStripsFalsyTypes)
TypeArena& arena = typechecker.currentModule->internalTypes;
auto [head, tail] = flatten(paramPack);
if (head.empty() && tail)
{
TypeArena& arena = typechecker.currentModule->internalTypes;
auto [head, tail] = flatten(paramPack);
if (head.empty() && tail)
{
std::optional<TypeId> fst = first(*tail);
if (!fst)
return ExprResult<TypePackId>{paramPack};
head.push_back(*fst);
}
typechecker.reportErrors(typechecker.resolve(predicates, scope, true));
if (head.size() > 0)
{
std::optional<TypeId> newhead = typechecker.pickTypesFromSense(head[0], true);
if (!newhead)
head = {typechecker.nilType};
else
head[0] = *newhead;
}
return ExprResult<TypePackId>{arena.addTypePack(TypePack{std::move(head), tail})};
std::optional<TypeId> fst = first(*tail);
if (!fst)
return WithPredicate<TypePackId>{paramPack};
head.push_back(*fst);
}
else
typechecker.resolve(predicates, scope, true);
if (head.size() > 0)
{
if (expr.args.size < 1)
return ExprResult<TypePackId>{paramPack};
typechecker.reportErrors(typechecker.resolve(predicates, scope, true));
return ExprResult<TypePackId>{paramPack};
std::optional<TypeId> newhead = typechecker.pickTypesFromSense(head[0], true);
if (!newhead)
head = {typechecker.nilType};
else
head[0] = *newhead;
}
return WithPredicate<TypePackId>{arena.addTypePack(TypePack{std::move(head), tail})};
}
static std::optional<ExprResult<TypePackId>> magicFunctionPack(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult)
static std::optional<WithPredicate<TypePackId>> magicFunctionPack(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = exprResult;
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
@ -484,7 +436,7 @@ static std::optional<ExprResult<TypePackId>> magicFunctionPack(
TypeId packedTable = arena.addType(
TableTypeVar{{{"n", {typechecker.numberType}}}, TableIndexer(typechecker.numberType, result), scope->level, TableState::Sealed});
return ExprResult<TypePackId>{arena.addTypePack({packedTable})};
return WithPredicate<TypePackId>{arena.addTypePack({packedTable})};
}
static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
@ -509,8 +461,8 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr)
return good;
}
static std::optional<ExprResult<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult)
static std::optional<WithPredicate<TypePackId>> magicFunctionRequire(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
TypeArena& arena = typechecker.currentModule->internalTypes;
@ -524,7 +476,7 @@ static std::optional<ExprResult<TypePackId>> magicFunctionRequire(
return std::nullopt;
if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr))
return ExprResult<TypePackId>{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})};
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})};
return std::nullopt;
}

450
Analysis/src/Clone.cpp Normal file
View file

@ -0,0 +1,450 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Clone.h"
#include "Luau/RecursionCounter.h"
#include "Luau/TxnLog.h"
#include "Luau/TypePack.h"
#include "Luau/Unifiable.h"
LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing)
LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300)
namespace Luau
{
namespace
{
struct TypePackCloner;
/*
* Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set.
* They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage.
*/
struct TypeCloner
{
TypeCloner(TypeArena& dest, TypeId typeId, CloneState& cloneState)
: dest(dest)
, typeId(typeId)
, seenTypes(cloneState.seenTypes)
, seenTypePacks(cloneState.seenTypePacks)
, cloneState(cloneState)
{
}
TypeArena& dest;
TypeId typeId;
SeenTypes& seenTypes;
SeenTypePacks& seenTypePacks;
CloneState& cloneState;
template<typename T>
void defaultClone(const T& t);
void operator()(const Unifiable::Free& t);
void operator()(const Unifiable::Generic& t);
void operator()(const Unifiable::Bound<TypeId>& t);
void operator()(const Unifiable::Error& t);
void operator()(const BlockedTypeVar& t);
void operator()(const PrimitiveTypeVar& t);
void operator()(const ConstrainedTypeVar& t);
void operator()(const SingletonTypeVar& t);
void operator()(const FunctionTypeVar& t);
void operator()(const TableTypeVar& t);
void operator()(const MetatableTypeVar& t);
void operator()(const ClassTypeVar& t);
void operator()(const AnyTypeVar& t);
void operator()(const UnionTypeVar& t);
void operator()(const IntersectionTypeVar& t);
void operator()(const LazyTypeVar& t);
};
struct TypePackCloner
{
TypeArena& dest;
TypePackId typePackId;
SeenTypes& seenTypes;
SeenTypePacks& seenTypePacks;
CloneState& cloneState;
TypePackCloner(TypeArena& dest, TypePackId typePackId, CloneState& cloneState)
: dest(dest)
, typePackId(typePackId)
, seenTypes(cloneState.seenTypes)
, seenTypePacks(cloneState.seenTypePacks)
, cloneState(cloneState)
{
}
template<typename T>
void defaultClone(const T& t)
{
TypePackId cloned = dest.addTypePack(TypePackVar{t});
seenTypePacks[typePackId] = cloned;
}
void operator()(const Unifiable::Free& t)
{
defaultClone(t);
}
void operator()(const Unifiable::Generic& t)
{
defaultClone(t);
}
void operator()(const Unifiable::Error& t)
{
defaultClone(t);
}
// While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter.
// We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer.
void operator()(const Unifiable::Bound<TypePackId>& t)
{
TypePackId cloned = clone(t.boundTo, dest, cloneState);
if (FFlag::DebugLuauCopyBeforeNormalizing)
cloned = dest.addTypePack(TypePackVar{BoundTypePack{cloned}});
seenTypePacks[typePackId] = cloned;
}
void operator()(const VariadicTypePack& t)
{
TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, cloneState), /*hidden*/ t.hidden}});
seenTypePacks[typePackId] = cloned;
}
void operator()(const TypePack& t)
{
TypePackId cloned = dest.addTypePack(TypePack{});
TypePack* destTp = getMutable<TypePack>(cloned);
LUAU_ASSERT(destTp != nullptr);
seenTypePacks[typePackId] = cloned;
for (TypeId ty : t.head)
destTp->head.push_back(clone(ty, dest, cloneState));
if (t.tail)
destTp->tail = clone(*t.tail, dest, cloneState);
}
};
template<typename T>
void TypeCloner::defaultClone(const T& t)
{
TypeId cloned = dest.addType(t);
seenTypes[typeId] = cloned;
}
void TypeCloner::operator()(const Unifiable::Free& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const Unifiable::Generic& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const Unifiable::Bound<TypeId>& t)
{
TypeId boundTo = clone(t.boundTo, dest, cloneState);
if (FFlag::DebugLuauCopyBeforeNormalizing)
boundTo = dest.addType(BoundTypeVar{boundTo});
seenTypes[typeId] = boundTo;
}
void TypeCloner::operator()(const Unifiable::Error& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const BlockedTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const PrimitiveTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const ConstrainedTypeVar& t)
{
TypeId res = dest.addType(ConstrainedTypeVar{t.level});
ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(res);
LUAU_ASSERT(ctv);
seenTypes[typeId] = res;
std::vector<TypeId> parts;
for (TypeId part : t.parts)
parts.push_back(clone(part, dest, cloneState));
ctv->parts = std::move(parts);
}
void TypeCloner::operator()(const SingletonTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const FunctionTypeVar& t)
{
TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(result);
LUAU_ASSERT(ftv != nullptr);
seenTypes[typeId] = result;
for (TypeId generic : t.generics)
ftv->generics.push_back(clone(generic, dest, cloneState));
for (TypePackId genericPack : t.genericPacks)
ftv->genericPacks.push_back(clone(genericPack, dest, cloneState));
ftv->tags = t.tags;
ftv->argTypes = clone(t.argTypes, dest, cloneState);
ftv->argNames = t.argNames;
ftv->retTypes = clone(t.retTypes, dest, cloneState);
ftv->hasNoGenerics = t.hasNoGenerics;
}
void TypeCloner::operator()(const TableTypeVar& t)
{
// If table is now bound to another one, we ignore the content of the original
if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo)
{
TypeId boundTo = clone(*t.boundTo, dest, cloneState);
seenTypes[typeId] = boundTo;
return;
}
TypeId result = dest.addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(result);
LUAU_ASSERT(ttv != nullptr);
*ttv = t;
seenTypes[typeId] = result;
ttv->level = TypeLevel{0, 0};
if (FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, cloneState);
for (const auto& [name, prop] : t.props)
ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags};
if (t.indexer)
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)};
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = clone(arg, dest, cloneState);
for (TypePackId& arg : ttv->instantiatedTypePackParams)
arg = clone(arg, dest, cloneState);
ttv->definitionModuleName = t.definitionModuleName;
ttv->tags = t.tags;
}
void TypeCloner::operator()(const MetatableTypeVar& t)
{
TypeId result = dest.addType(MetatableTypeVar{});
MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(result);
seenTypes[typeId] = result;
mtv->table = clone(t.table, dest, cloneState);
mtv->metatable = clone(t.metatable, dest, cloneState);
}
void TypeCloner::operator()(const ClassTypeVar& t)
{
TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData, t.definitionModuleName});
ClassTypeVar* ctv = getMutable<ClassTypeVar>(result);
seenTypes[typeId] = result;
for (const auto& [name, prop] : t.props)
ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags};
if (t.parent)
ctv->parent = clone(*t.parent, dest, cloneState);
if (t.metatable)
ctv->metatable = clone(*t.metatable, dest, cloneState);
}
void TypeCloner::operator()(const AnyTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const UnionTypeVar& t)
{
std::vector<TypeId> options;
options.reserve(t.options.size());
for (TypeId ty : t.options)
options.push_back(clone(ty, dest, cloneState));
TypeId result = dest.addType(UnionTypeVar{std::move(options)});
seenTypes[typeId] = result;
}
void TypeCloner::operator()(const IntersectionTypeVar& t)
{
TypeId result = dest.addType(IntersectionTypeVar{});
seenTypes[typeId] = result;
IntersectionTypeVar* option = getMutable<IntersectionTypeVar>(result);
LUAU_ASSERT(option != nullptr);
for (TypeId ty : t.parts)
option->parts.push_back(clone(ty, dest, cloneState));
}
void TypeCloner::operator()(const LazyTypeVar& t)
{
defaultClone(t);
}
} // anonymous namespace
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
{
if (tp->persistent)
return tp;
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
TypePackId& res = cloneState.seenTypePacks[tp];
if (res == nullptr)
{
TypePackCloner cloner{dest, tp, cloneState};
Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into.
}
return res;
}
TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
{
if (typeId->persistent)
return typeId;
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
TypeId& res = cloneState.seenTypes[typeId];
if (res == nullptr)
{
TypeCloner cloner{dest, typeId, cloneState};
Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into.
// Persistent types are not being cloned and we get the original type back which might be read-only
if (!res->persistent)
{
asMutable(res)->documentationSymbol = typeId->documentationSymbol;
asMutable(res)->normal = typeId->normal;
}
}
return res;
}
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
{
TypeFun result;
for (auto param : typeFun.typeParams)
{
TypeId ty = clone(param.ty, dest, cloneState);
std::optional<TypeId> defaultValue;
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, cloneState);
result.typeParams.push_back({ty, defaultValue});
}
for (auto param : typeFun.typePackParams)
{
TypePackId tp = clone(param.tp, dest, cloneState);
std::optional<TypePackId> defaultValue;
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, cloneState);
result.typePackParams.push_back({tp, defaultValue});
}
result.type = clone(typeFun.type, dest, cloneState);
return result;
}
TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log)
{
ty = log->follow(ty);
TypeId result = ty;
if (auto pty = log->pending(ty))
ty = &pty->pending;
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{
FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf};
clone.generics = ftv->generics;
clone.genericPacks = ftv->genericPacks;
clone.magicFunction = ftv->magicFunction;
clone.tags = ftv->tags;
clone.argNames = ftv->argNames;
result = dest.addType(std::move(clone));
}
else if (const TableTypeVar* ttv = get<TableTypeVar>(ty))
{
LUAU_ASSERT(!ttv->boundTo);
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state};
clone.definitionModuleName = ttv->definitionModuleName;
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.instantiatedTypeParams = ttv->instantiatedTypeParams;
clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams;
clone.tags = ttv->tags;
result = dest.addType(std::move(clone));
}
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{
MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable};
clone.syntheticName = mtv->syntheticName;
result = dest.addType(std::move(clone));
}
else if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
{
UnionTypeVar clone;
clone.options = utv->options;
result = dest.addType(std::move(clone));
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(ty))
{
IntersectionTypeVar clone;
clone.parts = itv->parts;
result = dest.addType(std::move(clone));
}
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
ConstrainedTypeVar clone{ctv->level, ctv->parts};
result = dest.addType(std::move(clone));
}
else
return result;
asMutable(result)->documentationSymbol = ty->documentationSymbol;
return result;
}
} // namespace Luau

View file

@ -0,0 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Constraint.h"
namespace Luau
{
Constraint::Constraint(ConstraintV&& c)
: c(std::move(c))
{
}
} // namespace Luau

View file

@ -0,0 +1,773 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/Scope.h"
namespace Luau
{
const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp
ConstraintGraphBuilder::ConstraintGraphBuilder(TypeArena* arena)
: singletonTypes(getSingletonTypes())
, arena(arena)
, rootScope(nullptr)
{
LUAU_ASSERT(arena);
}
TypeId ConstraintGraphBuilder::freshType(Scope2* scope)
{
LUAU_ASSERT(scope);
return arena->addType(FreeTypeVar{scope});
}
TypePackId ConstraintGraphBuilder::freshTypePack(Scope2* scope)
{
LUAU_ASSERT(scope);
FreeTypePack f{scope};
return arena->addTypePack(TypePackVar{std::move(f)});
}
Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent)
{
LUAU_ASSERT(parent);
auto scope = std::make_unique<Scope2>();
Scope2* borrow = scope.get();
scopes.emplace_back(location, std::move(scope));
borrow->parent = parent;
borrow->returnType = parent->returnType;
parent->children.push_back(borrow);
return borrow;
}
void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv)
{
LUAU_ASSERT(scope);
scope->constraints.emplace_back(new Constraint{std::move(cv)});
}
void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr<Constraint> c)
{
LUAU_ASSERT(scope);
scope->constraints.emplace_back(std::move(c));
}
void ConstraintGraphBuilder::visit(AstStatBlock* block)
{
LUAU_ASSERT(scopes.empty());
LUAU_ASSERT(rootScope == nullptr);
scopes.emplace_back(block->location, std::make_unique<Scope2>());
rootScope = scopes.back().second.get();
rootScope->returnType = freshTypePack(rootScope);
// TODO: We should share the global scope.
rootScope->typeBindings["nil"] = singletonTypes.nilType;
rootScope->typeBindings["number"] = singletonTypes.numberType;
rootScope->typeBindings["string"] = singletonTypes.stringType;
rootScope->typeBindings["boolean"] = singletonTypes.booleanType;
rootScope->typeBindings["thread"] = singletonTypes.threadType;
visit(rootScope, block);
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat)
{
LUAU_ASSERT(scope);
if (auto s = stat->as<AstStatBlock>())
visit(scope, s);
else if (auto s = stat->as<AstStatLocal>())
visit(scope, s);
else if (auto f = stat->as<AstStatFunction>())
visit(scope, f);
else if (auto f = stat->as<AstStatLocalFunction>())
visit(scope, f);
else if (auto r = stat->as<AstStatReturn>())
visit(scope, r);
else if (auto a = stat->as<AstStatAssign>())
visit(scope, a);
else if (auto e = stat->as<AstStatExpr>())
checkPack(scope, e->expr);
else if (auto i = stat->as<AstStatIf>())
visit(scope, i);
else if (auto a = stat->as<AstStatTypeAlias>())
visit(scope, a);
else
LUAU_ASSERT(0);
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local)
{
LUAU_ASSERT(scope);
std::vector<TypeId> varTypes;
for (AstLocal* local : local->vars)
{
TypeId ty = freshType(scope);
if (local->annotation)
{
TypeId annotation = resolveType(scope, local->annotation);
addConstraint(scope, SubtypeConstraint{ty, annotation});
}
varTypes.push_back(ty);
scope->bindings[local] = ty;
}
for (size_t i = 0; i < local->values.size; ++i)
{
if (local->values.data[i]->is<AstExprConstantNil>())
{
// HACK: we leave nil-initialized things floating under the assumption that they will later be populated.
// See the test TypeInfer/infer_locals_with_nil_value.
// Better flow awareness should make this obsolete.
}
else if (i == local->values.size - 1)
{
TypePackId exprPack = checkPack(scope, local->values.data[i]);
if (i < local->vars.size)
{
std::vector<TypeId> tailValues{varTypes.begin() + i, varTypes.end()};
TypePackId tailPack = arena->addTypePack(std::move(tailValues));
addConstraint(scope, PackSubtypeConstraint{exprPack, tailPack});
}
}
else
{
TypeId exprType = check(scope, local->values.data[i]);
if (i < varTypes.size())
addConstraint(scope, SubtypeConstraint{varTypes[i], exprType});
}
}
}
void addConstraints(Constraint* constraint, Scope2* scope)
{
LUAU_ASSERT(scope);
scope->constraints.reserve(scope->constraints.size() + scope->constraints.size());
for (const auto& c : scope->constraints)
constraint->dependencies.push_back(NotNull{c.get()});
for (Scope2* childScope : scope->children)
addConstraints(constraint, childScope);
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function)
{
LUAU_ASSERT(scope);
// Local
// Global
// Dotted path
// Self?
TypeId functionType = nullptr;
auto ty = scope->lookup(function->name);
if (ty.has_value())
{
// TODO: This is duplicate definition of a local function. Is this allowed?
functionType = *ty;
}
else
{
functionType = arena->addType(BlockedTypeVar{});
scope->bindings[function->name] = functionType;
}
auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func);
innerScope->bindings[function->name] = actualFunctionType;
checkFunctionBody(innerScope, function->func);
std::unique_ptr<Constraint> c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}};
addConstraints(c.get(), innerScope);
addConstraint(scope, std::move(c));
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function)
{
// Name could be AstStatLocal, AstStatGlobal, AstStatIndexName.
// With or without self
TypeId functionType = nullptr;
auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func);
if (AstExprLocal* localName = function->name->as<AstExprLocal>())
{
std::optional<TypeId> existingFunctionTy = scope->lookup(localName->local);
if (existingFunctionTy)
{
// Duplicate definition
functionType = *existingFunctionTy;
}
else
{
functionType = arena->addType(BlockedTypeVar{});
scope->bindings[localName->local] = functionType;
}
innerScope->bindings[localName->local] = actualFunctionType;
}
else if (AstExprGlobal* globalName = function->name->as<AstExprGlobal>())
{
std::optional<TypeId> existingFunctionTy = scope->lookup(globalName->name);
if (existingFunctionTy)
{
// Duplicate definition
functionType = *existingFunctionTy;
}
else
{
functionType = arena->addType(BlockedTypeVar{});
rootScope->bindings[globalName->name] = functionType;
}
innerScope->bindings[globalName->name] = actualFunctionType;
}
else if (AstExprIndexName* indexName = function->name->as<AstExprIndexName>())
{
LUAU_ASSERT(0); // not yet implemented
}
checkFunctionBody(innerScope, function->func);
std::unique_ptr<Constraint> c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}};
addConstraints(c.get(), innerScope);
addConstraint(scope, std::move(c));
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret)
{
LUAU_ASSERT(scope);
TypePackId exprTypes = checkPack(scope, ret->list);
addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType});
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block)
{
LUAU_ASSERT(scope);
// In order to enable mutually-recursive type aliases, we need to
// populate the type bindings before we actually check any of the
// alias statements. Since we're not ready to actually resolve
// any of the annotations, we just use a fresh type for now.
for (AstStat* stat : block->body)
{
if (auto alias = stat->as<AstStatTypeAlias>())
{
TypeId initialType = freshType(scope);
scope->typeBindings[alias->name.value] = initialType;
}
}
for (AstStat* stat : block->body)
visit(scope, stat);
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign)
{
TypePackId varPackId = checkExprList(scope, assign->vars);
TypePackId valuePack = checkPack(scope, assign->values);
addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId});
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement)
{
check(scope, ifStatement->condition);
Scope2* thenScope = childScope(ifStatement->thenbody->location, scope);
visit(thenScope, ifStatement->thenbody);
if (ifStatement->elsebody)
{
Scope2* elseScope = childScope(ifStatement->elsebody->location, scope);
visit(elseScope, ifStatement->elsebody);
}
}
void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias)
{
// TODO: Exported type aliases
// TODO: Generic type aliases
auto it = scope->typeBindings.find(alias->name.value);
// This should always be here since we do a separate pass over the
// AST to set up typeBindings. If it's not, we've somehow skipped
// this alias in that first pass.
LUAU_ASSERT(it != scope->typeBindings.end());
TypeId ty = resolveType(scope, alias->type);
// Rather than using a subtype constraint, we instead directly bind
// the free type we generated in the first pass to the resolved type.
// This prevents a case where you could cause another constraint to
// bind the free alias type to an unrelated type, causing havoc.
asMutable(it->second)->ty.emplace<BoundTypeVar>(ty);
addConstraint(scope, NameConstraint{ty, alias->name.value});
}
TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray<AstExpr*> exprs)
{
LUAU_ASSERT(scope);
if (exprs.size == 0)
return arena->addTypePack({});
std::vector<TypeId> types;
TypePackId last = nullptr;
for (size_t i = 0; i < exprs.size; ++i)
{
if (i < exprs.size - 1)
types.push_back(check(scope, exprs.data[i]));
else
last = checkPack(scope, exprs.data[i]);
}
LUAU_ASSERT(last != nullptr);
return arena->addTypePack(TypePack{std::move(types), last});
}
TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray<AstExpr*>& exprs)
{
TypePackId result = arena->addTypePack({});
TypePack* resultPack = getMutable<TypePack>(result);
LUAU_ASSERT(resultPack);
for (size_t i = 0; i < exprs.size; ++i)
{
AstExpr* expr = exprs.data[i];
if (i < exprs.size - 1)
resultPack->head.push_back(check(scope, expr));
else
resultPack->tail = checkPack(scope, expr);
}
if (resultPack->head.empty() && resultPack->tail)
return *resultPack->tail;
else
return result;
}
TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr)
{
LUAU_ASSERT(scope);
TypePackId result = nullptr;
if (AstExprCall* call = expr->as<AstExprCall>())
{
std::vector<TypeId> args;
for (AstExpr* arg : call->args)
{
args.push_back(check(scope, arg));
}
// TODO self
TypeId fnType = check(scope, call->func);
astOriginalCallTypes[call->func] = fnType;
TypeId instantiatedType = freshType(scope);
addConstraint(scope, InstantiationConstraint{instantiatedType, fnType});
TypePackId rets = freshTypePack(scope);
FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets);
TypeId inferredFnType = arena->addType(ftv);
addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType});
result = rets;
}
else
{
TypeId t = check(scope, expr);
result = arena->addTypePack({t});
}
LUAU_ASSERT(result);
astTypePacks[expr] = result;
return result;
}
TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr)
{
LUAU_ASSERT(scope);
TypeId result = nullptr;
if (auto group = expr->as<AstExprGroup>())
result = check(scope, group->expr);
else if (expr->is<AstExprConstantString>())
result = singletonTypes.stringType;
else if (expr->is<AstExprConstantNumber>())
result = singletonTypes.numberType;
else if (expr->is<AstExprConstantBool>())
result = singletonTypes.booleanType;
else if (expr->is<AstExprConstantNil>())
result = singletonTypes.nilType;
else if (auto a = expr->as<AstExprLocal>())
{
std::optional<TypeId> ty = scope->lookup(a->local);
if (ty)
result = *ty;
else
result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point?
}
else if (auto g = expr->as<AstExprGlobal>())
{
std::optional<TypeId> ty = scope->lookup(g->name);
if (ty)
result = *ty;
else
result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point?
}
else if (auto a = expr->as<AstExprCall>())
{
TypePackId packResult = checkPack(scope, expr);
if (auto f = first(packResult))
return *f;
else if (get<FreeTypePack>(packResult))
{
TypeId typeResult = freshType(scope);
TypePack onePack{{typeResult}, freshTypePack(scope)};
TypePackId oneTypePack = arena->addTypePack(std::move(onePack));
addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack});
return typeResult;
}
}
else if (auto a = expr->as<AstExprFunction>())
{
auto [fnType, functionScope] = checkFunctionSignature(scope, a);
checkFunctionBody(functionScope, a);
return fnType;
}
else if (auto indexName = expr->as<AstExprIndexName>())
{
result = check(scope, indexName);
}
else if (auto table = expr->as<AstExprTable>())
{
result = checkExprTable(scope, table);
}
else
{
LUAU_ASSERT(0);
result = freshType(scope);
}
LUAU_ASSERT(result);
astTypes[expr] = result;
return result;
}
TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName)
{
TypeId obj = check(scope, indexName->expr);
TypeId result = freshType(scope);
TableTypeVar::Props props{{indexName->index.value, Property{result}}};
const std::optional<TableIndexer> indexer;
TableTypeVar ttv{std::move(props), indexer, TypeLevel{}, TableState::Free};
TypeId expectedTableType = arena->addType(std::move(ttv));
addConstraint(scope, SubtypeConstraint{obj, expectedTableType});
return result;
}
TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr)
{
TypeId ty = arena->addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(ty);
LUAU_ASSERT(ttv);
auto createIndexer = [this, scope, ttv](TypeId currentIndexType, TypeId currentResultType) {
if (!ttv->indexer)
{
TypeId indexType = this->freshType(scope);
TypeId resultType = this->freshType(scope);
ttv->indexer = TableIndexer{indexType, resultType};
}
addConstraint(scope, SubtypeConstraint{ttv->indexer->indexType, currentIndexType});
addConstraint(scope, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType});
};
for (const AstExprTable::Item& item : expr->items)
{
TypeId itemTy = check(scope, item.value);
if (item.key)
{
// Even though we don't need to use the type of the item's key if
// it's a string constant, we still want to check it to populate
// astTypes.
TypeId keyTy = check(scope, item.key);
if (AstExprConstantString* key = item.key->as<AstExprConstantString>())
{
ttv->props[key->value.begin()] = {itemTy};
}
else
{
createIndexer(keyTy, itemTy);
}
}
else
{
TypeId numberType = singletonTypes.numberType;
createIndexer(numberType, itemTy);
}
}
return ty;
}
std::pair<TypeId, Scope2*> ConstraintGraphBuilder::checkFunctionSignature(Scope2* parent, AstExprFunction* fn)
{
Scope2* innerScope = childScope(fn->body->location, parent);
TypePackId returnType = freshTypePack(innerScope);
innerScope->returnType = returnType;
if (fn->returnAnnotation)
{
TypePackId annotatedRetType = resolveTypePack(innerScope, *fn->returnAnnotation);
addConstraint(innerScope, PackSubtypeConstraint{returnType, annotatedRetType});
}
std::vector<TypeId> argTypes;
for (AstLocal* local : fn->args)
{
TypeId t = freshType(innerScope);
argTypes.push_back(t);
innerScope->bindings[local] = t;
if (local->annotation)
{
TypeId argAnnotation = resolveType(innerScope, local->annotation);
addConstraint(innerScope, SubtypeConstraint{t, argAnnotation});
}
}
// TODO: Vararg annotation.
FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType};
TypeId actualFunctionType = arena->addType(std::move(actualFunction));
LUAU_ASSERT(actualFunctionType);
astTypes[fn] = actualFunctionType;
return {actualFunctionType, innerScope};
}
void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* fn)
{
for (AstStat* stat : fn->body->body)
visit(scope, stat);
// If it is possible for execution to reach the end of the function, the return type must be compatible with ()
if (nullptr != getFallthrough(fn->body))
{
TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever
addConstraint(scope, PackSubtypeConstraint{scope->returnType, empty});
}
}
TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty)
{
TypeId result = nullptr;
if (auto ref = ty->as<AstTypeReference>())
{
// TODO: Support imported types w/ require tracing.
// TODO: Support generic type references.
LUAU_ASSERT(!ref->prefix);
LUAU_ASSERT(!ref->hasParameterList);
// TODO: If it doesn't exist, should we introduce a free binding?
// This is probably important for handling type aliases.
result = scope->lookupTypeBinding(ref->name.value).value_or(singletonTypes.errorRecoveryType());
}
else if (auto tab = ty->as<AstTypeTable>())
{
TableTypeVar::Props props;
std::optional<TableIndexer> indexer;
for (const AstTableProp& prop : tab->props)
{
std::string name = prop.name.value;
// TODO: Recursion limit.
TypeId propTy = resolveType(scope, prop.type);
// TODO: Fill in location.
props[name] = {propTy};
}
if (tab->indexer)
{
// TODO: Recursion limit.
indexer = TableIndexer{
resolveType(scope, tab->indexer->indexType),
resolveType(scope, tab->indexer->resultType),
};
}
// TODO: Remove TypeLevel{} here, we don't need it.
result = arena->addType(TableTypeVar{props, indexer, TypeLevel{}, TableState::Sealed});
}
else if (auto fn = ty->as<AstTypeFunction>())
{
// TODO: Generic functions.
// TODO: Scope (though it may not be needed).
// TODO: Recursion limit.
TypePackId argTypes = resolveTypePack(scope, fn->argTypes);
TypePackId returnTypes = resolveTypePack(scope, fn->returnTypes);
// TODO: Is this the right constructor to use?
result = arena->addType(FunctionTypeVar{argTypes, returnTypes});
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(result);
ftv->argNames.reserve(fn->argNames.size);
for (const auto& el : fn->argNames)
{
if (el)
{
const auto& [name, location] = *el;
ftv->argNames.push_back(FunctionArgument{name.value, location});
}
else
{
ftv->argNames.push_back(std::nullopt);
}
}
}
else if (auto tof = ty->as<AstTypeTypeof>())
{
// TODO: Recursion limit.
TypeId exprType = check(scope, tof->expr);
result = exprType;
}
else if (auto unionAnnotation = ty->as<AstTypeUnion>())
{
std::vector<TypeId> parts;
for (AstType* part : unionAnnotation->types)
{
// TODO: Recursion limit.
parts.push_back(resolveType(scope, part));
}
result = arena->addType(UnionTypeVar{parts});
}
else if (auto intersectionAnnotation = ty->as<AstTypeIntersection>())
{
std::vector<TypeId> parts;
for (AstType* part : intersectionAnnotation->types)
{
// TODO: Recursion limit.
parts.push_back(resolveType(scope, part));
}
result = arena->addType(IntersectionTypeVar{parts});
}
else if (auto boolAnnotation = ty->as<AstTypeSingletonBool>())
{
result = arena->addType(SingletonTypeVar(BooleanSingleton{boolAnnotation->value}));
}
else if (auto stringAnnotation = ty->as<AstTypeSingletonString>())
{
result = arena->addType(SingletonTypeVar(StringSingleton{std::string(stringAnnotation->value.data, stringAnnotation->value.size)}));
}
else if (ty->is<AstTypeError>())
{
result = singletonTypes.errorRecoveryType();
}
else
{
LUAU_ASSERT(0);
result = singletonTypes.errorRecoveryType();
}
astResolvedTypes[ty] = result;
return result;
}
TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* tp)
{
TypePackId result;
if (auto expl = tp->as<AstTypePackExplicit>())
{
result = resolveTypePack(scope, expl->typeList);
}
else if (auto var = tp->as<AstTypePackVariadic>())
{
TypeId ty = resolveType(scope, var->variadicType);
result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}});
}
else if (auto gen = tp->as<AstTypePackGeneric>())
{
result = arena->addTypePack(TypePackVar{GenericTypePack{scope, gen->genericName.value}});
}
else
{
LUAU_ASSERT(0);
result = singletonTypes.errorRecoveryTypePack();
}
astResolvedTypePacks[tp] = result;
return result;
}
TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeList& list)
{
std::vector<TypeId> head;
for (AstType* headTy : list.types)
{
head.push_back(resolveType(scope, headTy));
}
std::optional<TypePackId> tail = std::nullopt;
if (list.tailType)
{
tail = resolveTypePack(scope, list.tailType);
}
return arena->addTypePack(TypePack{head, tail});
}
void collectConstraints(std::vector<NotNull<Constraint>>& result, Scope2* scope)
{
for (const auto& c : scope->constraints)
result.push_back(NotNull{c.get()});
for (Scope2* child : scope->children)
collectConstraints(result, child);
}
std::vector<NotNull<Constraint>> collectConstraints(Scope2* rootScope)
{
std::vector<NotNull<Constraint>> result;
collectConstraints(result, rootScope);
return result;
}
} // namespace Luau

View file

@ -0,0 +1,361 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ConstraintSolver.h"
#include "Luau/Instantiation.h"
#include "Luau/Location.h"
#include "Luau/Quantify.h"
#include "Luau/ToString.h"
#include "Luau/Unifier.h"
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false);
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false);
namespace Luau
{
[[maybe_unused]] static void dumpBindings(Scope2* scope, ToStringOptions& opts)
{
for (const auto& [k, v] : scope->bindings)
{
auto d = toStringDetailed(v, opts);
opts.nameMap = d.nameMap;
printf("\t%s : %s\n", k.c_str(), d.name.c_str());
}
for (Scope2* child : scope->children)
dumpBindings(child, opts);
}
static void dumpConstraints(Scope2* scope, ToStringOptions& opts)
{
for (const ConstraintPtr& c : scope->constraints)
{
printf("\t%s\n", toString(*c, opts).c_str());
}
for (Scope2* child : scope->children)
dumpConstraints(child, opts);
}
void dump(Scope2* rootScope, ToStringOptions& opts)
{
printf("constraints:\n");
dumpConstraints(rootScope, opts);
}
void dump(ConstraintSolver* cs, ToStringOptions& opts)
{
printf("constraints:\n");
for (const Constraint* c : cs->unsolvedConstraints)
{
printf("\t%s\n", toString(*c, opts).c_str());
for (const Constraint* dep : c->dependencies)
printf("\t\t%s\n", toString(*dep, opts).c_str());
}
}
ConstraintSolver::ConstraintSolver(TypeArena* arena, Scope2* rootScope)
: arena(arena)
, constraints(collectConstraints(rootScope))
, rootScope(rootScope)
{
for (NotNull<Constraint> c : constraints)
{
unsolvedConstraints.push_back(c);
for (NotNull<const Constraint> dep : c->dependencies)
{
block(dep, c);
}
}
}
void ConstraintSolver::run()
{
if (done())
return;
ToStringOptions opts;
if (FFlag::DebugLuauLogSolver)
{
printf("Starting solver\n");
dump(this, opts);
}
if (FFlag::DebugLuauLogSolverToJson)
{
logger.captureBoundarySnapshot(rootScope, unsolvedConstraints);
}
auto runSolverPass = [&](bool force) {
bool progress = false;
size_t i = 0;
while (i < unsolvedConstraints.size())
{
NotNull<const Constraint> c = unsolvedConstraints[i];
if (!force && isBlocked(c))
{
++i;
continue;
}
std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{};
if (FFlag::DebugLuauLogSolverToJson)
{
logger.prepareStepSnapshot(rootScope, c, unsolvedConstraints);
}
bool success = tryDispatch(c, force);
progress |= success;
if (success)
{
unsolvedConstraints.erase(unsolvedConstraints.begin() + i);
if (FFlag::DebugLuauLogSolverToJson)
{
logger.commitPreparedStepSnapshot();
}
if (FFlag::DebugLuauLogSolver)
{
if (force)
printf("Force ");
printf("Dispatched\n\t%s\n", saveMe.c_str());
dump(this, opts);
}
}
else
++i;
if (force && success)
return true;
}
return progress;
};
bool progress = false;
do
{
progress = runSolverPass(false);
if (!progress)
progress |= runSolverPass(true);
} while (progress);
if (FFlag::DebugLuauLogSolver)
{
dumpBindings(rootScope, opts);
}
if (FFlag::DebugLuauLogSolverToJson)
{
logger.captureBoundarySnapshot(rootScope, unsolvedConstraints);
printf("Logger output:\n%s\n", logger.compileOutput().c_str());
}
}
bool ConstraintSolver::done()
{
return unsolvedConstraints.empty();
}
bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool force)
{
if (!force && isBlocked(constraint))
return false;
bool success = false;
if (auto sc = get<SubtypeConstraint>(*constraint))
success = tryDispatch(*sc, constraint, force);
else if (auto psc = get<PackSubtypeConstraint>(*constraint))
success = tryDispatch(*psc, constraint, force);
else if (auto gc = get<GeneralizationConstraint>(*constraint))
success = tryDispatch(*gc, constraint, force);
else if (auto ic = get<InstantiationConstraint>(*constraint))
success = tryDispatch(*ic, constraint, force);
else if (auto nc = get<NameConstraint>(*constraint))
success = tryDispatch(*nc, constraint);
else
LUAU_ASSERT(0);
if (success)
{
unblock(constraint);
}
return success;
}
bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint, bool force)
{
if (isBlocked(c.subType))
return block(c.subType, constraint);
else if (isBlocked(c.superType))
return block(c.superType, constraint);
unify(c.subType, c.superType);
unblock(c.subType);
unblock(c.superType);
return true;
}
bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force)
{
unify(c.subPack, c.superPack);
unblock(c.subPack);
unblock(c.superPack);
return true;
}
bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull<const Constraint> constraint, bool force)
{
if (isBlocked(c.sourceType))
return block(c.sourceType, constraint);
if (isBlocked(c.generalizedType))
asMutable(c.generalizedType)->ty.emplace<BoundTypeVar>(c.sourceType);
else
unify(c.generalizedType, c.sourceType);
TypeId generalized = quantify(arena, c.sourceType, c.scope);
*asMutable(c.sourceType) = *generalized;
unblock(c.generalizedType);
unblock(c.sourceType);
return true;
}
bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull<const Constraint> constraint, bool force)
{
if (isBlocked(c.superType))
return block(c.superType, constraint);
Instantiation inst(TxnLog::empty(), arena, TypeLevel{});
std::optional<TypeId> instantiated = inst.substitute(c.superType);
LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS
unify(c.subType, *instantiated);
unblock(c.subType);
return true;
}
bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint)
{
if (isBlocked(c.namedType))
return block(c.namedType, constraint);
TypeId target = follow(c.namedType);
if (TableTypeVar* ttv = getMutable<TableTypeVar>(target))
ttv->name = c.name;
else if (MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(target))
mtv->syntheticName = c.name;
else
return block(c.namedType, constraint);
return true;
}
void ConstraintSolver::block_(BlockedConstraintId target, NotNull<const Constraint> constraint)
{
blocked[target].push_back(constraint);
auto& count = blockedConstraints[constraint];
count += 1;
}
void ConstraintSolver::block(NotNull<const Constraint> target, NotNull<const Constraint> constraint)
{
block_(target, constraint);
}
bool ConstraintSolver::block(TypeId target, NotNull<const Constraint> constraint)
{
block_(target, constraint);
return false;
}
bool ConstraintSolver::block(TypePackId target, NotNull<const Constraint> constraint)
{
block_(target, constraint);
return false;
}
void ConstraintSolver::unblock_(BlockedConstraintId progressed)
{
auto it = blocked.find(progressed);
if (it == blocked.end())
return;
// unblocked should contain a value always, because of the above check
for (NotNull<const Constraint> unblockedConstraint : it->second)
{
auto& count = blockedConstraints[unblockedConstraint];
// This assertion being hit indicates that `blocked` and
// `blockedConstraints` desynchronized at some point. This is problematic
// because we rely on this count being correct to skip over blocked
// constraints.
LUAU_ASSERT(count > 0);
count -= 1;
}
blocked.erase(it);
}
void ConstraintSolver::unblock(NotNull<const Constraint> progressed)
{
return unblock_(progressed);
}
void ConstraintSolver::unblock(TypeId progressed)
{
return unblock_(progressed);
}
void ConstraintSolver::unblock(TypePackId progressed)
{
return unblock_(progressed);
}
bool ConstraintSolver::isBlocked(TypeId ty)
{
return nullptr != get<BlockedTypeVar>(follow(ty));
}
bool ConstraintSolver::isBlocked(NotNull<const Constraint> constraint)
{
auto blockedIt = blockedConstraints.find(constraint);
return blockedIt != blockedConstraints.end() && blockedIt->second > 0;
}
void ConstraintSolver::unify(TypeId subType, TypeId superType)
{
UnifierSharedState sharedState{&iceReporter};
Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState};
u.tryUnify(subType, superType);
u.log.commit();
}
void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack)
{
UnifierSharedState sharedState{&iceReporter};
Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState};
u.tryUnify(subPack, superPack);
u.log.commit();
}
} // namespace Luau

View file

@ -0,0 +1,139 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ConstraintSolverLogger.h"
namespace Luau
{
static std::string dumpScopeAndChildren(const Scope2* scope, ToStringOptions& opts)
{
std::string output = "{\"bindings\":{";
bool comma = false;
for (const auto& [name, type] : scope->bindings)
{
if (comma)
output += ",";
output += "\"";
output += name.c_str();
output += "\": \"";
ToStringResult result = toStringDetailed(type, opts);
opts.nameMap = std::move(result.nameMap);
output += result.name;
output += "\"";
comma = true;
}
output += "},\"children\":[";
comma = false;
for (const Scope2* child : scope->children)
{
if (comma)
output += ",";
output += dumpScopeAndChildren(child, opts);
comma = true;
}
output += "]}";
return output;
}
static std::string dumpConstraintsToDot(std::vector<NotNull<const Constraint>>& constraints, ToStringOptions& opts)
{
std::string result = "digraph Constraints {\\n";
std::unordered_set<NotNull<const Constraint>> contained;
for (NotNull<const Constraint> c : constraints)
{
contained.insert(c);
}
for (NotNull<const Constraint> c : constraints)
{
std::string id = std::to_string(reinterpret_cast<size_t>(c.get()));
result += id;
result += " [label=\\\"";
result += toString(*c, opts).c_str();
result += "\\\"];\\n";
for (NotNull<const Constraint> dep : c->dependencies)
{
if (contained.count(dep) == 0)
continue;
result += std::to_string(reinterpret_cast<size_t>(dep.get()));
result += " -> ";
result += id;
result += ";\\n";
}
}
result += "}";
return result;
}
std::string ConstraintSolverLogger::compileOutput()
{
std::string output = "[";
bool comma = false;
for (const std::string& snapshot : snapshots)
{
if (comma)
output += ",";
output += snapshot;
comma = true;
}
output += "]";
return output;
}
void ConstraintSolverLogger::captureBoundarySnapshot(const Scope2* rootScope, std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
std::string snapshot = "{\"type\":\"boundary\",\"rootScope\":";
snapshot += dumpScopeAndChildren(rootScope, opts);
snapshot += ",\"constraintGraph\":\"";
snapshot += dumpConstraintsToDot(unsolvedConstraints, opts);
snapshot += "\"}";
snapshots.push_back(std::move(snapshot));
}
void ConstraintSolverLogger::prepareStepSnapshot(
const Scope2* rootScope, NotNull<const Constraint> current, std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
// LUAU_ASSERT(!preparedSnapshot);
std::string snapshot = "{\"type\":\"step\",\"rootScope\":";
snapshot += dumpScopeAndChildren(rootScope, opts);
snapshot += ",\"constraintGraph\":\"";
snapshot += dumpConstraintsToDot(unsolvedConstraints, opts);
snapshot += "\",\"currentId\":\"";
snapshot += std::to_string(reinterpret_cast<size_t>(current.get()));
snapshot += "\",\"current\":\"";
snapshot += toString(*current, opts);
snapshot += "\"}";
preparedSnapshot = std::move(snapshot);
}
void ConstraintSolverLogger::commitPreparedStepSnapshot()
{
if (preparedSnapshot)
{
snapshots.push_back(std::move(*preparedSnapshot));
preparedSnapshot = std::nullopt;
}
}
} // namespace Luau

View file

@ -7,7 +7,10 @@ namespace Luau
static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC(
declare bit32: {
-- band, bor, bxor, and btest are declared in C++
band: (...number) -> number,
bor: (...number) -> number,
bxor: (...number) -> number,
btest: (number, ...number) -> boolean,
rrotate: (number, number) -> number,
lrotate: (number, number) -> number,
lshift: (number, number) -> number,
@ -50,7 +53,8 @@ declare math: {
asin: (number) -> number,
atan2: (number, number) -> number,
-- min and max are declared in C++.
min: (number, ...number) -> number,
max: (number, ...number) -> number,
pi: number,
huge: number,
@ -143,7 +147,7 @@ declare coroutine: {
create: <A..., R...>((A...) -> R...) -> thread,
resume: <A..., R...>(thread, A...) -> (boolean, R...),
running: () -> thread,
status: (thread) -> string,
status: (thread) -> "dead" | "running" | "normal" | "suspended",
-- FIXME: This technically returns a function, but we can't represent this yet.
wrap: <A..., R...>((A...) -> R...) -> any,
yield: <A..., R...>(A...) -> R...,
@ -179,7 +183,7 @@ declare debug: {
}
declare utf8: {
char: (number, ...number) -> string,
char: (...number) -> string,
charpattern: string,
codes: (string) -> ((string, number) -> (number, number), string, number),
-- FIXME

View file

@ -1,14 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Error.h"
#include "Luau/Module.h"
#include "Luau/Clone.h"
#include "Luau/StringUtils.h"
#include "Luau/ToString.h"
#include <stdexcept>
LUAU_FASTFLAGVARIABLE(BetterDiagnosticCodesInStudio, false);
LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleName, false);
LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false)
LUAU_FASTFLAGVARIABLE(LuauUseInternalCompilerErrorException, false)
static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
{
@ -52,6 +52,8 @@ namespace Luau
struct ErrorConverter
{
FileResolver* fileResolver = nullptr;
std::string operator()(const Luau::TypeMismatch& tm) const
{
std::string givenTypeName = Luau::toString(tm.givenType);
@ -59,27 +61,30 @@ struct ErrorConverter
std::string result;
if (FFlag::LuauTypeMismatchModuleName)
if (givenTypeName == wantedTypeName)
{
if (givenTypeName == wantedTypeName)
if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType))
{
if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType))
if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType))
{
if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType))
if (FFlag::LuauTypeMismatchModuleNameResolution && fileResolver != nullptr)
{
std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule);
std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule);
result = "Type '" + givenTypeName + "' from '" + givenModuleName + "' could not be converted into '" + wantedTypeName +
"' from '" + wantedModuleName + "'";
}
else
{
result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName +
"' from '" + *wantedDefinitionModule + "'";
}
}
}
}
if (result.empty())
result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'";
}
else
{
if (result.empty())
result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'";
}
if (tm.error)
{
@ -88,7 +93,14 @@ struct ErrorConverter
if (!tm.reason.empty())
result += tm.reason + " ";
result += Luau::toString(*tm.error);
if (FFlag::LuauTypeMismatchModuleNameResolution)
{
result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver});
}
else
{
result += Luau::toString(*tm.error);
}
}
else if (!tm.reason.empty())
{
@ -187,15 +199,7 @@ struct ErrorConverter
std::string operator()(const Luau::FunctionRequiresSelf& e) const
{
if (e.requiredExtraNils)
{
const char* plural = e.requiredExtraNils == 1 ? "" : "s";
return format("This function was declared to accept self, but you did not pass enough arguments. Use a colon instead of a dot or "
"pass %i extra nil%s to suppress this warning",
e.requiredExtraNils, plural);
}
else
return "This function must be called with self. Did you mean to use a colon instead of a dot?";
return "This function must be called with self. Did you mean to use a colon instead of a dot?";
}
std::string operator()(const Luau::OccursCheckFailed&) const
@ -251,14 +255,7 @@ struct ErrorConverter
std::string operator()(const Luau::SyntaxError& e) const
{
if (FFlag::BetterDiagnosticCodesInStudio)
{
return e.message;
}
else
{
return "Syntax error: " + e.message;
}
return e.message;
}
std::string operator()(const Luau::CodeTooComplex&) const
@ -305,6 +302,11 @@ struct ErrorConverter
return e.message;
}
std::string operator()(const Luau::InternalError& e) const
{
return e.message;
}
std::string operator()(const Luau::CannotCallNonFunction& e) const
{
return "Cannot call non-function " + toString(e.ty);
@ -450,6 +452,11 @@ struct ErrorConverter
{
return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated";
}
std::string operator()(const NormalizationTooComplex&) const
{
return "Code is too complex to typecheck! Consider simplifying the code around this area";
}
};
struct InvalidNameChecker
@ -550,7 +557,7 @@ bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const
bool FunctionRequiresSelf::operator==(const FunctionRequiresSelf& e) const
{
return requiredExtraNils == e.requiredExtraNils;
return true;
}
bool OccursCheckFailed::operator==(const OccursCheckFailed&) const
@ -618,6 +625,11 @@ bool GenericError::operator==(const GenericError& rhs) const
return message == rhs.message;
}
bool InternalError::operator==(const InternalError& rhs) const
{
return message == rhs.message;
}
bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const
{
return ty == rhs.ty;
@ -705,7 +717,12 @@ bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const
std::string toString(const TypeError& error)
{
ErrorConverter converter;
return toString(error, TypeErrorToStringOptions{});
}
std::string toString(const TypeError& error, TypeErrorToStringOptions options)
{
ErrorConverter converter{options.fileResolver};
return Luau::visit(converter, error.data);
}
@ -715,14 +732,14 @@ bool containsParseErrorName(const TypeError& error)
}
template<typename T>
void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState cloneState)
void copyError(T& e, TypeArena& destArena, CloneState cloneState)
{
auto clone = [&](auto&& ty) {
return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks, cloneState);
return ::Luau::clone(ty, destArena, cloneState);
};
auto visitErrorData = [&](auto&& e) {
copyError(e, destArena, seenTypes, seenTypePacks, cloneState);
copyError(e, destArena, cloneState);
};
if constexpr (false)
@ -793,6 +810,9 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks&
else if constexpr (std::is_same_v<T, GenericError>)
{
}
else if constexpr (std::is_same_v<T, InternalError>)
{
}
else if constexpr (std::is_same_v<T, CannotCallNonFunction>)
{
e.ty = clone(e.ty);
@ -843,18 +863,19 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks&
e.left = clone(e.left);
e.right = clone(e.right);
}
else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
{
}
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}
void copyErrors(ErrorVec& errors, TypeArena& destArena)
{
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
auto visitErrorData = [&](auto&& e) {
copyError(e, destArena, seenTypes, seenTypePacks, cloneState);
copyError(e, destArena, cloneState);
};
LUAU_ASSERT(!destArena.typeVars.isFrozen());
@ -866,22 +887,51 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena)
void InternalErrorReporter::ice(const std::string& message, const Location& location)
{
std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message);
if (FFlag::LuauUseInternalCompilerErrorException)
{
InternalCompilerError error(message, moduleName, location);
if (onInternalError)
onInternalError(error.what());
if (onInternalError)
onInternalError(error.what());
throw error;
throw error;
}
else
{
std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message);
if (onInternalError)
onInternalError(error.what());
throw error;
}
}
void InternalErrorReporter::ice(const std::string& message)
{
std::runtime_error error("Internal error in " + moduleName + ": " + message);
if (FFlag::LuauUseInternalCompilerErrorException)
{
InternalCompilerError error(message, moduleName);
if (onInternalError)
onInternalError(error.what());
if (onInternalError)
onInternalError(error.what());
throw error;
throw error;
}
else
{
std::runtime_error error("Internal error in " + moduleName + ": " + message);
if (onInternalError)
onInternalError(error.what());
throw error;
}
}
const char* InternalCompilerError::what() const throw()
{
return this->message.data();
}
} // namespace Luau

View file

@ -1,23 +1,31 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Frontend.h"
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/Config.h"
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/FileResolver.h"
#include "Luau/Parser.h"
#include "Luau/Scope.h"
#include "Luau/StringUtils.h"
#include "Luau/TimeTrace.h"
#include "Luau/TypeChecker2.h"
#include "Luau/TypeInfer.h"
#include "Luau/Variant.h"
#include "Luau/Common.h"
#include <algorithm>
#include <chrono>
#include <stdexcept>
LUAU_FASTINT(LuauTypeInferIterationLimit)
LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false)
LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100)
LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
namespace Luau
{
@ -93,13 +101,11 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
if (checkedModule->errors.size() > 0)
return LoadDefinitionFileResult{false, parseResult, checkedModule};
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
for (const auto& [name, ty] : checkedModule->declaredGlobals)
{
TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState);
TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/global/" + name;
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
@ -109,7 +115,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
{
TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState);
TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState);
std::string documentationSymbol = packageName + "/globaltype/" + name;
generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy;
@ -211,7 +217,7 @@ ErrorVec accumulateErrors(
continue;
const SourceNode& sourceNode = it->second;
queue.insert(queue.end(), sourceNode.requires.begin(), sourceNode.requires.end());
queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end());
// FIXME: If a module has a syntax error, we won't be able to re-report it here.
// The solution is probably to move errors from Module to SourceNode
@ -234,12 +240,6 @@ ErrorVec accumulateErrors(
return result;
}
struct RequireCycle
{
Location location;
std::vector<ModuleName> path; // one of the paths for a require() to go all the way back to the originating module
};
// Given a source node (start), find all requires that start a transitive dependency path that ends back at start
// For each such path, record the full path and the location of the require in the starting module.
// Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V)
@ -356,33 +356,44 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
FrontendOptions frontendOptions = optionOverride.value_or(options);
CheckResult checkResult;
auto it = sourceNodes.find(name);
if (it != sourceNodes.end() && !it->second.dirty)
if (it != sourceNodes.end() && !it->second.hasDirtyModule(frontendOptions.forAutocomplete))
{
// No recheck required.
auto it2 = moduleResolver.modules.find(name);
if (it2 == moduleResolver.modules.end() || it2->second == nullptr)
throw std::runtime_error("Frontend::modules does not have data for " + name);
if (frontendOptions.forAutocomplete)
{
auto it2 = moduleResolverForAutocomplete.modules.find(name);
if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr)
throw std::runtime_error("Frontend::modules does not have data for " + name);
}
else
{
auto it2 = moduleResolver.modules.find(name);
if (it2 == moduleResolver.modules.end() || it2->second == nullptr)
throw std::runtime_error("Frontend::modules does not have data for " + name);
}
return CheckResult{accumulateErrors(sourceNodes, moduleResolver.modules, name)};
return CheckResult{
accumulateErrors(sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)};
}
std::vector<ModuleName> buildQueue;
bool cycleDetected = parseGraph(buildQueue, checkResult, name);
FrontendOptions frontendOptions = optionOverride.value_or(options);
bool cycleDetected = parseGraph(buildQueue, checkResult, name, frontendOptions.forAutocomplete);
// Keep track of which AST nodes we've reported cycles in
std::unordered_set<AstNode*> reportedCycles;
double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0;
for (const ModuleName& moduleName : buildQueue)
{
LUAU_ASSERT(sourceNodes.count(moduleName));
SourceNode& sourceNode = sourceNodes[moduleName];
if (!sourceNode.dirty)
if (!sourceNode.hasDirtyModule(frontendOptions.forAutocomplete))
continue;
LUAU_ASSERT(sourceModules.count(moduleName));
@ -408,17 +419,64 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
// This is used by the type checker to replace the resulting type of cyclic modules with any
sourceModule.cyclic = !requireCycles.empty();
ModulePtr module = typeChecker.check(sourceModule, mode, environmentScope);
// If we're typechecking twice, we do so.
// The second typecheck is always in strict mode with DM awareness
// to provide better typen information for IDE features.
if (frontendOptions.typecheckTwice)
if (frontendOptions.forAutocomplete)
{
// The autocomplete typecheck is always in strict mode with DM awareness
// to provide better type information for IDE features
typeCheckerForAutocomplete.requireCycles = requireCycles;
if (autocompleteTimeLimit != 0.0)
typeCheckerForAutocomplete.finishTime = TimeTrace::getClock() + autocompleteTimeLimit;
else
typeCheckerForAutocomplete.finishTime = std::nullopt;
if (FFlag::LuauAutocompleteDynamicLimits)
{
// TODO: This is a dirty ad hoc solution for autocomplete timeouts
// We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit
// so that we'll have type information for the whole file at lower quality instead of a full abort in the middle
if (FInt::LuauTarjanChildLimit > 0)
typeCheckerForAutocomplete.instantiationChildLimit =
std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt;
if (FInt::LuauTypeInferIterationLimit > 0)
typeCheckerForAutocomplete.unifierIterationLimit =
std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt;
}
ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict);
moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete;
double duration = getTimestamp() - timestamp;
if (moduleForAutocomplete->timeout)
{
checkResult.timeoutHits.push_back(moduleName);
if (FFlag::LuauAutocompleteDynamicLimits)
sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
}
else if (FFlag::LuauAutocompleteDynamicLimits && duration < autocompleteTimeLimit / 2.0)
{
sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0);
}
stats.timeCheck += duration;
stats.filesStrict += 1;
sourceNode.dirtyModuleForAutocomplete = false;
continue;
}
typeChecker.requireCycles = requireCycles;
ModulePtr module = FFlag::DebugLuauDeferredConstraintResolution ? check(sourceModule, mode, environmentScope)
: typeChecker.check(sourceModule, mode, environmentScope);
stats.timeCheck += getTimestamp() - timestamp;
stats.filesStrict += mode == Mode::Strict;
stats.filesNonstrict += mode == Mode::Nonstrict;
@ -461,13 +519,13 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end());
moduleResolver.modules[moduleName] = std::move(module);
sourceNode.dirty = false;
sourceNode.dirtyModule = false;
}
return checkResult;
}
bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& checkResult, const ModuleName& root)
bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete)
{
LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend");
LUAU_TIMETRACE_ARGUMENT("root", root.c_str());
@ -529,7 +587,7 @@ bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& chec
path.push_back(top);
// push children
for (const ModuleName& dep : top->requires)
for (const ModuleName& dep : top->requireSet)
{
auto it = sourceNodes.find(dep);
if (it != sourceNodes.end())
@ -538,7 +596,7 @@ bool Frontend::parseGraph(std::vector<ModuleName>& buildQueue, CheckResult& chec
// this relies on the fact that markDirty marks reverse-dependencies dirty as well
// thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need
// to be built, *and* can't form a cycle with any nodes we did process.
if (!it->second.dirty)
if (!it->second.hasDirtyModule(forAutocomplete))
continue;
// note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization
@ -625,30 +683,6 @@ std::pair<SourceModule, LintResult> Frontend::lintFragment(std::string_view sour
return {std::move(sourceModule), classifyLints(warnings, config)};
}
CheckResult Frontend::check(const SourceModule& module)
{
LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend");
LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str());
const Config& config = configResolver->getConfig(module.name);
Mode mode = module.mode.value_or(config.mode);
double timestamp = getTimestamp();
ModulePtr checkedModule = typeChecker.check(module, mode);
stats.timeCheck += getTimestamp() - timestamp;
stats.filesStrict += mode == Mode::Strict;
stats.filesNonstrict += mode == Mode::Nonstrict;
if (checkedModule == nullptr)
throw std::runtime_error("Frontend::check produced a nullptr module for module " + module.name);
moduleResolver.modules[module.name] = checkedModule;
return CheckResult{checkedModule->errors};
}
LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");
@ -685,10 +719,10 @@ LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOp
return classifyLints(warnings, config);
}
bool Frontend::isDirty(const ModuleName& name) const
bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
{
auto it = sourceNodes.find(name);
return it == sourceNodes.end() || it->second.dirty;
return it == sourceNodes.end() || it->second.hasDirtyModule(forAutocomplete);
}
/*
@ -699,13 +733,13 @@ bool Frontend::isDirty(const ModuleName& name) const
*/
void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty)
{
if (!moduleResolver.modules.count(name))
if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name))
return;
std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps;
for (const auto& module : sourceNodes)
{
for (const auto& dep : module.second.requires)
for (const auto& dep : module.second.requireSet)
reverseDeps[dep].push_back(module.first);
}
@ -722,10 +756,12 @@ void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* marked
if (markedDirty)
markedDirty->push_back(next);
if (sourceNode.dirty)
if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete)
continue;
sourceNode.dirty = true;
sourceNode.dirtySourceModule = true;
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
if (0 == reverseDeps.count(name))
continue;
@ -751,6 +787,30 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons
return const_cast<Frontend*>(this)->getSourceModule(moduleName);
}
ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope)
{
ModulePtr result = std::make_shared<Module>();
ConstraintGraphBuilder cgb{&result->internalTypes};
cgb.visit(sourceModule.root);
ConstraintSolver cs{&result->internalTypes, cgb.rootScope};
cs.run();
result->scope2s = std::move(cgb.scopes);
result->astTypes = std::move(cgb.astTypes);
result->astTypePacks = std::move(cgb.astTypePacks);
result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes);
result->astResolvedTypes = std::move(cgb.astResolvedTypes);
result->astResolvedTypePacks = std::move(cgb.astResolvedTypePacks);
result->clonePublicInterface(iceHandler);
Luau::check(sourceModule, result.get());
return result;
}
// Read AST into sourceModules if necessary. Trace require()s. Report parse errors.
std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name)
{
@ -758,7 +818,7 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& check
LUAU_TIMETRACE_ARGUMENT("name", name.c_str());
auto it = sourceNodes.find(name);
if (it != sourceNodes.end() && !it->second.dirty)
if (it != sourceNodes.end() && !it->second.hasDirtySourceModule())
{
auto moduleIt = sourceModules.find(name);
if (moduleIt != sourceModules.end())
@ -789,8 +849,8 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& check
SourceModule result = parse(name, source->source, opts);
result.type = source->type;
RequireTraceResult& requireTrace = requires[name];
requireTrace = traceRequires(fileResolver, result.root, name);
RequireTraceResult& require = requireTrace[name];
require = traceRequires(fileResolver, result.root, name);
SourceNode& sourceNode = sourceNodes[name];
SourceModule& sourceModule = sourceModules[name];
@ -799,14 +859,20 @@ std::pair<SourceNode*, SourceModule*> Frontend::getSourceNode(CheckResult& check
sourceModule.environmentName = environmentName;
sourceNode.name = name;
sourceNode.requires.clear();
sourceNode.requireSet.clear();
sourceNode.requireLocations.clear();
sourceNode.dirty = true;
sourceNode.dirtySourceModule = false;
for (const auto& [moduleName, location] : requireTrace.requires)
sourceNode.requires.insert(moduleName);
if (it == sourceNodes.end())
{
sourceNode.dirtyModule = true;
sourceNode.dirtyModuleForAutocomplete = true;
}
sourceNode.requireLocations = requireTrace.requires;
for (const auto& [moduleName, location] : require.requireList)
sourceNode.requireSet.insert(moduleName);
sourceNode.requireLocations = require.requireList;
return {&sourceNode, &sourceModule};
}
@ -867,8 +933,8 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const
std::optional<ModuleInfo> FrontendModuleResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr)
{
// FIXME I think this can be pushed into the FileResolver.
auto it = frontend->requires.find(currentModuleName);
if (it == frontend->requires.end())
auto it = frontend->requireTrace.find(currentModuleName);
if (it == frontend->requireTrace.end())
{
// CLI-43699
// If we can't find the current module name, that's because we bypassed the frontend's initializer
@ -967,7 +1033,7 @@ void Frontend::clear()
sourceModules.clear();
moduleResolver.modules.clear();
moduleResolverForAutocomplete.modules.clear();
requires.clear();
requireTrace.clear();
}
} // namespace Luau

View file

@ -0,0 +1,124 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Common.h"
#include "Luau/Instantiation.h"
#include "Luau/TxnLog.h"
#include "Luau/TypeArena.h"
namespace Luau
{
bool Instantiation::isDirty(TypeId ty)
{
if (const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty))
{
if (ftv->hasNoGenerics)
return false;
return true;
}
else
{
return false;
}
}
bool Instantiation::isDirty(TypePackId tp)
{
return false;
}
bool Instantiation::ignoreChildren(TypeId ty)
{
if (log->getMutable<FunctionTypeVar>(ty))
return true;
else
return false;
}
TypeId Instantiation::clean(TypeId ty)
{
const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv);
FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf};
clone.magicFunction = ftv->magicFunction;
clone.tags = ftv->tags;
clone.argNames = ftv->argNames;
TypeId result = addType(std::move(clone));
// Annoyingly, we have to do this even if there are no generics,
// to replace any generic tables.
ReplaceGenerics replaceGenerics{log, arena, level, ftv->generics, ftv->genericPacks};
// TODO: What to do if this returns nullopt?
// We don't have access to the error-reporting machinery
result = replaceGenerics.substitute(result).value_or(result);
asMutable(result)->documentationSymbol = ty->documentationSymbol;
return result;
}
TypePackId Instantiation::clean(TypePackId tp)
{
LUAU_ASSERT(false);
return tp;
}
bool ReplaceGenerics::ignoreChildren(TypeId ty)
{
if (const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty))
{
if (ftv->hasNoGenerics)
return true;
// We aren't recursing in the case of a generic function which
// binds the same generics. This can happen if, for example, there's recursive types.
// If T = <a>(a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'.
// It's OK to use vector equality here, since we always generate fresh generics
// whenever we quantify, so the vectors overlap if and only if they are equal.
return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks);
}
else
{
return false;
}
}
bool ReplaceGenerics::isDirty(TypeId ty)
{
if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
return ttv->state == TableState::Generic;
else if (log->getMutable<GenericTypeVar>(ty))
return std::find(generics.begin(), generics.end(), ty) != generics.end();
else
return false;
}
bool ReplaceGenerics::isDirty(TypePackId tp)
{
if (log->getMutable<GenericTypePack>(tp))
return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end();
else
return false;
}
TypeId ReplaceGenerics::clean(TypeId ty)
{
LUAU_ASSERT(isDirty(ty));
if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
{
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free};
clone.definitionModuleName = ttv->definitionModuleName;
return addType(std::move(clone));
}
else
return addType(FreeTypeVar{level});
}
TypePackId ReplaceGenerics::clean(TypePackId tp)
{
LUAU_ASSERT(isDirty(tp));
return addTypePack(TypePackVar(FreeTypePack{level}));
}
} // namespace Luau

View file

@ -23,9 +23,182 @@ std::ostream& operator<<(std::ostream& stream, const AstName& name)
return stream << "<empty>";
}
std::ostream& operator<<(std::ostream& stream, const TypeMismatch& tm)
template<typename T>
static void errorToString(std::ostream& stream, const T& err)
{
return stream << "TypeMismatch { " << toString(tm.wantedType) << ", " << toString(tm.givenType) << " }";
if constexpr (false)
{
}
else if constexpr (std::is_same_v<T, TypeMismatch>)
stream << "TypeMismatch { " << toString(err.wantedType) << ", " << toString(err.givenType) << " }";
else if constexpr (std::is_same_v<T, UnknownSymbol>)
stream << "UnknownSymbol { " << err.name << " , context " << err.context << " }";
else if constexpr (std::is_same_v<T, UnknownProperty>)
stream << "UnknownProperty { " << toString(err.table) << ", key = " << err.key << " }";
else if constexpr (std::is_same_v<T, NotATable>)
stream << "NotATable { " << toString(err.ty) << " }";
else if constexpr (std::is_same_v<T, CannotExtendTable>)
stream << "CannotExtendTable { " << toString(err.tableType) << ", context " << err.context << ", prop \"" << err.prop << "\" }";
else if constexpr (std::is_same_v<T, OnlyTablesCanHaveMethods>)
stream << "OnlyTablesCanHaveMethods { " << toString(err.tableType) << " }";
else if constexpr (std::is_same_v<T, DuplicateTypeDefinition>)
stream << "DuplicateTypeDefinition { " << err.name << " }";
else if constexpr (std::is_same_v<T, CountMismatch>)
stream << "CountMismatch { expected " << err.expected << ", got " << err.actual << ", context " << err.context << " }";
else if constexpr (std::is_same_v<T, FunctionDoesNotTakeSelf>)
stream << "FunctionDoesNotTakeSelf { }";
else if constexpr (std::is_same_v<T, FunctionRequiresSelf>)
stream << "FunctionRequiresSelf { }";
else if constexpr (std::is_same_v<T, OccursCheckFailed>)
stream << "OccursCheckFailed { }";
else if constexpr (std::is_same_v<T, UnknownRequire>)
stream << "UnknownRequire { " << err.modulePath << " }";
else if constexpr (std::is_same_v<T, IncorrectGenericParameterCount>)
{
stream << "IncorrectGenericParameterCount { name = " << err.name;
if (!err.typeFun.typeParams.empty() || !err.typeFun.typePackParams.empty())
{
stream << "<";
bool first = true;
for (auto param : err.typeFun.typeParams)
{
if (first)
first = false;
else
stream << ", ";
stream << toString(param.ty);
}
for (auto param : err.typeFun.typePackParams)
{
if (first)
first = false;
else
stream << ", ";
stream << toString(param.tp);
}
stream << ">";
}
stream << ", typeFun = " << toString(err.typeFun.type) << ", actualCount = " << err.actualParameters << " }";
}
else if constexpr (std::is_same_v<T, SyntaxError>)
stream << "SyntaxError { " << err.message << " }";
else if constexpr (std::is_same_v<T, CodeTooComplex>)
stream << "CodeTooComplex {}";
else if constexpr (std::is_same_v<T, UnificationTooComplex>)
stream << "UnificationTooComplex {}";
else if constexpr (std::is_same_v<T, UnknownPropButFoundLikeProp>)
{
stream << "UnknownPropButFoundLikeProp { key = '" << err.key << "', suggested = { ";
bool first = true;
for (Name name : err.candidates)
{
if (first)
first = false;
else
stream << ", ";
stream << "'" << name << "'";
}
stream << " }, table = " << toString(err.table) << " } ";
}
else if constexpr (std::is_same_v<T, GenericError>)
stream << "GenericError { " << err.message << " }";
else if constexpr (std::is_same_v<T, InternalError>)
stream << "InternalError { " << err.message << " }";
else if constexpr (std::is_same_v<T, CannotCallNonFunction>)
stream << "CannotCallNonFunction { " << toString(err.ty) << " }";
else if constexpr (std::is_same_v<T, ExtraInformation>)
stream << "ExtraInformation { " << err.message << " }";
else if constexpr (std::is_same_v<T, DeprecatedApiUsed>)
stream << "DeprecatedApiUsed { " << err.symbol << ", useInstead = " << err.useInstead << " }";
else if constexpr (std::is_same_v<T, ModuleHasCyclicDependency>)
{
stream << "ModuleHasCyclicDependency {";
bool first = true;
for (const ModuleName& name : err.cycle)
{
if (first)
first = false;
else
stream << ", ";
stream << name;
}
stream << "}";
}
else if constexpr (std::is_same_v<T, IllegalRequire>)
stream << "IllegalRequire { " << err.moduleName << ", reason = " << err.reason << " }";
else if constexpr (std::is_same_v<T, FunctionExitsWithoutReturning>)
stream << "FunctionExitsWithoutReturning {" << toString(err.expectedReturnType) << "}";
else if constexpr (std::is_same_v<T, DuplicateGenericParameter>)
stream << "DuplicateGenericParameter { " + err.parameterName + " }";
else if constexpr (std::is_same_v<T, CannotInferBinaryOperation>)
stream << "CannotInferBinaryOperation { op = " + toString(err.op) + ", suggested = '" +
(err.suggestedToAnnotate ? *err.suggestedToAnnotate : "") + "', kind "
<< err.kind << "}";
else if constexpr (std::is_same_v<T, MissingProperties>)
{
stream << "MissingProperties { superType = '" << toString(err.superType) << "', subType = '" << toString(err.subType) << "', properties = { ";
bool first = true;
for (Name name : err.properties)
{
if (first)
first = false;
else
stream << ", ";
stream << "'" << name << "'";
}
stream << " }, context " << err.context << " } ";
}
else if constexpr (std::is_same_v<T, SwappedGenericTypeParameter>)
stream << "SwappedGenericTypeParameter { name = '" + err.name + "', kind = " + std::to_string(err.kind) + " }";
else if constexpr (std::is_same_v<T, OptionalValueAccess>)
stream << "OptionalValueAccess { optional = '" + toString(err.optional) + "' }";
else if constexpr (std::is_same_v<T, MissingUnionProperty>)
{
stream << "MissingUnionProperty { type = '" + toString(err.type) + "', missing = { ";
bool first = true;
for (auto ty : err.missing)
{
if (first)
first = false;
else
stream << ", ";
stream << "'" << toString(ty) << "'";
}
stream << " }, key = '" + err.key + "' }";
}
else if constexpr (std::is_same_v<T, TypesAreUnrelated>)
stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }";
else if constexpr (std::is_same_v<T, NormalizationTooComplex>)
stream << "NormalizationTooComplex { }";
else
static_assert(always_false_v<T>, "Non-exhaustive type switch");
}
std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data)
{
auto cb = [&](const auto& e) {
return errorToString(stream, e);
};
visit(cb, data);
return stream;
}
std::ostream& operator<<(std::ostream& stream, const TypeError& error)
@ -33,241 +206,6 @@ std::ostream& operator<<(std::ostream& stream, const TypeError& error)
return stream << "TypeError { \"" << error.moduleName << "\", " << error.location << ", " << error.data << " }";
}
std::ostream& operator<<(std::ostream& stream, const UnknownSymbol& error)
{
return stream << "UnknownSymbol { " << error.name << " , context " << error.context << " }";
}
std::ostream& operator<<(std::ostream& stream, const UnknownProperty& error)
{
return stream << "UnknownProperty { " << toString(error.table) << ", key = " << error.key << " }";
}
std::ostream& operator<<(std::ostream& stream, const NotATable& ge)
{
return stream << "NotATable { " << toString(ge.ty) << " }";
}
std::ostream& operator<<(std::ostream& stream, const CannotExtendTable& error)
{
return stream << "CannotExtendTable { " << toString(error.tableType) << ", context " << error.context << ", prop \"" << error.prop << "\" }";
}
std::ostream& operator<<(std::ostream& stream, const OnlyTablesCanHaveMethods& error)
{
return stream << "OnlyTablesCanHaveMethods { " << toString(error.tableType) << " }";
}
std::ostream& operator<<(std::ostream& stream, const DuplicateTypeDefinition& error)
{
return stream << "DuplicateTypeDefinition { " << error.name << " }";
}
std::ostream& operator<<(std::ostream& stream, const CountMismatch& error)
{
return stream << "CountMismatch { expected " << error.expected << ", got " << error.actual << ", context " << error.context << " }";
}
std::ostream& operator<<(std::ostream& stream, const FunctionDoesNotTakeSelf&)
{
return stream << "FunctionDoesNotTakeSelf { }";
}
std::ostream& operator<<(std::ostream& stream, const FunctionRequiresSelf& error)
{
return stream << "FunctionRequiresSelf { extraNils " << error.requiredExtraNils << " }";
}
std::ostream& operator<<(std::ostream& stream, const OccursCheckFailed&)
{
return stream << "OccursCheckFailed { }";
}
std::ostream& operator<<(std::ostream& stream, const UnknownRequire& error)
{
return stream << "UnknownRequire { " << error.modulePath << " }";
}
std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCount& error)
{
stream << "IncorrectGenericParameterCount { name = " << error.name;
if (!error.typeFun.typeParams.empty() || !error.typeFun.typePackParams.empty())
{
stream << "<";
bool first = true;
for (auto param : error.typeFun.typeParams)
{
if (first)
first = false;
else
stream << ", ";
stream << toString(param.ty);
}
for (auto param : error.typeFun.typePackParams)
{
if (first)
first = false;
else
stream << ", ";
stream << toString(param.tp);
}
stream << ">";
}
stream << ", typeFun = " << toString(error.typeFun.type) << ", actualCount = " << error.actualParameters << " }";
return stream;
}
std::ostream& operator<<(std::ostream& stream, const SyntaxError& ge)
{
return stream << "SyntaxError { " << ge.message << " }";
}
std::ostream& operator<<(std::ostream& stream, const CodeTooComplex&)
{
return stream << "CodeTooComplex {}";
}
std::ostream& operator<<(std::ostream& stream, const UnificationTooComplex&)
{
return stream << "UnificationTooComplex {}";
}
std::ostream& operator<<(std::ostream& stream, const UnknownPropButFoundLikeProp& e)
{
stream << "UnknownPropButFoundLikeProp { key = '" << e.key << "', suggested = { ";
bool first = true;
for (Name name : e.candidates)
{
if (first)
first = false;
else
stream << ", ";
stream << "'" << name << "'";
}
return stream << " }, table = " << toString(e.table) << " } ";
}
std::ostream& operator<<(std::ostream& stream, const GenericError& ge)
{
return stream << "GenericError { " << ge.message << " }";
}
std::ostream& operator<<(std::ostream& stream, const CannotCallNonFunction& e)
{
return stream << "CannotCallNonFunction { " << toString(e.ty) << " }";
}
std::ostream& operator<<(std::ostream& stream, const FunctionExitsWithoutReturning& error)
{
return stream << "FunctionExitsWithoutReturning {" << toString(error.expectedReturnType) << "}";
}
std::ostream& operator<<(std::ostream& stream, const ExtraInformation& e)
{
return stream << "ExtraInformation { " << e.message << " }";
}
std::ostream& operator<<(std::ostream& stream, const DeprecatedApiUsed& e)
{
return stream << "DeprecatedApiUsed { " << e.symbol << ", useInstead = " << e.useInstead << " }";
}
std::ostream& operator<<(std::ostream& stream, const ModuleHasCyclicDependency& e)
{
stream << "ModuleHasCyclicDependency {";
bool first = true;
for (const ModuleName& name : e.cycle)
{
if (first)
first = false;
else
stream << ", ";
stream << name;
}
return stream << "}";
}
std::ostream& operator<<(std::ostream& stream, const IllegalRequire& e)
{
return stream << "IllegalRequire { " << e.moduleName << ", reason = " << e.reason << " }";
}
std::ostream& operator<<(std::ostream& stream, const MissingProperties& e)
{
stream << "MissingProperties { superType = '" << toString(e.superType) << "', subType = '" << toString(e.subType) << "', properties = { ";
bool first = true;
for (Name name : e.properties)
{
if (first)
first = false;
else
stream << ", ";
stream << "'" << name << "'";
}
return stream << " }, context " << e.context << " } ";
}
std::ostream& operator<<(std::ostream& stream, const DuplicateGenericParameter& error)
{
return stream << "DuplicateGenericParameter { " + error.parameterName + " }";
}
std::ostream& operator<<(std::ostream& stream, const CannotInferBinaryOperation& error)
{
return stream << "CannotInferBinaryOperation { op = " + toString(error.op) + ", suggested = '" +
(error.suggestedToAnnotate ? *error.suggestedToAnnotate : "") + "', kind "
<< error.kind << "}";
}
std::ostream& operator<<(std::ostream& stream, const SwappedGenericTypeParameter& error)
{
return stream << "SwappedGenericTypeParameter { name = '" + error.name + "', kind = " + std::to_string(error.kind) + " }";
}
std::ostream& operator<<(std::ostream& stream, const OptionalValueAccess& error)
{
return stream << "OptionalValueAccess { optional = '" + toString(error.optional) + "' }";
}
std::ostream& operator<<(std::ostream& stream, const MissingUnionProperty& error)
{
stream << "MissingUnionProperty { type = '" + toString(error.type) + "', missing = { ";
bool first = true;
for (auto ty : error.missing)
{
if (first)
first = false;
else
stream << ", ";
stream << "'" << toString(ty) << "'";
}
return stream << " }, key = '" + error.key + "' }";
}
std::ostream& operator<<(std::ostream& stream, const TypesAreUnrelated& error)
{
stream << "TypesAreUnrelated { left = '" + toString(error.left) + "', right = '" + toString(error.right) + "' }";
return stream;
}
std::ostream& operator<<(std::ostream& stream, const TableState& tv)
{
return stream << static_cast<std::underlying_type<TableState>::type>(tv);
@ -283,15 +221,4 @@ std::ostream& operator<<(std::ostream& stream, const TypePackVar& tv)
return stream << toString(tv);
}
std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted)
{
Luau::visit(
[&](const auto& a) {
lhs << a;
},
ted);
return lhs;
}
} // namespace Luau

View file

@ -403,35 +403,26 @@ struct AstJsonEncoder : public AstVisitor
void write(const AstExprTable::Item& item)
{
writeRaw("{");
bool comma = pushComma();
bool c = pushComma();
write("kind", item.kind);
switch (item.kind)
{
case AstExprTable::Item::List:
write(item.value);
write("value", item.value);
break;
default:
write(item.key);
writeRaw(",");
write(item.value);
write("key", item.key);
write("value", item.value);
break;
}
popComma(comma);
popComma(c);
writeRaw("}");
}
void write(class AstExprTable* node)
{
writeNode(node, "AstExprTable", [&]() {
bool comma = false;
for (const auto& prop : node->items)
{
if (comma)
writeRaw(",");
else
comma = true;
write(prop);
}
PROP(items);
});
}

View file

@ -77,19 +77,15 @@ std::optional<LValue> tryGetLValue(const AstExpr& node)
return std::nullopt;
}
std::pair<Symbol, std::vector<std::string>> getFullName(const LValue& lvalue)
Symbol getBaseSymbol(const LValue& lvalue)
{
const LValue* current = &lvalue;
std::vector<std::string> keys;
while (auto field = get<Field>(*current))
{
keys.push_back(field->key);
current = baseof(*current);
}
const Symbol* symbol = get<Symbol>(*current);
LUAU_ASSERT(symbol);
return {*symbol, std::vector<std::string>(keys.rbegin(), keys.rend())};
return *symbol;
}
void merge(RefinementMap& l, const RefinementMap& r, std::function<TypeId(TypeId, TypeId)> f)

View file

@ -14,7 +14,6 @@
LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4)
LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false)
LUAU_FASTFLAGVARIABLE(LuauLintNoRobloxBits, false)
namespace Luau
{
@ -1140,25 +1139,8 @@ private:
Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata.
Kind_Vector, // 'vector' but only used when type is used
Kind_Userdata, // custom userdata type
// TODO: remove these with LuauLintNoRobloxBits
Kind_Class, // custom userdata type that reflects Roblox Instance-derived hierarchy - Part/etc.
Kind_Enum, // custom userdata type referring to an enum item of enum classes, e.g. Enum.NormalId.Back/Enum.Axis.X/etc.
};
bool containsPropName(TypeId ty, const std::string& propName)
{
LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits);
if (auto ctv = get<ClassTypeVar>(ty))
return lookupClassProp(ctv, propName) != nullptr;
if (auto ttv = get<TableTypeVar>(ty))
return ttv->props.find(propName) != ttv->props.end();
return false;
}
TypeKind getTypeKind(const std::string& name)
{
if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" ||
@ -1168,23 +1150,10 @@ private:
if (name == "vector")
return Kind_Vector;
if (FFlag::LuauLintNoRobloxBits)
{
if (std::optional<TypeFun> maybeTy = context->scope->lookupType(name))
return Kind_Userdata;
if (std::optional<TypeFun> maybeTy = context->scope->lookupType(name))
return Kind_Userdata;
return Kind_Unknown;
}
else
{
if (std::optional<TypeFun> maybeTy = context->scope->lookupType(name))
// Kind_Userdata is probably not 100% precise but is close enough
return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata;
else if (std::optional<TypeFun> maybeTy = context->scope->lookupImportedType("Enum", name))
return Kind_Enum;
return Kind_Unknown;
}
return Kind_Unknown;
}
void validateType(AstExprConstantString* expr, std::initializer_list<TypeKind> expected, const char* expectedString)
@ -1202,67 +1171,11 @@ private:
{
if (kind == ek)
return;
// as a special case, Instance and EnumItem are both a userdata type (as returned by typeof) and a class type
if (!FFlag::LuauLintNoRobloxBits && ek == Kind_Userdata && (name == "Instance" || name == "EnumItem"))
return;
}
emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s' (expected %s)", name.c_str(), expectedString);
}
bool acceptsClassName(AstName method)
{
LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits);
return method.value[0] == 'F' && (method == "FindFirstChildOfClass" || method == "FindFirstChildWhichIsA" ||
method == "FindFirstAncestorOfClass" || method == "FindFirstAncestorWhichIsA");
}
bool visit(AstExprCall* node) override
{
// TODO: Simply remove the override
if (FFlag::LuauLintNoRobloxBits)
return true;
if (AstExprIndexName* index = node->func->as<AstExprIndexName>())
{
AstExprConstantString* arg0 = node->args.size > 0 ? node->args.data[0]->as<AstExprConstantString>() : NULL;
if (arg0)
{
if (node->self && index->index == "IsA" && node->args.size == 1)
{
validateType(arg0, {Kind_Class, Kind_Enum}, "class or enum type");
}
else if (node->self && (index->index == "GetService" || index->index == "FindService") && node->args.size == 1)
{
AstExprGlobal* g = index->expr->as<AstExprGlobal>();
if (g && (g->name == "game" || g->name == "Game"))
{
validateType(arg0, {Kind_Class}, "class type");
}
}
else if (node->self && acceptsClassName(index->index) && node->args.size == 1)
{
validateType(arg0, {Kind_Class}, "class type");
}
else if (!node->self && index->index == "new" && node->args.size <= 2)
{
AstExprGlobal* g = index->expr->as<AstExprGlobal>();
if (g && g->name == "Instance")
{
validateType(arg0, {Kind_Class}, "class type");
}
}
}
}
return true;
}
bool visit(AstExprBinary* node) override
{
if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq)
@ -2369,7 +2282,7 @@ private:
size_t getReturnCount(TypeId ty)
{
if (auto ftv = get<FunctionTypeVar>(ty))
return size(ftv->retType);
return size(ftv->retTypes);
if (auto itv = get<IntersectionTypeVar>(ty))
{
@ -2378,7 +2291,7 @@ private:
for (TypeId part : itv->parts)
if (auto ftv = get<FunctionTypeVar>(follow(part)))
result = std::max(result, size(ftv->retType));
result = std::max(result, size(ftv->retTypes));
return result;
}
@ -2740,12 +2653,12 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
}
else
{
std::string::size_type space = hc.content.find_first_of(" \t");
size_t space = hc.content.find_first_of(" \t");
std::string_view first = std::string_view(hc.content).substr(0, space);
if (first == "nolint")
{
std::string::size_type notspace = hc.content.find_first_not_of(" \t", space);
size_t notspace = hc.content.find_first_not_of(" \t", space);
if (space == std::string::npos || notspace == std::string::npos)
{
@ -2914,7 +2827,7 @@ uint64_t LintWarning::parseMask(const std::vector<HotComment>& hotcomments)
if (hc.content.compare(0, 6, "nolint") != 0)
continue;
std::string::size_type name = hc.content.find_first_not_of(" \t", 6);
size_t name = hc.content.find_first_not_of(" \t", 6);
// --!nolint disables everything
if (name == std::string::npos)

View file

@ -1,7 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Module.h"
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/ConstraintGraphBuilder.h"
#include "Luau/Normalize.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/TypeInfer.h"
@ -11,9 +14,9 @@
#include <algorithm>
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false)
LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300)
LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false)
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
LUAU_FASTFLAG(LuauNormalizeFlagIsConservative);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
namespace Luau
{
@ -53,421 +56,121 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos)
return contains(pos, *iter);
}
void TypeArena::clear()
struct ForceNormal : TypeVarOnceVisitor
{
typeVars.clear();
typePacks.clear();
}
const TypeArena* typeArena = nullptr;
TypeId TypeArena::addTV(TypeVar&& tv)
{
TypeId allocated = typeVars.allocate(std::move(tv));
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(TypeLevel level)
{
TypeId allocated = typeVars.allocate(FreeTypeVar{level});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(std::initializer_list<TypeId> types)
{
TypePackId allocated = typePacks.allocate(TypePack{std::move(types)});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(std::vector<TypeId> types)
{
TypePackId allocated = typePacks.allocate(TypePack{std::move(types)});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(TypePack tp)
{
TypePackId allocated = typePacks.allocate(std::move(tp));
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(TypePackVar tp)
{
TypePackId allocated = typePacks.allocate(std::move(tp));
asMutable(allocated)->owningArena = this;
return allocated;
}
namespace
{
struct TypePackCloner;
/*
* Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set.
* They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage.
*/
struct TypeCloner
{
TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState)
: dest(dest)
, typeId(typeId)
, seenTypes(seenTypes)
, seenTypePacks(seenTypePacks)
, cloneState(cloneState)
ForceNormal(const TypeArena* typeArena)
: typeArena(typeArena)
{
}
TypeArena& dest;
TypeId typeId;
SeenTypes& seenTypes;
SeenTypePacks& seenTypePacks;
CloneState& cloneState;
template<typename T>
void defaultClone(const T& t);
void operator()(const Unifiable::Free& t);
void operator()(const Unifiable::Generic& t);
void operator()(const Unifiable::Bound<TypeId>& t);
void operator()(const Unifiable::Error& t);
void operator()(const PrimitiveTypeVar& t);
void operator()(const SingletonTypeVar& t);
void operator()(const FunctionTypeVar& t);
void operator()(const TableTypeVar& t);
void operator()(const MetatableTypeVar& t);
void operator()(const ClassTypeVar& t);
void operator()(const AnyTypeVar& t);
void operator()(const UnionTypeVar& t);
void operator()(const IntersectionTypeVar& t);
void operator()(const LazyTypeVar& t);
};
struct TypePackCloner
{
TypeArena& dest;
TypePackId typePackId;
SeenTypes& seenTypes;
SeenTypePacks& seenTypePacks;
CloneState& cloneState;
TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState)
: dest(dest)
, typePackId(typePackId)
, seenTypes(seenTypes)
, seenTypePacks(seenTypePacks)
, cloneState(cloneState)
bool visit(TypeId ty) override
{
if (ty->owningArena != typeArena)
return false;
asMutable(ty)->normal = true;
return true;
}
template<typename T>
void defaultClone(const T& t)
bool visit(TypeId ty, const FreeTypeVar& ftv) override
{
TypePackId cloned = dest.addTypePack(TypePackVar{t});
seenTypePacks[typePackId] = cloned;
visit(ty);
return true;
}
void operator()(const Unifiable::Free& t)
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
cloneState.encounteredFreeType = true;
TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack);
TypePackId cloned = dest.addTypePack(*err);
seenTypePacks[typePackId] = cloned;
}
void operator()(const Unifiable::Generic& t)
{
defaultClone(t);
}
void operator()(const Unifiable::Error& t)
{
defaultClone(t);
}
// While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter.
// We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer.
void operator()(const Unifiable::Bound<TypePackId>& t)
{
TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState);
seenTypePacks[typePackId] = cloned;
}
void operator()(const VariadicTypePack& t)
{
TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}});
seenTypePacks[typePackId] = cloned;
}
void operator()(const TypePack& t)
{
TypePackId cloned = dest.addTypePack(TypePack{});
TypePack* destTp = getMutable<TypePack>(cloned);
LUAU_ASSERT(destTp != nullptr);
seenTypePacks[typePackId] = cloned;
for (TypeId ty : t.head)
destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
if (t.tail)
destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState);
return true;
}
};
template<typename T>
void TypeCloner::defaultClone(const T& t)
Module::~Module()
{
TypeId cloned = dest.addType(t);
seenTypes[typeId] = cloned;
unfreeze(interfaceTypes);
unfreeze(internalTypes);
}
void TypeCloner::operator()(const Unifiable::Free& t)
void Module::clonePublicInterface(InternalErrorReporter& ice)
{
cloneState.encounteredFreeType = true;
TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType);
TypeId cloned = dest.addType(*err);
seenTypes[typeId] = cloned;
}
LUAU_ASSERT(interfaceTypes.typeVars.empty());
LUAU_ASSERT(interfaceTypes.typePacks.empty());
void TypeCloner::operator()(const Unifiable::Generic& t)
{
defaultClone(t);
}
CloneState cloneState;
void TypeCloner::operator()(const Unifiable::Bound<TypeId>& t)
{
TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState);
seenTypes[typeId] = boundTo;
}
ScopePtr moduleScope = FFlag::DebugLuauDeferredConstraintResolution ? nullptr : getModuleScope();
Scope2* moduleScope2 = FFlag::DebugLuauDeferredConstraintResolution ? getModuleScope2() : nullptr;
void TypeCloner::operator()(const Unifiable::Error& t)
{
defaultClone(t);
}
TypePackId returnType = FFlag::DebugLuauDeferredConstraintResolution ? moduleScope2->returnType : moduleScope->returnType;
std::optional<TypePackId> varargPack = FFlag::DebugLuauDeferredConstraintResolution ? std::nullopt : moduleScope->varargPack;
std::unordered_map<Name, TypeFun>* exportedTypeBindings =
FFlag::DebugLuauDeferredConstraintResolution ? nullptr : &moduleScope->exportedTypeBindings;
void TypeCloner::operator()(const PrimitiveTypeVar& t)
{
defaultClone(t);
}
returnType = clone(returnType, interfaceTypes, cloneState);
void TypeCloner::operator()(const SingletonTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const FunctionTypeVar& t)
{
TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf});
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(result);
LUAU_ASSERT(ftv != nullptr);
seenTypes[typeId] = result;
for (TypeId generic : t.generics)
ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState));
for (TypePackId genericPack : t.genericPacks)
ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState));
ftv->tags = t.tags;
ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState);
ftv->argNames = t.argNames;
ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState);
}
void TypeCloner::operator()(const TableTypeVar& t)
{
// If table is now bound to another one, we ignore the content of the original
if (t.boundTo)
if (moduleScope)
{
TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState);
seenTypes[typeId] = boundTo;
return;
moduleScope->returnType = returnType;
if (varargPack)
{
varargPack = clone(*varargPack, interfaceTypes, cloneState);
moduleScope->varargPack = varargPack;
}
}
else
{
LUAU_ASSERT(moduleScope2);
moduleScope2->returnType = returnType; // TODO varargPack
}
TypeId result = dest.addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(result);
LUAU_ASSERT(ttv != nullptr);
*ttv = t;
seenTypes[typeId] = result;
ttv->level = TypeLevel{0, 0};
for (const auto& [name, prop] : t.props)
ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags};
if (t.indexer)
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState),
clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)};
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState);
for (TypePackId& arg : ttv->instantiatedTypePackParams)
arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState);
if (ttv->state == TableState::Free)
if (FFlag::LuauLowerBoundsCalculation)
{
cloneState.encounteredFreeType = true;
ttv->state = TableState::Sealed;
normalize(returnType, interfaceTypes, ice);
if (varargPack)
normalize(*varargPack, interfaceTypes, ice);
}
ttv->definitionModuleName = t.definitionModuleName;
ttv->methodDefinitionLocations = t.methodDefinitionLocations;
ttv->tags = t.tags;
}
ForceNormal forceNormal{&interfaceTypes};
void TypeCloner::operator()(const MetatableTypeVar& t)
{
TypeId result = dest.addType(MetatableTypeVar{});
MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(result);
seenTypes[typeId] = result;
mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState);
mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState);
}
void TypeCloner::operator()(const ClassTypeVar& t)
{
TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData});
ClassTypeVar* ctv = getMutable<ClassTypeVar>(result);
seenTypes[typeId] = result;
for (const auto& [name, prop] : t.props)
ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags};
if (t.parent)
ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState);
if (t.metatable)
ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState);
}
void TypeCloner::operator()(const AnyTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const UnionTypeVar& t)
{
std::vector<TypeId> options;
options.reserve(t.options.size());
for (TypeId ty : t.options)
options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
TypeId result = dest.addType(UnionTypeVar{std::move(options)});
seenTypes[typeId] = result;
}
void TypeCloner::operator()(const IntersectionTypeVar& t)
{
TypeId result = dest.addType(IntersectionTypeVar{});
seenTypes[typeId] = result;
IntersectionTypeVar* option = getMutable<IntersectionTypeVar>(result);
LUAU_ASSERT(option != nullptr);
for (TypeId ty : t.parts)
option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
}
void TypeCloner::operator()(const LazyTypeVar& t)
{
defaultClone(t);
}
} // anonymous namespace
TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState)
{
if (tp->persistent)
return tp;
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
TypePackId& res = seenTypePacks[tp];
if (res == nullptr)
if (exportedTypeBindings)
{
TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState};
Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into.
for (auto& [name, tf] : *exportedTypeBindings)
{
tf = clone(tf, interfaceTypes, cloneState);
if (FFlag::LuauLowerBoundsCalculation)
{
normalize(tf.type, interfaceTypes, ice);
if (FFlag::LuauNormalizeFlagIsConservative)
{
// We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables
// won't be marked normal. If the types aren't normal by now, they never will be.
forceNormal.traverse(tf.type);
}
}
}
}
return res;
}
TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState)
{
if (typeId->persistent)
return typeId;
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
TypeId& res = seenTypes[typeId];
if (res == nullptr)
for (TypeId ty : returnType)
{
TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState};
Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into.
// Persistent types are not being cloned and we get the original type back which might be read-only
if (!res->persistent)
asMutable(res)->documentationSymbol = typeId->documentationSymbol;
if (get<GenericTypeVar>(follow(ty)))
{
auto t = asMutable(ty);
t->ty = AnyTypeVar{};
t->normal = true;
}
}
return res;
}
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState)
{
TypeFun result;
for (auto param : typeFun.typeParams)
for (auto& [name, ty] : declaredGlobals)
{
TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState);
std::optional<TypeId> defaultValue;
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState);
result.typeParams.push_back({ty, defaultValue});
ty = clone(ty, interfaceTypes, cloneState);
if (FFlag::LuauLowerBoundsCalculation)
normalize(ty, interfaceTypes, ice);
}
for (auto param : typeFun.typePackParams)
{
TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState);
std::optional<TypePackId> defaultValue;
if (param.defaultValue)
defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState);
result.typePackParams.push_back({tp, defaultValue});
}
result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState);
return result;
freeze(internalTypes);
freeze(interfaceTypes);
}
ScopePtr Module::getModuleScope() const
@ -476,62 +179,10 @@ ScopePtr Module::getModuleScope() const
return scopes.front().second;
}
void freeze(TypeArena& arena)
Scope2* Module::getModuleScope2() const
{
if (!FFlag::DebugLuauFreezeArena)
return;
arena.typeVars.freeze();
arena.typePacks.freeze();
}
void unfreeze(TypeArena& arena)
{
if (!FFlag::DebugLuauFreezeArena)
return;
arena.typeVars.unfreeze();
arena.typePacks.unfreeze();
}
Module::~Module()
{
unfreeze(interfaceTypes);
unfreeze(internalTypes);
}
bool Module::clonePublicInterface()
{
LUAU_ASSERT(interfaceTypes.typeVars.empty());
LUAU_ASSERT(interfaceTypes.typePacks.empty());
SeenTypes seenTypes;
SeenTypePacks seenTypePacks;
CloneState cloneState;
ScopePtr moduleScope = getModuleScope();
moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, cloneState);
if (moduleScope->varargPack)
moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState);
for (auto& [name, tf] : moduleScope->exportedTypeBindings)
tf = clone(tf, interfaceTypes, seenTypes, seenTypePacks, cloneState);
for (TypeId ty : moduleScope->returnType)
if (get<GenericTypeVar>(follow(ty)))
*asMutable(ty) = AnyTypeVar{};
if (FFlag::LuauCloneDeclaredGlobals)
{
for (auto& [name, ty] : declaredGlobals)
ty = clone(ty, interfaceTypes, seenTypes, seenTypePacks, cloneState);
}
freeze(internalTypes);
freeze(interfaceTypes);
return cloneState.encounteredFreeType;
LUAU_ASSERT(!scope2s.empty());
return scope2s.front().second.get();
}
} // namespace Luau

859
Analysis/src/Normalize.cpp Normal file
View file

@ -0,0 +1,859 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Normalize.h"
#include <algorithm>
#include "Luau/Clone.h"
#include "Luau/Substitution.h"
#include "Luau/Unifier.h"
#include "Luau/VisitTypeVar.h"
LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false)
// This could theoretically be 2000 on amd64, but x86 requires this.
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false);
LUAU_FASTFLAGVARIABLE(LuauReplaceReplacer, false);
LUAU_FASTFLAG(LuauQuantifyConstrained)
namespace Luau
{
namespace
{
struct Replacer : Substitution
{
TypeId sourceType;
TypeId replacedType;
DenseHashMap<TypeId, TypeId> replacedTypes{nullptr};
DenseHashMap<TypePackId, TypePackId> replacedPacks{nullptr};
Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType)
: Substitution(TxnLog::empty(), arena)
, sourceType(sourceType)
, replacedType(replacedType)
{
}
bool isDirty(TypeId ty) override
{
if (!sourceType)
return false;
auto vecHasSourceType = [sourceType = sourceType](const auto& vec) {
return end(vec) != std::find(begin(vec), end(vec), sourceType);
};
// Walk every kind of TypeVar and find pointers to sourceType
if (auto t = get<FreeTypeVar>(ty))
return false;
else if (auto t = get<GenericTypeVar>(ty))
return false;
else if (auto t = get<ErrorTypeVar>(ty))
return false;
else if (auto t = get<PrimitiveTypeVar>(ty))
return false;
else if (auto t = get<ConstrainedTypeVar>(ty))
return vecHasSourceType(t->parts);
else if (auto t = get<SingletonTypeVar>(ty))
return false;
else if (auto t = get<FunctionTypeVar>(ty))
{
if (vecHasSourceType(t->generics))
return true;
return false;
}
else if (auto t = get<TableTypeVar>(ty))
{
if (t->boundTo)
return *t->boundTo == sourceType;
for (const auto& [_name, prop] : t->props)
{
if (prop.type == sourceType)
return true;
}
if (auto indexer = t->indexer)
{
if (indexer->indexType == sourceType || indexer->indexResultType == sourceType)
return true;
}
if (vecHasSourceType(t->instantiatedTypeParams))
return true;
return false;
}
else if (auto t = get<MetatableTypeVar>(ty))
return t->table == sourceType || t->metatable == sourceType;
else if (auto t = get<ClassTypeVar>(ty))
return false;
else if (auto t = get<AnyTypeVar>(ty))
return false;
else if (auto t = get<UnionTypeVar>(ty))
return vecHasSourceType(t->options);
else if (auto t = get<IntersectionTypeVar>(ty))
return vecHasSourceType(t->parts);
else if (auto t = get<LazyTypeVar>(ty))
return false;
LUAU_ASSERT(!"Luau::Replacer::isDirty internal error: Unknown TypeVar type");
LUAU_UNREACHABLE();
}
bool isDirty(TypePackId tp) override
{
if (auto it = replacedPacks.find(tp))
return false;
if (auto pack = get<TypePack>(tp))
{
for (TypeId ty : pack->head)
{
if (ty == sourceType)
return true;
}
return false;
}
else if (auto vtp = get<VariadicTypePack>(tp))
return vtp->ty == sourceType;
else
return false;
}
TypeId clean(TypeId ty) override
{
LUAU_ASSERT(sourceType && replacedType);
// Walk every kind of TypeVar and create a copy with sourceType replaced by replacedType
// Before returning, memoize the result for later use.
// Helpfully, Substitution::clone() only shallow-clones the kinds of types that we care to work with. This
// function returns the identity for things like primitives.
TypeId res = clone(ty);
if (auto t = get<FreeTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<GenericTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<ErrorTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<PrimitiveTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = getMutable<ConstrainedTypeVar>(res))
{
for (TypeId& part : t->parts)
{
if (part == sourceType)
part = replacedType;
}
}
else if (auto t = get<SingletonTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = getMutable<FunctionTypeVar>(res))
{
// The constituent typepacks are cleaned separately. We just need to walk the generics array.
for (TypeId& g : t->generics)
{
if (g == sourceType)
g = replacedType;
}
}
else if (auto t = getMutable<TableTypeVar>(res))
{
for (auto& [_key, prop] : t->props)
{
if (prop.type == sourceType)
prop.type = replacedType;
}
}
else if (auto t = getMutable<MetatableTypeVar>(res))
{
if (t->table == sourceType)
t->table = replacedType;
if (t->metatable == sourceType)
t->table = replacedType;
}
else if (auto t = get<ClassTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = get<AnyTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else if (auto t = getMutable<UnionTypeVar>(res))
{
for (TypeId& option : t->options)
{
if (option == sourceType)
option = replacedType;
}
}
else if (auto t = getMutable<IntersectionTypeVar>(res))
{
for (TypeId& part : t->parts)
{
if (part == sourceType)
part = replacedType;
}
}
else if (auto t = get<LazyTypeVar>(res))
LUAU_ASSERT(!"Impossible");
else
LUAU_ASSERT(!"Luau::Replacer::clean internal error: Unknown TypeVar type");
replacedTypes[ty] = res;
return res;
}
TypePackId clean(TypePackId tp) override
{
TypePackId res = clone(tp);
if (auto pack = getMutable<TypePack>(res))
{
for (TypeId& type : pack->head)
{
if (type == sourceType)
type = replacedType;
}
}
else if (auto vtp = getMutable<VariadicTypePack>(res))
{
if (vtp->ty == sourceType)
vtp->ty = replacedType;
}
replacedPacks[tp] = res;
return res;
}
TypeId smartClone(TypeId t)
{
if (FFlag::LuauReplaceReplacer)
{
// The new smartClone is just a memoized clone()
// TODO: Remove the Substitution base class and all other methods from this struct.
// Add DenseHashMap<TypeId, TypeId> newTypes;
t = log->follow(t);
TypeId* res = newTypes.find(t);
if (res)
return *res;
TypeId result = shallowClone(t, *arena, TxnLog::empty());
newTypes[t] = result;
newTypes[result] = result;
return result;
}
else
{
std::optional<TypeId> res = replace(t);
LUAU_ASSERT(res.has_value()); // TODO think about this
if (*res == t)
return clone(t);
return *res;
}
}
};
} // anonymous namespace
bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice)
{
UnifierSharedState sharedState{&ice};
TypeArena arena;
Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState};
u.anyIsTop = true;
u.tryUnify(subTy, superTy);
const bool ok = u.errors.empty() && u.log.empty();
return ok;
}
bool isSubtype(TypePackId subPack, TypePackId superPack, InternalErrorReporter& ice)
{
UnifierSharedState sharedState{&ice};
TypeArena arena;
Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState};
u.anyIsTop = true;
u.tryUnify(subPack, superPack);
const bool ok = u.errors.empty() && u.log.empty();
return ok;
}
template<typename T>
static bool areNormal_(const T& t, const std::unordered_set<void*>& seen, InternalErrorReporter& ice)
{
int count = 0;
auto isNormal = [&](TypeId ty) {
++count;
if (count >= FInt::LuauNormalizeIterationLimit)
ice.ice("Luau::areNormal hit iteration limit");
if (FFlag::LuauNormalizeFlagIsConservative)
return ty->normal;
else
{
// The follow is here because a bound type may not be normal, but the bound type is normal.
return ty->normal || follow(ty)->normal || seen.find(asMutable(ty)) != seen.end();
}
};
return std::all_of(begin(t), end(t), isNormal);
}
static bool areNormal(const std::vector<TypeId>& types, const std::unordered_set<void*>& seen, InternalErrorReporter& ice)
{
return areNormal_(types, seen, ice);
}
static bool areNormal(TypePackId tp, const std::unordered_set<void*>& seen, InternalErrorReporter& ice)
{
tp = follow(tp);
if (get<FreeTypePack>(tp))
return false;
auto [head, tail] = flatten(tp);
if (!areNormal_(head, seen, ice))
return false;
if (!tail)
return true;
if (auto vtp = get<VariadicTypePack>(*tail))
return vtp->ty->normal || follow(vtp->ty)->normal || seen.find(asMutable(vtp->ty)) != seen.end();
return true;
}
#define CHECK_ITERATION_LIMIT(...) \
do \
{ \
if (iterationLimit > FInt::LuauNormalizeIterationLimit) \
{ \
limitExceeded = true; \
return __VA_ARGS__; \
} \
++iterationLimit; \
} while (false)
struct Normalize final : TypeVarVisitor
{
using TypeVarVisitor::Set;
Normalize(TypeArena& arena, InternalErrorReporter& ice)
: arena(arena)
, ice(ice)
{
}
TypeArena& arena;
InternalErrorReporter& ice;
int iterationLimit = 0;
bool limitExceeded = false;
bool visit(TypeId ty, const FreeTypeVar&) override
{
LUAU_ASSERT(!ty->normal);
return false;
}
bool visit(TypeId ty, const BoundTypeVar& btv) override
{
// A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses.
// So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack.
if (seen.find(asMutable(btv.boundTo)) != seen.end())
return false;
// It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases.
LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal);
asMutable(ty)->normal = btv.boundTo->normal;
return !ty->normal;
}
bool visit(TypeId ty, const PrimitiveTypeVar&) override
{
LUAU_ASSERT(ty->normal);
return false;
}
bool visit(TypeId ty, const GenericTypeVar&) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const ErrorTypeVar&) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override
{
CHECK_ITERATION_LIMIT(false);
LUAU_ASSERT(!ty->normal);
ConstrainedTypeVar* ctv = const_cast<ConstrainedTypeVar*>(&ctvRef);
std::vector<TypeId> parts = std::move(ctv->parts);
// We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar
for (TypeId part : parts)
traverse(part);
std::vector<TypeId> newParts = normalizeUnion(parts);
if (FFlag::LuauQuantifyConstrained)
{
ctv->parts = std::move(newParts);
}
else
{
const bool normal = areNormal(newParts, seen, ice);
if (newParts.size() == 1)
*asMutable(ty) = BoundTypeVar{newParts[0]};
else
*asMutable(ty) = UnionTypeVar{std::move(newParts)};
asMutable(ty)->normal = normal;
}
return false;
}
bool visit(TypeId ty, const FunctionTypeVar& ftv) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
traverse(ftv.argTypes);
traverse(ftv.retTypes);
asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retTypes, seen, ice);
return false;
}
bool visit(TypeId ty, const TableTypeVar& ttv) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
bool normal = true;
auto checkNormal = [&](TypeId t) {
// if t is on the stack, it is possible that this type is normal.
// If t is not normal and it is not on the stack, this type is definitely not normal.
if (!t->normal && seen.find(asMutable(t)) == seen.end())
normal = false;
};
if (ttv.boundTo)
{
traverse(*ttv.boundTo);
asMutable(ty)->normal = (*ttv.boundTo)->normal;
return false;
}
for (const auto& [_name, prop] : ttv.props)
{
traverse(prop.type);
checkNormal(prop.type);
}
if (ttv.indexer)
{
traverse(ttv.indexer->indexType);
checkNormal(ttv.indexer->indexType);
traverse(ttv.indexer->indexResultType);
checkNormal(ttv.indexer->indexResultType);
}
// An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal.
if (FFlag::LuauQuantifyConstrained)
{
if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal))
asMutable(ty)->normal = normal;
}
else
asMutable(ty)->normal = normal;
return false;
}
bool visit(TypeId ty, const MetatableTypeVar& mtv) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
traverse(mtv.table);
traverse(mtv.metatable);
asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal;
return false;
}
bool visit(TypeId ty, const ClassTypeVar& ctv) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const AnyTypeVar&) override
{
LUAU_ASSERT(ty->normal);
return false;
}
bool visit(TypeId ty, const UnionTypeVar& utvRef) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
UnionTypeVar* utv = &const_cast<UnionTypeVar&>(utvRef);
std::vector<TypeId> options = std::move(utv->options);
// We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar
for (TypeId option : options)
traverse(option);
std::vector<TypeId> newOptions = normalizeUnion(options);
const bool normal = areNormal(newOptions, seen, ice);
LUAU_ASSERT(!newOptions.empty());
if (newOptions.size() == 1)
*asMutable(ty) = BoundTypeVar{newOptions[0]};
else
utv->options = std::move(newOptions);
asMutable(ty)->normal = normal;
return false;
}
bool visit(TypeId ty, const IntersectionTypeVar& itvRef) override
{
CHECK_ITERATION_LIMIT(false);
if (ty->normal)
return false;
IntersectionTypeVar* itv = &const_cast<IntersectionTypeVar&>(itvRef);
std::vector<TypeId> oldParts = std::move(itv->parts);
for (TypeId part : oldParts)
traverse(part);
std::vector<TypeId> tables;
for (TypeId part : oldParts)
{
part = follow(part);
if (get<TableTypeVar>(part))
tables.push_back(part);
else
{
Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD
combineIntoIntersection(replacer, itv, part);
}
}
// Don't allocate a new table if there's just one in the intersection.
if (tables.size() == 1)
itv->parts.push_back(tables[0]);
else if (!tables.empty())
{
const TableTypeVar* first = get<TableTypeVar>(tables[0]);
LUAU_ASSERT(first);
TypeId newTable = arena.addType(TableTypeVar{first->state, first->level});
TableTypeVar* ttv = getMutable<TableTypeVar>(newTable);
for (TypeId part : tables)
{
// Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need
// to be rewritten to point at 'newTable' in the clone.
Replacer replacer{&arena, part, newTable};
combineIntoTable(replacer, ttv, part);
}
itv->parts.push_back(newTable);
}
asMutable(ty)->normal = areNormal(itv->parts, seen, ice);
if (itv->parts.size() == 1)
{
TypeId part = itv->parts[0];
*asMutable(ty) = BoundTypeVar{part};
}
return false;
}
std::vector<TypeId> normalizeUnion(const std::vector<TypeId>& options)
{
if (options.size() == 1)
return options;
std::vector<TypeId> result;
for (TypeId part : options)
combineIntoUnion(result, part);
return result;
}
void combineIntoUnion(std::vector<TypeId>& result, TypeId ty)
{
ty = follow(ty);
if (auto utv = get<UnionTypeVar>(ty))
{
for (TypeId t : utv)
combineIntoUnion(result, t);
return;
}
for (TypeId& part : result)
{
if (isSubtype(ty, part, ice))
return; // no need to do anything
else if (isSubtype(part, ty, ice))
{
part = ty; // replace the less general type by the more general one
return;
}
}
result.push_back(ty);
}
/**
* @param replacer knows how to clone a type such that any recursive references point at the new containing type.
* @param result is an intersection that is safe for us to mutate in-place.
*/
void combineIntoIntersection(Replacer& replacer, IntersectionTypeVar* result, TypeId ty)
{
// Note: this check guards against running out of stack space
// so if you increase the size of a stack frame, you'll need to decrease the limit.
CHECK_ITERATION_LIMIT();
ty = follow(ty);
if (auto itv = get<IntersectionTypeVar>(ty))
{
for (TypeId part : itv->parts)
combineIntoIntersection(replacer, result, part);
return;
}
// Let's say that the last part of our result intersection is always a table, if any table is part of this intersection
if (get<TableTypeVar>(ty))
{
if (result->parts.empty())
result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}}));
TypeId theTable = result->parts.back();
if (!get<TableTypeVar>(follow(theTable)))
{
result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}}));
theTable = result->parts.back();
}
TypeId newTable = replacer.smartClone(theTable);
result->parts.back() = newTable;
combineIntoTable(replacer, getMutable<TableTypeVar>(newTable), ty);
}
else if (auto ftv = get<FunctionTypeVar>(ty))
{
bool merged = false;
for (TypeId& part : result->parts)
{
if (isSubtype(part, ty, ice))
{
merged = true;
break; // no need to do anything
}
else if (isSubtype(ty, part, ice))
{
merged = true;
part = ty; // replace the less general type by the more general one
break;
}
}
if (!merged)
result->parts.push_back(ty);
}
else
result->parts.push_back(ty);
}
TableState combineTableStates(TableState lhs, TableState rhs)
{
if (lhs == rhs)
return lhs;
if (lhs == TableState::Free || rhs == TableState::Free)
return TableState::Free;
if (lhs == TableState::Unsealed || rhs == TableState::Unsealed)
return TableState::Unsealed;
return lhs;
}
/**
* @param replacer gives us a way to clone a type such that recursive references are rewritten to the new
* "containing" type.
* @param table always points into a table that is safe for us to mutate.
*/
void combineIntoTable(Replacer& replacer, TableTypeVar* table, TypeId ty)
{
// Note: this check guards against running out of stack space
// so if you increase the size of a stack frame, you'll need to decrease the limit.
CHECK_ITERATION_LIMIT();
LUAU_ASSERT(table);
ty = follow(ty);
TableTypeVar* tyTable = getMutable<TableTypeVar>(ty);
LUAU_ASSERT(tyTable);
for (const auto& [propName, prop] : tyTable->props)
{
if (auto it = table->props.find(propName); it != table->props.end())
{
/**
* If we are going to recursively merge intersections of tables, we need to ensure that we never mutate
* a table that comes from somewhere else in the type graph.
*
* smarClone() does some nice things for us: It will perform a clone that is as shallow as possible
* while still rewriting any cyclic references back to the new 'root' table.
*
* replacer also keeps a mapping of types that have previously been copied, so we have the added
* advantage here of knowing that, whether or not a new copy was actually made, the resulting TypeVar is
* safe for us to mutate in-place.
*/
TypeId clone = replacer.smartClone(it->second.type);
it->second.type = combine(replacer, clone, prop.type);
}
else
table->props.insert({propName, prop});
}
table->state = combineTableStates(table->state, tyTable->state);
table->level = max(table->level, tyTable->level);
}
/**
* @param a is always cloned by the caller. It is safe to mutate in-place.
* @param b will never be mutated.
*/
TypeId combine(Replacer& replacer, TypeId a, TypeId b)
{
if (FFlag::LuauNormalizeCombineEqFix)
b = follow(b);
if (FFlag::LuauNormalizeCombineTableFix && a == b)
return a;
if (!get<IntersectionTypeVar>(a) && !get<TableTypeVar>(a))
{
if (!FFlag::LuauNormalizeCombineTableFix && a == b)
return a;
else
return arena.addType(IntersectionTypeVar{{a, b}});
}
if (auto itv = getMutable<IntersectionTypeVar>(a))
{
combineIntoIntersection(replacer, itv, b);
return a;
}
else if (auto ttv = getMutable<TableTypeVar>(a))
{
if (FFlag::LuauNormalizeCombineTableFix && !get<TableTypeVar>(FFlag::LuauNormalizeCombineEqFix ? b : follow(b)))
return arena.addType(IntersectionTypeVar{{a, b}});
combineIntoTable(replacer, ttv, b);
return a;
}
LUAU_ASSERT(!"Impossible");
LUAU_UNREACHABLE();
}
};
#undef CHECK_ITERATION_LIMIT
/**
* @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully)
*/
std::pair<TypeId, bool> normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice)
{
CloneState state;
if (FFlag::DebugLuauCopyBeforeNormalizing)
(void)clone(ty, arena, state);
Normalize n{arena, ice};
n.traverse(ty);
return {ty, !n.limitExceeded};
}
// TODO: Think about using a temporary arena and cloning types out of it so that we
// reclaim memory used by wantonly allocated intermediate types here.
// The main wrinkle here is that we don't want clone() to copy a type if the source and dest
// arena are the same.
std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice)
{
return normalize(ty, module->internalTypes, ice);
}
/**
* @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully)
*/
std::pair<TypePackId, bool> normalize(TypePackId tp, TypeArena& arena, InternalErrorReporter& ice)
{
CloneState state;
if (FFlag::DebugLuauCopyBeforeNormalizing)
(void)clone(tp, arena, state);
Normalize n{arena, ice};
n.traverse(tp);
return {tp, !n.limitExceeded};
}
std::pair<TypePackId, bool> normalize(TypePackId tp, const ModulePtr& module, InternalErrorReporter& ice)
{
return normalize(tp, module->internalTypes, ice);
}
} // namespace Luau

View file

@ -2,59 +2,139 @@
#include "Luau/Quantify.h"
#include "Luau/Scope.h"
#include "Luau/Substitution.h"
#include "Luau/TxnLog.h"
#include "Luau/VisitTypeVar.h"
LUAU_FASTFLAG(LuauAlwaysQuantify);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false)
namespace Luau
{
struct Quantifier
/// @return true if outer encloses inner
static bool subsumes(Scope2* outer, Scope2* inner)
{
while (inner)
{
if (inner == outer)
return true;
inner = inner->parent;
}
return false;
}
struct Quantifier final : TypeVarOnceVisitor
{
TypeLevel level;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
Scope2* scope = nullptr;
bool seenGenericType = false;
bool seenMutableType = false;
Quantifier(TypeLevel level)
explicit Quantifier(TypeLevel level)
: level(level)
{
LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution);
}
void cycle(TypeId) {}
void cycle(TypePackId) {}
bool operator()(TypeId ty, const FreeTypeVar& ftv)
explicit Quantifier(Scope2* scope)
: scope(scope)
{
if (!level.subsumes(ftv.level))
LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution);
}
/// @return true if outer encloses inner
bool subsumes(Scope2* outer, Scope2* inner)
{
while (inner)
{
if (inner == outer)
return true;
inner = inner->parent;
}
return false;
}
bool visit(TypeId ty, const FreeTypeVar& ftv) override
{
seenMutableType = true;
if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ftv.scope) : !level.subsumes(ftv.level))
return false;
*asMutable(ty) = GenericTypeVar{level};
if (FFlag::DebugLuauDeferredConstraintResolution)
*asMutable(ty) = GenericTypeVar{scope};
else
*asMutable(ty) = GenericTypeVar{level};
generics.push_back(ty);
return false;
}
template<typename T>
bool operator()(TypeId ty, const T& t)
bool visit(TypeId ty, const ConstrainedTypeVar&) override
{
return true;
if (FFlag::LuauQuantifyConstrained)
{
ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(ty);
seenMutableType = true;
if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ctv->scope) : !level.subsumes(ctv->level))
return false;
std::vector<TypeId> opts = std::move(ctv->parts);
// We might transmute, so it's not safe to rely on the builtin traversal logic
for (TypeId opt : opts)
traverse(opt);
if (opts.size() == 1)
*asMutable(ty) = BoundTypeVar{opts[0]};
else
*asMutable(ty) = UnionTypeVar{std::move(opts)};
return false;
}
else
return true;
}
template<typename T>
bool operator()(TypePackId, const T&)
{
return true;
}
bool operator()(TypeId ty, const TableTypeVar&)
bool visit(TypeId ty, const TableTypeVar&) override
{
LUAU_ASSERT(getMutable<TableTypeVar>(ty));
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic)
return false;
if (!level.subsumes(ttv.level))
return false;
if (ttv.state == TableState::Generic)
seenGenericType = true;
if (ttv.state == TableState::Free)
seenMutableType = true;
if (!FFlag::LuauQuantifyConstrained)
{
if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic)
return false;
}
if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ttv.scope) : !level.subsumes(ttv.level))
{
if (ttv.state == TableState::Unsealed)
seenMutableType = true;
return false;
}
if (ttv.state == TableState::Free)
{
ttv.state = TableState::Generic;
seenGenericType = true;
}
else if (ttv.state == TableState::Unsealed)
ttv.state = TableState::Sealed;
@ -63,9 +143,11 @@ struct Quantifier
return true;
}
bool operator()(TypePackId tp, const FreeTypePack& ftp)
bool visit(TypePackId tp, const FreeTypePack& ftp) override
{
if (!level.subsumes(ftp.level))
seenMutableType = true;
if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ftp.scope) : !level.subsumes(ftp.level))
return false;
*asMutable(tp) = GenericTypePack{level};
@ -77,13 +159,145 @@ struct Quantifier
void quantify(TypeId ty, TypeLevel level)
{
Quantifier q{level};
DenseHashSet<void*> seen{nullptr};
visitTypeVarOnce(ty, q, seen);
q.traverse(ty);
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv);
ftv->generics = q.generics;
ftv->genericPacks = q.genericPacks;
if (FFlag::LuauAlwaysQuantify)
{
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
}
else
{
ftv->generics = q.generics;
ftv->genericPacks = q.genericPacks;
}
if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType)
ftv->hasNoGenerics = true;
}
void quantify(TypeId ty, Scope2* scope)
{
Quantifier q{scope};
q.traverse(ty);
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv);
if (FFlag::LuauAlwaysQuantify)
{
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
}
else
{
ftv->generics = q.generics;
ftv->genericPacks = q.genericPacks;
}
if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType)
ftv->hasNoGenerics = true;
}
struct PureQuantifier : Substitution
{
Scope2* scope;
std::vector<TypeId> insertedGenerics;
std::vector<TypePackId> insertedGenericPacks;
PureQuantifier(const TxnLog* log, TypeArena* arena, Scope2* scope)
: Substitution(log, arena)
, scope(scope)
{
}
bool isDirty(TypeId ty) override
{
LUAU_ASSERT(ty == follow(ty));
if (auto ftv = get<FreeTypeVar>(ty))
{
return subsumes(scope, ftv->scope);
}
else if (auto ttv = get<TableTypeVar>(ty))
{
return ttv->state == TableState::Free && subsumes(scope, ttv->scope);
}
return false;
}
bool isDirty(TypePackId tp) override
{
if (auto ftp = get<FreeTypePack>(tp))
{
return subsumes(scope, ftp->scope);
}
return false;
}
TypeId clean(TypeId ty) override
{
if (auto ftv = get<FreeTypeVar>(ty))
{
TypeId result = arena->addType(GenericTypeVar{});
insertedGenerics.push_back(result);
return result;
}
else if (auto ttv = get<TableTypeVar>(ty))
{
TypeId result = arena->addType(TableTypeVar{});
TableTypeVar* resultTable = getMutable<TableTypeVar>(result);
LUAU_ASSERT(resultTable);
*resultTable = *ttv;
resultTable->scope = nullptr;
resultTable->state = TableState::Generic;
return result;
}
return ty;
}
TypePackId clean(TypePackId tp) override
{
if (auto ftp = get<FreeTypePack>(tp))
{
TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{}});
insertedGenericPacks.push_back(result);
return result;
}
return tp;
}
bool ignoreChildren(TypeId ty) override
{
return ty->persistent;
}
bool ignoreChildren(TypePackId ty) override
{
return ty->persistent;
}
};
TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope)
{
PureQuantifier quantifier{TxnLog::empty(), arena, scope};
std::optional<TypeId> result = quantifier.substitute(ty);
LUAU_ASSERT(result);
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(*result);
LUAU_ASSERT(ftv);
ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end());
// TODO: Set hasNoGenerics.
return *result;
}
} // namespace Luau

View file

@ -28,7 +28,7 @@ struct RequireTracer : AstVisitor
AstExprGlobal* global = expr->func->as<AstExprGlobal>();
if (global && global->name == "require" && expr->args.size >= 1)
requires.push_back(expr);
requireCalls.push_back(expr);
return true;
}
@ -84,9 +84,9 @@ struct RequireTracer : AstVisitor
ModuleInfo moduleContext{currentModuleName};
// seed worklist with require arguments
work.reserve(requires.size());
work.reserve(requireCalls.size());
for (AstExprCall* require : requires)
for (AstExprCall* require : requireCalls)
work.push_back(require->args.data[0]);
// push all dependent expressions to the work stack; note that the vector is modified during traversal
@ -125,15 +125,15 @@ struct RequireTracer : AstVisitor
}
// resolve all requires according to their argument
result.requires.reserve(requires.size());
result.requireList.reserve(requireCalls.size());
for (AstExprCall* require : requires)
for (AstExprCall* require : requireCalls)
{
AstExpr* arg = require->args.data[0];
if (const ModuleInfo* info = result.exprs.find(arg))
{
result.requires.push_back({info->name, require->location});
result.requireList.push_back({info->name, require->location});
ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info!
result.exprs[require] = std::move(infoCopy);
@ -151,7 +151,7 @@ struct RequireTracer : AstVisitor
DenseHashMap<AstLocal*, AstExpr*> locals;
std::vector<AstExpr*> work;
std::vector<AstExprCall*> requires;
std::vector<AstExprCall*> requireCalls;
};
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName)

View file

@ -2,8 +2,6 @@
#include "Luau/Scope.h"
LUAU_FASTFLAG(LuauTwoPassAliasDefinitionFix);
namespace Luau
{
@ -19,8 +17,7 @@ Scope::Scope(const ScopePtr& parent, int subLevel)
, returnType(parent->returnType)
, level(parent->level.incr())
{
if (FFlag::LuauTwoPassAliasDefinitionFix)
level = level.incr();
level = level.incr();
level.subLevel = subLevel;
}
@ -124,4 +121,36 @@ std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bo
return std::nullopt;
}
std::optional<TypeId> Scope2::lookup(Symbol sym)
{
Scope2* s = this;
while (true)
{
auto it = s->bindings.find(sym);
if (it != s->bindings.end())
return it->second;
if (s->parent)
s = s->parent;
else
return std::nullopt;
}
}
std::optional<TypeId> Scope2::lookupTypeBinding(const Name& name)
{
Scope2* s = this;
while (s)
{
auto it = s->typeBindings.find(name);
if (it != s->typeBindings.end())
return it->second;
s = s->parent;
}
return std::nullopt;
}
} // namespace Luau

View file

@ -2,29 +2,34 @@
#include "Luau/Substitution.h"
#include "Luau/Common.h"
#include "Luau/Clone.h"
#include "Luau/TxnLog.h"
#include <algorithm>
#include <stdexcept>
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000)
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000)
namespace Luau
{
void Tarjan::visitChildren(TypeId ty, int index)
{
ty = log->follow(ty);
LUAU_ASSERT(ty == log->follow(ty));
if (ignoreChildren(ty))
return;
if (const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty))
if (auto pty = log->pending(ty))
ty = &pty->pending;
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{
visitChild(ftv->argTypes);
visitChild(ftv->retType);
visitChild(ftv->retTypes);
}
else if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
else if (const TableTypeVar* ttv = get<TableTypeVar>(ty))
{
LUAU_ASSERT(!ttv->boundTo);
for (const auto& [name, prop] : ttv->props)
@ -41,38 +46,46 @@ void Tarjan::visitChildren(TypeId ty, int index)
for (TypePackId itp : ttv->instantiatedTypePackParams)
visitChild(itp);
}
else if (const MetatableTypeVar* mtv = log->getMutable<MetatableTypeVar>(ty))
else if (const MetatableTypeVar* mtv = get<MetatableTypeVar>(ty))
{
visitChild(mtv->table);
visitChild(mtv->metatable);
}
else if (const UnionTypeVar* utv = log->getMutable<UnionTypeVar>(ty))
else if (const UnionTypeVar* utv = get<UnionTypeVar>(ty))
{
for (TypeId opt : utv->options)
visitChild(opt);
}
else if (const IntersectionTypeVar* itv = log->getMutable<IntersectionTypeVar>(ty))
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(ty))
{
for (TypeId part : itv->parts)
visitChild(part);
}
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
for (TypeId part : ctv->parts)
visitChild(part);
}
}
void Tarjan::visitChildren(TypePackId tp, int index)
{
tp = log->follow(tp);
LUAU_ASSERT(tp == log->follow(tp));
if (ignoreChildren(tp))
return;
if (const TypePack* tpp = log->getMutable<TypePack>(tp))
if (auto ptp = log->pending(tp))
tp = &ptp->pending;
if (const TypePack* tpp = get<TypePack>(tp))
{
for (TypeId tv : tpp->head)
visitChild(tv);
if (tpp->tail)
visitChild(*tpp->tail);
}
else if (const VariadicTypePack* vtp = log->getMutable<VariadicTypePack>(tp))
else if (const VariadicTypePack* vtp = get<VariadicTypePack>(tp))
{
visitChild(vtp->ty);
}
@ -141,7 +154,7 @@ TarjanResult Tarjan::loop()
if (currEdge == -1)
{
++childCount;
if (FInt::LuauTarjanChildLimit > 0 && FInt::LuauTarjanChildLimit < childCount)
if (childLimit > 0 && childLimit < childCount)
return TarjanResult::TooManyChildren;
stack.push_back(index);
@ -229,6 +242,9 @@ TarjanResult Tarjan::loop()
TarjanResult Tarjan::visitRoot(TypeId ty)
{
childCount = 0;
if (childLimit == 0)
childLimit = FInt::LuauTarjanChildLimit;
ty = log->follow(ty);
auto [index, fresh] = indexify(ty);
@ -239,6 +255,9 @@ TarjanResult Tarjan::visitRoot(TypeId ty)
TarjanResult Tarjan::visitRoot(TypePackId tp)
{
childCount = 0;
if (childLimit == 0)
childLimit = FInt::LuauTarjanChildLimit;
tp = log->follow(tp);
auto [index, fresh] = indexify(tp);
@ -343,67 +362,24 @@ std::optional<TypePackId> Substitution::substitute(TypePackId tp)
TypeId Substitution::clone(TypeId ty)
{
ty = log->follow(ty);
TypeId result = ty;
if (const FunctionTypeVar* ftv = log->getMutable<FunctionTypeVar>(ty))
{
FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf};
clone.generics = ftv->generics;
clone.genericPacks = ftv->genericPacks;
clone.magicFunction = ftv->magicFunction;
clone.tags = ftv->tags;
clone.argNames = ftv->argNames;
result = addType(std::move(clone));
}
else if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
{
LUAU_ASSERT(!ttv->boundTo);
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state};
clone.methodDefinitionLocations = ttv->methodDefinitionLocations;
clone.definitionModuleName = ttv->definitionModuleName;
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.instantiatedTypeParams = ttv->instantiatedTypeParams;
clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams;
clone.tags = ttv->tags;
result = addType(std::move(clone));
}
else if (const MetatableTypeVar* mtv = log->getMutable<MetatableTypeVar>(ty))
{
MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable};
clone.syntheticName = mtv->syntheticName;
result = addType(std::move(clone));
}
else if (const UnionTypeVar* utv = log->getMutable<UnionTypeVar>(ty))
{
UnionTypeVar clone;
clone.options = utv->options;
result = addType(std::move(clone));
}
else if (const IntersectionTypeVar* itv = log->getMutable<IntersectionTypeVar>(ty))
{
IntersectionTypeVar clone;
clone.parts = itv->parts;
result = addType(std::move(clone));
}
asMutable(result)->documentationSymbol = ty->documentationSymbol;
return result;
return shallowClone(ty, *arena, log);
}
TypePackId Substitution::clone(TypePackId tp)
{
tp = log->follow(tp);
if (const TypePack* tpp = log->getMutable<TypePack>(tp))
if (auto ptp = log->pending(tp))
tp = &ptp->pending;
if (const TypePack* tpp = get<TypePack>(tp))
{
TypePack clone;
clone.head = tpp->head;
clone.tail = tpp->tail;
return addTypePack(std::move(clone));
}
else if (const VariadicTypePack* vtp = log->getMutable<VariadicTypePack>(tp))
else if (const VariadicTypePack* vtp = get<VariadicTypePack>(tp))
{
VariadicTypePack clone;
clone.ty = vtp->ty;
@ -416,24 +392,27 @@ TypePackId Substitution::clone(TypePackId tp)
void Substitution::foundDirty(TypeId ty)
{
ty = log->follow(ty);
if (isDirty(ty))
newTypes[ty] = clean(ty);
newTypes[ty] = follow(clean(ty));
else
newTypes[ty] = clone(ty);
newTypes[ty] = follow(clone(ty));
}
void Substitution::foundDirty(TypePackId tp)
{
tp = log->follow(tp);
if (isDirty(tp))
newPacks[tp] = clean(tp);
newPacks[tp] = follow(clean(tp));
else
newPacks[tp] = clone(tp);
newPacks[tp] = follow(clone(tp));
}
TypeId Substitution::replace(TypeId ty)
{
ty = log->follow(ty);
if (TypeId* prevTy = newTypes.find(ty))
return *prevTy;
else
@ -443,6 +422,7 @@ TypeId Substitution::replace(TypeId ty)
TypePackId Substitution::replace(TypePackId tp)
{
tp = log->follow(tp);
if (TypePackId* prevTp = newPacks.find(tp))
return *prevTp;
else
@ -451,7 +431,10 @@ TypePackId Substitution::replace(TypePackId tp)
void Substitution::replaceChildren(TypeId ty)
{
ty = log->follow(ty);
if (BoundTypeVar* btv = log->getMutable<BoundTypeVar>(ty); FFlag::LuauLowerBoundsCalculation && btv)
btv->boundTo = replace(btv->boundTo);
LUAU_ASSERT(ty == log->follow(ty));
if (ignoreChildren(ty))
return;
@ -459,7 +442,7 @@ void Substitution::replaceChildren(TypeId ty)
if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty))
{
ftv->argTypes = replace(ftv->argTypes);
ftv->retType = replace(ftv->retType);
ftv->retTypes = replace(ftv->retTypes);
}
else if (TableTypeVar* ttv = getMutable<TableTypeVar>(ty))
{
@ -493,11 +476,16 @@ void Substitution::replaceChildren(TypeId ty)
for (TypeId& part : itv->parts)
part = replace(part);
}
else if (ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(ty))
{
for (TypeId& part : ctv->parts)
part = replace(part);
}
}
void Substitution::replaceChildren(TypePackId tp)
{
tp = log->follow(tp);
LUAU_ASSERT(tp == log->follow(tp));
if (ignoreChildren(tp))
return;

View file

@ -154,7 +154,7 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNode();
visitChild(ftv->argTypes, index, "arg");
visitChild(ftv->retType, index, "ret");
visitChild(ftv->retTypes, index, "ret");
}
else if (const TableTypeVar* ttv = get<TableTypeVar>(ty))
{
@ -237,6 +237,15 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNodeLabel(ty);
finishNode();
}
else if (const ConstrainedTypeVar* ctv = get<ConstrainedTypeVar>(ty))
{
formatAppend(result, "ConstrainedTypeVar %d", index);
finishNodeLabel(ty);
finishNode();
for (TypeId part : ctv->parts)
visitChild(part, index);
}
else if (get<ErrorTypeVar>(ty))
{
formatAppend(result, "ErrorTypeVar %d", index);
@ -258,6 +267,28 @@ void StateDot::visitChildren(TypeId ty, int index)
if (ctv->metatable)
visitChild(*ctv->metatable, index, "[metatable]");
}
else if (const SingletonTypeVar* stv = get<SingletonTypeVar>(ty))
{
std::string res;
if (const StringSingleton* ss = get<StringSingleton>(stv))
{
// Don't put in quotes anywhere. If it's outside of the call to escape,
// then it's invalid syntax. If it's inside, then escaping is super noisy.
res = "string: " + escape(ss->value);
}
else if (const BooleanSingleton* bs = get<BooleanSingleton>(stv))
{
res = "boolean: ";
res += bs->value ? "true" : "false";
}
else
LUAU_ASSERT(!"unknown singleton type");
formatAppend(result, "SingletonTypeVar %s", res.c_str());
finishNodeLabel(ty);
finishNode();
}
else
{
LUAU_ASSERT(!"unknown type kind");
@ -296,7 +327,7 @@ void StateDot::visitChildren(TypePackId tp, int index)
}
else if (const VariadicTypePack* vtp = get<VariadicTypePack>(tp))
{
formatAppend(result, "VariadicTypePack %d", index);
formatAppend(result, "VariadicTypePack %s%d", vtp->hidden ? "hidden " : "", index);
finishNodeLabel(tp);
finishNode();

View file

@ -10,13 +10,15 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
/*
* Prefix generic typenames with gen-
* Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4
* Fair warning: Setting this will break a lot of Luau unit tests.
*/
LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false)
LUAU_FASTFLAGVARIABLE(LuauDocFuncParameters, false)
LUAU_FASTFLAGVARIABLE(LuauToStringTableBracesNewlines, false)
namespace Luau
{
@ -24,7 +26,7 @@ namespace Luau
namespace
{
struct FindCyclicTypes
struct FindCyclicTypes final : TypeVarVisitor
{
FindCyclicTypes() = default;
FindCyclicTypes(const FindCyclicTypes&) = delete;
@ -33,28 +35,30 @@ struct FindCyclicTypes
bool exhaustive = false;
std::unordered_set<TypeId> visited;
std::unordered_set<TypePackId> visitedPacks;
std::unordered_set<TypeId> cycles;
std::unordered_set<TypePackId> cycleTPs;
std::set<TypeId> cycles;
std::set<TypePackId> cycleTPs;
void cycle(TypeId ty)
void cycle(TypeId ty) override
{
cycles.insert(ty);
}
void cycle(TypePackId tp)
void cycle(TypePackId tp) override
{
cycleTPs.insert(tp);
}
template<typename T>
bool operator()(TypeId ty, const T&)
bool visit(TypeId ty) override
{
return visited.insert(ty).second;
}
bool operator()(TypeId ty, const TableTypeVar& ttv) = delete;
bool visit(TypePackId tp) override
{
return visitedPacks.insert(tp).second;
}
bool operator()(TypeId ty, const TableTypeVar& ttv, std::unordered_set<void*>& seen)
bool visit(TypeId ty, const TableTypeVar& ttv) override
{
if (!visited.insert(ty).second)
return false;
@ -62,10 +66,10 @@ struct FindCyclicTypes
if (ttv.name || ttv.syntheticName)
{
for (TypeId itp : ttv.instantiatedTypeParams)
visitTypeVar(itp, *this, seen);
traverse(itp);
for (TypePackId itp : ttv.instantiatedTypePackParams)
visitTypeVar(itp, *this, seen);
traverse(itp);
return exhaustive;
}
@ -73,24 +77,18 @@ struct FindCyclicTypes
return true;
}
bool operator()(TypeId, const ClassTypeVar&)
bool visit(TypeId ty, const ClassTypeVar&) override
{
return false;
}
template<typename T>
bool operator()(TypePackId tp, const T&)
{
return visitedPacks.insert(tp).second;
}
};
template<typename TID>
void findCyclicTypes(std::unordered_set<TypeId>& cycles, std::unordered_set<TypePackId>& cycleTPs, TID ty, bool exhaustive)
void findCyclicTypes(std::set<TypeId>& cycles, std::set<TypePackId>& cycleTPs, TID ty, bool exhaustive)
{
FindCyclicTypes fct;
fct.exhaustive = exhaustive;
visitTypeVar(ty, fct);
fct.traverse(ty);
cycles = std::move(fct.cycles);
cycleTPs = std::move(fct.cycleTPs);
@ -124,6 +122,7 @@ struct StringifierState
std::unordered_map<TypePackId, std::string> cycleTpNames;
std::unordered_set<void*> seen;
std::unordered_set<std::string> usedNames;
size_t indentation = 0;
bool exhaustive;
@ -180,6 +179,8 @@ struct StringifierState
return generateName(s);
}
int previousNameIndex = 0;
std::string getName(TypePackId ty)
{
const size_t s = result.nameMap.typePacks.size();
@ -189,9 +190,10 @@ struct StringifierState
for (int count = 0; count < 256; ++count)
{
std::string candidate = generateName(usedNames.size() + count);
std::string candidate = generateName(previousNameIndex + count);
if (!usedNames.count(candidate))
{
previousNameIndex += count;
usedNames.insert(candidate);
n = candidate;
return candidate;
@ -209,6 +211,13 @@ struct StringifierState
result.name += s;
}
void emit(TypeLevel level)
{
emit(std::to_string(level.level));
emit("-");
emit(std::to_string(level.subLevel));
}
void emit(const char* s)
{
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
@ -216,6 +225,39 @@ struct StringifierState
result.name += s;
}
void emit(int i)
{
emit(std::to_string(i).c_str());
}
void indent()
{
indentation += 4;
}
void dedent()
{
indentation -= 4;
}
void newline()
{
if (!opts.useLineBreaks)
return emit(" ");
emit("\n");
emitIndentation();
}
private:
void emitIndentation()
{
if (!opts.indent)
return;
emit(std::string(indentation, ' '));
}
};
struct TypeVarStringifier
@ -247,7 +289,8 @@ struct TypeVarStringifier
}
Luau::visit(
[this, tv](auto&& t) {
[this, tv](auto&& t)
{
return (*this)(tv, t);
},
tv->ty);
@ -312,7 +355,7 @@ struct TypeVarStringifier
if (FFlag::DebugLuauVerboseTypeNames)
{
state.emit("-");
state.emit(std::to_string(ftv.level.level));
state.emit(ftv.level);
}
}
@ -321,10 +364,11 @@ struct TypeVarStringifier
stringify(btv.boundTo);
}
void operator()(TypeId ty, const Unifiable::Generic& gtv)
void operator()(TypeId ty, const GenericTypeVar& gtv)
{
if (gtv.explicitName)
{
state.usedNames.insert(gtv.name);
state.result.nameMap.typeVars[ty] = gtv.name;
state.emit(gtv.name);
}
@ -332,6 +376,36 @@ struct TypeVarStringifier
state.emit(state.getName(ty));
}
void operator()(TypeId, const ConstrainedTypeVar& ctv)
{
state.result.invalid = true;
state.emit("[");
if (FFlag::DebugLuauVerboseTypeNames)
state.emit(ctv.level);
state.emit("[");
bool first = true;
for (TypeId ty : ctv.parts)
{
if (first)
first = false;
else
state.emit("|");
stringify(ty);
}
state.emit("]]");
}
void operator()(TypeId, const BlockedTypeVar& btv)
{
state.emit("*blocked-");
state.emit(btv.index);
state.emit("*");
}
void operator()(TypeId, const PrimitiveTypeVar& ptv)
{
switch (ptv.type)
@ -415,16 +489,31 @@ struct TypeVarStringifier
state.emit(") -> ");
bool plural = true;
if (auto retPack = get<TypePack>(follow(ftv.retType)))
if (FFlag::LuauLowerBoundsCalculation)
{
if (retPack->head.size() == 1 && !retPack->tail)
plural = false;
auto retBegin = begin(ftv.retTypes);
auto retEnd = end(ftv.retTypes);
if (retBegin != retEnd)
{
++retBegin;
if (retBegin == retEnd && !retBegin.tail())
plural = false;
}
}
else
{
if (auto retPack = get<TypePack>(follow(ftv.retTypes)))
{
if (retPack->head.size() == 1 && !retPack->tail)
plural = false;
}
}
if (plural)
state.emit("(");
stringify(ftv.retType);
stringify(ftv.retTypes);
if (plural)
state.emit(")");
@ -482,22 +571,54 @@ struct TypeVarStringifier
{
case TableState::Sealed:
state.result.invalid = true;
openbrace = "{| ";
closedbrace = " |}";
if (FFlag::LuauToStringTableBracesNewlines)
{
openbrace = "{|";
closedbrace = "|}";
}
else
{
openbrace = "{| ";
closedbrace = " |}";
}
break;
case TableState::Unsealed:
openbrace = "{ ";
closedbrace = " }";
if (FFlag::LuauToStringTableBracesNewlines)
{
openbrace = "{";
closedbrace = "}";
}
else
{
openbrace = "{ ";
closedbrace = " }";
}
break;
case TableState::Free:
state.result.invalid = true;
openbrace = "{- ";
closedbrace = " -}";
if (FFlag::LuauToStringTableBracesNewlines)
{
openbrace = "{-";
closedbrace = "-}";
}
else
{
openbrace = "{- ";
closedbrace = " -}";
}
break;
case TableState::Generic:
state.result.invalid = true;
openbrace = "{+ ";
closedbrace = " +}";
if (FFlag::LuauToStringTableBracesNewlines)
{
openbrace = "{+";
closedbrace = "+}";
}
else
{
openbrace = "{+ ";
closedbrace = " +}";
}
break;
}
@ -511,10 +632,13 @@ struct TypeVarStringifier
}
state.emit(openbrace);
state.indent();
bool comma = false;
if (ttv.indexer)
{
if (FFlag::LuauToStringTableBracesNewlines)
state.newline();
state.emit("[");
stringify(ttv.indexer->indexType);
state.emit("]: ");
@ -527,7 +651,14 @@ struct TypeVarStringifier
for (const auto& [name, prop] : ttv.props)
{
if (comma)
state.emit(state.opts.useLineBreaks ? ",\n" : ", ");
{
state.emit(",");
state.newline();
}
else if (FFlag::LuauToStringTableBracesNewlines)
{
state.newline();
}
size_t length = state.result.name.length() - oldLength;
@ -553,6 +684,14 @@ struct TypeVarStringifier
++index;
}
state.dedent();
if (FFlag::LuauToStringTableBracesNewlines)
{
if (comma)
state.newline();
else
state.emit(" ");
}
state.emit(closedbrace);
state.unsee(&ttv);
@ -563,7 +702,8 @@ struct TypeVarStringifier
state.result.invalid = true;
state.emit("{ @metatable ");
stringify(mtv.metatable);
state.emit(state.opts.useLineBreaks ? ",\n" : ", ");
state.emit(",");
state.newline();
stringify(mtv.table);
state.emit(" }");
}
@ -627,7 +767,10 @@ struct TypeVarStringifier
for (std::string& ss : results)
{
if (!first)
state.emit(" | ");
{
state.newline();
state.emit("| ");
}
state.emit(ss);
first = false;
}
@ -680,7 +823,10 @@ struct TypeVarStringifier
for (std::string& ss : results)
{
if (!first)
state.emit(" & ");
{
state.newline();
state.emit("& ");
}
state.emit(ss);
first = false;
}
@ -746,7 +892,8 @@ struct TypePackStringifier
}
Luau::visit(
[this, tp](auto&& t) {
[this, tp](auto&& t)
{
return (*this)(tp, t);
},
tp->ty);
@ -784,13 +931,16 @@ struct TypePackStringifier
if (tp.tail && !isEmpty(*tp.tail))
{
const auto& tail = *tp.tail;
if (first)
first = false;
else
state.emit(", ");
TypePackId tail = follow(*tp.tail);
if (auto vtp = get<VariadicTypePack>(tail); !vtp || (!FFlag::DebugLuauVerboseTypeNames && !vtp->hidden))
{
if (first)
first = false;
else
state.emit(", ");
stringify(tail);
stringify(tail);
}
}
state.unsee(&tp);
@ -805,6 +955,8 @@ struct TypePackStringifier
void operator()(TypePackId, const VariadicTypePack& pack)
{
state.emit("...");
if (FFlag::DebugLuauVerboseTypeNames && pack.hidden)
state.emit("<hidden>");
stringify(pack.ty);
}
@ -814,6 +966,7 @@ struct TypePackStringifier
state.emit("gen-");
if (pack.explicitName)
{
state.usedNames.insert(pack.name);
state.result.nameMap.typePacks[tp] = pack.name;
state.emit(pack.name);
}
@ -834,7 +987,7 @@ struct TypePackStringifier
if (FFlag::DebugLuauVerboseTypeNames)
{
state.emit("-");
state.emit(std::to_string(pack.level.level));
state.emit(pack.level);
}
state.emit("...");
@ -858,15 +1011,12 @@ void TypeVarStringifier::stringify(TypePackId tpid, const std::vector<std::optio
tps.stringify(tpid);
}
static void assignCycleNames(const std::unordered_set<TypeId>& cycles, const std::unordered_set<TypePackId>& cycleTPs,
static void assignCycleNames(const std::set<TypeId>& cycles, const std::set<TypePackId>& cycleTPs,
std::unordered_map<TypeId, std::string>& cycleNames, std::unordered_map<TypePackId, std::string>& cycleTpNames, bool exhaustive)
{
int nextIndex = 1;
std::vector<TypeId> sortedCycles{cycles.begin(), cycles.end()};
std::sort(sortedCycles.begin(), sortedCycles.end(), std::less<TypeId>{});
for (TypeId cycleTy : sortedCycles)
for (TypeId cycleTy : cycles)
{
std::string name;
@ -874,9 +1024,11 @@ static void assignCycleNames(const std::unordered_set<TypeId>& cycles, const std
if (auto ttv = get<TableTypeVar>(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name))
{
// If we have a cycle type in type parameters, assign a cycle name for this named table
if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) {
return cycles.count(follow(el));
}) != ttv->instantiatedTypeParams.end())
if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(),
[&](auto&& el)
{
return cycles.count(follow(el));
}) != ttv->instantiatedTypeParams.end())
cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName;
continue;
@ -888,10 +1040,7 @@ static void assignCycleNames(const std::unordered_set<TypeId>& cycles, const std
cycleNames[cycleTy] = std::move(name);
}
std::vector<TypePackId> sortedCycleTps{cycleTPs.begin(), cycleTPs.end()};
std::sort(sortedCycleTps.begin(), sortedCycleTps.end(), std::less<TypePackId>());
for (TypePackId tp : sortedCycleTps)
for (TypePackId tp : cycleTPs)
{
std::string name = "tp" + std::to_string(nextIndex);
++nextIndex;
@ -913,8 +1062,8 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
StringifierState state{opts, result, opts.nameMap};
std::unordered_set<TypeId> cycles;
std::unordered_set<TypePackId> cycleTPs;
std::set<TypeId> cycles;
std::set<TypePackId> cycleTPs;
findCyclicTypes(cycles, cycleTPs, ty, opts.exhaustive);
@ -975,9 +1124,11 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
state.exhaustive = true;
std::vector<std::pair<TypeId, std::string>> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()};
std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) {
return a.second < b.second;
});
std::sort(sortedCycleNames.begin(), sortedCycleNames.end(),
[](const auto& a, const auto& b)
{
return a.second < b.second;
});
bool semi = false;
for (const auto& [cycleTy, name] : sortedCycleNames)
@ -988,7 +1139,8 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
state.emit(name);
state.emit(" = ");
Luau::visit(
[&tvs, cycleTy = cycleTy](auto&& t) {
[&tvs, cycleTy = cycleTy](auto&& t)
{
return tvs(cycleTy, t);
},
cycleTy->ty);
@ -1016,8 +1168,8 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts)
ToStringResult result;
StringifierState state{opts, result, opts.nameMap};
std::unordered_set<TypeId> cycles;
std::unordered_set<TypePackId> cycleTPs;
std::set<TypeId> cycles;
std::set<TypePackId> cycleTPs;
findCyclicTypes(cycles, cycleTPs, tp, opts.exhaustive);
@ -1045,9 +1197,11 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts)
state.exhaustive = true;
std::vector<std::pair<TypeId, std::string>> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()};
std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) {
return a.second < b.second;
});
std::sort(sortedCycleNames.begin(), sortedCycleNames.end(),
[](const auto& a, const auto& b)
{
return a.second < b.second;
});
bool semi = false;
for (const auto& [cycleTy, name] : sortedCycleNames)
@ -1058,7 +1212,8 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts)
state.emit(name);
state.emit(" = ");
Luau::visit(
[&tvs, cycleTy = cycleTy](auto&& t) {
[&tvs, cycleTy = cycleTy](auto t)
{
return tvs(cycleTy, t);
},
cycleTy->ty);
@ -1108,81 +1263,66 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp
auto argPackIter = begin(ftv.argTypes);
bool first = true;
if (FFlag::LuauDocFuncParameters)
size_t idx = 0;
while (argPackIter != end(ftv.argTypes))
{
size_t idx = 0;
while (argPackIter != end(ftv.argTypes))
// ftv takes a self parameter as the first argument, skip it if specified in option
if (idx == 0 && ftv.hasSelf && opts.hideFunctionSelfArgument)
{
if (!first)
state.emit(", ");
first = false;
// We don't respect opts.functionTypeArguments
if (idx < opts.namedFunctionOverrideArgNames.size())
{
state.emit(opts.namedFunctionOverrideArgNames[idx] + ": ");
}
else if (idx < ftv.argNames.size() && ftv.argNames[idx])
{
state.emit(ftv.argNames[idx]->name + ": ");
}
else
{
state.emit("_: ");
}
tvs.stringify(*argPackIter);
++argPackIter;
++idx;
continue;
}
}
else
{
auto argNameIter = ftv.argNames.begin();
while (argPackIter != end(ftv.argTypes))
if (!first)
state.emit(", ");
first = false;
// We don't respect opts.functionTypeArguments
if (idx < opts.namedFunctionOverrideArgNames.size())
{
if (!first)
state.emit(", ");
first = false;
// We don't currently respect opts.functionTypeArguments. I don't think this function should.
if (argNameIter != ftv.argNames.end())
{
state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": ");
++argNameIter;
}
else
{
state.emit("_: ");
}
tvs.stringify(*argPackIter);
++argPackIter;
state.emit(opts.namedFunctionOverrideArgNames[idx] + ": ");
}
else if (idx < ftv.argNames.size() && ftv.argNames[idx])
{
state.emit(ftv.argNames[idx]->name + ": ");
}
else
{
state.emit("_: ");
}
tvs.stringify(*argPackIter);
++argPackIter;
++idx;
}
if (argPackIter.tail())
{
if (!first)
state.emit(", ");
if (auto vtp = get<VariadicTypePack>(*argPackIter.tail()); !vtp || !vtp->hidden)
{
if (!first)
state.emit(", ");
state.emit("...: ");
if (auto vtp = get<VariadicTypePack>(*argPackIter.tail()))
tvs.stringify(vtp->ty);
else
tvs.stringify(*argPackIter.tail());
state.emit("...: ");
if (vtp)
tvs.stringify(vtp->ty);
else
tvs.stringify(*argPackIter.tail());
}
}
state.emit("): ");
size_t retSize = size(ftv.retType);
bool hasTail = !finite(ftv.retType);
bool wrap = get<TypePack>(follow(ftv.retType)) && (hasTail ? retSize != 0 : retSize != 1);
size_t retSize = size(ftv.retTypes);
bool hasTail = !finite(ftv.retTypes);
bool wrap = get<TypePack>(follow(ftv.retTypes)) && (hasTail ? retSize != 0 : retSize != 1);
if (wrap)
state.emit("(");
tvs.stringify(ftv.retType);
tvs.stringify(ftv.retTypes);
if (wrap)
state.emit(")");
@ -1210,6 +1350,24 @@ std::string dump(TypePackId ty)
return s;
}
std::string dump(const ScopePtr& scope, const char* name)
{
auto binding = scope->linearSearchForBinding(name);
if (!binding)
{
printf("No binding %s\n", name);
return {};
}
TypeId ty = binding->typeId;
ToStringOptions opts;
opts.exhaustive = true;
opts.functionTypeArguments = true;
std::string s = toString(ty, opts);
printf("%s\n", s.c_str());
return s;
}
std::string generateName(size_t i)
{
std::string n;
@ -1219,4 +1377,61 @@ std::string generateName(size_t i)
return n;
}
std::string toString(const Constraint& c, ToStringOptions& opts)
{
if (const SubtypeConstraint* sc = Luau::get_if<SubtypeConstraint>(&c.c))
{
ToStringResult subStr = toStringDetailed(sc->subType, opts);
opts.nameMap = std::move(subStr.nameMap);
ToStringResult superStr = toStringDetailed(sc->superType, opts);
opts.nameMap = std::move(superStr.nameMap);
return subStr.name + " <: " + superStr.name;
}
else if (const PackSubtypeConstraint* psc = Luau::get_if<PackSubtypeConstraint>(&c.c))
{
ToStringResult subStr = toStringDetailed(psc->subPack, opts);
opts.nameMap = std::move(subStr.nameMap);
ToStringResult superStr = toStringDetailed(psc->superPack, opts);
opts.nameMap = std::move(superStr.nameMap);
return subStr.name + " <: " + superStr.name;
}
else if (const GeneralizationConstraint* gc = Luau::get_if<GeneralizationConstraint>(&c.c))
{
ToStringResult subStr = toStringDetailed(gc->generalizedType, opts);
opts.nameMap = std::move(subStr.nameMap);
ToStringResult superStr = toStringDetailed(gc->sourceType, opts);
opts.nameMap = std::move(superStr.nameMap);
return subStr.name + " ~ gen " + superStr.name;
}
else if (const InstantiationConstraint* ic = Luau::get_if<InstantiationConstraint>(&c.c))
{
ToStringResult subStr = toStringDetailed(ic->subType, opts);
opts.nameMap = std::move(subStr.nameMap);
ToStringResult superStr = toStringDetailed(ic->superType, opts);
opts.nameMap = std::move(superStr.nameMap);
return subStr.name + " ~ inst " + superStr.name;
}
else if (const NameConstraint* nc = Luau::get<NameConstraint>(c))
{
ToStringResult namedStr = toStringDetailed(nc->namedType, opts);
opts.nameMap = std::move(namedStr.nameMap);
return "@name(" + namedStr.name + ") = " + nc->name;
}
else
{
LUAU_ASSERT(false);
return "";
}
}
std::string dump(const Constraint& c)
{
ToStringOptions opts;
opts.exhaustive = true;
opts.functionTypeArguments = true;
std::string s = toString(c, opts);
printf("%s\n", s.c_str());
return s;
}
} // namespace Luau

View file

@ -215,6 +215,7 @@ struct ArcCollector : public AstVisitor
}
}
// Adds a dependency from the current node to the named node.
void add(const Identifier& name)
{
Node** it = map.find(name);

View file

@ -1025,31 +1025,42 @@ struct Printer
}
else if (const auto& a = typeAnnotation.as<AstTypeTable>())
{
CommaSeparatorInserter comma(writer);
AstTypeReference* indexType = a->indexer ? a->indexer->indexType->as<AstTypeReference>() : nullptr;
writer.symbol("{");
for (std::size_t i = 0; i < a->props.size; ++i)
if (a->props.size == 0 && indexType && indexType->name == "number")
{
comma();
advance(a->props.data[i].location.begin);
writer.identifier(a->props.data[i].name.value);
if (a->props.data[i].type)
{
writer.symbol(":");
visualizeTypeAnnotation(*a->props.data[i].type);
}
}
if (a->indexer)
{
comma();
writer.symbol("[");
visualizeTypeAnnotation(*a->indexer->indexType);
writer.symbol("]");
writer.symbol(":");
writer.symbol("{");
visualizeTypeAnnotation(*a->indexer->resultType);
writer.symbol("}");
}
else
{
CommaSeparatorInserter comma(writer);
writer.symbol("{");
for (std::size_t i = 0; i < a->props.size; ++i)
{
comma();
advance(a->props.data[i].location.begin);
writer.identifier(a->props.data[i].name.value);
if (a->props.data[i].type)
{
writer.symbol(":");
visualizeTypeAnnotation(*a->props.data[i].type);
}
}
if (a->indexer)
{
comma();
writer.symbol("[");
visualizeTypeAnnotation(*a->indexer->indexType);
writer.symbol("]");
writer.symbol(":");
visualizeTypeAnnotation(*a->indexer->resultType);
}
writer.symbol("}");
}
writer.symbol("}");
}
else if (auto a = typeAnnotation.as<AstTypeTypeof>())
{

View file

@ -7,6 +7,8 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauNonCopyableTypeVarFields)
namespace Luau
{
@ -79,10 +81,34 @@ void TxnLog::concat(TxnLog rhs)
void TxnLog::commit()
{
for (auto& [ty, rep] : typeVarChanges)
*asMutable(ty) = rep.get()->pending;
{
if (FFlag::LuauNonCopyableTypeVarFields)
{
asMutable(ty)->reassign(rep.get()->pending);
}
else
{
TypeArena* owningArena = ty->owningArena;
TypeVar* mtv = asMutable(ty);
*mtv = rep.get()->pending;
mtv->owningArena = owningArena;
}
}
for (auto& [tp, rep] : typePackChanges)
*asMutable(tp) = rep.get()->pending;
{
if (FFlag::LuauNonCopyableTypeVarFields)
{
asMutable(tp)->reassign(rep.get()->pending);
}
else
{
TypeArena* owningArena = tp->owningArena;
TypePackVar* mpv = asMutable(tp);
*mpv = rep.get()->pending;
mpv->owningArena = owningArena;
}
}
clear();
}
@ -144,11 +170,6 @@ bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const
return true;
}
if (parent)
{
return parent->haveSeen(lhs, rhs);
}
return false;
}
@ -173,8 +194,13 @@ PendingType* TxnLog::queue(TypeId ty)
// about this type, we don't want to mutate the parent's state.
auto& pending = typeVarChanges[ty];
if (!pending)
{
pending = std::make_unique<PendingType>(*ty);
if (FFlag::LuauNonCopyableTypeVarFields)
pending->pending.owningArena = nullptr;
}
return pending.get();
}
@ -186,8 +212,13 @@ PendingTypePack* TxnLog::queue(TypePackId tp)
// about this type, we don't want to mutate the parent's state.
auto& pending = typePackChanges[tp];
if (!pending)
{
pending = std::make_unique<PendingTypePack>(*tp);
if (FFlag::LuauNonCopyableTypeVarFields)
pending->pending.owningArena = nullptr;
}
return pending.get();
}
@ -199,8 +230,8 @@ PendingType* TxnLog::pending(TypeId ty) const
for (const TxnLog* current = this; current; current = current->parent)
{
if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end())
return it->second.get();
if (auto it = current->typeVarChanges.find(ty))
return it->get();
}
return nullptr;
@ -214,8 +245,8 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const
for (const TxnLog* current = this; current; current = current->parent)
{
if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end())
return it->second.get();
if (auto it = current->typePackChanges.find(tp))
return it->get();
}
return nullptr;
@ -224,14 +255,24 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const
PendingType* TxnLog::replace(TypeId ty, TypeVar replacement)
{
PendingType* newTy = queue(ty);
newTy->pending = replacement;
if (FFlag::LuauNonCopyableTypeVarFields)
newTy->pending.reassign(replacement);
else
newTy->pending = replacement;
return newTy;
}
PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement)
{
PendingTypePack* newTp = queue(tp);
newTp->pending = replacement;
if (FFlag::LuauNonCopyableTypeVarFields)
newTp->pending.reassign(replacement);
else
newTp->pending = replacement;
return newTp;
}

View file

@ -0,0 +1,88 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeArena.h"
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false);
namespace Luau
{
void TypeArena::clear()
{
typeVars.clear();
typePacks.clear();
}
TypeId TypeArena::addTV(TypeVar&& tv)
{
TypeId allocated = typeVars.allocate(std::move(tv));
asMutable(allocated)->owningArena = this;
return allocated;
}
TypeId TypeArena::freshType(TypeLevel level)
{
TypeId allocated = typeVars.allocate(FreeTypeVar{level});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(std::initializer_list<TypeId> types)
{
TypePackId allocated = typePacks.allocate(TypePack{std::move(types)});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(std::vector<TypeId> types)
{
TypePackId allocated = typePacks.allocate(TypePack{std::move(types)});
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(TypePack tp)
{
TypePackId allocated = typePacks.allocate(std::move(tp));
asMutable(allocated)->owningArena = this;
return allocated;
}
TypePackId TypeArena::addTypePack(TypePackVar tp)
{
TypePackId allocated = typePacks.allocate(std::move(tp));
asMutable(allocated)->owningArena = this;
return allocated;
}
void freeze(TypeArena& arena)
{
if (!FFlag::DebugLuauFreezeArena)
return;
arena.typeVars.freeze();
arena.typePacks.freeze();
}
void unfreeze(TypeArena& arena)
{
if (!FFlag::DebugLuauFreezeArena)
return;
arena.typeVars.unfreeze();
arena.typePacks.unfreeze();
}
} // namespace Luau

View file

@ -94,6 +94,21 @@ public:
}
}
AstType* operator()(const BlockedTypeVar& btv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*blocked*"));
}
AstType* operator()(const ConstrainedTypeVar& ctv)
{
AstArray<AstType*> types;
types.size = ctv.parts.size();
types.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * ctv.parts.size()));
for (size_t i = 0; i < ctv.parts.size(); ++i)
types.data[i] = Luau::visit(*this, ctv.parts[i]->ty);
return allocator->alloc<AstTypeIntersection>(Location(), types);
}
AstType* operator()(const SingletonTypeVar& stv)
{
if (const BooleanSingleton* bs = get<BooleanSingleton>(&stv))
@ -261,7 +276,7 @@ public:
}
AstArray<AstType*> returnTypes;
const auto& [retVector, retTail] = flatten(ftv.retType);
const auto& [retVector, retTail] = flatten(ftv.retTypes);
returnTypes.size = retVector.size();
returnTypes.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * returnTypes.size));
for (size_t i = 0; i < returnTypes.size; ++i)
@ -364,6 +379,9 @@ public:
AstTypePack* operator()(const VariadicTypePack& vtp) const
{
if (vtp.hidden)
return nullptr;
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*typeVisitor, vtp.ty->ty));
}
@ -466,6 +484,20 @@ public:
{
return visitLocal(al->local);
}
virtual bool visit(AstStatFor* stat) override
{
visitLocal(stat->var);
return true;
}
virtual bool visit(AstStatForIn* stat) override
{
for (size_t i = 0; i < stat->vars.size; ++i)
visitLocal(stat->vars.data[i]);
return true;
}
virtual bool visit(AstExprFunction* fn) override
{
// TODO: add generics if the inferred type of the function is generic CLI-39908

View file

@ -0,0 +1,333 @@
#include "Luau/TypeChecker2.h"
#include <algorithm>
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
#include "Luau/Clone.h"
#include "Luau/Normalize.h"
#include "Luau/ConstraintGraphBuilder.h" // FIXME move Scope2 into its own header
#include "Luau/Unifier.h"
#include "Luau/ToString.h"
namespace Luau
{
struct TypeChecker2 : public AstVisitor
{
const SourceModule* sourceModule;
Module* module;
InternalErrorReporter ice; // FIXME accept a pointer from Frontend
TypeChecker2(const SourceModule* sourceModule, Module* module)
: sourceModule(sourceModule)
, module(module)
{
}
using AstVisitor::visit;
TypePackId lookupPack(AstExpr* expr)
{
TypePackId* tp = module->astTypePacks.find(expr);
LUAU_ASSERT(tp);
return follow(*tp);
}
TypeId lookupType(AstExpr* expr)
{
TypeId* ty = module->astTypes.find(expr);
LUAU_ASSERT(ty);
return follow(*ty);
}
TypeId lookupAnnotation(AstType* annotation)
{
TypeId* ty = module->astResolvedTypes.find(annotation);
LUAU_ASSERT(ty);
return follow(*ty);
}
TypePackId reconstructPack(AstArray<AstExpr*> exprs, TypeArena& arena)
{
std::vector<TypeId> head;
for (size_t i = 0; i < exprs.size - 1; ++i)
{
head.push_back(lookupType(exprs.data[i]));
}
TypePackId tail = lookupPack(exprs.data[exprs.size - 1]);
return arena.addTypePack(TypePack{head, tail});
}
Scope2* findInnermostScope(Location location)
{
Scope2* bestScope = module->getModuleScope2();
Location bestLocation = module->scope2s[0].first;
for (size_t i = 0; i < module->scope2s.size(); ++i)
{
auto& [scopeBounds, scope] = module->scope2s[i];
if (scopeBounds.encloses(location))
{
if (scopeBounds.begin > bestLocation.begin || scopeBounds.end < bestLocation.end)
{
bestScope = scope.get();
bestLocation = scopeBounds;
}
}
else
{
// TODO: Is this sound? This relies on the fact that scopes are inserted
// into the scope list in the order that they appear in the AST.
break;
}
}
return bestScope;
}
bool visit(AstStatLocal* local) override
{
for (size_t i = 0; i < local->values.size; ++i)
{
AstExpr* value = local->values.data[i];
if (i == local->values.size - 1)
{
if (i < local->values.size)
{
TypePackId valueTypes = lookupPack(value);
auto it = begin(valueTypes);
for (size_t j = i; j < local->vars.size; ++j)
{
if (it == end(valueTypes))
{
break;
}
AstLocal* var = local->vars.data[i];
if (var->annotation)
{
TypeId varType = lookupAnnotation(var->annotation);
if (!isSubtype(*it, varType, ice))
{
reportError(TypeMismatch{varType, *it}, value->location);
}
}
++it;
}
}
}
else
{
TypeId valueType = lookupType(value);
AstLocal* var = local->vars.data[i];
if (var->annotation)
{
TypeId varType = lookupAnnotation(var->annotation);
if (!isSubtype(varType, valueType, ice))
{
reportError(TypeMismatch{varType, valueType}, value->location);
}
}
}
}
return true;
}
bool visit(AstStatAssign* assign) override
{
size_t count = std::min(assign->vars.size, assign->values.size);
for (size_t i = 0; i < count; ++i)
{
AstExpr* lhs = assign->vars.data[i];
TypeId* lhsType = module->astTypes.find(lhs);
LUAU_ASSERT(lhsType);
AstExpr* rhs = assign->values.data[i];
TypeId* rhsType = module->astTypes.find(rhs);
LUAU_ASSERT(rhsType);
if (!isSubtype(*rhsType, *lhsType, ice))
{
reportError(TypeMismatch{*lhsType, *rhsType}, rhs->location);
}
}
return true;
}
bool visit(AstStatReturn* ret) override
{
Scope2* scope = findInnermostScope(ret->location);
TypePackId expectedRetType = scope->returnType;
TypeArena arena;
TypePackId actualRetType = reconstructPack(ret->list, arena);
UnifierSharedState sharedState{&ice};
Unifier u{&arena, Mode::Strict, ret->location, Covariant, sharedState};
u.anyIsTop = true;
u.tryUnify(actualRetType, expectedRetType);
const bool ok = u.errors.empty() && u.log.empty();
if (!ok)
{
for (const TypeError& e : u.errors)
module->errors.push_back(e);
}
return true;
}
bool visit(AstExprCall* call) override
{
TypePackId expectedRetType = lookupPack(call);
TypeId functionType = lookupType(call->func);
TypeArena arena;
TypePack args;
for (const auto& arg : call->args)
{
TypeId argTy = module->astTypes[arg];
LUAU_ASSERT(argTy);
args.head.push_back(argTy);
}
TypePackId argsTp = arena.addTypePack(args);
FunctionTypeVar ftv{argsTp, expectedRetType};
TypeId expectedType = arena.addType(ftv);
if (!isSubtype(expectedType, functionType, ice))
{
unfreeze(module->interfaceTypes);
CloneState cloneState;
expectedType = clone(expectedType, module->interfaceTypes, cloneState);
freeze(module->interfaceTypes);
reportError(TypeMismatch{expectedType, functionType}, call->location);
}
return true;
}
bool visit(AstExprFunction* fn) override
{
TypeId inferredFnTy = lookupType(fn);
const FunctionTypeVar* inferredFtv = get<FunctionTypeVar>(inferredFnTy);
LUAU_ASSERT(inferredFtv);
auto argIt = begin(inferredFtv->argTypes);
for (const auto& arg : fn->args)
{
if (argIt == end(inferredFtv->argTypes))
break;
if (arg->annotation)
{
TypeId inferredArgTy = *argIt;
TypeId annotatedArgTy = lookupAnnotation(arg->annotation);
if (!isSubtype(annotatedArgTy, inferredArgTy, ice))
{
reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location);
}
}
++argIt;
}
return true;
}
bool visit(AstExprIndexName* indexName) override
{
TypeId leftType = lookupType(indexName->expr);
TypeId resultType = lookupType(indexName);
// leftType must have a property called indexName->index
if (auto ttv = get<TableTypeVar>(leftType))
{
auto it = ttv->props.find(indexName->index.value);
if (it == ttv->props.end())
{
reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location);
}
else if (!isSubtype(resultType, it->second.type, ice))
{
reportError(TypeMismatch{resultType, it->second.type}, indexName->location);
}
}
else
{
reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location);
}
return true;
}
bool visit(AstExprConstantNumber* number) override
{
TypeId actualType = lookupType(number);
TypeId numberType = getSingletonTypes().numberType;
if (!isSubtype(actualType, numberType, ice))
{
reportError(TypeMismatch{actualType, numberType}, number->location);
}
return true;
}
bool visit(AstExprConstantString* string) override
{
TypeId actualType = lookupType(string);
TypeId stringType = getSingletonTypes().stringType;
if (!isSubtype(actualType, stringType, ice))
{
reportError(TypeMismatch{actualType, stringType}, string->location);
}
return true;
}
bool visit(AstType* ty) override
{
return true;
}
bool visit(AstTypeReference* ty) override
{
Scope2* scope = findInnermostScope(ty->location);
// TODO: Imported types
// TODO: Generic types
if (!scope->lookupTypeBinding(ty->name.value))
{
reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location);
}
return true;
}
void reportError(TypeErrorData&& data, const Location& location)
{
module->errors.emplace_back(location, sourceModule->name, std::move(data));
}
};
void check(const SourceModule& sourceModule, Module* module)
{
TypeChecker2 typeChecker{&sourceModule, module};
sourceModule.root->visit(&typeChecker);
}
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -5,6 +5,8 @@
#include <stdexcept>
LUAU_FASTFLAG(LuauNonCopyableTypeVarFields)
namespace Luau
{
@ -36,6 +38,25 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp)
return *this;
}
TypePackVar& TypePackVar::operator=(const TypePackVar& rhs)
{
if (FFlag::LuauNonCopyableTypeVarFields)
{
LUAU_ASSERT(owningArena == rhs.owningArena);
LUAU_ASSERT(!rhs.persistent);
reassign(rhs);
}
else
{
ty = rhs.ty;
persistent = rhs.persistent;
owningArena = rhs.owningArena;
}
return *this;
}
TypePackIterator::TypePackIterator(TypePackId typePack)
: TypePackIterator(typePack, TxnLog::empty())
{
@ -104,7 +125,7 @@ TypePackIterator begin(TypePackId tp)
return TypePackIterator{tp};
}
TypePackIterator begin(TypePackId tp, TxnLog* log)
TypePackIterator begin(TypePackId tp, const TxnLog* log)
{
return TypePackIterator{tp, log};
}
@ -256,7 +277,7 @@ size_t size(const TypePack& tp, TxnLog* log)
return result;
}
std::optional<TypeId> first(TypePackId tp)
std::optional<TypeId> first(TypePackId tp, bool ignoreHiddenVariadics)
{
auto it = begin(tp);
auto endIter = end(tp);
@ -266,7 +287,7 @@ std::optional<TypeId> first(TypePackId tp)
if (auto tail = it.tail())
{
if (auto vtp = get<VariadicTypePack>(*tail))
if (auto vtp = get<VariadicTypePack>(*tail); vtp && (!vtp->hidden || !ignoreHiddenVariadics))
return vtp->ty;
}
@ -299,6 +320,46 @@ std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp)
return {res, iter.tail()};
}
std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp, const TxnLog& log)
{
tp = log.follow(tp);
std::vector<TypeId> flattened;
std::optional<TypePackId> tail = std::nullopt;
TypePackIterator it(tp, &log);
for (; it != end(tp); ++it)
{
flattened.push_back(*it);
}
tail = it.tail();
return {flattened, tail};
}
bool isVariadic(TypePackId tp)
{
return isVariadic(tp, *TxnLog::empty());
}
bool isVariadic(TypePackId tp, const TxnLog& log)
{
std::optional<TypePackId> tail = flatten(tp, log).second;
if (!tail)
return false;
if (log.get<GenericTypePack>(*tail))
return true;
if (auto vtp = log.get<VariadicTypePack>(*tail); vtp && !vtp->hidden)
return true;
return false;
}
TypePackVar* asMutable(TypePackId tp)
{
return const_cast<TypePackVar*>(tp);

View file

@ -5,8 +5,6 @@
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h"
LUAU_FASTFLAGVARIABLE(LuauTerminateCyclicMetatableIndexLookup, false)
namespace Luau
{
@ -55,13 +53,10 @@ std::optional<TypeId> findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t
{
TypeId index = follow(*mtIndex);
if (FFlag::LuauTerminateCyclicMetatableIndexLookup)
{
if (count >= 100)
return std::nullopt;
if (count >= 100)
return std::nullopt;
++count;
}
++count;
if (const auto& itt = getTableType(index))
{
@ -71,7 +66,7 @@ std::optional<TypeId> findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t
}
else if (const auto& itf = get<FunctionTypeVar>(index))
{
std::optional<TypeId> r = first(follow(itf->retType));
std::optional<TypeId> r = first(follow(itf->retTypes));
if (!r)
return getSingletonTypes().nilType;
else

View file

@ -23,16 +23,13 @@ LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauErrorRecoveryType)
LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables)
LUAU_FASTFLAG(LuauDiscriminableUnions2)
LUAU_FASTFLAGVARIABLE(LuauAnyInIsOptionalIsOptional, false)
LUAU_FASTFLAG(LuauNonCopyableTypeVarFields)
namespace Luau
{
std::optional<ExprResult<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult);
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
TypeId follow(TypeId t)
{
@ -174,22 +171,15 @@ bool isString(TypeId ty)
// Returns true when ty is a supertype of string
bool maybeString(TypeId ty)
{
if (FFlag::LuauSubtypingAddOptPropsToUnsealedTables)
{
ty = follow(ty);
if (isPrim(ty, PrimitiveTypeVar::String) || get<AnyTypeVar>(ty))
return true;
ty = follow(ty);
if (auto utv = get<UnionTypeVar>(ty))
return std::any_of(begin(utv), end(utv), maybeString);
if (isPrim(ty, PrimitiveTypeVar::String) || get<AnyTypeVar>(ty))
return true;
return false;
}
else
{
return isString(ty);
}
if (auto utv = get<UnionTypeVar>(ty))
return std::any_of(begin(utv), end(utv), maybeString);
return false;
}
bool isThread(TypeId ty)
@ -204,14 +194,14 @@ bool isOptional(TypeId ty)
ty = follow(ty);
if (FFlag::LuauAnyInIsOptionalIsOptional && get<AnyTypeVar>(ty))
if (get<AnyTypeVar>(ty))
return true;
auto utv = get<UnionTypeVar>(ty);
if (!utv)
return false;
return std::any_of(begin(utv), end(utv), FFlag::LuauAnyInIsOptionalIsOptional ? isOptional : isNil);
return std::any_of(begin(utv), end(utv), isOptional);
}
bool isTableIntersection(TypeId ty)
@ -304,6 +294,11 @@ std::optional<ModuleName> getDefinitionModuleName(TypeId type)
if (ftv->definition)
return ftv->definition->definitionModuleName;
}
else if (auto ctv = get<ClassTypeVar>(type))
{
if (!ctv->definitionModuleName.empty())
return ctv->definitionModuleName;
}
return std::nullopt;
}
@ -373,8 +368,7 @@ bool hasLength(TypeId ty, DenseHashSet<TypeId>& seen, int* recursionCount)
if (seen.contains(ty))
return true;
bool isStr = FFlag::LuauDiscriminableUnions2 ? isString(ty) : isPrim(ty, PrimitiveTypeVar::String);
if (isStr || get<AnyTypeVar>(ty) || get<TableTypeVar>(ty) || get<MetatableTypeVar>(ty))
if (isString(ty) || get<AnyTypeVar>(ty) || get<TableTypeVar>(ty) || get<MetatableTypeVar>(ty))
return true;
if (auto uty = get<UnionTypeVar>(ty))
@ -406,41 +400,48 @@ bool hasLength(TypeId ty, DenseHashSet<TypeId>& seen, int* recursionCount)
return false;
}
FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional<FunctionDefinition> defn, bool hasSelf)
BlockedTypeVar::BlockedTypeVar()
: index(++nextIndex)
{
}
int BlockedTypeVar::nextIndex = 0;
FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn, bool hasSelf)
: argTypes(argTypes)
, retType(retType)
, retTypes(retTypes)
, definition(std::move(defn))
, hasSelf(hasSelf)
{
}
FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional<FunctionDefinition> defn, bool hasSelf)
FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn, bool hasSelf)
: level(level)
, argTypes(argTypes)
, retType(retType)
, retTypes(retTypes)
, definition(std::move(defn))
, hasSelf(hasSelf)
{
}
FunctionTypeVar::FunctionTypeVar(std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retType,
FunctionTypeVar::FunctionTypeVar(std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes, TypePackId retTypes,
std::optional<FunctionDefinition> defn, bool hasSelf)
: generics(generics)
, genericPacks(genericPacks)
, argTypes(argTypes)
, retType(retType)
, retTypes(retTypes)
, definition(std::move(defn))
, hasSelf(hasSelf)
{
}
FunctionTypeVar::FunctionTypeVar(TypeLevel level, std::vector<TypeId> generics, std::vector<TypePackId> genericPacks, TypePackId argTypes,
TypePackId retType, std::optional<FunctionDefinition> defn, bool hasSelf)
TypePackId retTypes, std::optional<FunctionDefinition> defn, bool hasSelf)
: level(level)
, generics(generics)
, genericPacks(genericPacks)
, argTypes(argTypes)
, retType(retType)
, retTypes(retTypes)
, definition(std::move(defn))
, hasSelf(hasSelf)
{
@ -486,7 +487,7 @@ bool areEqual(SeenSet& seen, const FunctionTypeVar& lhs, const FunctionTypeVar&
if (!areEqual(seen, *lhs.argTypes, *rhs.argTypes))
return false;
if (!areEqual(seen, *lhs.retType, *rhs.retType))
if (!areEqual(seen, *lhs.retTypes, *rhs.retTypes))
return false;
return true;
@ -643,6 +644,26 @@ TypeVar& TypeVar::operator=(TypeVariant&& rhs)
return *this;
}
TypeVar& TypeVar::operator=(const TypeVar& rhs)
{
if (FFlag::LuauNonCopyableTypeVarFields)
{
LUAU_ASSERT(owningArena == rhs.owningArena);
LUAU_ASSERT(!rhs.persistent);
reassign(rhs);
}
else
{
ty = rhs.ty;
persistent = rhs.persistent;
normal = rhs.normal;
owningArena = rhs.owningArena;
}
return *this;
}
TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initializer_list<TypeId> generics,
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes);
@ -652,9 +673,10 @@ static TypeVar numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persist
static TypeVar stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true};
static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true};
static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true};
static TypeVar anyType_{AnyTypeVar{}};
static TypeVar errorType_{ErrorTypeVar{}};
static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}};
static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true};
static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true};
static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true};
static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true};
static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true};
static TypePackVar errorTypePack_{Unifiable::Error{}};
@ -665,8 +687,9 @@ SingletonTypes::SingletonTypes()
, stringType(&stringType_)
, booleanType(&booleanType_)
, threadType(&threadType_)
, trueType(&trueType_)
, falseType(&falseType_)
, anyType(&anyType_)
, optionalNumberType(&optionalNumberType_)
, anyTypePack(&anyTypePack_)
, arena(new TypeArena)
{
@ -694,7 +717,7 @@ TypeId SingletonTypes::makeStringMetatable()
{
const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}});
const TypeId optionalString = arena->addType(UnionTypeVar{{nilType, stringType}});
const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, &booleanType_}});
const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, booleanType}});
const TypePackId oneStringPack = arena->addTypePack({stringType});
const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true});
@ -718,14 +741,16 @@ TypeId SingletonTypes::makeStringMetatable()
TableTypeVar::Props stringLib = {
{"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}},
{"char", {arena->addType(FunctionTypeVar{arena->addTypePack(TypePack{{numberType}, numberVariadicList}), arena->addTypePack({stringType})})}},
{"find", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber, optionalBoolean}, {}, {optionalNumber, optionalNumber})}},
{"char", {arena->addType(FunctionTypeVar{numberVariadicList, arena->addTypePack({stringType})})}},
{"find", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})})}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber}, {}, {optionalString})}},
{"match", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}),
arena->addTypePack(TypePackVar{VariadicTypePack{optionalString}})})}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
@ -765,18 +790,12 @@ TypePackId SingletonTypes::errorRecoveryTypePack()
TypeId SingletonTypes::errorRecoveryType(TypeId guess)
{
if (FFlag::LuauErrorRecoveryType)
return guess;
else
return &errorType_;
return guess;
}
TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess)
{
if (FFlag::LuauErrorRecoveryType)
return guess;
else
return &errorTypePack_;
return guess;
}
SingletonTypes& getSingletonTypes()
@ -798,13 +817,14 @@ void persist(TypeId ty)
continue;
asMutable(t)->persistent = true;
asMutable(t)->normal = true; // all persistent types are assumed to be normal
if (auto btv = get<BoundTypeVar>(t))
queue.push_back(btv->boundTo);
else if (auto ftv = get<FunctionTypeVar>(t))
{
persist(ftv->argTypes);
persist(ftv->retType);
persist(ftv->retTypes);
}
else if (auto ttv = get<TableTypeVar>(t))
{
@ -834,6 +854,11 @@ void persist(TypeId ty)
for (TypeId opt : itv->parts)
queue.push_back(opt);
}
else if (auto ctv = get<ConstrainedTypeVar>(t))
{
for (TypeId opt : ctv->parts)
queue.push_back(opt);
}
else if (auto mtv = get<MetatableTypeVar>(t))
{
queue.push_back(mtv->table);
@ -895,6 +920,16 @@ TypeLevel* getMutableLevel(TypeId ty)
return const_cast<TypeLevel*>(getLevel(ty));
}
std::optional<TypeLevel> getLevel(TypePackId tp)
{
tp = follow(tp);
if (auto ftv = get<Unifiable::Free>(tp))
return ftv->level;
else
return std::nullopt;
}
const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name)
{
while (cls)
@ -1064,10 +1099,10 @@ static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const cha
return result;
}
std::optional<ExprResult<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult)
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = exprResult;
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
@ -1106,7 +1141,7 @@ std::optional<ExprResult<TypePackId>> magicFunctionFormat(
if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize))
typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}});
return ExprResult<TypePackId>{arena.addTypePack({typechecker.stringType})};
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
}
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate)

View file

@ -12,12 +12,16 @@ Free::Free(TypeLevel level)
{
}
Free::Free(Scope2* scope)
: scope(scope)
{
}
int Free::nextIndex = 0;
Generic::Generic()
: index(++nextIndex)
, name("g" + std::to_string(index))
, explicitName(false)
{
}
@ -25,7 +29,6 @@ Generic::Generic(TypeLevel level)
: index(++nextIndex)
, level(level)
, name("g" + std::to_string(index))
, explicitName(false)
{
}
@ -36,6 +39,12 @@ Generic::Generic(const Name& name)
{
}
Generic::Generic(Scope2* scope)
: index(++nextIndex)
, scope(scope)
{
}
Generic::Generic(TypeLevel level, const Name& name)
: index(++nextIndex)
, level(level)
@ -44,6 +53,14 @@ Generic::Generic(TypeLevel level, const Name& name)
{
}
Generic::Generic(Scope2* scope, const Name& name)
: index(++nextIndex)
, scope(scope)
, name(name)
, explicitName(true)
{
}
int Generic::nextIndex = 0;
Error::Error()

File diff suppressed because it is too large Load diff

View file

@ -313,7 +313,7 @@ template<typename T>
struct AstArray
{
T* data;
std::size_t size;
size_t size;
const T* begin() const
{

View file

@ -32,6 +32,7 @@ class DenseHashTable
{
public:
class const_iterator;
class iterator;
DenseHashTable(const Key& empty_key, size_t buckets = 0)
: count(0)
@ -43,7 +44,7 @@ public:
// don't move this to initializer list! this works around an MSVC codegen issue on AMD CPUs:
// https://developercommunity.visualstudio.com/t/stdvector-constructor-from-size-t-is-25-times-slow/1546547
if (buckets)
data.resize(buckets, ItemInterface::create(empty_key));
resize_data<Item>(buckets);
}
void clear()
@ -125,7 +126,7 @@ public:
if (data.empty() && data.capacity() >= newsize)
{
LUAU_ASSERT(count == 0);
data.resize(newsize, ItemInterface::create(empty_key));
resize_data<Item>(newsize);
return;
}
@ -169,6 +170,21 @@ public:
return const_iterator(this, data.size());
}
iterator begin()
{
size_t start = 0;
while (start < data.size() && eq(ItemInterface::getKey(data[start]), empty_key))
start++;
return iterator(this, start);
}
iterator end()
{
return iterator(this, data.size());
}
size_t size() const
{
return count;
@ -233,7 +249,82 @@ public:
size_t index;
};
class iterator
{
public:
iterator()
: set(0)
, index(0)
{
}
iterator(DenseHashTable<Key, Item, MutableItem, ItemInterface, Hash, Eq>* set, size_t index)
: set(set)
, index(index)
{
}
MutableItem& operator*() const
{
return *reinterpret_cast<MutableItem*>(&set->data[index]);
}
MutableItem* operator->() const
{
return reinterpret_cast<MutableItem*>(&set->data[index]);
}
bool operator==(const iterator& other) const
{
return set == other.set && index == other.index;
}
bool operator!=(const iterator& other) const
{
return set != other.set || index != other.index;
}
iterator& operator++()
{
size_t size = set->data.size();
do
{
index++;
} while (index < size && set->eq(ItemInterface::getKey(set->data[index]), set->empty_key));
return *this;
}
iterator operator++(int)
{
iterator res = *this;
++*this;
return res;
}
private:
DenseHashTable<Key, Item, MutableItem, ItemInterface, Hash, Eq>* set;
size_t index;
};
private:
template<typename T>
void resize_data(size_t count, typename std::enable_if_t<std::is_copy_assignable_v<T>>* dummy = nullptr)
{
data.resize(count, ItemInterface::create(empty_key));
}
template<typename T>
void resize_data(size_t count, typename std::enable_if_t<!std::is_copy_assignable_v<T>>* dummy = nullptr)
{
size_t size = data.size();
data.resize(count);
for (size_t i = size; i < count; i++)
data[i].first = empty_key;
}
std::vector<Item> data;
size_t count;
Key empty_key;
@ -290,6 +381,7 @@ class DenseHashSet
public:
typedef typename Impl::const_iterator const_iterator;
typedef typename Impl::iterator iterator;
DenseHashSet(const Key& empty_key, size_t buckets = 0)
: impl(empty_key, buckets)
@ -336,6 +428,16 @@ public:
{
return impl.end();
}
iterator begin()
{
return impl.begin();
}
iterator end()
{
return impl.end();
}
};
// This is a faster alternative of unordered_map, but it does not implement the same interface (i.e. it does not support erasing and has
@ -348,6 +450,7 @@ class DenseHashMap
public:
typedef typename Impl::const_iterator const_iterator;
typedef typename Impl::iterator iterator;
DenseHashMap(const Key& empty_key, size_t buckets = 0)
: impl(empty_key, buckets)
@ -401,10 +504,21 @@ public:
{
return impl.begin();
}
const_iterator end() const
{
return impl.end();
}
iterator begin()
{
return impl.begin();
}
iterator end()
{
return impl.end();
}
};
} // namespace Luau

View file

@ -173,7 +173,7 @@ public:
}
const Lexeme& next();
const Lexeme& next(bool skipComments);
const Lexeme& next(bool skipComments, bool updatePrevLocation);
void nextline();
Lexeme lookahead();

View file

@ -19,6 +19,7 @@ std::string format(const char* fmt, ...) LUAU_PRINTF_ATTR(1, 2);
std::string vformat(const char* fmt, va_list args);
void formatAppend(std::string& str, const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3);
void vformatAppend(std::string& ret, const char* fmt, va_list args);
std::string join(const std::vector<std::string_view>& segments, std::string_view delimiter);
std::string join(const std::vector<std::string>& segments, std::string_view delimiter);

View file

@ -1,7 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Common.h"
#include "Luau/Common.h"
#include <vector>
@ -9,14 +9,21 @@
LUAU_FASTFLAG(DebugLuauTimeTracing)
namespace Luau
{
namespace TimeTrace
{
double getClock();
uint32_t getClockMicroseconds();
} // namespace TimeTrace
} // namespace Luau
#if defined(LUAU_ENABLE_TIME_TRACE)
namespace Luau
{
namespace TimeTrace
{
uint32_t getClockMicroseconds();
struct Token
{
const char* name;
@ -130,8 +137,8 @@ ThreadContext& getThreadContext();
struct Scope
{
explicit Scope(ThreadContext& context, uint16_t token)
: context(context)
explicit Scope(uint16_t token)
: context(getThreadContext())
{
if (!FFlag::DebugLuauTimeTracing)
return;
@ -152,8 +159,8 @@ struct Scope
struct OptionalTailScope
{
explicit OptionalTailScope(ThreadContext& context, uint16_t token, uint32_t threshold)
: context(context)
explicit OptionalTailScope(uint16_t token, uint32_t threshold)
: context(getThreadContext())
, token(token)
, threshold(threshold)
{
@ -188,27 +195,27 @@ struct OptionalTailScope
uint32_t pos;
};
LUAU_NOINLINE std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeData(const char* name, const char* category);
LUAU_NOINLINE uint16_t createScopeData(const char* name, const char* category);
} // namespace TimeTrace
} // namespace Luau
// Regular scope
#define LUAU_TIMETRACE_SCOPE(name, category) \
static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first)
static uint16_t lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::Scope lttScope(lttScopeStatic)
// A scope without nested scopes that may be skipped if the time it took is less than the threshold
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \
static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec)
static uint16_t lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail, microsec)
// Extra key/value data can be added to regular scopes
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
if (FFlag::DebugLuauTimeTracing) \
lttScopeStatic.second.eventArgument(name, value); \
lttScope.context.eventArgument(name, value); \
} while (false)
#else

View file

@ -347,10 +347,10 @@ void Lexer::setReadNames(bool read)
const Lexeme& Lexer::next()
{
return next(this->skipComments);
return next(this->skipComments, true);
}
const Lexeme& Lexer::next(bool skipComments)
const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation)
{
// in skipComments mode we reject valid comments
do
@ -359,9 +359,11 @@ const Lexeme& Lexer::next(bool skipComments)
while (isSpace(peekch()))
consume();
prevLocation = lexeme.location;
if (updatePrevLocation)
prevLocation = lexeme.location;
lexeme = readNext();
updatePrevLocation = false;
} while (skipComments && (lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment));
return lexeme;

View file

@ -11,6 +11,9 @@
LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000)
LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false)
LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false)
namespace Luau
{
@ -165,6 +168,7 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc
Function top;
top.vararg = true;
functionStack.reserve(8);
functionStack.push_back(top);
nameSelf = names.addStatic("self");
@ -184,6 +188,13 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc
// all hot comments parsed after the first non-comment lexeme are special in that they don't affect type checking / linting mode
hotcommentHeader = false;
// preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays
localStack.reserve(16);
scratchStat.reserve(16);
scratchExpr.reserve(16);
scratchLocal.reserve(16);
scratchBinding.reserve(16);
}
bool Parser::blockFollow(const Lexeme& l)
@ -1108,8 +1119,12 @@ AstTypePack* Parser::parseTypeList(TempVector<AstType*>& result, TempVector<std:
std::optional<AstTypeList> Parser::parseOptionalReturnTypeAnnotation()
{
if (options.allowTypeAnnotations && lexer.current().type == ':')
if (options.allowTypeAnnotations &&
(lexer.current().type == ':' || (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow)))
{
if (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow)
report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'");
nextLexeme();
unsigned int oldRecursionCount = recursionCounter;
@ -1340,8 +1355,12 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
AstArray<AstType*> paramTypes = copy(params);
bool returnTypeIntroducer =
FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false;
// Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element
if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow)
if (params.size() == 1 && !varargAnnotation && monomorphic &&
(FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow))
{
if (allowPack)
return {{}, allocator.alloc<AstTypePackExplicit>(begin.location, AstTypeList{paramTypes, nullptr})};
@ -1349,7 +1368,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
return {params[0], {}};
}
if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack)
if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && monomorphic && allowPack)
return {{}, allocator.alloc<AstTypePackExplicit>(begin.location, AstTypeList{paramTypes, varargAnnotation})};
AstArray<std::optional<AstArgumentName>> paramNames = copy(names);
@ -1363,8 +1382,13 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<A
{
incrementRecursionCounter("type annotation");
if (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == ':')
{
report(lexer.current().location, "Return types in function type annotations are written after '->' instead of ':'");
lexer.next();
}
// Users occasionally write '()' as the 'unit' type when they actually want to use 'nil', here we'll try to give a more specific error
if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0)
else if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0)
{
report(Location(begin.location, lexer.previousLocation()), "Expected '->' after '()' when parsing function type; did you mean 'nil'?");
@ -1420,6 +1444,11 @@ AstType* Parser::parseTypeAnnotation(TempVector<AstType*>& parts, const Location
parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type);
isIntersection = true;
}
else if (c == Lexeme::Dot3)
{
report(lexer.current().location, "Unexpected '...' after type annotation");
nextLexeme();
}
else
break;
}
@ -1536,6 +1565,11 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
prefix = name.name;
name = parseIndexName("field name", pointPosition);
}
else if (lexer.current().type == Lexeme::Dot3)
{
report(lexer.current().location, "Unexpected '...' after type name; type pack is not allowed in this context");
nextLexeme();
}
else if (name.name == "typeof")
{
Lexeme typeofBegin = lexer.current();
@ -1571,6 +1605,17 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
{
return parseFunctionTypeAnnotation(allowPack);
}
else if (FFlag::LuauParserFunctionKeywordAsTypeHelp && lexer.current().type == Lexeme::ReservedFunction)
{
Location location = lexer.current().location;
nextLexeme();
return {reportTypeAnnotationError(location, {}, /*isMissing*/ false,
"Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> "
"...any'"),
{}};
}
else
{
Location location = lexer.current().location;
@ -2778,7 +2823,7 @@ void Parser::nextLexeme()
{
if (options.captureComments)
{
Lexeme::Type type = lexer.next(/* skipComments= */ false).type;
Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type;
while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment)
{
@ -2802,7 +2847,7 @@ void Parser::nextLexeme()
hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)});
}
type = lexer.next(/* skipComments= */ false).type;
type = lexer.next(/* skipComments= */ false, /* updatePrevLocation= */ false).type;
}
}
else

View file

@ -11,7 +11,7 @@
namespace Luau
{
static void vformatAppend(std::string& ret, const char* fmt, va_list args)
void vformatAppend(std::string& ret, const char* fmt, va_list args)
{
va_list argscopy;
va_copy(argscopy, args);

View file

@ -26,9 +26,6 @@
#include <time.h>
LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false)
#if defined(LUAU_ENABLE_TIME_TRACE)
namespace Luau
{
namespace TimeTrace
@ -67,6 +64,14 @@ static double getClockTimestamp()
#endif
}
double getClock()
{
static double period = getClockPeriod();
static double start = getClockTimestamp();
return (getClockTimestamp() - start) * period;
}
uint32_t getClockMicroseconds()
{
static double period = getClockPeriod() * 1e6;
@ -74,7 +79,15 @@ uint32_t getClockMicroseconds()
return uint32_t((getClockTimestamp() - start) * period);
}
} // namespace TimeTrace
} // namespace Luau
#if defined(LUAU_ENABLE_TIME_TRACE)
namespace Luau
{
namespace TimeTrace
{
struct GlobalContext
{
GlobalContext() = default;
@ -246,10 +259,9 @@ ThreadContext& getThreadContext()
return context;
}
std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeData(const char* name, const char* category)
uint16_t createScopeData(const char* name, const char* category)
{
uint16_t token = createToken(Luau::TimeTrace::getGlobalContext(), name, category);
return {token, Luau::TimeTrace::getThreadContext()};
return createToken(Luau::TimeTrace::getGlobalContext(), name, category);
}
} // namespace TimeTrace
} // namespace Luau

View file

@ -9,6 +9,7 @@
#include "FileUtils.h"
LUAU_FASTFLAG(DebugLuauTimeTracing)
LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution)
enum class ReportFormat
{
@ -49,6 +50,9 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con
if (const Luau::SyntaxError* syntaxError = Luau::get_if<Luau::SyntaxError>(&error.data))
report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str());
else if (FFlag::LuauTypeMismatchModuleNameResolution)
report(format, humanReadableName.c_str(), error.location, "TypeError",
Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str());
else
report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str());
}

View file

@ -240,7 +240,7 @@ std::optional<std::string> getParentPath(const std::string& path)
return std::nullopt;
#endif
std::string::size_type slash = path.find_last_of("\\/", path.size() - 1);
size_t slash = path.find_last_of("\\/", path.size() - 1);
if (slash == 0)
return "/";
@ -253,7 +253,7 @@ std::optional<std::string> getParentPath(const std::string& path)
static std::string getExtension(const std::string& path)
{
std::string::size_type dot = path.find_last_of(".\\/");
size_t dot = path.find_last_of(".\\/");
if (dot == std::string::npos || path[dot] != '.')
return "";

View file

@ -21,6 +21,8 @@
#include <fcntl.h>
#endif
#include <locale.h>
LUAU_FASTFLAG(DebugLuauTimeTracing)
enum class CliMode
@ -34,7 +36,8 @@ enum class CliMode
enum class CompileFormat
{
Text,
Binary
Binary,
Null
};
constexpr int MaxTraversalLimit = 50;
@ -434,6 +437,9 @@ static void runReplImpl(lua_State* L)
{
ic_set_default_completer(completeRepl, L);
// Reset the locale to C
setlocale(LC_ALL, "C");
// Make brace matching easier to see
ic_style_def("ic-bracematch", "teal");
@ -579,7 +585,8 @@ static bool compileFile(const char* name, CompileFormat format)
if (format == CompileFormat::Text)
{
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals);
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals |
Luau::BytecodeBuilder::Dump_Remarks);
bcb.setDumpSource(*source);
}
@ -593,6 +600,8 @@ static bool compileFile(const char* name, CompileFormat format)
case CompileFormat::Binary:
fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout);
break;
case CompileFormat::Null:
break;
}
return true;
@ -636,13 +645,60 @@ static int assertionHandler(const char* expr, const char* file, int line, const
return 1;
}
static void setLuauFlags(bool state)
{
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
{
if (strncmp(flag->name, "Luau", 4) == 0)
flag->value = state;
}
}
static void setFlag(std::string_view name, bool state)
{
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
{
if (name == flag->name)
{
flag->value = state;
return;
}
}
fprintf(stderr, "Warning: --fflag unrecognized flag '%.*s'.\n\n", int(name.length()), name.data());
}
static void applyFlagKeyValue(std::string_view element)
{
if (size_t separator = element.find('='); separator != std::string_view::npos)
{
std::string_view key = element.substr(0, separator);
std::string_view value = element.substr(separator + 1);
if (value == "true")
setFlag(key, true);
else if (value == "false")
setFlag(key, false);
else
fprintf(stderr, "Warning: --fflag unrecognized value '%.*s' for flag '%.*s'.\n\n", int(value.length()), value.data(), int(key.length()),
key.data());
}
else
{
if (element == "true")
setLuauFlags(true);
else if (element == "false")
setLuauFlags(false);
else
setFlag(element, true);
}
}
int replMain(int argc, char** argv)
{
Luau::assertHandler() = assertionHandler;
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
if (strncmp(flag->name, "Luau", 4) == 0)
flag->value = true;
setLuauFlags(true);
CliMode mode = CliMode::Unknown;
CompileFormat compileFormat{};
@ -668,6 +724,10 @@ int replMain(int argc, char** argv)
{
compileFormat = CompileFormat::Text;
}
else if (strcmp(argv[1], "--compile=null") == 0)
{
compileFormat = CompileFormat::Null;
}
else
{
fprintf(stderr, "Error: Unrecognized value for '--compile' specified.\n");
@ -727,6 +787,22 @@ int replMain(int argc, char** argv)
return 1;
#endif
}
else if (strncmp(argv[i], "--fflags=", 9) == 0)
{
std::string_view list = argv[i] + 9;
while (!list.empty())
{
size_t ending = list.find(",");
applyFlagKeyValue(list.substr(0, ending));
if (ending != std::string_view::npos)
list.remove_prefix(ending + 1);
else
break;
}
}
else if (argv[i][0] == '-')
{
fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]);

View file

@ -11,6 +11,7 @@ option(LUAU_BUILD_TESTS "Build tests" ON)
option(LUAU_BUILD_WEB "Build Web module" OFF)
option(LUAU_WERROR "Warnings as errors" OFF)
option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF)
option(LUAU_EXTERN_C "Use extern C for all APIs" OFF)
if(LUAU_STATIC_CRT)
cmake_minimum_required(VERSION 3.15)
@ -19,9 +20,11 @@ if(LUAU_STATIC_CRT)
endif()
project(Luau LANGUAGES CXX C)
add_library(Luau.Common INTERFACE)
add_library(Luau.Ast STATIC)
add_library(Luau.Compiler STATIC)
add_library(Luau.Analysis STATIC)
add_library(Luau.CodeGen STATIC)
add_library(Luau.VM STATIC)
add_library(isocline STATIC)
@ -48,8 +51,11 @@ endif()
include(Sources.cmake)
target_include_directories(Luau.Common INTERFACE Common/include)
target_compile_features(Luau.Ast PUBLIC cxx_std_17)
target_include_directories(Luau.Ast PUBLIC Ast/include)
target_link_libraries(Luau.Ast PUBLIC Luau.Common)
target_compile_features(Luau.Compiler PUBLIC cxx_std_17)
target_include_directories(Luau.Compiler PUBLIC Compiler/include)
@ -59,8 +65,13 @@ target_compile_features(Luau.Analysis PUBLIC cxx_std_17)
target_include_directories(Luau.Analysis PUBLIC Analysis/include)
target_link_libraries(Luau.Analysis PUBLIC Luau.Ast)
target_compile_features(Luau.CodeGen PRIVATE cxx_std_17)
target_include_directories(Luau.CodeGen PUBLIC CodeGen/include)
target_link_libraries(Luau.CodeGen PUBLIC Luau.Common)
target_compile_features(Luau.VM PRIVATE cxx_std_11)
target_include_directories(Luau.VM PUBLIC VM/include)
target_link_libraries(Luau.VM PUBLIC Luau.Common)
target_include_directories(isocline PUBLIC extern/isocline/include)
@ -73,6 +84,12 @@ else()
list(APPEND LUAU_OPTIONS -Wall) # All warnings
endif()
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
# Some gcc versions treat var in `if (type var = val)` as unused
# Some gcc versions treat variables used in constexpr if blocks as unused
list(APPEND LUAU_OPTIONS -Wno-unused)
endif()
# Enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere
if(LUAU_WERROR)
if(MSVC)
@ -95,19 +112,35 @@ endif()
target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS})
target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS})
target_compile_options(Luau.CodeGen PRIVATE ${LUAU_OPTIONS})
target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS})
target_compile_options(isocline PRIVATE ${LUAU_OPTIONS} ${ISOCLINE_OPTIONS})
if(LUAU_EXTERN_C)
# enable extern "C" for VM (lua.h, lualib.h) and Compiler (luacode.h) to make Luau friendlier to use from non-C++ languages
# note that we enable LUA_USE_LONGJMP=1 as well; otherwise functions like luaL_error will throw C++ exceptions, which can't be done from extern "C" functions
target_compile_definitions(Luau.VM PUBLIC LUA_USE_LONGJMP=1)
target_compile_definitions(Luau.VM PUBLIC LUA_API=extern\"C\")
target_compile_definitions(Luau.Compiler PUBLIC LUACODE_API=extern\"C\")
endif()
if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924)
# disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022:
# https://developercommunity.visualstudio.com/t/performance-regression-on-a-complex-interpreter-lo/1631863
set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-)
endif()
if(MSVC AND LUAU_BUILD_CLI)
# the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger
set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152)
set_target_properties(Luau.Repl.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152)
endif()
# embed .natvis inside the library debug information
if(MSVC)
target_link_options(Luau.Ast INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Ast.natvis)
target_link_options(Luau.Analysis INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Analysis.natvis)
target_link_options(Luau.CodeGen INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/CodeGen.natvis)
target_link_options(Luau.VM INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/VM.natvis)
endif()
@ -115,6 +148,7 @@ endif()
if(MSVC_IDE)
target_sources(Luau.Ast PRIVATE tools/natvis/Ast.natvis)
target_sources(Luau.Analysis PRIVATE tools/natvis/Analysis.natvis)
target_sources(Luau.CodeGen PRIVATE tools/natvis/CodeGen.natvis)
target_sources(Luau.VM PRIVATE tools/natvis/VM.natvis)
endif()
@ -142,7 +176,7 @@ endif()
if(LUAU_BUILD_TESTS)
target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS})
target_include_directories(Luau.UnitTest PRIVATE extern)
target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler)
target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen)
target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS})
target_include_directories(Luau.Conformance PRIVATE extern)

View file

@ -0,0 +1,169 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include "Luau/Condition.h"
#include "Luau/Label.h"
#include "Luau/OperandX64.h"
#include "Luau/RegisterX64.h"
#include <string>
#include <vector>
namespace Luau
{
namespace CodeGen
{
class AssemblyBuilderX64
{
public:
explicit AssemblyBuilderX64(bool logText);
~AssemblyBuilderX64();
// Base two operand instructions with 9 opcode selection
void add(OperandX64 lhs, OperandX64 rhs);
void sub(OperandX64 lhs, OperandX64 rhs);
void cmp(OperandX64 lhs, OperandX64 rhs);
void and_(OperandX64 lhs, OperandX64 rhs);
void or_(OperandX64 lhs, OperandX64 rhs);
void xor_(OperandX64 lhs, OperandX64 rhs);
// Binary shift instructions with special rhs handling
void sal(OperandX64 lhs, OperandX64 rhs);
void sar(OperandX64 lhs, OperandX64 rhs);
void shl(OperandX64 lhs, OperandX64 rhs);
void shr(OperandX64 lhs, OperandX64 rhs);
// Two operand mov instruction has additional specialized encodings
void mov(OperandX64 lhs, OperandX64 rhs);
void mov64(RegisterX64 lhs, int64_t imm);
// Base one operand instruction with 2 opcode selection
void div(OperandX64 op);
void idiv(OperandX64 op);
void mul(OperandX64 op);
void neg(OperandX64 op);
void not_(OperandX64 op);
void test(OperandX64 lhs, OperandX64 rhs);
void lea(OperandX64 lhs, OperandX64 rhs);
void push(OperandX64 op);
void pop(OperandX64 op);
void ret();
// Control flow
void jcc(Condition cond, Label& label);
void jmp(Label& label);
void jmp(OperandX64 op);
// AVX
void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vaddsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vaddss(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vsqrtpd(OperandX64 dst, OperandX64 src);
void vsqrtps(OperandX64 dst, OperandX64 src);
void vsqrtsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vsqrtss(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vmovsd(OperandX64 dst, OperandX64 src);
void vmovsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vmovss(OperandX64 dst, OperandX64 src);
void vmovss(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vmovapd(OperandX64 dst, OperandX64 src);
void vmovaps(OperandX64 dst, OperandX64 src);
void vmovupd(OperandX64 dst, OperandX64 src);
void vmovups(OperandX64 dst, OperandX64 src);
// Run final checks
void finalize();
// Places a label at current location and returns it
Label setLabel();
// Assigns label position to the current location
void setLabel(Label& label);
// Constant allocation (uses rip-relative addressing)
OperandX64 i64(int64_t value);
OperandX64 f32(float value);
OperandX64 f64(double value);
OperandX64 f32x4(float x, float y, float z, float w);
// Resulting data and code that need to be copied over one after the other
// The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code'
std::vector<uint8_t> data;
std::vector<uint8_t> code;
std::string text;
private:
// Instruction archetypes
void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev,
uint8_t coderev, uint8_t code8, uint8_t code, uint8_t opreg);
void placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code, uint8_t codeImm8, uint8_t opreg);
void placeBinaryRegAndRegMem(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code);
void placeBinaryRegMemAndReg(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code);
void placeUnaryModRegMem(const char* name, OperandX64 op, uint8_t code8, uint8_t code, uint8_t opreg);
void placeShift(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t opreg);
void placeJcc(const char* name, Label& label, uint8_t cc);
void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, bool setW, uint8_t mode, uint8_t prefix);
void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, uint8_t coderev, bool setW, uint8_t mode, uint8_t prefix);
void placeAvx(const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t code, bool setW, uint8_t mode, uint8_t prefix);
// Instruction components
void placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs);
void placeModRegMem(OperandX64 rhs, uint8_t regop);
void placeRex(RegisterX64 op);
void placeRex(OperandX64 op);
void placeRex(RegisterX64 lhs, OperandX64 rhs);
void placeVex(OperandX64 dst, OperandX64 src1, OperandX64 src2, bool setW, uint8_t mode, uint8_t prefix);
void placeImm8Or32(int32_t imm);
void placeImm8(int32_t imm);
void placeImm32(int32_t imm);
void placeImm64(int64_t imm);
void placeLabel(Label& label);
void place(uint8_t byte);
void commit();
LUAU_NOINLINE void extend();
uint32_t getCodeSize();
// Data
size_t allocateData(size_t size, size_t align);
// Logging of assembly in text form (Intel asm with VS disassembly formatting)
LUAU_NOINLINE void log(const char* opcode);
LUAU_NOINLINE void log(const char* opcode, OperandX64 op);
LUAU_NOINLINE void log(const char* opcode, OperandX64 op1, OperandX64 op2);
LUAU_NOINLINE void log(const char* opcode, OperandX64 op1, OperandX64 op2, OperandX64 op3);
LUAU_NOINLINE void log(Label label);
LUAU_NOINLINE void log(const char* opcode, Label label);
void log(OperandX64 op);
void logAppend(const char* fmt, ...);
const char* getSizeName(SizeX64 size);
const char* getRegisterName(RegisterX64 reg);
uint32_t nextLabel = 1;
std::vector<Label> pendingLabels;
std::vector<uint32_t> labelLocations;
bool logText = false;
bool finalized = false;
size_t dataPos = 0;
uint8_t* codePos = nullptr;
uint8_t* codeEnd = nullptr;
};
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,46 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
namespace Luau
{
namespace CodeGen
{
enum class Condition
{
Overflow,
NoOverflow,
Carry,
NoCarry,
Below,
BelowEqual,
Above,
AboveEqual,
Equal,
Less,
LessEqual,
Greater,
GreaterEqual,
NotBelow,
NotBelowEqual,
NotAbove,
NotAboveEqual,
NotEqual,
NotLess,
NotLessEqual,
NotGreater,
NotGreaterEqual,
Zero,
NotZero,
// TODO: ordered and unordered floating-point conditions
Count
};
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,18 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <stdint.h>
namespace Luau
{
namespace CodeGen
{
struct Label
{
uint32_t id = 0;
uint32_t location = ~0u;
};
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,136 @@
#pragma once
#include "Luau/Common.h"
#include "Luau/RegisterX64.h"
#include <stdint.h>
namespace Luau
{
namespace CodeGen
{
enum class CategoryX64 : uint8_t
{
reg,
mem,
imm,
};
struct OperandX64
{
constexpr OperandX64(RegisterX64 reg)
: cat(CategoryX64::reg)
, index(noreg)
, base(reg)
, memSize(SizeX64::none)
, scale(1)
, imm(0)
{
}
constexpr OperandX64(int32_t imm)
: cat(CategoryX64::imm)
, index(noreg)
, base(noreg)
, memSize(SizeX64::none)
, scale(1)
, imm(imm)
{
}
constexpr explicit OperandX64(SizeX64 size, RegisterX64 index, uint8_t scale, RegisterX64 base, int32_t disp)
: cat(CategoryX64::mem)
, index(index)
, base(base)
, memSize(size)
, scale(scale)
, imm(disp)
{
}
// Fields are carefully placed to make this struct fit into an 8 byte register
CategoryX64 cat;
RegisterX64 index;
RegisterX64 base;
SizeX64 memSize : 4;
uint8_t scale : 4;
int32_t imm;
constexpr OperandX64 operator[](OperandX64&& addr) const
{
LUAU_ASSERT(cat == CategoryX64::mem);
LUAU_ASSERT(memSize != SizeX64::none && index == noreg && scale == 1 && base == noreg && imm == 0);
LUAU_ASSERT(addr.memSize == SizeX64::none);
addr.cat = CategoryX64::mem;
addr.memSize = memSize;
return addr;
}
};
constexpr OperandX64 byte{SizeX64::byte, noreg, 1, noreg, 0};
constexpr OperandX64 word{SizeX64::word, noreg, 1, noreg, 0};
constexpr OperandX64 dword{SizeX64::dword, noreg, 1, noreg, 0};
constexpr OperandX64 qword{SizeX64::qword, noreg, 1, noreg, 0};
constexpr OperandX64 xmmword{SizeX64::xmmword, noreg, 1, noreg, 0};
constexpr OperandX64 ymmword{SizeX64::ymmword, noreg, 1, noreg, 0};
constexpr OperandX64 ptr{sizeof(void*) == 4 ? SizeX64::dword : SizeX64::qword, noreg, 1, noreg, 0};
constexpr OperandX64 operator*(RegisterX64 reg, uint8_t scale)
{
if (scale == 1)
return OperandX64(reg);
LUAU_ASSERT(scale == 1 || scale == 2 || scale == 4 || scale == 8);
LUAU_ASSERT(reg.index != 0b100 && "can't scale SP");
return OperandX64(SizeX64::none, reg, scale, noreg, 0);
}
constexpr OperandX64 operator+(RegisterX64 reg, int32_t disp)
{
return OperandX64(SizeX64::none, noreg, 1, reg, disp);
}
constexpr OperandX64 operator+(RegisterX64 base, RegisterX64 index)
{
LUAU_ASSERT(index.index != 4 && "sp cannot be used as index");
LUAU_ASSERT(base.size == index.size);
return OperandX64(SizeX64::none, index, 1, base, 0);
}
constexpr OperandX64 operator+(OperandX64 op, int32_t disp)
{
LUAU_ASSERT(op.cat == CategoryX64::mem);
LUAU_ASSERT(op.memSize == SizeX64::none);
op.imm += disp;
return op;
}
constexpr OperandX64 operator+(OperandX64 op, RegisterX64 base)
{
LUAU_ASSERT(op.cat == CategoryX64::mem);
LUAU_ASSERT(op.memSize == SizeX64::none);
LUAU_ASSERT(op.base == noreg);
LUAU_ASSERT(op.index == noreg || op.index.size == base.size);
op.base = base;
return op;
}
constexpr OperandX64 operator+(RegisterX64 base, OperandX64 op)
{
LUAU_ASSERT(op.cat == CategoryX64::mem);
LUAU_ASSERT(op.memSize == SizeX64::none);
LUAU_ASSERT(op.base == noreg);
LUAU_ASSERT(op.index == noreg || op.index.size == base.size);
op.base = base;
return op;
}
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,116 @@
#pragma once
#include "Luau/Common.h"
#include <stdint.h>
namespace Luau
{
namespace CodeGen
{
enum class SizeX64 : uint8_t
{
none,
byte,
word,
dword,
qword,
xmmword,
ymmword,
};
struct RegisterX64
{
SizeX64 size : 3;
uint8_t index : 5;
constexpr bool operator==(RegisterX64 rhs) const
{
return size == rhs.size && index == rhs.index;
}
constexpr bool operator!=(RegisterX64 rhs) const
{
return !(*this == rhs);
}
};
constexpr RegisterX64 noreg{SizeX64::none, 16};
constexpr RegisterX64 rip{SizeX64::none, 0};
constexpr RegisterX64 al{SizeX64::byte, 0};
constexpr RegisterX64 cl{SizeX64::byte, 1};
constexpr RegisterX64 dl{SizeX64::byte, 2};
constexpr RegisterX64 bl{SizeX64::byte, 3};
constexpr RegisterX64 eax{SizeX64::dword, 0};
constexpr RegisterX64 ecx{SizeX64::dword, 1};
constexpr RegisterX64 edx{SizeX64::dword, 2};
constexpr RegisterX64 ebx{SizeX64::dword, 3};
constexpr RegisterX64 esp{SizeX64::dword, 4};
constexpr RegisterX64 ebp{SizeX64::dword, 5};
constexpr RegisterX64 esi{SizeX64::dword, 6};
constexpr RegisterX64 edi{SizeX64::dword, 7};
constexpr RegisterX64 r8d{SizeX64::dword, 8};
constexpr RegisterX64 r9d{SizeX64::dword, 9};
constexpr RegisterX64 r10d{SizeX64::dword, 10};
constexpr RegisterX64 r11d{SizeX64::dword, 11};
constexpr RegisterX64 r12d{SizeX64::dword, 12};
constexpr RegisterX64 r13d{SizeX64::dword, 13};
constexpr RegisterX64 r14d{SizeX64::dword, 14};
constexpr RegisterX64 r15d{SizeX64::dword, 15};
constexpr RegisterX64 rax{SizeX64::qword, 0};
constexpr RegisterX64 rcx{SizeX64::qword, 1};
constexpr RegisterX64 rdx{SizeX64::qword, 2};
constexpr RegisterX64 rbx{SizeX64::qword, 3};
constexpr RegisterX64 rsp{SizeX64::qword, 4};
constexpr RegisterX64 rbp{SizeX64::qword, 5};
constexpr RegisterX64 rsi{SizeX64::qword, 6};
constexpr RegisterX64 rdi{SizeX64::qword, 7};
constexpr RegisterX64 r8{SizeX64::qword, 8};
constexpr RegisterX64 r9{SizeX64::qword, 9};
constexpr RegisterX64 r10{SizeX64::qword, 10};
constexpr RegisterX64 r11{SizeX64::qword, 11};
constexpr RegisterX64 r12{SizeX64::qword, 12};
constexpr RegisterX64 r13{SizeX64::qword, 13};
constexpr RegisterX64 r14{SizeX64::qword, 14};
constexpr RegisterX64 r15{SizeX64::qword, 15};
constexpr RegisterX64 xmm0{SizeX64::xmmword, 0};
constexpr RegisterX64 xmm1{SizeX64::xmmword, 1};
constexpr RegisterX64 xmm2{SizeX64::xmmword, 2};
constexpr RegisterX64 xmm3{SizeX64::xmmword, 3};
constexpr RegisterX64 xmm4{SizeX64::xmmword, 4};
constexpr RegisterX64 xmm5{SizeX64::xmmword, 5};
constexpr RegisterX64 xmm6{SizeX64::xmmword, 6};
constexpr RegisterX64 xmm7{SizeX64::xmmword, 7};
constexpr RegisterX64 xmm8{SizeX64::xmmword, 8};
constexpr RegisterX64 xmm9{SizeX64::xmmword, 9};
constexpr RegisterX64 xmm10{SizeX64::xmmword, 10};
constexpr RegisterX64 xmm11{SizeX64::xmmword, 11};
constexpr RegisterX64 xmm12{SizeX64::xmmword, 12};
constexpr RegisterX64 xmm13{SizeX64::xmmword, 13};
constexpr RegisterX64 xmm14{SizeX64::xmmword, 14};
constexpr RegisterX64 xmm15{SizeX64::xmmword, 15};
constexpr RegisterX64 ymm0{SizeX64::ymmword, 0};
constexpr RegisterX64 ymm1{SizeX64::ymmword, 1};
constexpr RegisterX64 ymm2{SizeX64::ymmword, 2};
constexpr RegisterX64 ymm3{SizeX64::ymmword, 3};
constexpr RegisterX64 ymm4{SizeX64::ymmword, 4};
constexpr RegisterX64 ymm5{SizeX64::ymmword, 5};
constexpr RegisterX64 ymm6{SizeX64::ymmword, 6};
constexpr RegisterX64 ymm7{SizeX64::ymmword, 7};
constexpr RegisterX64 ymm8{SizeX64::ymmword, 8};
constexpr RegisterX64 ymm9{SizeX64::ymmword, 9};
constexpr RegisterX64 ymm10{SizeX64::ymmword, 10};
constexpr RegisterX64 ymm11{SizeX64::ymmword, 11};
constexpr RegisterX64 ymm12{SizeX64::ymmword, 12};
constexpr RegisterX64 ymm13{SizeX64::ymmword, 13};
constexpr RegisterX64 ymm14{SizeX64::ymmword, 14};
constexpr RegisterX64 ymm15{SizeX64::ymmword, 15};
} // namespace CodeGen
} // namespace Luau

File diff suppressed because it is too large Load diff

View file

@ -7,7 +7,7 @@
// Creating the bytecode is outside the scope of this file and is handled by bytecode builder (BytecodeBuilder.h) and bytecode compiler (Compiler.h)
// Note that ALL enums declared in this file are order-sensitive since the values are baked into bytecode that needs to be processed by legacy clients.
// Bytecode definitions
// # Bytecode definitions
// Bytecode instructions are using "word code" - each instruction is one or many 32-bit words.
// The first word in the instruction is always the instruction header, and *must* contain the opcode (enum below) in the least significant byte.
//
@ -19,7 +19,7 @@
// Instruction word is sometimes followed by one extra word, indicated as AUX - this is just a 32-bit word and is decoded according to the specification for each opcode.
// For each opcode the encoding is *static* - that is, based on the opcode you know a-priory how large the instruction is, with the exception of NEWCLOSURE
// Bytecode indices
// # Bytecode indices
// Bytecode instructions commonly refer to integer values that define offsets or indices for various entities. For each type, there's a maximum encodable value.
// Note that in some cases, the compiler will set a lower limit than the maximum encodable value is to prevent fragile code into bumping against the limits whenever we change the compilation details.
// Additionally, in some specific instructions such as ANDK, the limit on the encoded value is smaller; this means that if a value is larger, a different instruction must be selected.
@ -29,6 +29,15 @@
// Constants: 0-2^23-1. Constants are stored in a table allocated with each proto; to allow for future bytecode tweaks the encodable value is limited to 23 bits.
// Closures: 0-2^15-1. Closures are created from child protos via a child index; the limit is for the number of closures immediately referenced in each function.
// Jumps: -2^23..2^23. Jump offsets are specified in word increments, so jumping over an instruction may sometimes require an offset of 2 or more.
// # Bytecode versions
// Bytecode serialized format embeds a version number, that dictates both the serialized form as well as the allowed instructions. As long as the bytecode version falls into supported
// range (indicated by LBC_BYTECODE_MIN / LBC_BYTECODE_MAX) and was produced by Luau compiler, it should load and execute correctly.
//
// Note that Luau runtime doesn't provide indefinite bytecode compatibility: support for older versions gets removed over time. As such, bytecode isn't a durable storage format and it's expected
// that Luau users can recompile bytecode from source on Luau version upgrades if necessary.
// Bytecode opcode, part of the instruction header
enum LuauOpcode
{
// NOP: noop
@ -353,6 +362,11 @@ enum LuauOpcode
// AUX: constant index
LOP_FASTCALL2K,
// FORGPREP: prepare loop variables for a generic for loop, jump to the loop backedge unconditionally
// A: target register; generic for loops assume a register layout [generator, state, index, variables...]
// D: jump offset (-32768..32767)
LOP_FORGPREP,
// Enum entry for number of opcodes, not a valid opcode by itself!
LOP__COUNT
};
@ -375,9 +389,10 @@ enum LuauOpcode
// Bytecode tags, used internally for bytecode encoded as a string
enum LuauBytecodeTag
{
// Bytecode version
LBC_VERSION = 1,
LBC_VERSION_FUTURE = 2, // TODO: This will be removed in favor of LBC_VERSION with LuauBytecodeV2Force
// Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled
LBC_VERSION_MIN = 2,
LBC_VERSION_MAX = 2,
LBC_VERSION_TARGET = 2,
// Types of constant table entries
LBC_CONSTANT_NIL = 0,
LBC_CONSTANT_BOOLEAN,

View file

@ -3,6 +3,7 @@
#include "Luau/Bytecode.h"
#include "Luau/DenseHash.h"
#include "Luau/StringUtils.h"
#include <string>
@ -80,6 +81,8 @@ public:
void pushDebugUpval(StringRef name);
uint32_t getDebugPC() const;
void addDebugRemark(const char* format, ...) LUAU_PRINTF_ATTR(2, 3);
void finalize();
enum DumpFlags
@ -88,6 +91,7 @@ public:
Dump_Lines = 1 << 1,
Dump_Source = 1 << 2,
Dump_Locals = 1 << 3,
Dump_Remarks = 1 << 4,
};
void setDumpFlags(uint32_t flags)
@ -115,6 +119,8 @@ public:
static std::string getError(const std::string& message);
static uint8_t getVersion();
private:
struct Constant
{
@ -220,6 +226,7 @@ private:
DenseHashMap<ConstantKey, int32_t, ConstantKeyHash> constantMap;
DenseHashMap<TableShape, int32_t, TableShapeHash> tableShapeMap;
DenseHashMap<uint32_t, int16_t> protoMap;
int debugLine = 0;
@ -228,6 +235,9 @@ private:
DenseHashMap<StringRef, unsigned int, StringRefHash> stringTable;
std::vector<std::pair<uint32_t, uint32_t>> debugRemarks;
std::string debugRemarkBuffer;
BytecodeEncoder* encoder = nullptr;
std::string bytecode;
@ -239,7 +249,7 @@ private:
void validate() const;
std::string dumpCurrentFunction() const;
const uint32_t* dumpInstruction(const uint32_t* opcode, std::string& output) const;
void dumpInstruction(const uint32_t* opcode, std::string& output, int targetLabel) const;
void writeFunction(std::string& ss, uint32_t id) const;
void writeLineInfo(std::string& ss) const;

View file

@ -9,6 +9,9 @@
namespace Luau
{
static_assert(LBC_VERSION_TARGET >= LBC_VERSION_MIN && LBC_VERSION_TARGET <= LBC_VERSION_MAX, "Invalid bytecode version setup");
static_assert(LBC_VERSION_MAX <= 127, "Bytecode version should be 7-bit so that we can extend the serialization to use varint transparently");
static const uint32_t kMaxConstantCount = 1 << 23;
static const uint32_t kMaxClosureCount = 1 << 15;
@ -96,6 +99,7 @@ inline bool isJumpD(LuauOpcode op)
case LOP_JUMPIFNOTLT:
case LOP_FORNPREP:
case LOP_FORNLOOP:
case LOP_FORGPREP:
case LOP_FORGLOOP:
case LOP_FORGPREP_INEXT:
case LOP_FORGLOOP_INEXT:
@ -127,6 +131,20 @@ inline bool isSkipC(LuauOpcode op)
}
}
static int getJumpTarget(uint32_t insn, uint32_t pc)
{
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn));
if (isJumpD(op))
return int(pc + LUAU_INSN_D(insn) + 1);
else if (isSkipC(op) && LUAU_INSN_C(insn))
return int(pc + LUAU_INSN_C(insn) + 1);
else if (op == LOP_JUMPX)
return int(pc + LUAU_INSN_E(insn) + 1);
else
return -1;
}
bool BytecodeBuilder::StringRef::operator==(const StringRef& other) const
{
return (data && other.data) ? (length == other.length && memcmp(data, other.data, length) == 0) : (data == other.data);
@ -180,10 +198,18 @@ size_t BytecodeBuilder::TableShapeHash::operator()(const TableShape& v) const
BytecodeBuilder::BytecodeBuilder(BytecodeEncoder* encoder)
: constantMap({Constant::Type_Nil, ~0ull})
, tableShapeMap(TableShape())
, protoMap(~0u)
, stringTable({nullptr, 0})
, encoder(encoder)
{
LUAU_ASSERT(stringTable.find(StringRef{"", 0}) == nullptr);
// preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays
insns.reserve(32);
lines.reserve(32);
constants.reserve(16);
protos.reserve(16);
functions.reserve(8);
}
uint32_t BytecodeBuilder::beginFunction(uint8_t numparams, bool isvararg)
@ -219,8 +245,8 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues)
validate();
#endif
// very approximate: 4 bytes per instruction for code, 1 byte for debug line, and 1-2 bytes for aux data like constants
func.data.reserve(insns.size() * 7);
// very approximate: 4 bytes per instruction for code, 1 byte for debug line, and 1-2 bytes for aux data like constants plus overhead
func.data.reserve(32 + insns.size() * 7);
writeFunction(func.data, currentFunction);
@ -242,10 +268,16 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues)
constantMap.clear();
tableShapeMap.clear();
protoMap.clear();
debugRemarks.clear();
debugRemarkBuffer.clear();
}
void BytecodeBuilder::setMainFunction(uint32_t fid)
{
LUAU_ASSERT(fid < functions.size());
mainFunction = fid;
}
@ -359,11 +391,15 @@ int32_t BytecodeBuilder::addConstantClosure(uint32_t fid)
int16_t BytecodeBuilder::addChildFunction(uint32_t fid)
{
if (int16_t* cache = protoMap.find(fid))
return *cache;
uint32_t id = uint32_t(protos.size());
if (id >= kMaxClosureCount)
return -1;
protoMap[fid] = int16_t(id);
protos.push_back(fid);
return int16_t(id);
@ -505,10 +541,44 @@ uint32_t BytecodeBuilder::getDebugPC() const
return uint32_t(insns.size());
}
void BytecodeBuilder::addDebugRemark(const char* format, ...)
{
if ((dumpFlags & Dump_Remarks) == 0)
return;
size_t offset = debugRemarkBuffer.size();
va_list args;
va_start(args, format);
vformatAppend(debugRemarkBuffer, format, args);
va_end(args);
// we null-terminate all remarks to avoid storing remark length
debugRemarkBuffer += '\0';
debugRemarks.emplace_back(uint32_t(insns.size()), uint32_t(offset));
}
void BytecodeBuilder::finalize()
{
LUAU_ASSERT(bytecode.empty());
bytecode = char(LBC_VERSION_FUTURE);
// preallocate space for bytecode blob
size_t capacity = 16;
for (auto& p : stringTable)
capacity += p.first.length + 2;
for (const Function& func : functions)
capacity += func.data.size();
bytecode.reserve(capacity);
// assemble final bytecode blob
uint8_t version = getVersion();
LUAU_ASSERT(version >= LBC_VERSION_MIN && version <= LBC_VERSION_MAX);
bytecode = char(version);
writeStringTable(bytecode);
@ -663,6 +733,8 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const
void BytecodeBuilder::writeLineInfo(std::string& ss) const
{
LUAU_ASSERT(!lines.empty());
// this function encodes lines inside each span as a 8-bit delta to span baseline
// span is always a power of two; depending on the line info input, it may need to be as low as 1
int span = 1 << 24;
@ -693,7 +765,17 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const
}
// second pass: compute span base
std::vector<int> baseline((lines.size() - 1) / span + 1);
int baselineOne = 0;
std::vector<int> baselineScratch;
int* baseline = &baselineOne;
size_t baselineSize = (lines.size() - 1) / span + 1;
if (baselineSize > 1)
{
// avoid heap allocation for single-element baseline which is most functions (<256 lines)
baselineScratch.resize(baselineSize);
baseline = baselineScratch.data();
}
for (size_t offset = 0; offset < lines.size(); offset += span)
{
@ -725,7 +807,7 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const
int lastLine = 0;
for (size_t i = 0; i < baseline.size(); ++i)
for (size_t i = 0; i < baselineSize; ++i)
{
writeInt(ss, baseline[i] - lastLine);
lastLine = baseline[i];
@ -964,7 +1046,7 @@ void BytecodeBuilder::expandJumps()
std::string BytecodeBuilder::getError(const std::string& message)
{
// 0 acts as a special marker for error bytecode (it's equal to LBC_VERSION for valid bytecode blobs)
// 0 acts as a special marker for error bytecode (it's equal to LBC_VERSION_TARGET for valid bytecode blobs)
std::string result;
result += char(0);
result += message;
@ -972,6 +1054,12 @@ std::string BytecodeBuilder::getError(const std::string& message)
return result;
}
uint8_t BytecodeBuilder::getVersion()
{
// This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags
return LBC_VERSION_TARGET;
}
#ifdef LUAU_ASSERTENABLED
void BytecodeBuilder::validate() const
{
@ -999,6 +1087,8 @@ void BytecodeBuilder::validate() const
LUAU_ASSERT(i <= insns.size());
}
std::vector<uint8_t> openCaptures;
// second pass: validate the rest of the bytecode
for (size_t i = 0; i < insns.size();)
{
@ -1045,6 +1135,8 @@ void BytecodeBuilder::validate() const
case LOP_CLOSEUPVALS:
VREG(LUAU_INSN_A(insn));
while (openCaptures.size() && openCaptures.back() >= LUAU_INSN_A(insn))
openCaptures.pop_back();
break;
case LOP_GETIMPORT:
@ -1214,6 +1306,11 @@ void BytecodeBuilder::validate() const
VJUMP(LUAU_INSN_D(insn));
break;
case LOP_FORGPREP:
VREG(LUAU_INSN_A(insn) + 2 + 1); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables
VJUMP(LUAU_INSN_D(insn));
break;
case LOP_FORGLOOP:
VREG(
LUAU_INSN_A(insn) + 2 + insns[i + 1]); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables
@ -1307,8 +1404,12 @@ void BytecodeBuilder::validate() const
switch (LUAU_INSN_A(insn))
{
case LCT_VAL:
VREG(LUAU_INSN_B(insn));
break;
case LCT_REF:
VREG(LUAU_INSN_B(insn));
openCaptures.push_back(LUAU_INSN_B(insn));
break;
case LCT_UPVAL:
@ -1328,6 +1429,12 @@ void BytecodeBuilder::validate() const
LUAU_ASSERT(i <= insns.size());
}
// all CAPTURE REF instructions must have a CLOSEUPVALS instruction after them in the bytecode stream
// this doesn't guarantee safety as it doesn't perform basic block based analysis, but if this fails
// then the bytecode is definitely unsafe to run since the compiler won't generate backwards branches
// except for loop edges
LUAU_ASSERT(openCaptures.empty());
#undef VREG
#undef VREGEND
#undef VUPVAL
@ -1337,7 +1444,7 @@ void BytecodeBuilder::validate() const
}
#endif
const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result) const
void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, int targetLabel) const
{
uint32_t insn = *code++;
@ -1432,39 +1539,39 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri
break;
case LOP_JUMP:
formatAppend(result, "JUMP %+d\n", LUAU_INSN_D(insn));
formatAppend(result, "JUMP L%d\n", targetLabel);
break;
case LOP_JUMPIF:
formatAppend(result, "JUMPIF R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn));
formatAppend(result, "JUMPIF R%d L%d\n", LUAU_INSN_A(insn), targetLabel);
break;
case LOP_JUMPIFNOT:
formatAppend(result, "JUMPIFNOT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn));
formatAppend(result, "JUMPIFNOT R%d L%d\n", LUAU_INSN_A(insn), targetLabel);
break;
case LOP_JUMPIFEQ:
formatAppend(result, "JUMPIFEQ R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn));
formatAppend(result, "JUMPIFEQ R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel);
break;
case LOP_JUMPIFLE:
formatAppend(result, "JUMPIFLE R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn));
formatAppend(result, "JUMPIFLE R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel);
break;
case LOP_JUMPIFLT:
formatAppend(result, "JUMPIFLT R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn));
formatAppend(result, "JUMPIFLT R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel);
break;
case LOP_JUMPIFNOTEQ:
formatAppend(result, "JUMPIFNOTEQ R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn));
formatAppend(result, "JUMPIFNOTEQ R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel);
break;
case LOP_JUMPIFNOTLE:
formatAppend(result, "JUMPIFNOTLE R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn));
formatAppend(result, "JUMPIFNOTLE R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel);
break;
case LOP_JUMPIFNOTLT:
formatAppend(result, "JUMPIFNOTLT R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn));
formatAppend(result, "JUMPIFNOTLT R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel);
break;
case LOP_ADD:
@ -1560,31 +1667,35 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri
break;
case LOP_FORNPREP:
formatAppend(result, "FORNPREP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn));
formatAppend(result, "FORNPREP R%d L%d\n", LUAU_INSN_A(insn), targetLabel);
break;
case LOP_FORNLOOP:
formatAppend(result, "FORNLOOP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn));
formatAppend(result, "FORNLOOP R%d L%d\n", LUAU_INSN_A(insn), targetLabel);
break;
case LOP_FORGPREP:
formatAppend(result, "FORGPREP R%d L%d\n", LUAU_INSN_A(insn), targetLabel);
break;
case LOP_FORGLOOP:
formatAppend(result, "FORGLOOP R%d %+d %d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn), *code++);
formatAppend(result, "FORGLOOP R%d L%d %d\n", LUAU_INSN_A(insn), targetLabel, *code++);
break;
case LOP_FORGPREP_INEXT:
formatAppend(result, "FORGPREP_INEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn));
formatAppend(result, "FORGPREP_INEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel);
break;
case LOP_FORGLOOP_INEXT:
formatAppend(result, "FORGLOOP_INEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn));
formatAppend(result, "FORGLOOP_INEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel);
break;
case LOP_FORGPREP_NEXT:
formatAppend(result, "FORGPREP_NEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn));
formatAppend(result, "FORGPREP_NEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel);
break;
case LOP_FORGLOOP_NEXT:
formatAppend(result, "FORGLOOP_NEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn));
formatAppend(result, "FORGLOOP_NEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel);
break;
case LOP_GETVARARGS:
@ -1600,7 +1711,7 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri
break;
case LOP_JUMPBACK:
formatAppend(result, "JUMPBACK %+d\n", LUAU_INSN_D(insn));
formatAppend(result, "JUMPBACK L%d\n", targetLabel);
break;
case LOP_LOADKX:
@ -1608,26 +1719,26 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri
break;
case LOP_JUMPX:
formatAppend(result, "JUMPX %+d\n", LUAU_INSN_E(insn));
formatAppend(result, "JUMPX L%d\n", targetLabel);
break;
case LOP_FASTCALL:
formatAppend(result, "FASTCALL %d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_C(insn));
formatAppend(result, "FASTCALL %d L%d\n", LUAU_INSN_A(insn), targetLabel);
break;
case LOP_FASTCALL1:
formatAppend(result, "FASTCALL1 %d R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn));
formatAppend(result, "FASTCALL1 %d R%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), targetLabel);
break;
case LOP_FASTCALL2:
{
uint32_t aux = *code++;
formatAppend(result, "FASTCALL2 %d R%d R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, LUAU_INSN_C(insn));
formatAppend(result, "FASTCALL2 %d R%d R%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, targetLabel);
break;
}
case LOP_FASTCALL2K:
{
uint32_t aux = *code++;
formatAppend(result, "FASTCALL2K %d R%d K%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, LUAU_INSN_C(insn));
formatAppend(result, "FASTCALL2K %d R%d K%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, targetLabel);
break;
}
@ -1637,23 +1748,24 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri
case LOP_CAPTURE:
formatAppend(result, "CAPTURE %s %c%d\n",
LUAU_INSN_A(insn) == LCT_UPVAL ? "UPVAL" : LUAU_INSN_A(insn) == LCT_REF ? "REF" : LUAU_INSN_A(insn) == LCT_VAL ? "VAL" : "",
LUAU_INSN_A(insn) == LCT_UPVAL ? "UPVAL"
: LUAU_INSN_A(insn) == LCT_REF ? "REF"
: LUAU_INSN_A(insn) == LCT_VAL ? "VAL"
: "",
LUAU_INSN_A(insn) == LCT_UPVAL ? 'U' : 'R', LUAU_INSN_B(insn));
break;
case LOP_JUMPIFEQK:
formatAppend(result, "JUMPIFEQK R%d K%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn));
formatAppend(result, "JUMPIFEQK R%d K%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel);
break;
case LOP_JUMPIFNOTEQK:
formatAppend(result, "JUMPIFNOTEQK R%d K%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn));
formatAppend(result, "JUMPIFNOTEQK R%d K%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel);
break;
default:
LUAU_ASSERT(!"Unsupported opcode");
}
return code;
}
std::string BytecodeBuilder::dumpCurrentFunction() const
@ -1661,10 +1773,8 @@ std::string BytecodeBuilder::dumpCurrentFunction() const
if ((dumpFlags & Dump_Code) == 0)
return std::string();
const uint32_t* code = insns.data();
const uint32_t* codeEnd = insns.data() + insns.size();
int lastLine = -1;
size_t nextRemark = 0;
std::string result;
@ -1684,20 +1794,54 @@ std::string BytecodeBuilder::dumpCurrentFunction() const
}
}
while (code != codeEnd)
std::vector<int> labels(insns.size(), -1);
// annotate valid jump targets with 0
for (size_t i = 0; i < insns.size();)
{
int target = getJumpTarget(insns[i], uint32_t(i));
if (target >= 0)
{
LUAU_ASSERT(size_t(target) < insns.size());
labels[target] = 0;
}
i += getOpLength(LuauOpcode(LUAU_INSN_OP(insns[i])));
LUAU_ASSERT(i <= insns.size());
}
int nextLabel = 0;
// compute label ids (sequential integers for all jump targets)
for (size_t i = 0; i < labels.size(); ++i)
if (labels[i] == 0)
labels[i] = nextLabel++;
for (size_t i = 0; i < insns.size();)
{
const uint32_t* code = &insns[i];
uint8_t op = LUAU_INSN_OP(*code);
if (op == LOP_PREPVARARGS)
{
// Don't emit function header in bytecode - it's used for call dispatching and doesn't contain "interesting" information
code++;
i++;
continue;
}
if (dumpFlags & Dump_Remarks)
{
while (nextRemark < debugRemarks.size() && debugRemarks[nextRemark].first == i)
{
formatAppend(result, "REMARK %s\n", debugRemarkBuffer.c_str() + debugRemarks[nextRemark].second);
nextRemark++;
}
}
if (dumpFlags & Dump_Source)
{
int line = lines[code - insns.data()];
int line = lines[i];
if (line > 0 && line != lastLine)
{
@ -1708,11 +1852,17 @@ std::string BytecodeBuilder::dumpCurrentFunction() const
}
if (dumpFlags & Dump_Lines)
{
formatAppend(result, "%d: ", lines[code - insns.data()]);
}
formatAppend(result, "%d: ", lines[i]);
code = dumpInstruction(code, result);
if (labels[i] != -1)
formatAppend(result, "L%d: ", labels[i]);
int target = getJumpTarget(*code, uint32_t(i));
dumpInstruction(code, result, target >= 0 ? labels[target] : -1);
i += getOpLength(LuauOpcode(op));
LUAU_ASSERT(i <= insns.size());
}
return result;
@ -1722,11 +1872,11 @@ void BytecodeBuilder::setDumpSource(const std::string& source)
{
dumpSource.clear();
std::string::size_type pos = 0;
size_t pos = 0;
while (pos != std::string::npos)
{
std::string::size_type next = source.find('\n', pos);
size_t next = source.find('\n', pos);
if (next == std::string::npos)
{

File diff suppressed because it is too large Load diff

View file

@ -191,14 +191,18 @@ struct ConstantVisitor : AstVisitor
{
DenseHashMap<AstExpr*, Constant>& constants;
DenseHashMap<AstLocal*, Variable>& variables;
DenseHashMap<AstLocal*, Constant>& locals;
DenseHashMap<AstLocal*, Constant> locals;
bool wasEmpty = false;
ConstantVisitor(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables)
ConstantVisitor(
DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables, DenseHashMap<AstLocal*, Constant>& locals)
: constants(constants)
, variables(variables)
, locals(nullptr)
, locals(locals)
{
// since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries
wasEmpty = constants.empty() && locals.empty();
}
Constant analyze(AstExpr* node)
@ -290,7 +294,8 @@ struct ConstantVisitor : AstVisitor
Constant la = analyze(expr->left);
Constant ra = analyze(expr->right);
if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown)
// note: ra doesn't need to be constant to fold and/or
if (la.type != Constant::Type_Unknown)
foldBinary(result, expr->op, la, ra);
}
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
@ -313,12 +318,35 @@ struct ConstantVisitor : AstVisitor
LUAU_ASSERT(!"Unknown expression type");
}
if (result.type != Constant::Type_Unknown)
constants[node] = result;
recordConstant(constants, node, result);
return result;
}
template<typename T>
void recordConstant(DenseHashMap<T, Constant>& map, T key, const Constant& value)
{
if (value.type != Constant::Type_Unknown)
map[key] = value;
else if (wasEmpty)
;
else if (Constant* old = map.find(key))
old->type = Constant::Type_Unknown;
}
void recordValue(AstLocal* local, const Constant& value)
{
// note: we rely on trackValues to have been run before us
Variable* v = variables.find(local);
LUAU_ASSERT(v);
if (!v->written)
{
v->constant = (value.type != Constant::Type_Unknown);
recordConstant(locals, local, value);
}
}
bool visit(AstExpr* node) override
{
// note: we short-circuit the visitor traversal through any expression trees by returning false
@ -335,18 +363,7 @@ struct ConstantVisitor : AstVisitor
{
Constant arg = analyze(node->values.data[i]);
if (arg.type != Constant::Type_Unknown)
{
// note: we rely on trackValues to have been run before us
Variable* v = variables.find(node->vars.data[i]);
LUAU_ASSERT(v);
if (!v->written)
{
locals[node->vars.data[i]] = arg;
v->constant = true;
}
}
recordValue(node->vars.data[i], arg);
}
if (node->vars.size > node->values.size)
@ -360,15 +377,8 @@ struct ConstantVisitor : AstVisitor
{
for (size_t i = node->values.size; i < node->vars.size; ++i)
{
// note: we rely on trackValues to have been run before us
Variable* v = variables.find(node->vars.data[i]);
LUAU_ASSERT(v);
if (!v->written)
{
locals[node->vars.data[i]].type = Constant::Type_Nil;
v->constant = true;
}
Constant nil = {Constant::Type_Nil};
recordValue(node->vars.data[i], nil);
}
}
}
@ -384,9 +394,10 @@ struct ConstantVisitor : AstVisitor
}
};
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables, AstNode* root)
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, AstNode* root)
{
ConstantVisitor visitor{constants, variables};
ConstantVisitor visitor{constants, variables, locals};
root->visit(&visitor);
}

View file

@ -42,7 +42,8 @@ struct Constant
}
};
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables, AstNode* root);
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, AstNode* root);
} // namespace Compile
} // namespace Luau

372
Compiler/src/CostModel.cpp Normal file
View file

@ -0,0 +1,372 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "CostModel.h"
#include "Luau/Common.h"
#include "Luau/DenseHash.h"
#include <limits.h>
namespace Luau
{
namespace Compile
{
inline uint64_t parallelAddSat(uint64_t x, uint64_t y)
{
uint64_t r = x + y;
uint64_t s = r & 0x8080808080808080ull; // saturation mask
return (r ^ s) | (s - (s >> 7));
}
static uint64_t parallelMulSat(uint64_t a, int b)
{
int bs = (b < 127) ? b : 127;
// multiply every other value by b, yielding 14-bit products
uint64_t l = bs * ((a >> 0) & 0x007f007f007f007full);
uint64_t h = bs * ((a >> 8) & 0x007f007f007f007full);
// each product is 14-bit, so adding 32768-128 sets high bit iff the sum is 128 or larger without an overflow
uint64_t ls = l + 0x7f807f807f807f80ull;
uint64_t hs = h + 0x7f807f807f807f80ull;
// we now merge saturation bits as well as low 7-bits of each product into one
uint64_t s = (hs & 0x8000800080008000ull) | ((ls & 0x8000800080008000ull) >> 8);
uint64_t r = ((h & 0x007f007f007f007full) << 8) | (l & 0x007f007f007f007full);
// the low bits are now correct for values that didn't saturate, and we simply need to mask them if high bit is 1
return r | (s - (s >> 7));
}
inline bool getNumber(AstExpr* node, double& result)
{
// since constant model doesn't use constant folding atm, we perform the basic extraction that's sufficient to handle positive/negative literals
if (AstExprConstantNumber* ne = node->as<AstExprConstantNumber>())
{
result = ne->value;
return true;
}
if (AstExprUnary* ue = node->as<AstExprUnary>(); ue && ue->op == AstExprUnary::Minus)
if (AstExprConstantNumber* ne = ue->expr->as<AstExprConstantNumber>())
{
result = -ne->value;
return true;
}
return false;
}
struct Cost
{
static const uint64_t kLiteral = ~0ull;
// cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant
uint64_t model;
// constant mask: 8-byte 0xff mask; equal to all ff's for literals, for variables only byte #i (1+) is set to align with model
uint64_t constant;
Cost(int cost = 0, uint64_t constant = 0)
: model(cost < 0x7f ? cost : 0x7f)
, constant(constant)
{
}
Cost operator+(const Cost& other) const
{
Cost result;
result.model = parallelAddSat(model, other.model);
return result;
}
Cost& operator+=(const Cost& other)
{
model = parallelAddSat(model, other.model);
constant = 0;
return *this;
}
Cost operator*(int other) const
{
Cost result;
result.model = parallelMulSat(model, other);
return result;
}
static Cost fold(const Cost& x, const Cost& y)
{
uint64_t newmodel = parallelAddSat(x.model, y.model);
uint64_t newconstant = x.constant & y.constant;
// the extra cost for folding is 1; the discount is 1 for the variable that is shared by x&y (or whichever one is used in x/y if the other is
// literal)
uint64_t extra = (newconstant == kLiteral) ? 0 : (1 | (0x0101010101010101ull & newconstant));
Cost result;
result.model = parallelAddSat(newmodel, extra);
result.constant = newconstant;
return result;
}
};
struct CostVisitor : AstVisitor
{
DenseHashMap<AstLocal*, uint64_t> vars;
Cost result;
CostVisitor()
: vars(nullptr)
{
}
Cost model(AstExpr* node)
{
if (AstExprGroup* expr = node->as<AstExprGroup>())
{
return model(expr->expr);
}
else if (node->is<AstExprConstantNil>() || node->is<AstExprConstantBool>() || node->is<AstExprConstantNumber>() ||
node->is<AstExprConstantString>())
{
return Cost(0, Cost::kLiteral);
}
else if (AstExprLocal* expr = node->as<AstExprLocal>())
{
const uint64_t* i = vars.find(expr->local);
return Cost(0, i ? *i : 0); // locals typically don't require extra instructions to compute
}
else if (node->is<AstExprGlobal>())
{
return 1;
}
else if (node->is<AstExprVarargs>())
{
return 3;
}
else if (AstExprCall* expr = node->as<AstExprCall>())
{
Cost cost = 3;
cost += model(expr->func);
for (size_t i = 0; i < expr->args.size; ++i)
{
Cost ac = model(expr->args.data[i]);
// for constants/locals we still need to copy them to the argument list
cost += ac.model == 0 ? Cost(1) : ac;
}
return cost;
}
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
{
return model(expr->expr) + 1;
}
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
{
return model(expr->expr) + model(expr->index) + 1;
}
else if (AstExprFunction* expr = node->as<AstExprFunction>())
{
return 10; // high baseline cost due to allocation
}
else if (AstExprTable* expr = node->as<AstExprTable>())
{
Cost cost = 10; // high baseline cost due to allocation
for (size_t i = 0; i < expr->items.size; ++i)
{
const AstExprTable::Item& item = expr->items.data[i];
if (item.key)
cost += model(item.key);
cost += model(item.value);
cost += 1;
}
return cost;
}
else if (AstExprUnary* expr = node->as<AstExprUnary>())
{
return Cost::fold(model(expr->expr), Cost(0, Cost::kLiteral));
}
else if (AstExprBinary* expr = node->as<AstExprBinary>())
{
return Cost::fold(model(expr->left), model(expr->right));
}
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
{
return model(expr->expr);
}
else if (AstExprIfElse* expr = node->as<AstExprIfElse>())
{
return model(expr->condition) + model(expr->trueExpr) + model(expr->falseExpr) + 2;
}
else
{
LUAU_ASSERT(!"Unknown expression type");
return {};
}
}
void assign(AstExpr* expr)
{
// variable assignments reset variable mask, so that further uses of this variable aren't discounted
// this doesn't work perfectly with backwards control flow like loops, but is good enough for a single pass
if (AstExprLocal* lv = expr->as<AstExprLocal>())
if (uint64_t* i = vars.find(lv->local))
*i = 0;
}
void loop(AstStatBlock* body, Cost iterCost, int factor = 3)
{
Cost before = result;
result = Cost();
body->visit(this);
result = before + (result + iterCost) * factor;
}
bool visit(AstExpr* node) override
{
// note: we short-circuit the visitor traversal through any expression trees by returning false
// recursive traversal is happening inside model() which makes it easier to get the resulting value of the subexpression
result += model(node);
return false;
}
bool visit(AstStatFor* node) override
{
result += model(node->from);
result += model(node->to);
if (node->step)
result += model(node->step);
int tripCount = -1;
double from, to, step = 1;
if (getNumber(node->from, from) && getNumber(node->to, to) && (!node->step || getNumber(node->step, step)))
tripCount = getTripCount(from, to, step);
loop(node->body, 1, tripCount < 0 ? 3 : tripCount);
return false;
}
bool visit(AstStatForIn* node) override
{
for (size_t i = 0; i < node->values.size; ++i)
result += model(node->values.data[i]);
loop(node->body, 1);
return false;
}
bool visit(AstStatWhile* node) override
{
Cost condition = model(node->condition);
loop(node->body, condition);
return false;
}
bool visit(AstStatRepeat* node) override
{
Cost condition = model(node->condition);
loop(node->body, condition);
return false;
}
bool visit(AstStat* node) override
{
if (node->is<AstStatIf>())
result += 2;
else if (node->is<AstStatBreak>() || node->is<AstStatContinue>())
result += 1;
return true;
}
bool visit(AstStatLocal* node) override
{
for (size_t i = 0; i < node->values.size; ++i)
{
Cost arg = model(node->values.data[i]);
// propagate constant mask from expression through variables
if (arg.constant && i < node->vars.size)
vars[node->vars.data[i]] = arg.constant;
result += arg;
}
return false;
}
bool visit(AstStatAssign* node) override
{
for (size_t i = 0; i < node->vars.size; ++i)
assign(node->vars.data[i]);
return true;
}
bool visit(AstStatCompoundAssign* node) override
{
assign(node->var);
// if lhs is not a local, setting it requires an extra table operation
result += node->var->is<AstExprLocal>() ? 1 : 2;
return true;
}
};
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount)
{
CostVisitor visitor;
for (size_t i = 0; i < varCount && i < 7; ++i)
visitor.vars[vars[i]] = 0xffull << (i * 8 + 8);
root->visit(&visitor);
return visitor.result.model;
}
int computeCost(uint64_t model, const bool* varsConst, size_t varCount)
{
int cost = int(model & 0x7f);
// don't apply discounts to what is likely a saturated sum
if (cost == 0x7f)
return cost;
for (size_t i = 0; i < varCount && i < 7; ++i)
cost -= int((model >> (i * 8 + 8)) & 0x7f) * varsConst[i];
return cost;
}
int getTripCount(double from, double to, double step)
{
// we compute trip count in integers because that way we know that the loop math (repeated addition) is precise
int fromi = (from >= -32767 && from <= 32767 && double(int(from)) == from) ? int(from) : INT_MIN;
int toi = (to >= -32767 && to <= 32767 && double(int(to)) == to) ? int(to) : INT_MIN;
int stepi = (step >= -32767 && step <= 32767 && double(int(step)) == step) ? int(step) : INT_MIN;
if (fromi == INT_MIN || toi == INT_MIN || stepi == INT_MIN || stepi == 0)
return -1;
if ((stepi < 0 && toi > fromi) || (stepi > 0 && toi < fromi))
return 0;
return (toi - fromi) / stepi + 1;
}
} // namespace Compile
} // namespace Luau

21
Compiler/src/CostModel.h Normal file
View file

@ -0,0 +1,21 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Ast.h"
namespace Luau
{
namespace Compile
{
// cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount);
// cost is computed as B - sum(Di * Ci), where B is baseline cost, Di is the discount for each variable and Ci is 1 when variable #i is constant
int computeCost(uint64_t model, const bool* varsConst, size_t varCount);
// get loop trip count or -1 if we can't compute it precisely
int getTripCount(double from, double to, double step);
} // namespace Compile
} // namespace Luau

View file

@ -19,6 +19,10 @@ ANALYSIS_SOURCES=$(wildcard Analysis/src/*.cpp)
ANALYSIS_OBJECTS=$(ANALYSIS_SOURCES:%=$(BUILD)/%.o)
ANALYSIS_TARGET=$(BUILD)/libluauanalysis.a
CODEGEN_SOURCES=$(wildcard CodeGen/src/*.cpp)
CODEGEN_OBJECTS=$(CODEGEN_SOURCES:%=$(BUILD)/%.o)
CODEGEN_TARGET=$(BUILD)/libluaucodegen.a
VM_SOURCES=$(wildcard VM/src/*.cpp)
VM_OBJECTS=$(VM_SOURCES:%=$(BUILD)/%.o)
VM_TARGET=$(BUILD)/libluauvm.a
@ -47,7 +51,7 @@ ifneq ($(flags),)
TESTS_ARGS+=--fflags=$(flags)
endif
OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS)
OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS)
# common flags
CXXFLAGS=-g -Wall
@ -90,15 +94,16 @@ ifeq ($(config),fuzz)
endif
# target-specific flags
$(AST_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include
$(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -IAst/include
$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include
$(VM_OBJECTS): CXXFLAGS+=-std=c++11 -IVM/include
$(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include
$(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include
$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include
$(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include
$(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include
$(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include
$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICLI -Iextern
$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IVM/include -Iextern -Iextern/isocline/include
$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include -Iextern
$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include
$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern
$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -Iextern -Iextern/isocline/include
$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -Iextern
$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include
$(TESTS_TARGET): LDFLAGS+=-lpthread
$(REPL_CLI_TARGET): LDFLAGS+=-lpthread
@ -126,7 +131,7 @@ coverage: $(TESTS_TARGET)
llvm-cov export -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -format lcov --instr-profile default.profdata build/coverage/luau-tests >coverage.info
format:
find . -name '*.h' -or -name '*.cpp' | xargs clang-format -i
find . -name '*.h' -or -name '*.cpp' | xargs clang-format-11 -i
luau-size: luau
nm --print-size --demangle luau | grep ' t void luau_execute<false>' | awk -F ' ' '{sum += strtonum("0x" $$2)} END {print sum " interpreter" }'
@ -140,7 +145,7 @@ luau-analyze: $(ANALYZE_CLI_TARGET)
ln -fs $^ $@
# executable targets
$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)
$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)
$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)
$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(AST_TARGET)
@ -158,10 +163,11 @@ fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(B
$(AST_TARGET): $(AST_OBJECTS)
$(COMPILER_TARGET): $(COMPILER_OBJECTS)
$(ANALYSIS_TARGET): $(ANALYSIS_OBJECTS)
$(CODEGEN_TARGET): $(CODEGEN_OBJECTS)
$(VM_TARGET): $(VM_OBJECTS)
$(ISOCLINE_TARGET): $(ISOCLINE_OBJECTS)
$(AST_TARGET) $(COMPILER_TARGET) $(ANALYSIS_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET):
$(AST_TARGET) $(COMPILER_TARGET) $(ANALYSIS_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET):
ar rcs $@ $^
# object file targets

View file

@ -1,7 +1,15 @@
# Luau.Common Sources
# Note: Until 3.19, INTERFACE targets couldn't have SOURCES property set
if(NOT ${CMAKE_VERSION} VERSION_LESS "3.19")
target_sources(Luau.Common PRIVATE
Common/include/Luau/Common.h
Common/include/Luau/Bytecode.h
)
endif()
# Luau.Ast Sources
target_sources(Luau.Ast PRIVATE
Ast/include/Luau/Ast.h
Ast/include/Luau/Common.h
Ast/include/Luau/Confusables.h
Ast/include/Luau/DenseHash.h
Ast/include/Luau/Lexer.h
@ -23,7 +31,6 @@ target_sources(Luau.Ast PRIVATE
# Luau.Compiler Sources
target_sources(Luau.Compiler PRIVATE
Compiler/include/Luau/Bytecode.h
Compiler/include/Luau/BytecodeBuilder.h
Compiler/include/Luau/Compiler.h
Compiler/include/luacode.h
@ -32,31 +39,51 @@ target_sources(Luau.Compiler PRIVATE
Compiler/src/Compiler.cpp
Compiler/src/Builtins.cpp
Compiler/src/ConstantFolding.cpp
Compiler/src/CostModel.cpp
Compiler/src/TableShape.cpp
Compiler/src/ValueTracking.cpp
Compiler/src/lcode.cpp
Compiler/src/Builtins.h
Compiler/src/ConstantFolding.h
Compiler/src/CostModel.h
Compiler/src/TableShape.h
Compiler/src/ValueTracking.h
)
# Luau.CodeGen Sources
target_sources(Luau.CodeGen PRIVATE
CodeGen/include/Luau/AssemblyBuilderX64.h
CodeGen/include/Luau/Condition.h
CodeGen/include/Luau/Label.h
CodeGen/include/Luau/OperandX64.h
CodeGen/include/Luau/RegisterX64.h
CodeGen/src/AssemblyBuilderX64.cpp
)
# Luau.Analysis Sources
target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/AstQuery.h
Analysis/include/Luau/Autocomplete.h
Analysis/include/Luau/BuiltinDefinitions.h
Analysis/include/Luau/Clone.h
Analysis/include/Luau/Config.h
Analysis/include/Luau/Constraint.h
Analysis/include/Luau/ConstraintGraphBuilder.h
Analysis/include/Luau/ConstraintSolver.h
Analysis/include/Luau/ConstraintSolverLogger.h
Analysis/include/Luau/Documentation.h
Analysis/include/Luau/Error.h
Analysis/include/Luau/FileResolver.h
Analysis/include/Luau/Frontend.h
Analysis/include/Luau/Instantiation.h
Analysis/include/Luau/IostreamHelpers.h
Analysis/include/Luau/JsonEncoder.h
Analysis/include/Luau/Linter.h
Analysis/include/Luau/LValue.h
Analysis/include/Luau/Module.h
Analysis/include/Luau/ModuleResolver.h
Analysis/include/Luau/Normalize.h
Analysis/include/Luau/Predicate.h
Analysis/include/Luau/Quantify.h
Analysis/include/Luau/RecursionCounter.h
@ -69,7 +96,9 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/ToString.h
Analysis/include/Luau/Transpiler.h
Analysis/include/Luau/TxnLog.h
Analysis/include/Luau/TypeArena.h
Analysis/include/Luau/TypeAttach.h
Analysis/include/Luau/TypeChecker2.h
Analysis/include/Luau/TypedAllocator.h
Analysis/include/Luau/TypeInfer.h
Analysis/include/Luau/TypePack.h
@ -84,14 +113,21 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/AstQuery.cpp
Analysis/src/Autocomplete.cpp
Analysis/src/BuiltinDefinitions.cpp
Analysis/src/Clone.cpp
Analysis/src/Config.cpp
Analysis/src/Constraint.cpp
Analysis/src/ConstraintGraphBuilder.cpp
Analysis/src/ConstraintSolver.cpp
Analysis/src/ConstraintSolverLogger.cpp
Analysis/src/Error.cpp
Analysis/src/Frontend.cpp
Analysis/src/Instantiation.cpp
Analysis/src/IostreamHelpers.cpp
Analysis/src/JsonEncoder.cpp
Analysis/src/Linter.cpp
Analysis/src/LValue.cpp
Analysis/src/Module.cpp
Analysis/src/Normalize.cpp
Analysis/src/Quantify.cpp
Analysis/src/RequireTracer.cpp
Analysis/src/Scope.cpp
@ -102,7 +138,9 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/ToString.cpp
Analysis/src/Transpiler.cpp
Analysis/src/TxnLog.cpp
Analysis/src/TypeArena.cpp
Analysis/src/TypeAttach.cpp
Analysis/src/TypeChecker2.cpp
Analysis/src/TypedAllocator.cpp
Analysis/src/TypeInfer.cpp
Analysis/src/TypePack.cpp
@ -215,6 +253,7 @@ if(TARGET Luau.UnitTest)
tests/BuiltinDefinitions.test.cpp
tests/Compiler.test.cpp
tests/Config.test.cpp
tests/CostModel.test.cpp
tests/Error.test.cpp
tests/Frontend.test.cpp
tests/JsonEncoder.test.cpp
@ -222,8 +261,12 @@ if(TARGET Luau.UnitTest)
tests/LValue.test.cpp
tests/Module.test.cpp
tests/NonstrictMode.test.cpp
tests/Normalize.test.cpp
tests/ConstraintGraphBuilder.test.cpp
tests/ConstraintSolver.test.cpp
tests/Parser.test.cpp
tests/RequireTracer.test.cpp
tests/RuntimeLimits.test.cpp
tests/StringUtils.test.cpp
tests/Symbol.test.cpp
tests/ToDot.test.cpp
@ -255,6 +298,8 @@ if(TARGET Luau.UnitTest)
tests/TypePack.test.cpp
tests/TypeVar.test.cpp
tests/Variant.test.cpp
tests/VisitTypeVar.test.cpp
tests/AssemblyBuilderX64.test.cpp
tests/main.cpp)
endif()

Some files were not shown because too many files have changed in this diff Show more