Merge branch 'master' into global-type-paths

This commit is contained in:
HawDevelopment 2022-08-10 13:21:34 +02:00
commit b69e5bbbb3
196 changed files with 14482 additions and 5041 deletions

1
.github/codecov.yml vendored Normal file
View file

@ -0,0 +1 @@
comment: false

View file

@ -75,7 +75,7 @@ jobs:
name: ${{ matrix.bench.title }} (Windows ${{matrix.arch}})
tool: "benchmarkluau"
output-file-path: ./${{ matrix.bench.script }}-output.txt
external-data-json-path: ./gh-pages/dev/bench/data.json
external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Push benchmark results
@ -85,7 +85,7 @@ jobs:
cd gh-pages
git config user.name github-actions
git config user.email github@users.noreply.github.com
git add ./dev/bench/data.json
git add ./dev/bench/data-${{ matrix.os }}.json
git commit -m "Add benchmarks results for ${{ github.sha }}"
git push
cd ..
@ -156,7 +156,7 @@ jobs:
name: ${{ matrix.bench.title }}
tool: "benchmarkluau"
output-file-path: ./${{ matrix.bench.script }}-output.txt
external-data-json-path: ./gh-pages/dev/bench/data.json
external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json
github-token: ${{ secrets.BENCH_GITHUB_TOKEN }}
- name: Store ${{ matrix.bench.title }} result (CacheGrind)
@ -166,7 +166,7 @@ jobs:
name: ${{ matrix.bench.title }} (CacheGrind)
tool: "roblox"
output-file-path: ./${{ matrix.bench.script }}-output.txt
external-data-json-path: ./gh-pages/dev/bench/data.json
external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json
github-token: ${{ secrets.BENCH_GITHUB_TOKEN }}
- name: Push benchmark results
@ -176,7 +176,7 @@ jobs:
cd gh-pages
git config user.name github-actions
git config user.email github@users.noreply.github.com
git add ./dev/bench/data.json
git add ./dev/bench/data-${{ matrix.os }}.json
git commit -m "Add benchmarks results for ${{ github.sha }}"
git push
cd ..
@ -220,13 +220,13 @@ jobs:
sudo apt-get install valgrind
- name: Run Luau Analyze on static file
run: sudo python ./bench/measure_time.py ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee ${{ matrix.bench.script }}-output.txt
run: sudo python ./bench/measure_time.py ./build/release/luau-analyze bench/other/LuauPolyfillMap.lua | tee ${{ matrix.bench.script }}-output.txt
- name: Run ${{ matrix.bench.title }} (Cold Cachegrind)
run: sudo ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt
run: sudo ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 ./build/release/luau-analyze bench/other/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt
- name: Run ${{ matrix.bench.title }} (Warm Cachegrind)
run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt
run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}" 1 ./build/release/luau-analyze bench/other/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt
- name: Checkout Benchmark Results repository
uses: actions/checkout@v3
@ -244,7 +244,7 @@ jobs:
gh-pages-branch: "main"
output-file-path: ./${{ matrix.bench.script }}-output.txt
external-data-json-path: ./gh-pages/dev/bench/data.json
external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json
github-token: ${{ secrets.BENCH_GITHUB_TOKEN }}
- name: Store ${{ matrix.bench.title }} result (CacheGrind)
@ -254,7 +254,7 @@ jobs:
tool: "roblox"
gh-pages-branch: "main"
output-file-path: ./${{ matrix.bench.script }}-output.txt
external-data-json-path: ./gh-pages/dev/bench/data.json
external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json
github-token: ${{ secrets.BENCH_GITHUB_TOKEN }}
- name: Push benchmark results
@ -264,7 +264,7 @@ jobs:
cd gh-pages
git config user.name github-actions
git config user.email github@users.noreply.github.com
git add ./dev/bench/data.json
git add ./dev/bench/data-${{ matrix.os }}.json
git commit -m "Add benchmarks results for ${{ github.sha }}"
git push
cd ..

View file

@ -32,11 +32,33 @@ jobs:
sudo apt-get install valgrind
- name: Build Luau
run: CXX=${{ matrix.compiler }} make config=release CALLGRIND=1 luau
run: CXX=${{ matrix.compiler }} make config=release CALLGRIND=1 luau luau-analyze
- name: Run benchmark
- name: Run benchmark (bench)
run: |
python bench/bench.py --callgrind --vm "./luau -O2" | tee output.txt
python bench/bench.py --callgrind --vm "./luau -O2" | tee -a bench-output.txt
- name: Run benchmark (analyze)
run: |
filter() {
awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau-analyze"}'
}
valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-nonstrict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-strict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/regex.lua 2>&1 | filter regex-nonstrict | tee -a analyze-output.txt
valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/regex.lua 2>&1 | filter regex-strict | tee -a analyze-output.txt
- name: Run benchmark (compile)
run: |
filter() {
awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau --compile"}'
}
valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt
valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt
- name: Checkout benchmark results
uses: actions/checkout@v3
@ -46,13 +68,29 @@ jobs:
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
- name: Store results
- name: Store results (bench)
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: callgrind ${{ matrix.compiler }}
tool: "benchmarkluau"
output-file-path: ./output.txt
external-data-json-path: ./gh-pages/bench/data.json
output-file-path: ./bench-output.txt
external-data-json-path: ./gh-pages/bench.json
- name: Store results (analyze)
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: luau-analyze
tool: "benchmarkluau"
output-file-path: ./analyze-output.txt
external-data-json-path: ./gh-pages/analyze.json
- name: Store results (compile)
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: luau --compile
tool: "benchmarkluau"
output-file-path: ./compile-output.txt
external-data-json-path: ./gh-pages/compile.json
- name: Push benchmark results
if: github.event_name == 'push'
@ -61,7 +99,7 @@ jobs:
cd gh-pages
git config user.name github-actions
git config user.email github@users.noreply.github.com
git add ./bench/data.json
git add *.json
git commit -m "Add benchmarks results for ${{ github.sha }}"
git push
cd ..

View file

@ -76,20 +76,10 @@ jobs:
- name: make coverage
run: |
CXX=clang++-10 make -j2 config=coverage coverage
- name: debug coverage
run: |
git status
git log -5
echo SHA: $GITHUB_SHA
- name: upload coverage
uses: coverallsapp/github-action@master
uses: codecov/codecov-action@v3
with:
path-to-lcov: ./coverage.info
github-token: ${{ secrets.GITHUB_TOKEN }}
- uses: actions/upload-artifact@v2
with:
name: coverage
path: coverage
files: ./coverage.info
web:
runs-on: ubuntu-latest

3
.gitignore vendored
View file

@ -8,3 +8,6 @@
/default.prof*
/fuzz-*
/luau
/luau-tests
/luau-analyze
__pycache__

View file

@ -0,0 +1,32 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Substitution.h"
#include "Luau/TxnLog.h"
#include "Luau/TypeVar.h"
namespace Luau
{
// A substitution which replaces the type parameters of a type function by arguments
struct ApplyTypeFunction : Substitution
{
ApplyTypeFunction(TypeArena* arena)
: Substitution(TxnLog::empty(), arena)
, encounteredForwardedType(false)
{
}
// Never set under deferred constraint resolution.
bool encounteredForwardedType;
std::unordered_map<TypeId, TypeId> typeArguments;
std::unordered_map<TypePackId, TypePackId> typePackArguments;
bool ignoreChildren(TypeId ty) override;
bool ignoreChildren(TypePackId tp) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
};
} // namespace Luau

View file

@ -2,12 +2,15 @@
#pragma once
#include <string>
#include <vector>
namespace Luau
{
class AstNode;
struct Comment;
std::string toJson(AstNode* node);
std::string toJson(AstNode* node, const std::vector<Comment>& commentLocations);
} // namespace Luau

View file

@ -63,6 +63,7 @@ private:
AstLocal* local = nullptr;
};
std::vector<AstNode*> findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos);
std::vector<AstNode*> findAstAncestryOfPosition(const SourceModule& source, Position pos);
AstNode* findNodeAtPosition(const SourceModule& source, Position pos);
AstExpr* findExprAtPosition(const SourceModule& source, Position pos);

View file

@ -78,16 +78,6 @@ struct AutocompleteResult
using ModuleName = std::string;
using StringCompletionCallback = std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassTypeVar*> ctx)>;
struct OwningAutocompleteResult
{
AutocompleteResult result;
ModulePtr module;
std::unique_ptr<SourceModule> sourceModule;
};
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback);
// Deprecated, do not use in new work.
OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback);
} // namespace Luau

View file

@ -4,6 +4,7 @@
#include "Luau/Ast.h" // Used for some of the enumerations
#include "Luau/NotNull.h"
#include "Luau/Variant.h"
#include "Luau/TypeVar.h"
#include <string>
#include <memory>
@ -12,7 +13,8 @@
namespace Luau
{
struct Scope2;
struct Scope;
struct TypeVar;
using TypeId = const TypeVar*;
@ -38,7 +40,7 @@ struct GeneralizationConstraint
{
TypeId generalizedType;
TypeId sourceType;
Scope2* scope;
Scope* scope;
};
// subType ~ inst superType
@ -70,8 +72,15 @@ struct NameConstraint
std::string name;
};
// target ~ inst target
struct TypeAliasExpansionConstraint
{
// Must be a PendingExpansionTypeVar.
TypeId target;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint,
BinaryConstraint, NameConstraint>;
BinaryConstraint, NameConstraint, TypeAliasExpansionConstraint>;
using ConstraintPtr = std::unique_ptr<struct Constraint>;
struct Constraint

View file

@ -17,21 +17,22 @@
namespace Luau
{
struct Scope2;
struct Scope;
using ScopePtr = std::shared_ptr<Scope>;
struct ConstraintGraphBuilder
{
// A list of all the scopes in the module. This vector holds ownership of the
// scope pointers; the scopes themselves borrow pointers to other scopes to
// define the scope hierarchy.
std::vector<std::pair<Location, std::unique_ptr<Scope2>>> scopes;
std::vector<std::pair<Location, ScopePtr>> scopes;
ModuleName moduleName;
SingletonTypes& singletonTypes;
const NotNull<TypeArena> arena;
// The root scope of the module we're generating constraints for.
// This is null when the CGB is initially constructed.
Scope2* rootScope;
Scope* rootScope;
// A mapping of AST node to TypeId.
DenseHashMap<const AstExpr*, TypeId> astTypes{nullptr};
// A mapping of AST node to TypePackId.
@ -41,6 +42,8 @@ struct ConstraintGraphBuilder
DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr};
// Type packs resolved from type annotations. Analogous to astTypePacks.
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
// Defining scopes for AST nodes.
DenseHashMap<const AstStatTypeAlias*, ScopePtr> astTypeAliasDefiningScopes{nullptr};
int recursionCount = 0;
@ -50,42 +53,42 @@ struct ConstraintGraphBuilder
// Occasionally constraint generation needs to produce an ICE.
const NotNull<InternalErrorReporter> ice;
NotNull<Scope2> globalScope;
NotNull<Scope> globalScope;
ConstraintGraphBuilder(const ModuleName& moduleName, TypeArena* arena, NotNull<InternalErrorReporter> ice, NotNull<Scope2> globalScope);
ConstraintGraphBuilder(const ModuleName& moduleName, TypeArena* arena, NotNull<InternalErrorReporter> ice, NotNull<Scope> globalScope);
/**
* Fabricates a new free type belonging to a given scope.
* @param scope the scope the free type belongs to.
*/
TypeId freshType(NotNull<Scope2> scope);
TypeId freshType(const ScopePtr& scope);
/**
* Fabricates a new free type pack belonging to a given scope.
* @param scope the scope the free type pack belongs to.
*/
TypePackId freshTypePack(NotNull<Scope2> scope);
TypePackId freshTypePack(const ScopePtr& scope);
/**
* Fabricates a scope that is a child of another scope.
* @param location the lexical extent of the scope in the source code.
* @param parent the parent scope of the new scope. Must not be null.
*/
NotNull<Scope2> childScope(Location location, NotNull<Scope2> parent);
ScopePtr childScope(Location location, const ScopePtr& parent);
/**
* Adds a new constraint with no dependencies to a given scope.
* @param scope the scope to add the constraint to.
* @param cv the constraint variant to add.
*/
void addConstraint(NotNull<Scope2> scope, ConstraintV cv);
void addConstraint(const ScopePtr& scope, ConstraintV cv);
/**
* Adds a constraint to a given scope.
* @param scope the scope to add the constraint to. Must not be null.
* @param c the constraint to add.
*/
void addConstraint(NotNull<Scope2> scope, std::unique_ptr<Constraint> c);
void addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c);
/**
* The entry point to the ConstraintGraphBuilder. This will construct a set
@ -94,22 +97,26 @@ struct ConstraintGraphBuilder
*/
void visit(AstStatBlock* block);
void visitBlockWithoutChildScope(NotNull<Scope2> scope, AstStatBlock* block);
void visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block);
void visit(NotNull<Scope2> scope, AstStat* stat);
void visit(NotNull<Scope2> scope, AstStatBlock* block);
void visit(NotNull<Scope2> scope, AstStatLocal* local);
void visit(NotNull<Scope2> scope, AstStatLocalFunction* function);
void visit(NotNull<Scope2> scope, AstStatFunction* function);
void visit(NotNull<Scope2> scope, AstStatReturn* ret);
void visit(NotNull<Scope2> scope, AstStatAssign* assign);
void visit(NotNull<Scope2> scope, AstStatIf* ifStatement);
void visit(NotNull<Scope2> scope, AstStatTypeAlias* alias);
void visit(const ScopePtr& scope, AstStat* stat);
void visit(const ScopePtr& scope, AstStatBlock* block);
void visit(const ScopePtr& scope, AstStatLocal* local);
void visit(const ScopePtr& scope, AstStatFor* for_);
void visit(const ScopePtr& scope, AstStatLocalFunction* function);
void visit(const ScopePtr& scope, AstStatFunction* function);
void visit(const ScopePtr& scope, AstStatReturn* ret);
void visit(const ScopePtr& scope, AstStatAssign* assign);
void visit(const ScopePtr& scope, AstStatIf* ifStatement);
void visit(const ScopePtr& scope, AstStatTypeAlias* alias);
void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal);
void visit(const ScopePtr& scope, AstStatDeclareClass* declareClass);
void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction);
TypePackId checkExprList(NotNull<Scope2> scope, const AstArray<AstExpr*>& exprs);
TypePackId checkExprList(const ScopePtr& scope, const AstArray<AstExpr*>& exprs);
TypePackId checkPack(NotNull<Scope2> scope, AstArray<AstExpr*> exprs);
TypePackId checkPack(NotNull<Scope2> scope, AstExpr* expr);
TypePackId checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs);
TypePackId checkPack(const ScopePtr& scope, AstExpr* expr);
/**
* Checks an expression that is expected to evaluate to one type.
@ -117,13 +124,13 @@ struct ConstraintGraphBuilder
* @param expr the expression to check.
* @return the type of the expression.
*/
TypeId check(NotNull<Scope2> scope, AstExpr* expr);
TypeId check(const ScopePtr& scope, AstExpr* expr);
TypeId checkExprTable(NotNull<Scope2> scope, AstExprTable* expr);
TypeId check(NotNull<Scope2> scope, AstExprIndexName* indexName);
TypeId check(NotNull<Scope2> scope, AstExprIndexExpr* indexExpr);
TypeId check(NotNull<Scope2> scope, AstExprUnary* unary);
TypeId check(NotNull<Scope2> scope, AstExprBinary* binary);
TypeId checkExprTable(const ScopePtr& scope, AstExprTable* expr);
TypeId check(const ScopePtr& scope, AstExprIndexName* indexName);
TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
TypeId check(const ScopePtr& scope, AstExprUnary* unary);
TypeId check(const ScopePtr& scope, AstExprBinary* binary);
struct FunctionSignature
{
@ -132,28 +139,29 @@ struct ConstraintGraphBuilder
// The scope that encompasses the function's signature. May be nullptr
// if there was no need for a signature scope (the function has no
// generics).
Scope2* signatureScope;
ScopePtr signatureScope;
// The scope that encompasses the function's body. Is a child scope of
// signatureScope, if present.
NotNull<Scope2> bodyScope;
ScopePtr bodyScope;
};
FunctionSignature checkFunctionSignature(NotNull<Scope2> parent, AstExprFunction* fn);
FunctionSignature checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn);
/**
* Checks the body of a function expression.
* @param scope the interior scope of the body of the function.
* @param fn the function expression to check.
*/
void checkFunctionBody(NotNull<Scope2> scope, AstExprFunction* fn);
void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn);
/**
* Resolves a type from its AST annotation.
* @param scope the scope that the type annotation appears within.
* @param ty the AST annotation to resolve.
* @param topLevel whether the annotation is a "top-level" annotation.
* @return the type of the AST annotation.
**/
TypeId resolveType(NotNull<Scope2> scope, AstType* ty);
TypeId resolveType(const ScopePtr& scope, AstType* ty, bool topLevel = false);
/**
* Resolves a type pack from its AST annotation.
@ -161,14 +169,14 @@ struct ConstraintGraphBuilder
* @param tp the AST annotation to resolve.
* @return the type pack of the AST annotation.
**/
TypePackId resolveTypePack(NotNull<Scope2> scope, AstTypePack* tp);
TypePackId resolveTypePack(const ScopePtr& scope, AstTypePack* tp);
TypePackId resolveTypePack(NotNull<Scope2> scope, const AstTypeList& list);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list);
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(NotNull<Scope2> scope, AstArray<AstGenericType> generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(NotNull<Scope2> scope, AstArray<AstGenericTypePack> packs);
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(const ScopePtr& scope, AstArray<AstGenericType> generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(const ScopePtr& scope, AstArray<AstGenericTypePack> packs);
TypeId flattenPack(NotNull<Scope2> scope, Location location, TypePackId tp);
TypeId flattenPack(const ScopePtr& scope, Location location, TypePackId tp);
void reportError(Location location, TypeErrorData err);
void reportCodeTooComplex(Location location);
@ -179,7 +187,7 @@ struct ConstraintGraphBuilder
* real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an
* initial scan of the AST and note what globals are defined.
*/
void prepopulateGlobalScope(NotNull<Scope2> globalScope, AstStatBlock* program);
void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program);
};
/**
@ -192,6 +200,6 @@ struct ConstraintGraphBuilder
* @return a list of pointers to constraints contained within the scope graph.
* None of these pointers should be null.
*/
std::vector<NotNull<Constraint>> collectConstraints(NotNull<Scope2> rootScope);
std::vector<NotNull<Constraint>> collectConstraints(NotNull<Scope> rootScope);
} // namespace Luau

View file

@ -17,15 +17,35 @@ namespace Luau
// never dereference this pointer.
using BlockedConstraintId = const void*;
struct InstantiationSignature
{
TypeFun fn;
std::vector<TypeId> arguments;
std::vector<TypePackId> packArguments;
bool operator==(const InstantiationSignature& rhs) const;
bool operator!=(const InstantiationSignature& rhs) const
{
return !((*this) == rhs);
}
};
struct HashInstantiationSignature
{
size_t operator()(const InstantiationSignature& signature) const;
};
struct ConstraintSolver
{
TypeArena* arena;
InternalErrorReporter iceReporter;
// The entire set of constraints that the solver is trying to resolve. It
// is important to not add elements to this vector, lest the underlying
// storage that we retain pointers to be mutated underneath us.
const std::vector<NotNull<Constraint>> constraints;
NotNull<Scope2> rootScope;
// The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints;
NotNull<Scope> rootScope;
// Constraints that the solver has generated, rather than sourcing from the
// scope tree.
std::vector<std::unique_ptr<Constraint>> solverConstraints;
// This includes every constraint that has not been fully solved.
// A constraint can be both blocked and unsolved, for instance.
@ -37,10 +57,12 @@ struct ConstraintSolver
std::unordered_map<NotNull<const Constraint>, size_t> blockedConstraints;
// A mapping of type/pack pointers to the constraints they block.
std::unordered_map<BlockedConstraintId, std::vector<NotNull<const Constraint>>> blocked;
// Memoized instantiations of type aliases.
DenseHashMap<InstantiationSignature, TypeId, HashInstantiationSignature> instantiatedAliases{{}};
ConstraintSolverLogger logger;
explicit ConstraintSolver(TypeArena* arena, NotNull<Scope2> rootScope);
explicit ConstraintSolver(TypeArena* arena, NotNull<Scope> rootScope);
/**
* Attempts to dispatch all pending constraints and reach a type solution
@ -62,6 +84,7 @@ struct ConstraintSolver
bool tryDispatch(const UnaryConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const BinaryConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const NameConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
/**
@ -102,6 +125,11 @@ struct ConstraintSolver
*/
void unify(TypePackId subPack, TypePackId superPack);
/** Pushes a new solver constraint to the solver.
* @param cv the body of the constraint.
**/
void pushConstraint(ConstraintV cv);
private:
/**
* Marks a constraint as being blocked on a type or type pack. The constraint
@ -121,6 +149,6 @@ private:
void unblock_(BlockedConstraintId progressed);
};
void dump(NotNull<Scope2> rootScope, struct ToStringOptions& opts);
void dump(NotNull<Scope> rootScope, struct ToStringOptions& opts);
} // namespace Luau

View file

@ -15,8 +15,8 @@ namespace Luau
struct ConstraintSolverLogger
{
std::string compileOutput();
void captureBoundarySnapshot(const Scope2* rootScope, std::vector<NotNull<const Constraint>>& unsolvedConstraints);
void prepareStepSnapshot(const Scope2* rootScope, NotNull<const Constraint> current, std::vector<NotNull<const Constraint>>& unsolvedConstraints);
void captureBoundarySnapshot(const Scope* rootScope, std::vector<NotNull<const Constraint>>& unsolvedConstraints);
void prepareStepSnapshot(const Scope* rootScope, NotNull<const Constraint> current, std::vector<NotNull<const Constraint>>& unsolvedConstraints);
void commitPreparedStepSnapshot();
private:

View file

@ -127,13 +127,6 @@ struct Frontend
CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess
LintResult lint(const ModuleName& name, std::optional<LintOptions> enabledLintWarnings = {});
/** Lint some code that has no associated DataModel object
*
* Since this source fragment has no name, we cannot cache its AST. Instead,
* we return it to the caller to use as they wish.
*/
std::pair<SourceModule, LintResult> lintFragment(std::string_view source, std::optional<LintOptions> enabledLintWarnings = {});
LintResult lint(const SourceModule& module, std::optional<LintOptions> enabledLintWarnings = {});
bool isDirty(const ModuleName& name, bool forAutocomplete = false) const;
@ -159,7 +152,9 @@ struct Frontend
void registerBuiltinDefinition(const std::string& name, std::function<void(TypeChecker&, ScopePtr)>);
void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName);
NotNull<Scope2> getGlobalScope2();
LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName);
NotNull<Scope> getGlobalScope();
private:
ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope);
@ -176,7 +171,7 @@ private:
std::unordered_map<std::string, ScopePtr> environments;
std::unordered_map<std::string, std::function<void(TypeChecker&, ScopePtr)>> builtinDefinitions;
std::unique_ptr<Scope2> globalScope2;
ScopePtr globalScope;
public:
FileResolver* fileResolver;
@ -187,6 +182,7 @@ public:
ConfigResolver* configResolver;
FrontendOptions options;
InternalErrorReporter iceHandler;
TypeArena globalTypes;
TypeArena arenaForAutocomplete;
std::unordered_map<ModuleName, SourceNode> sourceNodes;

View file

@ -0,0 +1,235 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <type_traits>
#include <string>
#include <optional>
#include <vector>
#include "Luau/NotNull.h"
namespace Luau::Json
{
struct JsonEmitter;
/// Writes a value to the JsonEmitter. Note that this can produce invalid JSON
/// if you do not insert commas or appropriate object / array syntax.
template<typename T>
void write(JsonEmitter&, T) = delete;
/// Writes a boolean to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param b the boolean to write.
void write(JsonEmitter& emitter, bool b);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, int i);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, long i);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, long long i);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, unsigned int i);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, unsigned long i);
/// Writes an integer to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param i the integer to write.
void write(JsonEmitter& emitter, unsigned long long i);
/// Writes a double to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param d the double to write.
void write(JsonEmitter& emitter, double d);
/// Writes a string to a JsonEmitter. The string will be escaped.
/// @param emitter the emitter to write to.
/// @param sv the string to write.
void write(JsonEmitter& emitter, std::string_view sv);
/// Writes a character to a JsonEmitter as a single-character string. The
/// character will be escaped.
/// @param emitter the emitter to write to.
/// @param c the string to write.
void write(JsonEmitter& emitter, char c);
/// Writes a string to a JsonEmitter. The string will be escaped.
/// @param emitter the emitter to write to.
/// @param str the string to write.
void write(JsonEmitter& emitter, const char* str);
/// Writes a string to a JsonEmitter. The string will be escaped.
/// @param emitter the emitter to write to.
/// @param str the string to write.
void write(JsonEmitter& emitter, const std::string& str);
/// Writes null to a JsonEmitter.
/// @param emitter the emitter to write to.
void write(JsonEmitter& emitter, std::nullptr_t);
/// Writes null to a JsonEmitter.
/// @param emitter the emitter to write to.
void write(JsonEmitter& emitter, std::nullopt_t);
struct ObjectEmitter;
struct ArrayEmitter;
struct JsonEmitter
{
JsonEmitter();
/// Converts the current contents of the JsonEmitter to a string value. This
/// does not invalidate the emitter, but it does not clear it either.
std::string str();
/// Returns the current comma state and resets it to false. Use popComma to
/// restore the old state.
/// @returns the previous comma state.
bool pushComma();
/// Restores a previous comma state.
/// @param c the comma state to restore.
void popComma(bool c);
/// Writes a raw sequence of characters to the buffer, without escaping or
/// other processing.
/// @param sv the character sequence to write.
void writeRaw(std::string_view sv);
/// Writes a character to the buffer, without escaping or other processing.
/// @param c the character to write.
void writeRaw(char c);
/// Writes a comma if this wasn't the first time writeComma has been
/// invoked. Otherwise, sets the comma state to true.
/// @see pushComma
/// @see popComma
void writeComma();
/// Begins writing an object to the emitter.
/// @returns an ObjectEmitter that can be used to write key-value pairs.
ObjectEmitter writeObject();
/// Begins writing an array to the emitter.
/// @returns an ArrayEmitter that can be used to write values.
ArrayEmitter writeArray();
private:
bool comma = false;
std::vector<std::string> chunks;
void newChunk();
};
/// An interface for writing an object into a JsonEmitter instance.
/// @see JsonEmitter::writeObject
struct ObjectEmitter
{
ObjectEmitter(NotNull<JsonEmitter> emitter);
~ObjectEmitter();
NotNull<JsonEmitter> emitter;
bool comma;
bool finished;
/// Writes a key-value pair to the associated JsonEmitter. Keys will be escaped.
/// @param name the name of the key-value pair.
/// @param value the value to write.
template<typename T>
void writePair(std::string_view name, T value)
{
if (finished)
{
return;
}
emitter->writeComma();
write(*emitter, name);
emitter->writeRaw(':');
write(*emitter, value);
}
/// Finishes writing the object, appending a closing `}` character and
/// resetting the comma state of the associated emitter. This can only be
/// called once, and once called will render the emitter unusable. This
/// method is also called when the ObjectEmitter is destructed.
void finish();
};
/// An interface for writing an array into a JsonEmitter instance. Array values
/// do not need to be the same type.
/// @see JsonEmitter::writeArray
struct ArrayEmitter
{
ArrayEmitter(NotNull<JsonEmitter> emitter);
~ArrayEmitter();
NotNull<JsonEmitter> emitter;
bool comma;
bool finished;
/// Writes a value to the array.
/// @param value the value to write.
template<typename T>
void writeValue(T value)
{
if (finished)
{
return;
}
emitter->writeComma();
write(*emitter, value);
}
/// Finishes writing the object, appending a closing `]` character and
/// resetting the comma state of the associated emitter. This can only be
/// called once, and once called will render the emitter unusable. This
/// method is also called when the ArrayEmitter is destructed.
void finish();
};
/// Writes a vector as an array to a JsonEmitter.
/// @param emitter the emitter to write to.
/// @param vec the vector to write.
template<typename T>
void write(JsonEmitter& emitter, const std::vector<T>& vec)
{
ArrayEmitter a = emitter.writeArray();
for (const T& value : vec)
a.writeValue(value);
a.finish();
}
/// Writes an optional to a JsonEmitter. Will write the contained value, if
/// present, or null, if no value is present.
/// @param emitter the emitter to write to.
/// @param v the value to write.
template<typename T>
void write(JsonEmitter& emitter, const std::optional<T>& v)
{
if (v.has_value())
write(emitter, *v);
else
emitter.writeRaw("null");
}
} // namespace Luau::Json

View file

@ -52,6 +52,7 @@ struct LintWarning
Code_DuplicateCondition = 24,
Code_MisleadingAndOr = 25,
Code_CommentDirective = 26,
Code_IntegerParsing = 27,
Code__Count
};

View file

@ -68,8 +68,7 @@ struct Module
std::shared_ptr<Allocator> allocator;
std::shared_ptr<AstNameTable> names;
std::vector<std::pair<Location, ScopePtr>> scopes; // never empty
std::vector<std::pair<Location, std::unique_ptr<Scope2>>> scope2s; // never empty
std::vector<std::pair<Location, ScopePtr>> scopes; // never empty
DenseHashMap<const AstExpr*, TypeId> astTypes{nullptr};
DenseHashMap<const AstExpr*, TypePackId> astTypePacks{nullptr};
@ -86,7 +85,6 @@ struct Module
bool timeout = false;
ScopePtr getModuleScope() const;
Scope2* getModuleScope2() const;
// Once a module has been typechecked, we clone its public interface into a separate arena.
// This helps us to force TypeVar ownership into a DAG rather than a DCG.

View file

@ -7,9 +7,9 @@ namespace Luau
{
struct TypeArena;
struct Scope2;
struct Scope;
void quantify(TypeId ty, TypeLevel level);
TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope);
TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope);
} // namespace Luau

View file

@ -32,10 +32,16 @@ struct Scope
explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr.
const ScopePtr parent; // null for the root
// All the children of this scope.
std::vector<NotNull<Scope>> children;
std::unordered_map<Symbol, Binding> bindings;
std::unordered_map<Name, TypeFun> typeBindings;
std::unordered_map<Name, TypePackId> typePackBindings;
TypePackId returnType;
bool breakOk = false;
std::optional<TypePackId> varargPack;
// All constraints belonging to this scope.
std::vector<ConstraintPtr> constraints;
TypeLevel level;
@ -45,7 +51,9 @@ struct Scope
std::unordered_map<Name, std::unordered_map<Name, TypeFun>> importedTypeBindings;
std::optional<TypeId> lookup(const Symbol& name);
std::optional<TypeId> lookup(Symbol sym);
std::optional<TypeFun> lookupTypeBinding(const Name& name);
std::optional<TypePackId> lookupTypePackBinding(const Name& name);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);
@ -66,24 +74,4 @@ struct Scope
std::unordered_map<Name, TypePackId> typeAliasTypePackParameters;
};
struct Scope2
{
// The parent scope of this scope. Null if there is no parent (i.e. this
// is the module-level scope).
Scope2* parent = nullptr;
// All the children of this scope.
std::vector<NotNull<Scope2>> children;
std::unordered_map<Symbol, TypeId> bindings; // TODO: I think this can be a DenseHashMap
std::unordered_map<Name, TypeId> typeBindings;
std::unordered_map<Name, TypePackId> typePackBindings;
TypePackId returnType;
std::optional<TypePackId> varargPack;
// All constraints belonging to this scope.
std::vector<ConstraintPtr> constraints;
std::optional<TypeId> lookup(Symbol sym);
std::optional<TypeId> lookupTypeBinding(const Name& name);
std::optional<TypePackId> lookupTypePackBinding(const Name& name);
};
} // namespace Luau

View file

@ -139,6 +139,8 @@ struct FindDirty : Tarjan
{
std::vector<bool> dirty;
void clearTarjan();
// Get/set the dirty bit for an index (grows the vector if needed)
bool getDirty(int index);
void setDirty(int index, bool d);
@ -176,6 +178,8 @@ public:
TypeArena* arena;
DenseHashMap<TypeId, TypeId> newTypes{nullptr};
DenseHashMap<TypePackId, TypePackId> newPacks{nullptr};
DenseHashSet<TypeId> replacedTypes{nullptr};
DenseHashSet<TypePackId> replacedTypePacks{nullptr};
std::optional<TypeId> substitute(TypeId ty);
std::optional<TypePackId> substitute(TypePackId tp);

View file

@ -65,28 +65,6 @@ struct Anyification : Substitution
}
};
// A substitution which replaces the type parameters of a type function by arguments
struct ApplyTypeFunction : Substitution
{
ApplyTypeFunction(TypeArena* arena, TypeLevel level)
: Substitution(TxnLog::empty(), arena)
, level(level)
, encounteredForwardedType(false)
{
}
TypeLevel level;
bool encounteredForwardedType;
std::unordered_map<TypeId, TypeId> typeArguments;
std::unordered_map<TypePackId, TypePackId> typePackArguments;
bool ignoreChildren(TypeId ty) override;
bool ignoreChildren(TypePackId tp) override;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
};
struct GenericTypeDefinitions
{
std::vector<GenericTypeDefinition> genericTypes;
@ -153,7 +131,7 @@ struct TypeChecker
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {});
TypeId checkBinaryOperation(
const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {});
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprBinary& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional<TypeId> expectedType = std::nullopt);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprError& expr);
WithPredicate<TypeId> checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType = std::nullopt);
@ -180,8 +158,12 @@ struct TypeChecker
const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector<Location>& argLocations);
WithPredicate<TypePackId> checkExprPack(const ScopePtr& scope, const AstExpr& expr);
WithPredicate<TypePackId> checkExprPack(const ScopePtr& scope, const AstExprCall& expr);
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr);
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr);
std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall);
std::optional<WithPredicate<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors);
@ -236,10 +218,11 @@ struct TypeChecker
void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location);
std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location);
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location);
std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors);
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors);
std::optional<TypeId> getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors);
std::optional<TypeId> getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors);
// Reduces the union to its simplest possible shape.
// (A | B) | B | C yields A | B | C
@ -316,11 +299,12 @@ private:
TypeIdPredicate mkTruthyPredicate(bool sense);
// Returns nullopt if the predicate filters down the TypeId to 0 options.
std::optional<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
// TODO: Return TypeId only.
std::optional<TypeId> filterMapImpl(TypeId type, TypeIdPredicate predicate);
std::pair<std::optional<TypeId>, bool> filterMap(TypeId type, TypeIdPredicate predicate);
public:
std::optional<TypeId> pickTypesFromSense(TypeId type, bool sense);
std::pair<std::optional<TypeId>, bool> pickTypesFromSense(TypeId type, bool sense);
private:
TypeId unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes = true);
@ -345,6 +329,7 @@ private:
TypePackId freshTypePack(TypeLevel level);
TypeId resolveType(const ScopePtr& scope, const AstType& annotation);
TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types);
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation);
TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector<TypeId>& typeParams,
@ -412,8 +397,12 @@ public:
const TypeId booleanType;
const TypeId threadType;
const TypeId anyType;
const TypeId unknownType;
const TypeId neverType;
const TypePackId anyTypePack;
const TypePackId neverTypePack;
const TypePackId uninhabitableTypePack;
private:
int checkRecursionCount = 0;

View file

@ -173,5 +173,6 @@ std::pair<std::vector<TypeId>, std::optional<TypePackId>> flatten(TypePackId tp,
bool isVariadic(TypePackId tp);
bool isVariadic(TypePackId tp, const TxnLog& log);
bool containsNever(TypePackId tp);
} // namespace Luau

View file

@ -24,7 +24,7 @@ namespace Luau
{
struct TypeArena;
struct Scope2;
struct Scope;
/**
* There are three kinds of type variables:
@ -143,7 +143,7 @@ struct ConstrainedTypeVar
std::vector<TypeId> parts;
TypeLevel level;
Scope2* scope = nullptr;
Scope* scope = nullptr;
};
// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
@ -223,12 +223,16 @@ struct GenericTypeDefinition
{
TypeId ty;
std::optional<TypeId> defaultValue;
bool operator==(const GenericTypeDefinition& rhs) const;
};
struct GenericTypePackDefinition
{
TypePackId tp;
std::optional<TypePackId> defaultValue;
bool operator==(const GenericTypePackDefinition& rhs) const;
};
struct FunctionArgument
@ -275,7 +279,7 @@ struct FunctionTypeVar
std::optional<FunctionDefinition> defn = {}, bool hasSelf = false);
TypeLevel level;
Scope2* scope = nullptr;
Scope* scope = nullptr;
/// These should all be generic
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
@ -344,7 +348,7 @@ struct TableTypeVar
TableState state = TableState::Unsealed;
TypeLevel level;
Scope2* scope = nullptr;
Scope* scope = nullptr;
std::optional<std::string> name;
// Sometimes we throw a type on a name to make for nicer error messages, but without creating any entry in the type namespace
@ -426,6 +430,12 @@ struct TypeFun
TypeId type;
TypeFun() = default;
explicit TypeFun(TypeId ty)
: type(ty)
{
}
TypeFun(std::vector<GenericTypeDefinition> typeParams, TypeId type)
: typeParams(std::move(typeParams))
, type(type)
@ -438,6 +448,27 @@ struct TypeFun
, type(type)
{
}
bool operator==(const TypeFun& rhs) const;
};
/** Represents a pending type alias instantiation.
*
* In order to afford (co)recursive type aliases, we need to reason about a
* partially-complete instantiation. This requires encoding more information in
* a type variable than a BlockedTypeVar affords, hence this. Each
* PendingExpansionTypeVar has a corresponding TypeAliasExpansionConstraint
* enqueued in the solver to convert it to an actual instantiated type
*/
struct PendingExpansionTypeVar
{
PendingExpansionTypeVar(TypeFun fn, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments);
TypeFun fn;
std::vector<TypeId> typeArguments;
std::vector<TypePackId> packArguments;
size_t index;
static size_t nextIndex;
};
// Anything! All static checking is off.
@ -460,10 +491,20 @@ struct LazyTypeVar
std::function<TypeId()> thunk;
};
struct UnknownTypeVar
{
};
struct NeverTypeVar
{
};
using ErrorTypeVar = Unifiable::Error;
using TypeVariant = Unifiable::Variant<TypeId, PrimitiveTypeVar, ConstrainedTypeVar, BlockedTypeVar, SingletonTypeVar, FunctionTypeVar, TableTypeVar,
MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar>;
using TypeVariant =
Unifiable::Variant<TypeId, PrimitiveTypeVar, ConstrainedTypeVar, BlockedTypeVar, PendingExpansionTypeVar, SingletonTypeVar, FunctionTypeVar,
TableTypeVar, MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar>;
struct TypeVar final
{
@ -575,8 +616,12 @@ struct SingletonTypes
const TypeId trueType;
const TypeId falseType;
const TypeId anyType;
const TypeId unknownType;
const TypeId neverType;
const TypePackId anyTypePack;
const TypePackId neverTypePack;
const TypePackId uninhabitableTypePack;
SingletonTypes();
~SingletonTypes();
@ -632,12 +677,30 @@ T* getMutable(TypeId tv)
return get_if<T>(&asMutable(tv)->ty);
}
/* Traverses the UnionTypeVar yielding each TypeId.
* If the iterator encounters a nested UnionTypeVar, it will instead yield each TypeId within.
*
* Beware: the iterator does not currently filter for unique TypeIds. This may change in the future.
const std::vector<TypeId>& getTypes(const UnionTypeVar* utv);
const std::vector<TypeId>& getTypes(const IntersectionTypeVar* itv);
const std::vector<TypeId>& getTypes(const ConstrainedTypeVar* ctv);
template<typename T>
struct TypeIterator;
using UnionTypeVarIterator = TypeIterator<UnionTypeVar>;
UnionTypeVarIterator begin(const UnionTypeVar* utv);
UnionTypeVarIterator end(const UnionTypeVar* utv);
using IntersectionTypeVarIterator = TypeIterator<IntersectionTypeVar>;
IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv);
IntersectionTypeVarIterator end(const IntersectionTypeVar* itv);
using ConstrainedTypeVarIterator = TypeIterator<ConstrainedTypeVar>;
ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv);
ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv);
/* Traverses the type T yielding each TypeId.
* If the iterator encounters a nested type T, it will instead yield each TypeId within.
*/
struct UnionTypeVarIterator
template<typename T>
struct TypeIterator
{
using value_type = Luau::TypeId;
using pointer = value_type*;
@ -645,33 +708,116 @@ struct UnionTypeVarIterator
using difference_type = size_t;
using iterator_category = std::input_iterator_tag;
explicit UnionTypeVarIterator(const UnionTypeVar* utv);
explicit TypeIterator(const T* t)
{
LUAU_ASSERT(t);
UnionTypeVarIterator& operator++();
UnionTypeVarIterator operator++(int);
bool operator!=(const UnionTypeVarIterator& rhs);
bool operator==(const UnionTypeVarIterator& rhs);
const std::vector<TypeId>& types = getTypes(t);
if (!types.empty())
stack.push_front({t, 0});
const TypeId& operator*();
seen.insert(t);
}
friend UnionTypeVarIterator end(const UnionTypeVar* utv);
TypeIterator<T>& operator++()
{
advance();
descend();
return *this;
}
TypeIterator<T> operator++(int)
{
TypeIterator<T> copy = *this;
++copy;
return copy;
}
bool operator==(const TypeIterator<T>& rhs) const
{
if (!stack.empty() && !rhs.stack.empty())
return stack.front() == rhs.stack.front();
return stack.empty() && rhs.stack.empty();
}
bool operator!=(const TypeIterator<T>& rhs) const
{
return !(*this == rhs);
}
const TypeId& operator*()
{
LUAU_ASSERT(!stack.empty());
descend();
auto [t, currentIndex] = stack.front();
LUAU_ASSERT(t);
const std::vector<TypeId>& types = getTypes(t);
LUAU_ASSERT(currentIndex < types.size());
const TypeId& ty = types[currentIndex];
LUAU_ASSERT(!get<T>(follow(ty)));
return ty;
}
// Normally, we'd have `begin` and `end` be a template but there's too much trouble
// with templates portability in this area, so not worth it. Thanks MSVC.
friend UnionTypeVarIterator end(const UnionTypeVar*);
friend IntersectionTypeVarIterator end(const IntersectionTypeVar*);
friend ConstrainedTypeVarIterator end(const ConstrainedTypeVar*);
private:
UnionTypeVarIterator() = default;
TypeIterator() = default;
// (UnionTypeVar* utv, size_t currentIndex)
using SavedIterInfo = std::pair<const UnionTypeVar*, size_t>;
// (T* t, size_t currentIndex)
using SavedIterInfo = std::pair<const T*, size_t>;
std::deque<SavedIterInfo> stack;
std::unordered_set<const UnionTypeVar*> seen; // Only needed to protect the iterator from hanging the thread.
std::unordered_set<const T*> seen; // Only needed to protect the iterator from hanging the thread.
void advance();
void descend();
void advance()
{
while (!stack.empty())
{
auto& [t, currentIndex] = stack.front();
++currentIndex;
const std::vector<TypeId>& types = getTypes(t);
if (currentIndex >= types.size())
stack.pop_front();
else
break;
}
}
void descend()
{
while (!stack.empty())
{
auto [current, currentIndex] = stack.front();
const std::vector<TypeId>& types = getTypes(current);
if (auto inner = get<T>(follow(types[currentIndex])))
{
// If we're about to descend into a cyclic type, we should skip over this.
// Ideally this should never happen, but alas it does from time to time. :(
if (seen.find(inner) != seen.end())
advance();
else
{
seen.insert(inner);
stack.push_front({inner, 0});
}
continue;
}
break;
}
}
};
UnionTypeVarIterator begin(const UnionTypeVar* utv);
UnionTypeVarIterator end(const UnionTypeVar* utv);
using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>;
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);

View file

@ -8,7 +8,7 @@
namespace Luau
{
struct Scope2;
struct Scope;
/**
* The 'level' of a TypeVar is an indirect way to talk about the scope that it 'belongs' too.
@ -84,11 +84,11 @@ using Name = std::string;
struct Free
{
explicit Free(TypeLevel level);
explicit Free(Scope2* scope);
explicit Free(Scope* scope);
int index;
TypeLevel level;
Scope2* scope = nullptr;
Scope* scope = nullptr;
// True if this free type variable is part of a mutually
// recursive type alias whose definitions haven't been
// resolved yet.
@ -115,13 +115,13 @@ struct Generic
Generic();
explicit Generic(TypeLevel level);
explicit Generic(const Name& name);
explicit Generic(Scope2* scope);
explicit Generic(Scope* scope);
Generic(TypeLevel level, const Name& name);
Generic(Scope2* scope, const Name& name);
Generic(Scope* scope, const Name& name);
int index;
TypeLevel level;
Scope2* scope = nullptr;
Scope* scope = nullptr;
Name name;
bool explicitName = false;

View file

@ -79,6 +79,7 @@ private:
void tryUnifySingletons(TypeId subTy, TypeId superTy);
void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false);
void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false);
void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed);

View file

@ -9,7 +9,7 @@
#include "Luau/TypeVar.h"
LUAU_FASTINT(LuauVisitRecursionLimit)
LUAU_FASTFLAG(LuauNormalizeFlagIsConservative)
LUAU_FASTFLAG(LuauCompleteVisitor);
namespace Luau
{
@ -129,6 +129,14 @@ struct GenericTypeVarVisitor
{
return visit(ty);
}
virtual bool visit(TypeId ty, const UnknownTypeVar& utv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const NeverTypeVar& ntv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const UnionTypeVar& utv)
{
return visit(ty);
@ -137,6 +145,18 @@ struct GenericTypeVarVisitor
{
return visit(ty);
}
virtual bool visit(TypeId ty, const BlockedTypeVar& btv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const PendingExpansionTypeVar& petv)
{
return visit(ty);
}
virtual bool visit(TypeId ty, const SingletonTypeVar& stv)
{
return visit(ty);
}
virtual bool visit(TypePackId tp)
{
@ -182,16 +202,12 @@ struct GenericTypeVarVisitor
if (visit(ty, *btv))
traverse(btv->boundTo);
}
else if (auto ftv = get<FreeTypeVar>(ty))
visit(ty, *ftv);
else if (auto gtv = get<GenericTypeVar>(ty))
visit(ty, *gtv);
else if (auto etv = get<ErrorTypeVar>(ty))
visit(ty, *etv);
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
if (visit(ty, *ctv))
@ -200,10 +216,8 @@ struct GenericTypeVarVisitor
traverse(part);
}
}
else if (auto ptv = get<PrimitiveTypeVar>(ty))
visit(ty, *ptv);
else if (auto ftv = get<FunctionTypeVar>(ty))
{
if (visit(ty, *ftv))
@ -212,7 +226,6 @@ struct GenericTypeVarVisitor
traverse(ftv->retTypes);
}
}
else if (auto ttv = get<TableTypeVar>(ty))
{
// Some visitors want to see bound tables, that's why we traverse the original type
@ -235,7 +248,6 @@ struct GenericTypeVarVisitor
}
}
}
else if (auto mtv = get<MetatableTypeVar>(ty))
{
if (visit(ty, *mtv))
@ -244,7 +256,6 @@ struct GenericTypeVarVisitor
traverse(mtv->metatable);
}
}
else if (auto ctv = get<ClassTypeVar>(ty))
{
if (visit(ty, *ctv))
@ -259,10 +270,8 @@ struct GenericTypeVarVisitor
traverse(*ctv->metatable);
}
}
else if (auto atv = get<AnyTypeVar>(ty))
visit(ty, *atv);
else if (auto utv = get<UnionTypeVar>(ty))
{
if (visit(ty, *utv))
@ -271,7 +280,6 @@ struct GenericTypeVarVisitor
traverse(optTy);
}
}
else if (auto itv = get<IntersectionTypeVar>(ty))
{
if (visit(ty, *itv))
@ -280,6 +288,53 @@ struct GenericTypeVarVisitor
traverse(partTy);
}
}
else if (get<LazyTypeVar>(ty))
{
// Visiting into LazyTypeVar may necessarily cause infinite expansion, so we don't do that on purpose.
// Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassTypeVar
// that doesn't need to be expanded.
}
else if (auto stv = get<SingletonTypeVar>(ty))
visit(ty, *stv);
else if (auto btv = get<BlockedTypeVar>(ty))
visit(ty, *btv);
else if (auto utv = get<UnknownTypeVar>(ty))
visit(ty, *utv);
else if (auto ntv = get<NeverTypeVar>(ty))
visit(ty, *ntv);
else if (auto petv = get<PendingExpansionTypeVar>(ty))
{
if (visit(ty, *petv))
{
traverse(petv->fn.type);
for (const GenericTypeDefinition& p : petv->fn.typeParams)
{
traverse(p.ty);
if (p.defaultValue)
traverse(*p.defaultValue);
}
for (const GenericTypePackDefinition& p : petv->fn.typePackParams)
{
traverse(p.tp);
if (p.defaultValue)
traverse(*p.defaultValue);
}
for (TypeId a : petv->typeArguments)
traverse(a);
for (TypePackId a : petv->packArguments)
traverse(a);
}
}
else if (!FFlag::LuauCompleteVisitor)
return visit_detail::unsee(seen, ty);
else
LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypeId) is not exhaustive!");
visit_detail::unsee(seen, ty);
}
@ -310,7 +365,7 @@ struct GenericTypeVarVisitor
else if (auto pack = get<TypePack>(tp))
{
bool res = visit(tp, *pack);
if (!FFlag::LuauNormalizeFlagIsConservative || res)
if (res)
{
for (TypeId ty : pack->head)
traverse(ty);
@ -322,7 +377,7 @@ struct GenericTypeVarVisitor
else if (auto pack = get<VariadicTypePack>(tp))
{
bool res = visit(tp, *pack);
if (!FFlag::LuauNormalizeFlagIsConservative || res)
if (res)
traverse(pack->ty);
}
else

View file

@ -0,0 +1,60 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ApplyTypeFunction.h"
namespace Luau
{
bool ApplyTypeFunction::isDirty(TypeId ty)
{
if (typeArguments.count(ty))
return true;
else if (const FreeTypeVar* ftv = get<FreeTypeVar>(ty))
{
if (ftv->forwardedTypeAlias)
encounteredForwardedType = true;
return false;
}
else
return false;
}
bool ApplyTypeFunction::isDirty(TypePackId tp)
{
if (typePackArguments.count(tp))
return true;
else
return false;
}
bool ApplyTypeFunction::ignoreChildren(TypeId ty)
{
if (get<GenericTypeVar>(ty))
return true;
else
return false;
}
bool ApplyTypeFunction::ignoreChildren(TypePackId tp)
{
if (get<GenericTypePack>(tp))
return true;
else
return false;
}
TypeId ApplyTypeFunction::clean(TypeId ty)
{
TypeId& arg = typeArguments[ty];
LUAU_ASSERT(arg);
return arg;
}
TypePackId ApplyTypeFunction::clean(TypePackId tp)
{
TypePackId& arg = typePackArguments[tp];
LUAU_ASSERT(arg);
return arg;
}
} // namespace Luau

View file

@ -1,7 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/JsonEncoder.h"
#include "Luau/AstJsonEncoder.h"
#include "Luau/Ast.h"
#include "Luau/ParseResult.h"
#include "Luau/StringUtils.h"
#include "Luau/Common.h"
@ -75,6 +76,11 @@ struct AstJsonEncoder : public AstVisitor
writeRaw(std::string_view{&c, 1});
}
void writeType(std::string_view propValue)
{
write("type", propValue);
}
template<typename T>
void write(std::string_view propName, const T& value)
{
@ -97,8 +103,8 @@ struct AstJsonEncoder : public AstVisitor
void write(double d)
{
char b[256];
sprintf(b, "%g", d);
char b[32];
snprintf(b, sizeof(b), "%.17g", d);
writeRaw(b);
}
@ -111,8 +117,12 @@ struct AstJsonEncoder : public AstVisitor
{
if (c == '"')
writeRaw("\\\"");
else if (c == '\0')
writeRaw("\\\0");
else if (c == '\\')
writeRaw("\\\\");
else if (c < ' ')
writeRaw(format("\\u%04x", c));
else if (c == '\n')
writeRaw("\\n");
else
writeRaw(c);
}
@ -189,10 +199,11 @@ struct AstJsonEncoder : public AstVisitor
writeRaw("{");
bool c = pushComma();
if (local->annotation != nullptr)
write("type", local->annotation);
write("luauType", local->annotation);
else
write("type", nullptr);
write("luauType", nullptr);
write("name", local->name);
writeType("AstLocal");
write("location", local->location);
popComma(c);
writeRaw("}");
@ -208,7 +219,7 @@ struct AstJsonEncoder : public AstVisitor
{
writeRaw("{");
bool c = pushComma();
write("type", name);
writeType(name);
writeNode(node);
f();
popComma(c);
@ -358,6 +369,7 @@ struct AstJsonEncoder : public AstVisitor
{
writeRaw("{");
bool c = pushComma();
writeType("AstTypeList");
write("types", typeList.types);
if (typeList.tailType)
write("tailType", typeList.tailType);
@ -369,9 +381,10 @@ struct AstJsonEncoder : public AstVisitor
{
writeRaw("{");
bool c = pushComma();
writeType("AstGenericType");
write("name", genericType.name);
if (genericType.defaultValue)
write("type", genericType.defaultValue);
write("luauType", genericType.defaultValue);
popComma(c);
writeRaw("}");
}
@ -380,9 +393,10 @@ struct AstJsonEncoder : public AstVisitor
{
writeRaw("{");
bool c = pushComma();
writeType("AstGenericTypePack");
write("name", genericTypePack.name);
if (genericTypePack.defaultValue)
write("type", genericTypePack.defaultValue);
write("luauType", genericTypePack.defaultValue);
popComma(c);
writeRaw("}");
}
@ -404,6 +418,7 @@ struct AstJsonEncoder : public AstVisitor
{
writeRaw("{");
bool c = pushComma();
writeType("AstExprTableItem");
write("kind", item.kind);
switch (item.kind)
{
@ -419,6 +434,17 @@ struct AstJsonEncoder : public AstVisitor
writeRaw("}");
}
void write(class AstExprIfElse* node)
{
writeNode(node, "AstExprIfElse", [&]() {
PROP(condition);
PROP(hasThen);
PROP(trueExpr);
PROP(hasElse);
PROP(falseExpr);
});
}
void write(class AstExprTable* node)
{
writeNode(node, "AstExprTable", [&]() {
@ -431,11 +457,11 @@ struct AstJsonEncoder : public AstVisitor
switch (op)
{
case AstExprUnary::Not:
return writeString("not");
return writeString("Not");
case AstExprUnary::Minus:
return writeString("minus");
return writeString("Minus");
case AstExprUnary::Len:
return writeString("len");
return writeString("Len");
}
}
@ -541,7 +567,7 @@ struct AstJsonEncoder : public AstVisitor
void write(class AstStatWhile* node)
{
writeNode(node, "AtStatWhile", [&]() {
writeNode(node, "AstStatWhile", [&]() {
PROP(condition);
PROP(body);
PROP(hasDo);
@ -684,7 +710,8 @@ struct AstJsonEncoder : public AstVisitor
writeRaw("{");
bool c = pushComma();
write("name", prop.name);
write("type", prop.ty);
writeType("AstDeclaredClassProp");
write("luauType", prop.ty);
popComma(c);
writeRaw("}");
}
@ -731,8 +758,9 @@ struct AstJsonEncoder : public AstVisitor
bool c = pushComma();
write("name", prop.name);
writeType("AstTableProp");
write("location", prop.location);
write("type", prop.type);
write("propType", prop.type);
popComma(c);
writeRaw("}");
@ -746,6 +774,24 @@ struct AstJsonEncoder : public AstVisitor
});
}
void write(struct AstTableIndexer* indexer)
{
if (indexer)
{
writeRaw("{");
bool c = pushComma();
write("location", indexer->location);
write("indexType", indexer->indexType);
write("resultType", indexer->resultType);
popComma(c);
writeRaw("}");
}
else
{
writeRaw("null");
}
}
void write(class AstTypeFunction* node)
{
writeNode(node, "AstTypeFunction", [&]() {
@ -836,6 +882,12 @@ struct AstJsonEncoder : public AstVisitor
return false;
}
bool visit(class AstExprIfElse* node) override
{
write(node);
return false;
}
bool visit(class AstExprLocal* node) override
{
write(node);
@ -1093,6 +1145,41 @@ struct AstJsonEncoder : public AstVisitor
write(node);
return false;
}
void writeComments(std::vector<Comment> commentLocations)
{
bool commentComma = false;
for (Comment comment : commentLocations)
{
if (commentComma)
{
writeRaw(",");
}
else
{
commentComma = true;
}
writeRaw("{");
bool c = pushComma();
switch (comment.type)
{
case Lexeme::Comment:
writeType("Comment");
break;
case Lexeme::BlockComment:
writeType("BlockComment");
break;
case Lexeme::BrokenComment:
writeType("BrokenComment");
break;
default:
break;
}
write("location", comment.location);
popComma(c);
writeRaw("}");
}
}
};
std::string toJson(AstNode* node)
@ -1102,4 +1189,15 @@ std::string toJson(AstNode* node)
return encoder.str();
}
std::string toJson(AstNode* node, const std::vector<Comment>& commentLocations)
{
AstJsonEncoder encoder;
encoder.writeRaw(R"({"root":)");
node->visit(&encoder);
encoder.writeRaw(R"(,"commentLocations":[)");
encoder.writeComments(commentLocations);
encoder.writeRaw("]}");
return encoder.str();
}
} // namespace Luau

View file

@ -17,6 +17,104 @@ namespace Luau
namespace
{
struct AutocompleteNodeFinder : public AstVisitor
{
const Position pos;
std::vector<AstNode*> ancestry;
explicit AutocompleteNodeFinder(Position pos, AstNode* root)
: pos(pos)
{
}
bool visit(AstExpr* expr) override
{
if (expr->location.begin < pos && pos <= expr->location.end)
{
ancestry.push_back(expr);
return true;
}
return false;
}
bool visit(AstStat* stat) override
{
if (stat->location.begin < pos && pos <= stat->location.end)
{
ancestry.push_back(stat);
return true;
}
return false;
}
bool visit(AstType* type) override
{
if (type->location.begin < pos && pos <= type->location.end)
{
ancestry.push_back(type);
return true;
}
return false;
}
bool visit(AstTypeError* type) override
{
// For a missing type, match the whole range including the start position
if (type->isMissing && type->location.containsClosed(pos))
{
ancestry.push_back(type);
return true;
}
return false;
}
bool visit(class AstTypePack* typePack) override
{
return true;
}
bool visit(AstStatBlock* block) override
{
// If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite.
if (ancestry.empty())
{
ancestry.push_back(block);
return true;
}
// AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes.
// ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz}
if (!ancestry.empty() && ancestry.back()->is<AstExprIndexName>())
return false;
// Type annotation error might intersect the block statement when the function header is being written,
// annotation takes priority
if (!ancestry.empty() && ancestry.back()->is<AstTypeError>())
return false;
// If the cursor is at the end of an expression or type and simultaneously at the beginning of a block,
// the expression or type wins out.
// The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to
// be within the block.
if (block->location.begin == pos && !ancestry.empty())
{
if (ancestry.back()->asExpr() && !ancestry.back()->is<AstExprFunction>())
return false;
if (ancestry.back()->asType())
return false;
}
if (block->location.begin <= pos && pos <= block->location.end)
{
ancestry.push_back(block);
return true;
}
return false;
}
};
struct FindNode : public AstVisitor
{
const Position pos;
@ -102,6 +200,13 @@ struct FindFullAncestry final : public AstVisitor
} // namespace
std::vector<AstNode*> findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos)
{
AutocompleteNodeFinder finder{pos, source.root};
source.root->visit(&finder);
return finder.ancestry;
}
std::vector<AstNode*> findAstAncestryOfPosition(const SourceModule& source, Position pos)
{
const Position end = source.root->location.end;
@ -110,7 +215,7 @@ std::vector<AstNode*> findAstAncestryOfPosition(const SourceModule& source, Posi
FindFullAncestry finder(pos, end);
source.root->visit(&finder);
return std::move(finder.nodes);
return finder.nodes;
}
AstNode* findNodeAtPosition(const SourceModule& source, Position pos)
@ -209,7 +314,7 @@ std::optional<Binding> findBindingAtPosition(const Module& module, const SourceM
auto iter = currentScope->bindings.find(name);
if (iter != currentScope->bindings.end() && iter->second.location.begin <= pos)
{
/* Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope */
// Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope
std::optional<AstStatLocal*> bindingStatement = findBindingLocalStatement(source, iter->second);
if (!bindingStatement || !(*bindingStatement)->location.contains(pos))
return iter->second;

View file

@ -7,13 +7,12 @@
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/Parser.h" // TODO: only needed for autocompleteSource which is deprecated
#include <algorithm>
#include <unordered_set>
#include <utility>
LUAU_FASTFLAG(LuauSelfCallAutocompleteFix2)
LUAU_FASTFLAG(LuauSelfCallAutocompleteFix3)
static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -21,102 +20,6 @@ static const std::unordered_set<std::string> kStatementStartingKeywords = {
namespace Luau
{
struct NodeFinder : public AstVisitor
{
const Position pos;
std::vector<AstNode*> ancestry;
explicit NodeFinder(Position pos, AstNode* root)
: pos(pos)
{
}
bool visit(AstExpr* expr) override
{
if (expr->location.begin < pos && pos <= expr->location.end)
{
ancestry.push_back(expr);
return true;
}
return false;
}
bool visit(AstStat* stat) override
{
if (stat->location.begin < pos && pos <= stat->location.end)
{
ancestry.push_back(stat);
return true;
}
return false;
}
bool visit(AstType* type) override
{
if (type->location.begin < pos && pos <= type->location.end)
{
ancestry.push_back(type);
return true;
}
return false;
}
bool visit(AstTypeError* type) override
{
// For a missing type, match the whole range including the start position
if (type->isMissing && type->location.containsClosed(pos))
{
ancestry.push_back(type);
return true;
}
return false;
}
bool visit(class AstTypePack* typePack) override
{
return true;
}
bool visit(AstStatBlock* block) override
{
// If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite.
if (ancestry.empty())
{
ancestry.push_back(block);
return true;
}
// AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes.
// ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz}
if (!ancestry.empty() && ancestry.back()->is<AstExprIndexName>())
return false;
// Type annotation error might intersect the block statement when the function header is being written,
// annotation takes priority
if (!ancestry.empty() && ancestry.back()->is<AstTypeError>())
return false;
// If the cursor is at the end of an expression or type and simultaneously at the beginning of a block,
// the expression or type wins out.
// The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to
// be within the block.
if (block->location.begin == pos && !ancestry.empty())
{
if (ancestry.back()->asExpr() && !ancestry.back()->is<AstExprFunction>())
return false;
if (ancestry.back()->asType())
return false;
}
if (block->location.begin <= pos && pos <= block->location.end)
{
ancestry.push_back(block);
return true;
}
return false;
}
};
static bool alreadyHasParens(const std::vector<AstNode*>& nodes)
{
@ -246,7 +149,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
ty = follow(ty);
auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) {
LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2);
LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3);
InternalErrorReporter iceReporter;
UnifierSharedState unifierState(&iceReporter);
@ -265,7 +168,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
TypeId expectedType = follow(*typeAtPosition);
auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) {
if (FFlag::LuauSelfCallAutocompleteFix2)
if (FFlag::LuauSelfCallAutocompleteFix3)
{
if (std::optional<TypeId> firstRetTy = first(ftv->retTypes))
return checkTypeMatch(typeArena, *firstRetTy, expectedType);
@ -306,7 +209,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
}
}
if (FFlag::LuauSelfCallAutocompleteFix2)
if (FFlag::LuauSelfCallAutocompleteFix3)
return checkTypeMatch(typeArena, ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
else
return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
@ -323,7 +226,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
const std::vector<AstNode*>& nodes, AutocompleteEntryMap& result, std::unordered_set<TypeId>& seen,
std::optional<const ClassTypeVar*> containingClass = std::nullopt)
{
if (FFlag::LuauSelfCallAutocompleteFix2)
if (FFlag::LuauSelfCallAutocompleteFix3)
rootTy = follow(rootTy);
ty = follow(ty);
@ -333,7 +236,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
seen.insert(ty);
auto isWrongIndexer_DEPRECATED = [indexType, useStrictFunctionIndexers = !!get<ClassTypeVar>(ty)](Luau::TypeId type) {
LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2);
LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3);
if (indexType == PropIndexType::Key)
return false;
@ -366,7 +269,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
}
};
auto isWrongIndexer = [typeArena, rootTy, indexType](Luau::TypeId type) {
LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix2);
LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix3);
if (indexType == PropIndexType::Key)
return false;
@ -374,21 +277,20 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
bool calledWithSelf = indexType == PropIndexType::Colon;
auto isCompatibleCall = [typeArena, rootTy, calledWithSelf](const FunctionTypeVar* ftv) {
if (get<ClassTypeVar>(rootTy))
{
// Calls on classes require strict match between how function is declared and how it's called
return calledWithSelf == ftv->hasSelf;
}
// Strong match with definition is a success
if (calledWithSelf == ftv->hasSelf)
return true;
// If a call is made with ':', it is invalid if a function has incompatible first argument or no arguments at all
// If a call is made with '.', but it was declared with 'self', it is considered invalid if first argument is compatible
if (calledWithSelf || ftv->hasSelf)
// Calls on classes require strict match between how function is declared and how it's called
if (get<ClassTypeVar>(rootTy))
return false;
// When called with ':', but declared without 'self', it is invalid if a function has incompatible first argument or no arguments at all
// When called with '.', but declared with 'self', it is considered invalid if first argument is compatible
if (std::optional<TypeId> firstArgTy = first(ftv->argTypes))
{
if (std::optional<TypeId> firstArgTy = first(ftv->argTypes))
{
if (checkTypeMatch(typeArena, rootTy, *firstArgTy))
return calledWithSelf;
}
if (checkTypeMatch(typeArena, rootTy, *firstArgTy))
return calledWithSelf;
}
return !calledWithSelf;
@ -430,7 +332,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
AutocompleteEntryKind::Property,
type,
prop.deprecated,
FFlag::LuauSelfCallAutocompleteFix2 ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type),
FFlag::LuauSelfCallAutocompleteFix3 ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type),
typeCorrect,
containingClass,
&prop,
@ -473,7 +375,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
{
autocompleteProps(module, typeArena, rootTy, mt->table, indexType, nodes, result, seen);
if (FFlag::LuauSelfCallAutocompleteFix2)
if (FFlag::LuauSelfCallAutocompleteFix3)
{
if (auto mtable = get<TableTypeVar>(mt->metatable))
fillMetatableProps(mtable);
@ -539,7 +441,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
AutocompleteEntryMap inner;
std::unordered_set<TypeId> innerSeen;
if (!FFlag::LuauSelfCallAutocompleteFix2)
if (!FFlag::LuauSelfCallAutocompleteFix3)
innerSeen = seen;
if (isNil(*iter))
@ -565,7 +467,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
++iter;
}
}
else if (auto pt = get<PrimitiveTypeVar>(ty); pt && FFlag::LuauSelfCallAutocompleteFix2)
else if (auto pt = get<PrimitiveTypeVar>(ty); pt && FFlag::LuauSelfCallAutocompleteFix3)
{
if (pt->metatable)
{
@ -573,7 +475,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
fillMetatableProps(mtable);
}
}
else if (FFlag::LuauSelfCallAutocompleteFix2 && get<StringSingleton>(get<SingletonTypeVar>(ty)))
else if (FFlag::LuauSelfCallAutocompleteFix3 && get<StringSingleton>(get<SingletonTypeVar>(ty)))
{
autocompleteProps(module, typeArena, rootTy, getSingletonTypes().stringType, indexType, nodes, result, seen);
}
@ -905,7 +807,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi
}
AstNode* parent = nullptr;
AstType* topType = nullptr;
AstType* topType = nullptr; // TODO: rename?
for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it)
{
@ -1477,21 +1379,20 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
if (isWithinComment(sourceModule, position))
return {};
NodeFinder finder{position, sourceModule.root};
sourceModule.root->visit(&finder);
LUAU_ASSERT(!finder.ancestry.empty());
AstNode* node = finder.ancestry.back();
std::vector<AstNode*> ancestry = findAncestryAtPositionForAutocomplete(sourceModule, position);
LUAU_ASSERT(!ancestry.empty());
AstNode* node = ancestry.back();
AstExprConstantNil dummy{Location{}};
AstNode* parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy;
AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy;
// If we are inside a body of a function that doesn't have a completed argument list, ignore the body node
if (auto exprFunction = parent->as<AstExprFunction>(); exprFunction && !exprFunction->argLocation && node == exprFunction->body)
{
finder.ancestry.pop_back();
ancestry.pop_back();
node = finder.ancestry.back();
parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy;
node = ancestry.back();
parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy;
}
if (auto indexName = node->as<AstExprIndexName>())
@ -1503,48 +1404,48 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
TypeId ty = follow(*it);
PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point;
if (!FFlag::LuauSelfCallAutocompleteFix2 && isString(ty))
return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry),
finder.ancestry};
if (!FFlag::LuauSelfCallAutocompleteFix3 && isString(ty))
return {
autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), ancestry};
else
return {autocompleteProps(*module, typeArena, ty, indexType, finder.ancestry), finder.ancestry};
return {autocompleteProps(*module, typeArena, ty, indexType, ancestry), ancestry};
}
else if (auto typeReference = node->as<AstTypeReference>())
{
if (typeReference->prefix)
return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), finder.ancestry};
return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), ancestry};
else
return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry};
return {autocompleteTypeNames(*module, position, ancestry), ancestry};
}
else if (node->is<AstTypeError>())
{
return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry};
return {autocompleteTypeNames(*module, position, ancestry), ancestry};
}
else if (AstStatLocal* statLocal = node->as<AstStatLocal>())
{
if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin))
return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry};
return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry};
else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end)
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry};
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry};
else
return {};
}
else if (AstStatFor* statFor = extractStat<AstStatFor>(finder.ancestry))
else if (AstStatFor* statFor = extractStat<AstStatFor>(ancestry))
{
if (!statFor->hasDo || position < statFor->doLocation.begin)
{
if (!statFor->from->is<AstExprError>() && !statFor->to->is<AstExprError>() && (!statFor->step || !statFor->step->is<AstExprError>()))
return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry};
return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry};
if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) ||
(statFor->step && statFor->step->location.containsClosed(position)))
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry};
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry};
return {};
}
return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry};
return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry};
}
else if (AstStatForIn* statForIn = parent->as<AstStatForIn>(); statForIn && (node->is<AstStatBlock>() || isIdentifier(node)))
@ -1560,7 +1461,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
return {};
}
return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry};
return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry};
}
if (!statForIn->hasDo || position <= statForIn->doLocation.begin)
@ -1569,58 +1470,58 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1];
if (lastExpr->location.containsClosed(position))
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry};
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry};
if (position > lastExpr->location.end)
return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry};
return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry};
return {}; // Not sure what this means
}
}
else if (AstStatForIn* statForIn = extractStat<AstStatForIn>(finder.ancestry))
else if (AstStatForIn* statForIn = extractStat<AstStatForIn>(ancestry))
{
// The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed.
// ex "for f in f do"
if (!statForIn->hasDo)
return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry};
return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry};
return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry};
return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry};
}
else if (AstStatWhile* statWhile = parent->as<AstStatWhile>(); node->is<AstStatBlock>() && statWhile)
{
if (!statWhile->hasDo && !statWhile->condition->is<AstStatError>() && position > statWhile->condition->location.end)
return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry};
return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry};
if (!statWhile->hasDo || position < statWhile->doLocation.begin)
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry};
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry};
if (statWhile->hasDo && position > statWhile->doLocation.end)
return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry};
return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry};
}
else if (AstStatWhile* statWhile = extractStat<AstStatWhile>(finder.ancestry); statWhile && !statWhile->hasDo)
return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry};
else if (AstStatWhile* statWhile = extractStat<AstStatWhile>(ancestry); statWhile && !statWhile->hasDo)
return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry};
else if (AstStatIf* statIf = node->as<AstStatIf>(); statIf && !statIf->elseLocation.has_value())
{
return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}},
finder.ancestry};
return {
{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry};
}
else if (AstStatIf* statIf = parent->as<AstStatIf>(); statIf && node->is<AstStatBlock>())
{
if (statIf->condition->is<AstExprError>())
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry};
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry};
else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position))
return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry};
return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry};
}
else if (AstStatIf* statIf = extractStat<AstStatIf>(finder.ancestry);
else if (AstStatIf* statIf = extractStat<AstStatIf>(ancestry);
statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)))
return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry};
return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry};
else if (AstStatRepeat* statRepeat = node->as<AstStatRepeat>(); statRepeat && statRepeat->condition->is<AstExprError>())
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry};
else if (AstStatRepeat* statRepeat = extractStat<AstStatRepeat>(finder.ancestry); statRepeat)
return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry};
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry};
else if (AstStatRepeat* statRepeat = extractStat<AstStatRepeat>(ancestry); statRepeat)
return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry};
else if (AstExprTable* exprTable = parent->as<AstExprTable>(); exprTable && (node->is<AstExprGlobal>() || node->is<AstExprConstantString>()))
{
for (const auto& [kind, key, value] : exprTable->items)
@ -1630,7 +1531,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
{
if (auto it = module->astExpectedTypes.find(exprTable))
{
auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, finder.ancestry);
auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, ancestry);
// Remove keys that are already completed
for (const auto& item : exprTable->items)
@ -1644,9 +1545,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
// If we know for sure that a key is being written, do not offer general expression suggestions
if (!key)
autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position, result);
autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position, result);
return {result, finder.ancestry};
return {result, ancestry};
}
break;
@ -1654,11 +1555,11 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
}
}
else if (isIdentifier(node) && (parent->is<AstStatExpr>() || parent->is<AstStatError>()))
return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry};
return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry};
if (std::optional<AutocompleteEntryMap> ret = autocompleteStringParams(sourceModule, module, finder.ancestry, position, callback))
if (std::optional<AutocompleteEntryMap> ret = autocompleteStringParams(sourceModule, module, ancestry, position, callback))
{
return {*ret, finder.ancestry};
return {*ret, ancestry};
}
else if (node->is<AstExprConstantString>())
{
@ -1667,14 +1568,14 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
if (auto it = module->astExpectedTypes.find(node->asExpr()))
autocompleteStringSingleton(*it, false, result);
if (finder.ancestry.size() >= 2)
if (ancestry.size() >= 2)
{
if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as<AstExprIndexExpr>())
if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as<AstExprIndexExpr>())
{
if (auto it = module->astTypes.find(idxExpr->expr))
autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result);
autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, ancestry, result);
}
else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as<AstExprBinary>())
else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as<AstExprBinary>())
{
if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe)
{
@ -1684,7 +1585,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
}
}
return {result, finder.ancestry};
return {result, ancestry};
}
if (node->is<AstExprConstantNumber>())
@ -1693,9 +1594,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
}
if (node->asExpr())
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry};
return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry};
else if (node->asStat())
return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry};
return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry};
return {};
}
@ -1725,32 +1626,4 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName
return autocompleteResult;
}
OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback)
{
// TODO: Remove #include "Luau/Parser.h" with this function
auto sourceModule = std::make_unique<SourceModule>();
ParseOptions parseOptions;
parseOptions.captureComments = true;
ParseResult result = Parser::parse(source.data(), source.size(), *sourceModule->names, *sourceModule->allocator, parseOptions);
if (!result.root)
return {AutocompleteResult{}, {}, nullptr};
sourceModule->name = "FRAGMENT_SCRIPT";
sourceModule->root = result.root;
sourceModule->mode = Mode::Strict;
sourceModule->commentLocations = std::move(result.commentLocations);
TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete;
ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict);
OwningAutocompleteResult autocompleteResult = {
autocomplete(*sourceModule, module, typeChecker, &frontend.arenaForAutocomplete, position, callback), std::move(module),
std::move(sourceModule)};
frontend.arenaForAutocomplete.clear();
return autocompleteResult;
}
} // namespace Luau

View file

@ -9,6 +9,8 @@
#include <algorithm>
LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false)
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauBuiltInMetatableNoBadSynthetic, false)
/** FIXME: Many of these type definitions are not quite completely accurate.
*
@ -222,14 +224,14 @@ void registerBuiltinTypes(TypeChecker& typeChecker)
addGlobalBinding(typeChecker, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau");
// setmetatable<MT>({ @metatable MT }, MT) -> { @metatable MT }
// clang-format off
// setmetatable<T: {}, MT>(T, MT) -> { @metatable MT, T }
addGlobalBinding(typeChecker, "setmetatable",
arena.addType(
FunctionTypeVar{
{genericMT},
{},
arena.addTypePack(TypePack{{tableMetaMT, genericMT}}),
arena.addTypePack(TypePack{{FFlag::LuauUnknownAndNeverType ? tabTy : tableMetaMT, genericMT}}),
arena.addTypePack(TypePack{{tableMetaMT}})
}
), "@luau"
@ -309,6 +311,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
{
auto [paramPack, _predicates] = withPredicate;
if (FFlag::LuauUnknownAndNeverType)
{
if (size(paramPack) < 2 && finite(paramPack))
return std::nullopt;
}
TypeArena& arena = typechecker.currentModule->internalTypes;
std::vector<TypeId> expectedArgs = typechecker.unTypePack(scope, paramPack, 2, expr.location);
@ -316,6 +324,12 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
TypeId target = follow(expectedArgs[0]);
TypeId mt = follow(expectedArgs[1]);
if (FFlag::LuauUnknownAndNeverType)
{
typechecker.tablify(target);
typechecker.tablify(mt);
}
if (const auto& tab = get<TableTypeVar>(target))
{
if (target->persistent)
@ -324,7 +338,8 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
}
else
{
typechecker.tablify(mt);
if (!FFlag::LuauUnknownAndNeverType)
typechecker.tablify(mt);
const TableTypeVar* mtTtv = get<TableTypeVar>(mt);
MetatableTypeVar mtv{target, mt};
@ -335,7 +350,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
if (tableName == metatableName)
mtv.syntheticName = tableName;
else
else if (!FFlag::LuauBuiltInMetatableNoBadSynthetic)
mtv.syntheticName = "{ @metatable: " + metatableName + ", " + tableName + " }";
}
@ -343,7 +358,10 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionSetMetaTable(
if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1)
{
return WithPredicate<TypePackId>{};
if (FFlag::LuauUnknownAndNeverType)
return std::nullopt;
else
return WithPredicate<TypePackId>{};
}
if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self)
@ -390,11 +408,21 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionAssert(
if (head.size() > 0)
{
std::optional<TypeId> newhead = typechecker.pickTypesFromSense(head[0], true);
if (!newhead)
head = {typechecker.nilType};
auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true);
if (FFlag::LuauUnknownAndNeverType)
{
if (get<NeverTypeVar>(*ty))
head = {*ty};
else
head[0] = *ty;
}
else
head[0] = *newhead;
{
if (!ty)
head = {typechecker.nilType};
else
head[0] = *ty;
}
}
return WithPredicate<TypePackId>{arena.addTypePack(TypePack{std::move(head), tail})};

View file

@ -48,6 +48,7 @@ struct TypeCloner
void operator()(const Unifiable::Bound<TypeId>& t);
void operator()(const Unifiable::Error& t);
void operator()(const BlockedTypeVar& t);
void operator()(const PendingExpansionTypeVar& t);
void operator()(const PrimitiveTypeVar& t);
void operator()(const ConstrainedTypeVar& t);
void operator()(const SingletonTypeVar& t);
@ -59,6 +60,8 @@ struct TypeCloner
void operator()(const UnionTypeVar& t);
void operator()(const IntersectionTypeVar& t);
void operator()(const LazyTypeVar& t);
void operator()(const UnknownTypeVar& t);
void operator()(const NeverTypeVar& t);
};
struct TypePackCloner
@ -164,6 +167,52 @@ void TypeCloner::operator()(const BlockedTypeVar& t)
defaultClone(t);
}
void TypeCloner::operator()(const PendingExpansionTypeVar& t)
{
TypeId res = dest.addType(PendingExpansionTypeVar{t.fn, t.typeArguments, t.packArguments});
PendingExpansionTypeVar* petv = getMutable<PendingExpansionTypeVar>(res);
LUAU_ASSERT(petv);
seenTypes[typeId] = res;
std::vector<TypeId> typeArguments;
for (TypeId arg : t.typeArguments)
typeArguments.push_back(clone(arg, dest, cloneState));
std::vector<TypePackId> packArguments;
for (TypePackId arg : t.packArguments)
packArguments.push_back(clone(arg, dest, cloneState));
TypeFun fn;
fn.type = clone(t.fn.type, dest, cloneState);
for (const GenericTypeDefinition& param : t.fn.typeParams)
{
TypeId ty = clone(param.ty, dest, cloneState);
std::optional<TypeId> defaultValue = param.defaultValue;
if (defaultValue)
defaultValue = clone(*defaultValue, dest, cloneState);
fn.typeParams.push_back(GenericTypeDefinition{ty, defaultValue});
}
for (const GenericTypePackDefinition& param : t.fn.typePackParams)
{
TypePackId tp = clone(param.tp, dest, cloneState);
std::optional<TypePackId> defaultValue = param.defaultValue;
if (defaultValue)
defaultValue = clone(*defaultValue, dest, cloneState);
fn.typePackParams.push_back(GenericTypePackDefinition{tp, defaultValue});
}
petv->fn = std::move(fn);
petv->typeArguments = std::move(typeArguments);
petv->packArguments = std::move(packArguments);
}
void TypeCloner::operator()(const PrimitiveTypeVar& t)
{
defaultClone(t);
@ -310,6 +359,16 @@ void TypeCloner::operator()(const LazyTypeVar& t)
defaultClone(t);
}
void TypeCloner::operator()(const UnknownTypeVar& t)
{
defaultClone(t);
}
void TypeCloner::operator()(const NeverTypeVar& t)
{
defaultClone(t);
}
} // anonymous namespace
TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
@ -440,6 +499,11 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log)
ConstrainedTypeVar clone{ctv->level, ctv->parts};
result = dest.addType(std::move(clone));
}
else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty))
{
PendingExpansionTypeVar clone{petv->fn, petv->typeArguments, petv->packArguments};
result = dest.addType(std::move(clone));
}
else
return result;

View file

@ -14,7 +14,7 @@ namespace Luau
const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp
ConstraintGraphBuilder::ConstraintGraphBuilder(
const ModuleName& moduleName, TypeArena* arena, NotNull<InternalErrorReporter> ice, NotNull<Scope2> globalScope)
const ModuleName& moduleName, TypeArena* arena, NotNull<InternalErrorReporter> ice, NotNull<Scope> globalScope)
: moduleName(moduleName)
, singletonTypes(getSingletonTypes())
, arena(arena)
@ -25,36 +25,34 @@ ConstraintGraphBuilder::ConstraintGraphBuilder(
LUAU_ASSERT(arena);
}
TypeId ConstraintGraphBuilder::freshType(NotNull<Scope2> scope)
TypeId ConstraintGraphBuilder::freshType(const ScopePtr& scope)
{
return arena->addType(FreeTypeVar{scope});
return arena->addType(FreeTypeVar{scope.get()});
}
TypePackId ConstraintGraphBuilder::freshTypePack(NotNull<Scope2> scope)
TypePackId ConstraintGraphBuilder::freshTypePack(const ScopePtr& scope)
{
FreeTypePack f{scope};
FreeTypePack f{scope.get()};
return arena->addTypePack(TypePackVar{std::move(f)});
}
NotNull<Scope2> ConstraintGraphBuilder::childScope(Location location, NotNull<Scope2> parent)
ScopePtr ConstraintGraphBuilder::childScope(Location location, const ScopePtr& parent)
{
auto scope = std::make_unique<Scope2>();
NotNull<Scope2> borrow = NotNull(scope.get());
scopes.emplace_back(location, std::move(scope));
auto scope = std::make_shared<Scope>(parent);
scopes.emplace_back(location, scope);
borrow->parent = parent;
borrow->returnType = parent->returnType;
parent->children.push_back(borrow);
scope->returnType = parent->returnType;
parent->children.push_back(NotNull(scope.get()));
return borrow;
return scope;
}
void ConstraintGraphBuilder::addConstraint(NotNull<Scope2> scope, ConstraintV cv)
void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, ConstraintV cv)
{
scope->constraints.emplace_back(new Constraint{std::move(cv)});
}
void ConstraintGraphBuilder::addConstraint(NotNull<Scope2> scope, std::unique_ptr<Constraint> c)
void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c)
{
scope->constraints.emplace_back(std::move(c));
}
@ -63,25 +61,25 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block)
{
LUAU_ASSERT(scopes.empty());
LUAU_ASSERT(rootScope == nullptr);
scopes.emplace_back(block->location, std::make_unique<Scope2>());
rootScope = scopes.back().second.get();
NotNull<Scope2> borrow = NotNull(rootScope);
ScopePtr scope = std::make_shared<Scope>(singletonTypes.anyTypePack);
rootScope = scope.get();
scopes.emplace_back(block->location, scope);
rootScope->returnType = freshTypePack(borrow);
rootScope->returnType = freshTypePack(scope);
prepopulateGlobalScope(borrow, block);
prepopulateGlobalScope(scope, block);
// TODO: We should share the global scope.
rootScope->typeBindings["nil"] = singletonTypes.nilType;
rootScope->typeBindings["number"] = singletonTypes.numberType;
rootScope->typeBindings["string"] = singletonTypes.stringType;
rootScope->typeBindings["boolean"] = singletonTypes.booleanType;
rootScope->typeBindings["thread"] = singletonTypes.threadType;
rootScope->typeBindings["nil"] = TypeFun{singletonTypes.nilType};
rootScope->typeBindings["number"] = TypeFun{singletonTypes.numberType};
rootScope->typeBindings["string"] = TypeFun{singletonTypes.stringType};
rootScope->typeBindings["boolean"] = TypeFun{singletonTypes.booleanType};
rootScope->typeBindings["thread"] = TypeFun{singletonTypes.threadType};
visitBlockWithoutChildScope(borrow, block);
visitBlockWithoutChildScope(scope, block);
}
void ConstraintGraphBuilder::visitBlockWithoutChildScope(NotNull<Scope2> scope, AstStatBlock* block)
void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block)
{
RecursionCounter counter{&recursionCount};
@ -91,11 +89,58 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(NotNull<Scope2> scope,
return;
}
std::unordered_map<Name, Location> aliasDefinitionLocations;
// In order to enable mutually-recursive type aliases, we need to
// populate the type bindings before we actually check any of the
// alias statements. Since we're not ready to actually resolve
// any of the annotations, we just use a fresh type for now.
for (AstStat* stat : block->body)
{
if (auto alias = stat->as<AstStatTypeAlias>())
{
if (scope->typeBindings.count(alias->name.value) != 0)
{
auto it = aliasDefinitionLocations.find(alias->name.value);
LUAU_ASSERT(it != aliasDefinitionLocations.end());
reportError(alias->location, DuplicateTypeDefinition{alias->name.value, it->second});
continue;
}
bool hasGenerics = alias->generics.size > 0 || alias->genericPacks.size > 0;
ScopePtr defnScope = scope;
if (hasGenerics)
{
defnScope = childScope(alias->location, scope);
}
TypeId initialType = freshType(scope);
TypeFun initialFun = TypeFun{initialType};
for (const auto& [name, gen] : createGenerics(defnScope, alias->generics))
{
initialFun.typeParams.push_back(gen);
defnScope->typeBindings[name] = TypeFun{gen.ty};
}
for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks))
{
initialFun.typePackParams.push_back(genPack);
defnScope->typePackBindings[name] = genPack.tp;
}
scope->typeBindings[alias->name.value] = std::move(initialFun);
astTypeAliasDefiningScopes[alias] = defnScope;
aliasDefinitionLocations[alias->name.value] = alias->location;
}
}
for (AstStat* stat : block->body)
visit(scope, stat);
}
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStat* stat)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat)
{
RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit};
@ -103,6 +148,8 @@ void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStat* stat)
visit(scope, s);
else if (auto s = stat->as<AstStatLocal>())
visit(scope, s);
else if (auto s = stat->as<AstStatFor>())
visit(scope, s);
else if (auto f = stat->as<AstStatFunction>())
visit(scope, f);
else if (auto f = stat->as<AstStatLocalFunction>())
@ -117,26 +164,34 @@ void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStat* stat)
visit(scope, i);
else if (auto a = stat->as<AstStatTypeAlias>())
visit(scope, a);
else if (auto s = stat->as<AstStatDeclareGlobal>())
visit(scope, s);
else if (auto s = stat->as<AstStatDeclareClass>())
visit(scope, s);
else if (auto s = stat->as<AstStatDeclareFunction>())
visit(scope, s);
else
LUAU_ASSERT(0);
}
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatLocal* local)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
{
std::vector<TypeId> varTypes;
for (AstLocal* local : local->vars)
{
TypeId ty = freshType(scope);
Location location = local->location;
if (local->annotation)
{
TypeId annotation = resolveType(scope, local->annotation);
location = local->annotation->location;
TypeId annotation = resolveType(scope, local->annotation, /* topLevel */ true);
addConstraint(scope, SubtypeConstraint{ty, annotation});
}
varTypes.push_back(ty);
scope->bindings[local] = ty;
scope->bindings[local] = Binding{ty, location};
}
for (size_t i = 0; i < local->values.size; ++i)
@ -167,18 +222,38 @@ void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatLocal* local)
}
}
void addConstraints(Constraint* constraint, NotNull<Scope2> scope)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_)
{
auto checkNumber = [&](AstExpr* expr) {
if (!expr)
return;
TypeId t = check(scope, expr);
addConstraint(scope, SubtypeConstraint{t, singletonTypes.numberType});
};
checkNumber(for_->from);
checkNumber(for_->to);
checkNumber(for_->step);
ScopePtr forScope = childScope(for_->location, scope);
forScope->bindings[for_->var] = Binding{singletonTypes.numberType, for_->var->location};
visit(forScope, for_->body);
}
void addConstraints(Constraint* constraint, NotNull<Scope> scope)
{
scope->constraints.reserve(scope->constraints.size() + scope->constraints.size());
for (const auto& c : scope->constraints)
constraint->dependencies.push_back(NotNull{c.get()});
for (NotNull<Scope2> childScope : scope->children)
for (NotNull<Scope> childScope : scope->children)
addConstraints(constraint, childScope);
}
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatLocalFunction* function)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function)
{
// Local
// Global
@ -190,21 +265,21 @@ void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatLocalFunction*
LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name.
functionType = arena->addType(BlockedTypeVar{});
scope->bindings[function->name] = functionType;
scope->bindings[function->name] = Binding{functionType, function->name->location};
FunctionSignature sig = checkFunctionSignature(scope, function->func);
sig.bodyScope->bindings[function->name] = sig.signature;
sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location};
checkFunctionBody(sig.bodyScope, function->func);
std::unique_ptr<Constraint> c{
new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}};
addConstraints(c.get(), sig.bodyScope);
new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}}};
addConstraints(c.get(), NotNull(sig.bodyScope.get()));
addConstraint(scope, std::move(c));
}
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatFunction* function)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* function)
{
// Name could be AstStatLocal, AstStatGlobal, AstStatIndexName.
// With or without self
@ -224,9 +299,9 @@ void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatFunction* funct
else
{
functionType = arena->addType(BlockedTypeVar{});
scope->bindings[localName->local] = functionType;
scope->bindings[localName->local] = Binding{functionType, localName->location};
}
sig.bodyScope->bindings[localName->local] = sig.signature;
sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location};
}
else if (AstExprGlobal* globalName = function->name->as<AstExprGlobal>())
{
@ -239,9 +314,9 @@ void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatFunction* funct
else
{
functionType = arena->addType(BlockedTypeVar{});
rootScope->bindings[globalName->name] = functionType;
rootScope->bindings[globalName->name] = Binding{functionType, globalName->location};
}
sig.bodyScope->bindings[globalName->name] = sig.signature;
sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location};
}
else if (AstExprIndexName* indexName = function->name->as<AstExprIndexName>())
{
@ -268,39 +343,26 @@ void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatFunction* funct
checkFunctionBody(sig.bodyScope, function->func);
std::unique_ptr<Constraint> c{
new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}};
addConstraints(c.get(), sig.bodyScope);
new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}}};
addConstraints(c.get(), NotNull(sig.bodyScope.get()));
addConstraint(scope, std::move(c));
}
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatReturn* ret)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret)
{
TypePackId exprTypes = checkPack(scope, ret->list);
addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType});
}
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatBlock* block)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block)
{
NotNull<Scope2> innerScope = childScope(block->location, scope);
// In order to enable mutually-recursive type aliases, we need to
// populate the type bindings before we actually check any of the
// alias statements. Since we're not ready to actually resolve
// any of the annotations, we just use a fresh type for now.
for (AstStat* stat : block->body)
{
if (auto alias = stat->as<AstStatTypeAlias>())
{
TypeId initialType = freshType(scope);
scope->typeBindings[alias->name.value] = initialType;
}
}
ScopePtr innerScope = childScope(block->location, scope);
visitBlockWithoutChildScope(innerScope, block);
}
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatAssign* assign)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign)
{
TypePackId varPackId = checkExprList(scope, assign->vars);
TypePackId valuePack = checkPack(scope, assign->values);
@ -308,47 +370,66 @@ void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatAssign* assign)
addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId});
}
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatIf* ifStatement)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement)
{
check(scope, ifStatement->condition);
NotNull<Scope2> thenScope = childScope(ifStatement->thenbody->location, scope);
ScopePtr thenScope = childScope(ifStatement->thenbody->location, scope);
visit(thenScope, ifStatement->thenbody);
if (ifStatement->elsebody)
{
NotNull<Scope2> elseScope = childScope(ifStatement->elsebody->location, scope);
ScopePtr elseScope = childScope(ifStatement->elsebody->location, scope);
visit(elseScope, ifStatement->elsebody);
}
}
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatTypeAlias* alias)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alias)
{
// TODO: Exported type aliases
// TODO: Generic type aliases
auto it = scope->typeBindings.find(alias->name.value);
// This should always be here since we do a separate pass over the
// AST to set up typeBindings. If it's not, we've somehow skipped
// this alias in that first pass.
LUAU_ASSERT(it != scope->typeBindings.end());
if (it == scope->typeBindings.end())
auto bindingIt = scope->typeBindings.find(alias->name.value);
ScopePtr* defnIt = astTypeAliasDefiningScopes.find(alias);
// These will be undefined if the alias was a duplicate definition, in which
// case we just skip over it.
if (bindingIt == scope->typeBindings.end() || defnIt == nullptr)
{
ice->ice("Type alias does not have a pre-populated binding", alias->location);
return;
}
TypeId ty = resolveType(scope, alias->type);
ScopePtr resolvingScope = *defnIt;
TypeId ty = resolveType(resolvingScope, alias->type, /* topLevel */ true);
LUAU_ASSERT(get<FreeTypeVar>(bindingIt->second.type));
// Rather than using a subtype constraint, we instead directly bind
// the free type we generated in the first pass to the resolved type.
// This prevents a case where you could cause another constraint to
// bind the free alias type to an unrelated type, causing havoc.
asMutable(it->second)->ty.emplace<BoundTypeVar>(ty);
asMutable(bindingIt->second.type)->ty.emplace<BoundTypeVar>(ty);
addConstraint(scope, NameConstraint{ty, alias->name.value});
}
TypePackId ConstraintGraphBuilder::checkPack(NotNull<Scope2> scope, AstArray<AstExpr*> exprs)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* global)
{
LUAU_ASSERT(global->type);
TypeId globalTy = resolveType(scope, global->type);
scope->bindings[global->name] = Binding{globalTy, global->location};
}
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* global)
{
LUAU_ASSERT(false); // TODO: implement
}
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global)
{
LUAU_ASSERT(false); // TODO: implement
}
TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstArray<AstExpr*> exprs)
{
if (exprs.size == 0)
return arena->addTypePack({});
@ -369,7 +450,7 @@ TypePackId ConstraintGraphBuilder::checkPack(NotNull<Scope2> scope, AstArray<Ast
return arena->addTypePack(TypePack{std::move(types), last});
}
TypePackId ConstraintGraphBuilder::checkExprList(NotNull<Scope2> scope, const AstArray<AstExpr*>& exprs)
TypePackId ConstraintGraphBuilder::checkExprList(const ScopePtr& scope, const AstArray<AstExpr*>& exprs)
{
TypePackId result = arena->addTypePack({});
TypePack* resultPack = getMutable<TypePack>(result);
@ -390,7 +471,7 @@ TypePackId ConstraintGraphBuilder::checkExprList(NotNull<Scope2> scope, const As
return result;
}
TypePackId ConstraintGraphBuilder::checkPack(NotNull<Scope2> scope, AstExpr* expr)
TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* expr)
{
RecursionCounter counter{&recursionCount};
@ -445,7 +526,7 @@ TypePackId ConstraintGraphBuilder::checkPack(NotNull<Scope2> scope, AstExpr* exp
return result;
}
TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExpr* expr)
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr)
{
RecursionCounter counter{&recursionCount};
@ -525,7 +606,7 @@ TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExpr* expr)
return result;
}
TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprIndexName* indexName)
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName)
{
TypeId obj = check(scope, indexName->expr);
TypeId result = freshType(scope);
@ -541,7 +622,7 @@ TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprIndexName* in
return result;
}
TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprIndexExpr* indexExpr)
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* indexExpr)
{
TypeId obj = check(scope, indexExpr->expr);
TypeId indexType = check(scope, indexExpr->index);
@ -556,7 +637,7 @@ TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprIndexExpr* in
return result;
}
TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprUnary* unary)
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary)
{
TypeId operandType = check(scope, unary->expr);
@ -576,7 +657,7 @@ TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprUnary* unary)
return singletonTypes.errorRecoveryType();
}
TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprBinary* binary)
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binary)
{
TypeId leftType = check(scope, binary->left);
TypeId rightType = check(scope, binary->right);
@ -601,7 +682,7 @@ TypeId ConstraintGraphBuilder::check(NotNull<Scope2> scope, AstExprBinary* binar
return nullptr;
}
TypeId ConstraintGraphBuilder::checkExprTable(NotNull<Scope2> scope, AstExprTable* expr)
TypeId ConstraintGraphBuilder::checkExprTable(const ScopePtr& scope, AstExprTable* expr)
{
TypeId ty = arena->addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(ty);
@ -651,10 +732,10 @@ TypeId ConstraintGraphBuilder::checkExprTable(NotNull<Scope2> scope, AstExprTabl
return ty;
}
ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(NotNull<Scope2> parent, AstExprFunction* fn)
ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn)
{
Scope2* signatureScope = nullptr;
Scope2* bodyScope = nullptr;
ScopePtr signatureScope = nullptr;
ScopePtr bodyScope = nullptr;
TypePackId returnType = nullptr;
std::vector<TypeId> genericTypes;
@ -667,25 +748,24 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
// generics properly.
if (hasGenerics)
{
NotNull signatureBorrow = childScope(fn->location, parent);
signatureScope = signatureBorrow.get();
signatureScope = childScope(fn->location, parent);
// We need to assign returnType before creating bodyScope so that the
// return type gets propogated to bodyScope.
returnType = freshTypePack(signatureBorrow);
returnType = freshTypePack(signatureScope);
signatureScope->returnType = returnType;
bodyScope = childScope(fn->body->location, signatureBorrow).get();
bodyScope = childScope(fn->body->location, signatureScope);
std::vector<std::pair<Name, GenericTypeDefinition>> genericDefinitions = createGenerics(signatureBorrow, fn->generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks);
std::vector<std::pair<Name, GenericTypeDefinition>> genericDefinitions = createGenerics(signatureScope, fn->generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks);
// We do not support default values on function generics, so we only
// care about the types involved.
for (const auto& [name, g] : genericDefinitions)
{
genericTypes.push_back(g.ty);
signatureScope->typeBindings[name] = g.ty;
signatureScope->typeBindings[name] = TypeFun{g.ty};
}
for (const auto& [name, g] : genericPackDefinitions)
@ -696,11 +776,10 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
}
else
{
NotNull bodyBorrow = childScope(fn->body->location, parent);
bodyScope = bodyBorrow.get();
bodyScope = childScope(fn->body->location, parent);
returnType = freshTypePack(bodyBorrow);
bodyBorrow->returnType = returnType;
returnType = freshTypePack(bodyScope);
bodyScope->returnType = returnType;
// To eliminate the need to branch on hasGenerics below, we say that the
// signature scope is the body scope when there is no real signature
@ -708,27 +787,24 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
signatureScope = bodyScope;
}
NotNull bodyBorrow = NotNull(bodyScope);
NotNull signatureBorrow = NotNull(signatureScope);
if (fn->returnAnnotation)
{
TypePackId annotatedRetType = resolveTypePack(signatureBorrow, *fn->returnAnnotation);
addConstraint(signatureBorrow, PackSubtypeConstraint{returnType, annotatedRetType});
TypePackId annotatedRetType = resolveTypePack(signatureScope, *fn->returnAnnotation);
addConstraint(signatureScope, PackSubtypeConstraint{returnType, annotatedRetType});
}
std::vector<TypeId> argTypes;
for (AstLocal* local : fn->args)
{
TypeId t = freshType(signatureBorrow);
TypeId t = freshType(signatureScope);
argTypes.push_back(t);
signatureScope->bindings[local] = t;
signatureScope->bindings[local] = Binding{t, local->location};
if (local->annotation)
{
TypeId argAnnotation = resolveType(signatureBorrow, local->annotation);
addConstraint(signatureBorrow, SubtypeConstraint{t, argAnnotation});
TypeId argAnnotation = resolveType(signatureScope, local->annotation, /* topLevel */ true);
addConstraint(signatureScope, SubtypeConstraint{t, argAnnotation});
}
}
@ -749,11 +825,11 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
// Undo the workaround we made above: if there's no signature scope,
// don't report it.
/* signatureScope */ hasGenerics ? signatureScope : nullptr,
/* bodyScope */ bodyBorrow,
/* bodyScope */ bodyScope,
};
}
void ConstraintGraphBuilder::checkFunctionBody(NotNull<Scope2> scope, AstExprFunction* fn)
void ConstraintGraphBuilder::checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn)
{
visitBlockWithoutChildScope(scope, fn->body);
@ -766,20 +842,65 @@ void ConstraintGraphBuilder::checkFunctionBody(NotNull<Scope2> scope, AstExprFun
}
}
TypeId ConstraintGraphBuilder::resolveType(NotNull<Scope2> scope, AstType* ty)
TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, bool topLevel)
{
TypeId result = nullptr;
if (auto ref = ty->as<AstTypeReference>())
{
// TODO: Support imported types w/ require tracing.
// TODO: Support generic type references.
LUAU_ASSERT(!ref->prefix);
LUAU_ASSERT(!ref->hasParameterList);
// TODO: If it doesn't exist, should we introduce a free binding?
// This is probably important for handling type aliases.
result = scope->lookupTypeBinding(ref->name.value).value_or(singletonTypes.errorRecoveryType());
std::optional<TypeFun> alias = scope->lookupTypeBinding(ref->name.value);
if (alias.has_value())
{
// If the alias is not generic, we don't need to set up a blocked
// type and an instantiation constraint.
if (alias->typeParams.empty() && alias->typePackParams.empty())
{
result = alias->type;
}
else
{
std::vector<TypeId> parameters;
std::vector<TypePackId> packParameters;
for (const AstTypeOrPack& p : ref->parameters)
{
// We do not enforce the ordering of types vs. type packs here;
// that is done in the parser.
if (p.type)
{
parameters.push_back(resolveType(scope, p.type));
}
else if (p.typePack)
{
packParameters.push_back(resolveTypePack(scope, p.typePack));
}
else
{
// This indicates a parser bug: one of these two pointers
// should be set.
LUAU_ASSERT(false);
}
}
result = arena->addType(PendingExpansionTypeVar{*alias, parameters, packParameters});
if (topLevel)
{
addConstraint(scope, TypeAliasExpansionConstraint{
/* target */ result,
});
}
}
}
else
{
reportError(ty->location, UnknownSymbol{ref->name.value, UnknownSymbol::Context::Type});
result = singletonTypes.errorRecoveryType();
}
}
else if (auto tab = ty->as<AstTypeTable>())
{
@ -811,7 +932,7 @@ TypeId ConstraintGraphBuilder::resolveType(NotNull<Scope2> scope, AstType* ty)
{
// TODO: Recursion limit.
bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0;
Scope2* signatureScope = nullptr;
ScopePtr signatureScope = nullptr;
std::vector<TypeId> genericTypes;
std::vector<TypePackId> genericTypePacks;
@ -820,22 +941,21 @@ TypeId ConstraintGraphBuilder::resolveType(NotNull<Scope2> scope, AstType* ty)
// for the generic bindings to live on.
if (hasGenerics)
{
NotNull<Scope2> signatureBorrow = childScope(fn->location, scope);
signatureScope = signatureBorrow.get();
signatureScope = childScope(fn->location, scope);
std::vector<std::pair<Name, GenericTypeDefinition>> genericDefinitions = createGenerics(signatureBorrow, fn->generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks);
std::vector<std::pair<Name, GenericTypeDefinition>> genericDefinitions = createGenerics(signatureScope, fn->generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks);
for (const auto& [name, g] : genericDefinitions)
{
genericTypes.push_back(g.ty);
signatureBorrow->typeBindings[name] = g.ty;
signatureScope->typeBindings[name] = TypeFun{g.ty};
}
for (const auto& [name, g] : genericPackDefinitions)
{
genericTypePacks.push_back(g.tp);
signatureBorrow->typePackBindings[name] = g.tp;
signatureScope->typePackBindings[name] = g.tp;
}
}
else
@ -843,13 +963,11 @@ TypeId ConstraintGraphBuilder::resolveType(NotNull<Scope2> scope, AstType* ty)
// To eliminate the need to branch on hasGenerics below, we say that
// the signature scope is the parent scope if we don't have
// generics.
signatureScope = scope.get();
signatureScope = scope;
}
NotNull<Scope2> signatureBorrow(signatureScope);
TypePackId argTypes = resolveTypePack(signatureBorrow, fn->argTypes);
TypePackId returnTypes = resolveTypePack(signatureBorrow, fn->returnTypes);
TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes);
TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes);
// TODO: FunctionTypeVar needs a pointer to the scope so that we know
// how to quantify/instantiate it.
@ -927,7 +1045,7 @@ TypeId ConstraintGraphBuilder::resolveType(NotNull<Scope2> scope, AstType* ty)
return result;
}
TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull<Scope2> scope, AstTypePack* tp)
TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, AstTypePack* tp)
{
TypePackId result;
if (auto expl = tp->as<AstTypePackExplicit>())
@ -941,7 +1059,15 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull<Scope2> scope, AstTyp
}
else if (auto gen = tp->as<AstTypePackGeneric>())
{
result = arena->addTypePack(TypePackVar{GenericTypePack{scope, gen->genericName.value}});
if (std::optional<TypePackId> lookup = scope->lookupTypePackBinding(gen->genericName.value))
{
result = *lookup;
}
else
{
reportError(tp->location, UnknownSymbol{gen->genericName.value, UnknownSymbol::Context::Type});
result = singletonTypes.errorRecoveryTypePack();
}
}
else
{
@ -953,7 +1079,7 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull<Scope2> scope, AstTyp
return result;
}
TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull<Scope2> scope, const AstTypeList& list)
TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const AstTypeList& list)
{
std::vector<TypeId> head;
@ -971,12 +1097,12 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull<Scope2> scope, const
return arena->addTypePack(TypePack{head, tail});
}
std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGraphBuilder::createGenerics(NotNull<Scope2> scope, AstArray<AstGenericType> generics)
std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGraphBuilder::createGenerics(const ScopePtr& scope, AstArray<AstGenericType> generics)
{
std::vector<std::pair<Name, GenericTypeDefinition>> result;
for (const auto& generic : generics)
{
TypeId genericTy = arena->addType(GenericTypeVar{scope, generic.name.value});
TypeId genericTy = arena->addType(GenericTypeVar{scope.get(), generic.name.value});
std::optional<TypeId> defaultTy = std::nullopt;
if (generic.defaultValue)
@ -992,12 +1118,12 @@ std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGraphBuilder::crea
}
std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGraphBuilder::createGenericPacks(
NotNull<Scope2> scope, AstArray<AstGenericTypePack> generics)
const ScopePtr& scope, AstArray<AstGenericTypePack> generics)
{
std::vector<std::pair<Name, GenericTypePackDefinition>> result;
for (const auto& generic : generics)
{
TypePackId genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope, generic.name.value}});
TypePackId genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic.name.value}});
std::optional<TypePackId> defaultTy = std::nullopt;
if (generic.defaultValue)
@ -1012,7 +1138,7 @@ std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGraphBuilder::
return result;
}
TypeId ConstraintGraphBuilder::flattenPack(NotNull<Scope2> scope, Location location, TypePackId tp)
TypeId ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location location, TypePackId tp)
{
if (auto f = first(tp))
return *f;
@ -1038,10 +1164,10 @@ void ConstraintGraphBuilder::reportCodeTooComplex(Location location)
struct GlobalPrepopulator : AstVisitor
{
const NotNull<Scope2> globalScope;
const NotNull<Scope> globalScope;
const NotNull<TypeArena> arena;
GlobalPrepopulator(NotNull<Scope2> globalScope, NotNull<TypeArena> arena)
GlobalPrepopulator(NotNull<Scope> globalScope, NotNull<TypeArena> arena)
: globalScope(globalScope)
, arena(arena)
{
@ -1050,29 +1176,29 @@ struct GlobalPrepopulator : AstVisitor
bool visit(AstStatFunction* function) override
{
if (AstExprGlobal* g = function->name->as<AstExprGlobal>())
globalScope->bindings[g->name] = arena->addType(BlockedTypeVar{});
globalScope->bindings[g->name] = Binding{arena->addType(BlockedTypeVar{})};
return true;
}
};
void ConstraintGraphBuilder::prepopulateGlobalScope(NotNull<Scope2> globalScope, AstStatBlock* program)
void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program)
{
GlobalPrepopulator gp{NotNull{globalScope}, arena};
GlobalPrepopulator gp{NotNull{globalScope.get()}, arena};
program->visit(&gp);
}
void collectConstraints(std::vector<NotNull<Constraint>>& result, NotNull<Scope2> scope)
void collectConstraints(std::vector<NotNull<Constraint>>& result, NotNull<Scope> scope)
{
for (const auto& c : scope->constraints)
result.push_back(NotNull{c.get()});
for (NotNull<Scope2> child : scope->children)
for (NotNull<Scope> child : scope->children)
collectConstraints(result, child);
}
std::vector<NotNull<Constraint>> collectConstraints(NotNull<Scope2> rootScope)
std::vector<NotNull<Constraint>> collectConstraints(NotNull<Scope> rootScope)
{
std::vector<NotNull<Constraint>> result;
collectConstraints(result, rootScope);

View file

@ -1,11 +1,13 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/ApplyTypeFunction.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/Instantiation.h"
#include "Luau/Location.h"
#include "Luau/Quantify.h"
#include "Luau/ToString.h"
#include "Luau/Unifier.h"
#include "Luau/VisitTypeVar.h"
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false);
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false);
@ -13,31 +15,195 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false);
namespace Luau
{
[[maybe_unused]] static void dumpBindings(NotNull<Scope2> scope, ToStringOptions& opts)
[[maybe_unused]] static void dumpBindings(NotNull<Scope> scope, ToStringOptions& opts)
{
for (const auto& [k, v] : scope->bindings)
{
auto d = toStringDetailed(v, opts);
auto d = toStringDetailed(v.typeId, opts);
opts.nameMap = d.nameMap;
printf("\t%s : %s\n", k.c_str(), d.name.c_str());
}
for (NotNull<Scope2> child : scope->children)
for (NotNull<Scope> child : scope->children)
dumpBindings(child, opts);
}
static void dumpConstraints(NotNull<Scope2> scope, ToStringOptions& opts)
static void dumpConstraints(NotNull<Scope> scope, ToStringOptions& opts)
{
for (const ConstraintPtr& c : scope->constraints)
{
printf("\t%s\n", toString(*c, opts).c_str());
}
for (NotNull<Scope2> child : scope->children)
for (NotNull<Scope> child : scope->children)
dumpConstraints(child, opts);
}
void dump(NotNull<Scope2> rootScope, ToStringOptions& opts)
static std::pair<std::vector<TypeId>, std::vector<TypePackId>> saturateArguments(
const TypeFun& fn, const std::vector<TypeId>& rawTypeArguments, const std::vector<TypePackId>& rawPackArguments, TypeArena* arena)
{
std::vector<TypeId> saturatedTypeArguments;
std::vector<TypeId> extraTypes;
std::vector<TypePackId> saturatedPackArguments;
for (size_t i = 0; i < rawTypeArguments.size(); ++i)
{
TypeId ty = rawTypeArguments[i];
if (i < fn.typeParams.size())
saturatedTypeArguments.push_back(ty);
else
extraTypes.push_back(ty);
}
// If we collected extra types, put them in a type pack now. This case is
// mutually exclusive with the type pack -> type conversion we do below:
// extraTypes will only have elements in it if we have more types than we
// have parameter slots for them to go into.
if (!extraTypes.empty())
{
saturatedPackArguments.push_back(arena->addTypePack(extraTypes));
}
for (size_t i = 0; i < rawPackArguments.size(); ++i)
{
TypePackId tp = rawPackArguments[i];
// If we are short on regular type saturatedTypeArguments and we have a single
// element type pack, we can decompose that to the type it contains and
// use that as a type parameter.
if (saturatedTypeArguments.size() < fn.typeParams.size() && size(tp) == 1 && finite(tp) && first(tp) && saturatedPackArguments.empty())
{
saturatedTypeArguments.push_back(*first(tp));
}
else
{
saturatedPackArguments.push_back(tp);
}
}
size_t typesProvided = saturatedTypeArguments.size();
size_t typesRequired = fn.typeParams.size();
size_t packsProvided = saturatedPackArguments.size();
size_t packsRequired = fn.typePackParams.size();
// Extra types should be accumulated in extraTypes, not saturatedTypeArguments. Extra
// packs will be accumulated in saturatedPackArguments, so we don't have an
// assertion for that.
LUAU_ASSERT(typesProvided <= typesRequired);
// If we didn't provide enough types, but we did provide a type pack, we
// don't want to use defaults. The rationale for this is that if the user
// provides a pack but doesn't provide enough types, we want to report an
// error, rather than simply using the default saturatedTypeArguments, if they exist. If
// they did provide enough types, but not enough packs, we of course want to
// use the default packs.
bool needsDefaults = (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired);
if (needsDefaults)
{
// Default types can reference earlier types. It's legal to write
// something like
// type T<A, B = A> = (A, B) -> number
// and we need to respect that. We use an ApplyTypeFunction for this.
ApplyTypeFunction atf{arena};
for (size_t i = 0; i < typesProvided; ++i)
atf.typeArguments[fn.typeParams[i].ty] = saturatedTypeArguments[i];
for (size_t i = typesProvided; i < typesRequired; ++i)
{
TypeId defaultTy = fn.typeParams[i].defaultValue.value_or(nullptr);
// We will fill this in with the error type later.
if (!defaultTy)
break;
TypeId instantiatedDefault = atf.substitute(defaultTy).value_or(getSingletonTypes().errorRecoveryType());
atf.typeArguments[fn.typeParams[i].ty] = instantiatedDefault;
saturatedTypeArguments.push_back(instantiatedDefault);
}
for (size_t i = 0; i < packsProvided; ++i)
{
atf.typePackArguments[fn.typePackParams[i].tp] = saturatedPackArguments[i];
}
for (size_t i = packsProvided; i < packsRequired; ++i)
{
TypePackId defaultTp = fn.typePackParams[i].defaultValue.value_or(nullptr);
// We will fill this in with the error type pack later.
if (!defaultTp)
break;
TypePackId instantiatedDefault = atf.substitute(defaultTp).value_or(getSingletonTypes().errorRecoveryTypePack());
atf.typePackArguments[fn.typePackParams[i].tp] = instantiatedDefault;
saturatedPackArguments.push_back(instantiatedDefault);
}
}
// If we didn't create an extra type pack from overflowing parameter packs,
// and we're still missing a type pack, plug in an empty type pack as the
// value of the empty packs.
if (extraTypes.empty() && saturatedPackArguments.size() + 1 == fn.typePackParams.size())
{
saturatedPackArguments.push_back(arena->addTypePack({}));
}
// We need to have _something_ when we substitute the generic saturatedTypeArguments,
// even if they're missing, so we use the error type as a filler.
for (size_t i = saturatedTypeArguments.size(); i < typesRequired; ++i)
{
saturatedTypeArguments.push_back(getSingletonTypes().errorRecoveryType());
}
for (size_t i = saturatedPackArguments.size(); i < packsRequired; ++i)
{
saturatedPackArguments.push_back(getSingletonTypes().errorRecoveryTypePack());
}
// At this point, these two conditions should be true. If they aren't we
// will run into access violations.
LUAU_ASSERT(saturatedTypeArguments.size() == fn.typeParams.size());
LUAU_ASSERT(saturatedPackArguments.size() == fn.typePackParams.size());
return {saturatedTypeArguments, saturatedPackArguments};
}
bool InstantiationSignature::operator==(const InstantiationSignature& rhs) const
{
return fn == rhs.fn && arguments == rhs.arguments && packArguments == rhs.packArguments;
}
size_t HashInstantiationSignature::operator()(const InstantiationSignature& signature) const
{
size_t hash = std::hash<TypeId>{}(signature.fn.type);
for (const GenericTypeDefinition& p : signature.fn.typeParams)
{
hash ^= (std::hash<TypeId>{}(p.ty) << 1);
}
for (const GenericTypePackDefinition& p : signature.fn.typePackParams)
{
hash ^= (std::hash<TypePackId>{}(p.tp) << 1);
}
for (const TypeId a : signature.arguments)
{
hash ^= (std::hash<TypeId>{}(a) << 1);
}
for (const TypePackId a : signature.packArguments)
{
hash ^= (std::hash<TypePackId>{}(a) << 1);
}
return hash;
}
void dump(NotNull<Scope> rootScope, ToStringOptions& opts)
{
printf("constraints:\n");
dumpConstraints(rootScope, opts);
@ -55,7 +221,7 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts)
}
}
ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull<Scope2> rootScope)
ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull<Scope> rootScope)
: arena(arena)
, constraints(collectConstraints(rootScope))
, rootScope(rootScope)
@ -77,6 +243,7 @@ void ConstraintSolver::run()
return;
ToStringOptions opts;
opts.exhaustive = true;
if (FFlag::DebugLuauLogSolver)
{
@ -186,6 +353,8 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*bc, constraint, force);
else if (auto nc = get<NameConstraint>(*constraint))
success = tryDispatch(*nc, constraint);
else if (auto taec = get<TypeAliasExpansionConstraint>(*constraint))
success = tryDispatch(*taec, constraint);
else
LUAU_ASSERT(0);
@ -325,6 +494,198 @@ bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull<const Constr
return true;
}
struct InfiniteTypeFinder : TypeVarOnceVisitor
{
ConstraintSolver* solver;
const InstantiationSignature& signature;
bool foundInfiniteType = false;
explicit InfiniteTypeFinder(ConstraintSolver* solver, const InstantiationSignature& signature)
: solver(solver)
, signature(signature)
{
}
bool visit(TypeId ty, const PendingExpansionTypeVar& petv) override
{
auto [typeArguments, packArguments] = saturateArguments(petv.fn, petv.typeArguments, petv.packArguments, solver->arena);
if (follow(petv.fn.type) == follow(signature.fn.type) && (signature.arguments != typeArguments || signature.packArguments != packArguments))
{
foundInfiniteType = true;
return false;
}
return true;
}
};
struct InstantiationQueuer : TypeVarOnceVisitor
{
ConstraintSolver* solver;
const InstantiationSignature& signature;
explicit InstantiationQueuer(ConstraintSolver* solver, const InstantiationSignature& signature)
: solver(solver)
, signature(signature)
{
}
bool visit(TypeId ty, const PendingExpansionTypeVar& petv) override
{
solver->pushConstraint(TypeAliasExpansionConstraint{ty});
return false;
}
};
bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNull<const Constraint> constraint)
{
const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(follow(c.target));
if (!petv)
{
unblock(c.target);
return true;
}
auto bindResult = [this, &c](TypeId result) {
asMutable(c.target)->ty.emplace<BoundTypeVar>(result);
unblock(c.target);
};
// If there are no parameters to the type function we can just use the type
// directly.
if (petv->fn.typeParams.empty() && petv->fn.typePackParams.empty())
{
bindResult(petv->fn.type);
return true;
}
auto [typeArguments, packArguments] = saturateArguments(petv->fn, petv->typeArguments, petv->packArguments, arena);
bool sameTypes =
std::equal(typeArguments.begin(), typeArguments.end(), petv->fn.typeParams.begin(), petv->fn.typeParams.end(), [](auto&& itp, auto&& p) {
return itp == p.ty;
});
bool samePacks = std::equal(
packArguments.begin(), packArguments.end(), petv->fn.typePackParams.begin(), petv->fn.typePackParams.end(), [](auto&& itp, auto&& p) {
return itp == p.tp;
});
// If we're instantiating the type with its generic saturatedTypeArguments we are
// performing the identity substitution. We can just short-circuit and bind
// to the TypeFun's type.
if (sameTypes && samePacks)
{
bindResult(petv->fn.type);
return true;
}
InstantiationSignature signature{
petv->fn,
typeArguments,
packArguments,
};
// If we use the same signature, we don't need to bother trying to
// instantiate the alias again, since the instantiation should be
// deterministic.
if (TypeId* cached = instantiatedAliases.find(signature))
{
bindResult(*cached);
return true;
}
// In order to prevent infinite types from being expanded and causing us to
// cycle infinitely, we need to scan the type function for cases where we
// expand the same alias with different type saturatedTypeArguments. See
// https://github.com/Roblox/luau/pull/68 for the RFC responsible for this.
// This is a little nicer than using a recursion limit because we can catch
// the infinite expansion before actually trying to expand it.
InfiniteTypeFinder itf{this, signature};
itf.traverse(petv->fn.type);
if (itf.foundInfiniteType)
{
// TODO (CLI-56761): Report an error.
bindResult(getSingletonTypes().errorRecoveryType());
return true;
}
ApplyTypeFunction applyTypeFunction{arena};
for (size_t i = 0; i < typeArguments.size(); ++i)
{
applyTypeFunction.typeArguments[petv->fn.typeParams[i].ty] = typeArguments[i];
}
for (size_t i = 0; i < packArguments.size(); ++i)
{
applyTypeFunction.typePackArguments[petv->fn.typePackParams[i].tp] = packArguments[i];
}
std::optional<TypeId> maybeInstantiated = applyTypeFunction.substitute(petv->fn.type);
// Note that ApplyTypeFunction::encounteredForwardedType is never set in
// DCR, because we do not use free types for forward-declared generic
// aliases.
if (!maybeInstantiated.has_value())
{
// TODO (CLI-56761): Report an error.
bindResult(getSingletonTypes().errorRecoveryType());
return true;
}
TypeId instantiated = *maybeInstantiated;
TypeId target = follow(instantiated);
// Type function application will happily give us the exact same type if
// there are e.g. generic saturatedTypeArguments that go unused.
bool needsClone = follow(petv->fn.type) == target;
// Only tables have the properties we're trying to set.
TableTypeVar* ttv = getMutableTableType(target);
if (ttv)
{
if (needsClone)
{
// Substitution::clone is a shallow clone. If this is a
// metatable type, we want to mutate its table, so we need to
// explicitly clone that table as well. If we don't, we will
// mutate another module's type surface and cause a
// use-after-free.
if (get<MetatableTypeVar>(target))
{
instantiated = applyTypeFunction.clone(target);
MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(instantiated);
mtv->table = applyTypeFunction.clone(mtv->table);
ttv = getMutable<TableTypeVar>(mtv->table);
}
else if (get<TableTypeVar>(target))
{
instantiated = applyTypeFunction.clone(target);
ttv = getMutable<TableTypeVar>(instantiated);
}
target = follow(instantiated);
}
ttv->instantiatedTypeParams = typeArguments;
ttv->instantiatedTypePackParams = packArguments;
// TODO: Fill in definitionModuleName.
}
bindResult(target);
// The application is not recursive, so we need to queue up application of
// any child type function instantiations within the result in order for it
// to be complete.
InstantiationQueuer queuer{this, signature};
queuer.traverse(target);
instantiatedAliases[signature] = target;
return true;
}
void ConstraintSolver::block_(BlockedConstraintId target, NotNull<const Constraint> constraint)
{
blocked[target].push_back(constraint);
@ -388,7 +749,7 @@ void ConstraintSolver::unblock(TypePackId progressed)
bool ConstraintSolver::isBlocked(TypeId ty)
{
return nullptr != get<BlockedTypeVar>(follow(ty));
return nullptr != get<BlockedTypeVar>(follow(ty)) || nullptr != get<PendingExpansionTypeVar>(follow(ty));
}
bool ConstraintSolver::isBlocked(NotNull<const Constraint> constraint)
@ -415,4 +776,12 @@ void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack)
u.log.commit();
}
void ConstraintSolver::pushConstraint(ConstraintV cv)
{
std::unique_ptr<Constraint> c = std::make_unique<Constraint>(std::move(cv));
NotNull<Constraint> borrow = NotNull(c.get());
solverConstraints.push_back(std::move(c));
unsolvedConstraints.push_back(borrow);
}
} // namespace Luau

View file

@ -2,45 +2,39 @@
#include "Luau/ConstraintSolverLogger.h"
#include "Luau/JsonEmitter.h"
namespace Luau
{
static std::string dumpScopeAndChildren(const Scope2* scope, ToStringOptions& opts)
static void dumpScopeAndChildren(const Scope* scope, Json::JsonEmitter& emitter, ToStringOptions& opts)
{
std::string output = "{\"bindings\":{";
emitter.writeRaw("{");
Json::write(emitter, "bindings");
emitter.writeRaw(":");
bool comma = false;
for (const auto& [name, type] : scope->bindings)
Json::ObjectEmitter o = emitter.writeObject();
for (const auto& [name, binding] : scope->bindings)
{
if (comma)
output += ",";
output += "\"";
output += name.c_str();
output += "\": \"";
ToStringResult result = toStringDetailed(type, opts);
ToStringResult result = toStringDetailed(binding.typeId, opts);
opts.nameMap = std::move(result.nameMap);
output += result.name;
output += "\"";
comma = true;
o.writePair(name.c_str(), result.name);
}
output += "},\"children\":[";
comma = false;
o.finish();
emitter.writeRaw(",");
Json::write(emitter, "children");
emitter.writeRaw(":");
for (const Scope2* child : scope->children)
Json::ArrayEmitter a = emitter.writeArray();
for (const Scope* child : scope->children)
{
if (comma)
output += ",";
output += dumpScopeAndChildren(child, opts);
comma = true;
dumpScopeAndChildren(child, emitter, opts);
}
output += "]}";
return output;
a.finish();
emitter.writeRaw("}");
}
static std::string dumpConstraintsToDot(std::vector<NotNull<const Constraint>>& constraints, ToStringOptions& opts)
@ -80,51 +74,49 @@ static std::string dumpConstraintsToDot(std::vector<NotNull<const Constraint>>&
std::string ConstraintSolverLogger::compileOutput()
{
std::string output = "[";
bool comma = false;
Json::JsonEmitter emitter;
emitter.writeRaw("[");
for (const std::string& snapshot : snapshots)
{
if (comma)
output += ",";
output += snapshot;
comma = true;
emitter.writeComma();
emitter.writeRaw(snapshot);
}
output += "]";
return output;
emitter.writeRaw("]");
return emitter.str();
}
void ConstraintSolverLogger::captureBoundarySnapshot(const Scope2* rootScope, std::vector<NotNull<const Constraint>>& unsolvedConstraints)
void ConstraintSolverLogger::captureBoundarySnapshot(const Scope* rootScope, std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
std::string snapshot = "{\"type\":\"boundary\",\"rootScope\":";
Json::JsonEmitter emitter;
Json::ObjectEmitter o = emitter.writeObject();
o.writePair("type", "boundary");
o.writePair("constraintGraph", dumpConstraintsToDot(unsolvedConstraints, opts));
emitter.writeComma();
Json::write(emitter, "rootScope");
emitter.writeRaw(":");
dumpScopeAndChildren(rootScope, emitter, opts);
o.finish();
snapshot += dumpScopeAndChildren(rootScope, opts);
snapshot += ",\"constraintGraph\":\"";
snapshot += dumpConstraintsToDot(unsolvedConstraints, opts);
snapshot += "\"}";
snapshots.push_back(std::move(snapshot));
snapshots.push_back(emitter.str());
}
void ConstraintSolverLogger::prepareStepSnapshot(
const Scope2* rootScope, NotNull<const Constraint> current, std::vector<NotNull<const Constraint>>& unsolvedConstraints)
const Scope* rootScope, NotNull<const Constraint> current, std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
// LUAU_ASSERT(!preparedSnapshot);
Json::JsonEmitter emitter;
Json::ObjectEmitter o = emitter.writeObject();
o.writePair("type", "step");
o.writePair("constraintGraph", dumpConstraintsToDot(unsolvedConstraints, opts));
o.writePair("currentId", std::to_string(reinterpret_cast<size_t>(current.get())));
o.writePair("current", toString(*current, opts));
emitter.writeComma();
Json::write(emitter, "rootScope");
emitter.writeRaw(":");
dumpScopeAndChildren(rootScope, emitter, opts);
o.finish();
std::string snapshot = "{\"type\":\"step\",\"rootScope\":";
snapshot += dumpScopeAndChildren(rootScope, opts);
snapshot += ",\"constraintGraph\":\"";
snapshot += dumpConstraintsToDot(unsolvedConstraints, opts);
snapshot += "\",\"currentId\":\"";
snapshot += std::to_string(reinterpret_cast<size_t>(current.get()));
snapshot += "\",\"current\":\"";
snapshot += toString(*current, opts);
snapshot += "\"}";
preparedSnapshot = std::move(snapshot);
preparedSnapshot = emitter.str();
}
void ConstraintSolverLogger::commitPreparedStepSnapshot()

View file

@ -1,7 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/BuiltinDefinitions.h"
LUAU_FASTFLAG(LuauCheckLenMT)
LUAU_FASTFLAG(LuauUnknownAndNeverType)
namespace Luau
{
@ -116,14 +116,13 @@ declare function typeof<T>(value: T): string
-- `assert` has a magic function attached that will give more detailed type information
declare function assert<T>(value: T, errorMessage: string?): T
declare function error<T>(message: T, level: number?)
declare function tostring<T>(value: T): string
declare function tonumber<T>(value: T, radix: number?): number?
declare function rawequal<T1, T2>(a: T1, b: T2): boolean
declare function rawget<K, V>(tab: {[K]: V}, k: K): V
declare function rawset<K, V>(tab: {[K]: V}, k: K, v: V): {[K]: V}
declare function rawlen<K, V>(obj: {[K]: V} | string): number
declare function setfenv<T..., R...>(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)?
@ -204,11 +203,13 @@ declare function unpack<V>(tab: {V}, i: number?, j: number?): ...V
std::string getBuiltinDefinitionSource()
{
std::string result = kBuiltinDefinitionLuaSrc;
// TODO: move this into kBuiltinDefinitionLuaSrc
if (FFlag::LuauCheckLenMT)
result += "declare function rawlen<K, V>(obj: {[K]: V} | string): number\n";
if (FFlag::LuauUnknownAndNeverType)
result += "declare function error<T>(message: T, level: number?): never\n";
else
result += "declare function error<T>(message: T, level: number?)\n";
return result;
}

View file

@ -102,8 +102,11 @@ void loadModuleIntoScope(TypeChecker& typeChecker, ModulePtr module, ScopePtr sc
}
}
LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName)
LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName)
{
if (!FFlag::DebugLuauDeferredConstraintResolution)
return Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, source, packageName);
LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend");
Luau::Allocator allocator;
@ -121,12 +124,12 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
module.root = parseResult.root;
module.mode = Mode::Definition;
ModulePtr checkedModule = typeChecker.check(module, Mode::Definition);
ModulePtr checkedModule = check(module, Mode::Definition, globalScope);
if (checkedModule->errors.size() > 0)
return LoadDefinitionFileResult{false, parseResult, checkedModule};
loadModuleIntoScope(typeChecker, checkedModule, targetScope, packageName);
loadModuleIntoScope(typeChecker, checkedModule, globalScope, packageName);
return LoadDefinitionFileResult{true, parseResult, checkedModule};
}
@ -503,6 +506,8 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
module->astTypes.clear();
module->astExpectedTypes.clear();
module->astOriginalCallTypes.clear();
module->astResolvedTypes.clear();
module->astResolvedTypePacks.clear();
module->scopes.resize(1);
}
@ -692,29 +697,6 @@ LintResult Frontend::lint(const ModuleName& name, std::optional<Luau::LintOption
return lint(*sourceModule, enabledLintWarnings);
}
std::pair<SourceModule, LintResult> Frontend::lintFragment(std::string_view source, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lintFragment", "Frontend");
const Config& config = configResolver->getConfig("");
SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions);
uint64_t ignoreLints = LintWarning::parseMask(sourceModule.hotcomments);
Luau::LintOptions lintOptions = enabledLintWarnings.value_or(config.enabledLint);
lintOptions.warningMask &= ~ignoreLints;
double timestamp = getTimestamp();
std::vector<LintWarning> warnings = Luau::lint(sourceModule.root, *sourceModule.names.get(), typeChecker.globalScope, nullptr,
sourceModule.hotcomments, enabledLintWarnings.value_or(config.enabledLint));
stats.timeLint += getTimestamp() - timestamp;
return {std::move(sourceModule), classifyLints(warnings, config)};
}
LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");
@ -819,35 +801,28 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons
return const_cast<Frontend*>(this)->getSourceModule(moduleName);
}
NotNull<Scope2> Frontend::getGlobalScope2()
NotNull<Scope> Frontend::getGlobalScope()
{
if (!globalScope2)
if (!globalScope)
{
const SingletonTypes& singletonTypes = getSingletonTypes();
globalScope2 = std::make_unique<Scope2>();
globalScope2->typeBindings["nil"] = singletonTypes.nilType;
globalScope2->typeBindings["number"] = singletonTypes.numberType;
globalScope2->typeBindings["string"] = singletonTypes.stringType;
globalScope2->typeBindings["boolean"] = singletonTypes.booleanType;
globalScope2->typeBindings["thread"] = singletonTypes.threadType;
globalScope = typeChecker.globalScope;
}
return NotNull(globalScope2.get());
return NotNull(globalScope.get());
}
ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope)
{
ModulePtr result = std::make_shared<Module>();
ConstraintGraphBuilder cgb{sourceModule.name, &result->internalTypes, NotNull(&iceHandler), getGlobalScope2()};
ConstraintGraphBuilder cgb{sourceModule.name, &result->internalTypes, NotNull(&iceHandler), getGlobalScope()};
cgb.visit(sourceModule.root);
result->errors = std::move(cgb.errors);
ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope)};
cs.run();
result->scope2s = std::move(cgb.scopes);
result->scopes = std::move(cgb.scopes);
result->astTypes = std::move(cgb.astTypes);
result->astTypePacks = std::move(cgb.astTypePacks);
result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes);
@ -988,7 +963,7 @@ std::optional<ModuleInfo> FrontendModuleResolver::resolveModuleInfo(const Module
{
// CLI-43699
// If we can't find the current module name, that's because we bypassed the frontend's initializer
// and called typeChecker.check directly. (This is done by autocompleteSource, for example).
// and called typeChecker.check directly.
// In that case, requires will always fail.
return std::nullopt;
}

View file

@ -4,6 +4,8 @@
#include "Luau/TxnLog.h"
#include "Luau/TypeArena.h"
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau
{
@ -31,6 +33,8 @@ bool Instantiation::ignoreChildren(TypeId ty)
{
if (log->getMutable<FunctionTypeVar>(ty))
return true;
else if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassTypeVar>(ty))
return true;
else
return false;
}

View file

@ -0,0 +1,220 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/JsonEmitter.h"
#include "Luau/StringUtils.h"
#include <string.h>
namespace Luau::Json
{
static constexpr int CHUNK_SIZE = 1024;
ObjectEmitter::ObjectEmitter(NotNull<JsonEmitter> emitter)
: emitter(emitter), finished(false)
{
comma = emitter->pushComma();
emitter->writeRaw('{');
}
ObjectEmitter::~ObjectEmitter()
{
finish();
}
void ObjectEmitter::finish()
{
if (finished)
return;
emitter->writeRaw('}');
emitter->popComma(comma);
finished = true;
}
ArrayEmitter::ArrayEmitter(NotNull<JsonEmitter> emitter)
: emitter(emitter), finished(false)
{
comma = emitter->pushComma();
emitter->writeRaw('[');
}
ArrayEmitter::~ArrayEmitter()
{
finish();
}
void ArrayEmitter::finish()
{
if (finished)
return;
emitter->writeRaw(']');
emitter->popComma(comma);
finished = true;
}
JsonEmitter::JsonEmitter()
{
newChunk();
}
std::string JsonEmitter::str()
{
return join(chunks, "");
}
bool JsonEmitter::pushComma()
{
bool current = comma;
comma = false;
return current;
}
void JsonEmitter::popComma(bool c)
{
comma = c;
}
void JsonEmitter::writeRaw(std::string_view sv)
{
if (sv.size() > CHUNK_SIZE)
{
chunks.emplace_back(sv);
newChunk();
return;
}
auto& chunk = chunks.back();
if (chunk.size() + sv.size() < CHUNK_SIZE)
{
chunk.append(sv.data(), sv.size());
return;
}
size_t prefix = CHUNK_SIZE - chunk.size();
chunk.append(sv.data(), prefix);
newChunk();
chunks.back().append(sv.data() + prefix, sv.size() - prefix);
}
void JsonEmitter::writeRaw(char c)
{
writeRaw(std::string_view{&c, 1});
}
void write(JsonEmitter& emitter, bool b)
{
if (b)
emitter.writeRaw("true");
else
emitter.writeRaw("false");
}
void write(JsonEmitter& emitter, double d)
{
emitter.writeRaw(std::to_string(d));
}
void write(JsonEmitter& emitter, int i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, long i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, long long i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, unsigned int i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, unsigned long i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, unsigned long long i)
{
emitter.writeRaw(std::to_string(i));
}
void write(JsonEmitter& emitter, std::string_view sv)
{
emitter.writeRaw('\"');
for (char c : sv)
{
if (c == '"')
emitter.writeRaw("\\\"");
else if (c == '\\')
emitter.writeRaw("\\\\");
else if (c == '\n')
emitter.writeRaw("\\n");
else if (c < ' ')
emitter.writeRaw(format("\\u%04x", c));
else
emitter.writeRaw(c);
}
emitter.writeRaw('\"');
}
void write(JsonEmitter& emitter, char c)
{
write(emitter, std::string_view{&c, 1});
}
void write(JsonEmitter& emitter, const char* str)
{
write(emitter, std::string_view{str, strlen(str)});
}
void write(JsonEmitter& emitter, const std::string& str)
{
write(emitter, std::string_view{str});
}
void write(JsonEmitter& emitter, std::nullptr_t)
{
emitter.writeRaw("null");
}
void write(JsonEmitter& emitter, std::nullopt_t)
{
emitter.writeRaw("null");
}
void JsonEmitter::writeComma()
{
if (comma)
writeRaw(',');
else
comma = true;
}
ObjectEmitter JsonEmitter::writeObject()
{
return ObjectEmitter{NotNull(this)};
}
ArrayEmitter JsonEmitter::writeArray()
{
return ArrayEmitter{NotNull(this)};
}
void JsonEmitter::newChunk()
{
chunks.emplace_back();
chunks.back().reserve(CHUNK_SIZE);
}
} // namespace Luau::Json

View file

@ -48,6 +48,7 @@ static const char* kWarningNames[] = {
"DuplicateCondition",
"MisleadingAndOr",
"CommentDirective",
"IntegerParsing",
};
// clang-format on
@ -1433,7 +1434,7 @@ private:
const char* checkStringFormat(const char* data, size_t size)
{
const char* flags = "-+ #0";
const char* options = "cdiouxXeEfgGqs";
const char* options = "cdiouxXeEfgGqs*";
for (size_t i = 0; i < size; ++i)
{
@ -2589,6 +2590,45 @@ private:
}
};
class LintIntegerParsing : AstVisitor
{
public:
LUAU_NOINLINE static void process(LintContext& context)
{
LintIntegerParsing pass;
pass.context = &context;
context.root->visit(&pass);
}
private:
LintContext* context;
bool visit(AstExprConstantNumber* node) override
{
switch (node->parseResult)
{
case ConstantNumberParseResult::Ok:
case ConstantNumberParseResult::Malformed:
break;
case ConstantNumberParseResult::BinOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Binary number literal exceeded available precision and has been truncated to 2^64");
break;
case ConstantNumberParseResult::HexOverflow:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Hexadecimal number literal exceeded available precision and has been truncated to 2^64");
break;
case ConstantNumberParseResult::DoublePrefix:
emitWarning(*context, LintWarning::Code_IntegerParsing, node->location,
"Hexadecimal number literal has a double prefix, which will fail to parse in the future; remove the extra 0x to fix");
break;
}
return true;
}
};
static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names, const ScopePtr& env)
{
ScopePtr current = env;
@ -2688,6 +2728,21 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
else
seenMode = true;
}
else if (first == "optimize")
{
size_t notspace = hc.content.find_first_not_of(" \t", space);
if (space == std::string::npos || notspace == std::string::npos)
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "optimize directive requires an optimization level");
else
{
const char* level = hc.content.c_str() + notspace;
if (strcmp(level, "0") && strcmp(level, "1") && strcmp(level, "2"))
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"optimize directive uses unknown optimization level '%s', 0..2 expected", level);
}
}
else
{
static const char* kHotComments[] = {
@ -2695,6 +2750,7 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
"nocheck",
"nonstrict",
"strict",
"optimize",
};
if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments)))
@ -2794,6 +2850,9 @@ std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const Sc
if (context.warningEnabled(LintWarning::Code_CommentDirective))
lintComments(context, hotcomments);
if (context.warningEnabled(LintWarning::Code_IntegerParsing))
LintIntegerParsing::process(context);
std::sort(context.result.begin(), context.result.end(), WarningComparator());
return context.result;

View file

@ -15,8 +15,8 @@
#include <algorithm>
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
LUAU_FASTFLAG(LuauNormalizeFlagIsConservative);
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(LuauForceExportSurfacesToBeNormal, false);
namespace Luau
{
@ -99,40 +99,37 @@ void Module::clonePublicInterface(InternalErrorReporter& ice)
CloneState cloneState;
ScopePtr moduleScope = FFlag::DebugLuauDeferredConstraintResolution ? nullptr : getModuleScope();
Scope2* moduleScope2 = FFlag::DebugLuauDeferredConstraintResolution ? getModuleScope2() : nullptr;
ScopePtr moduleScope = getModuleScope();
TypePackId returnType = FFlag::DebugLuauDeferredConstraintResolution ? moduleScope2->returnType : moduleScope->returnType;
TypePackId returnType = moduleScope->returnType;
std::optional<TypePackId> varargPack = FFlag::DebugLuauDeferredConstraintResolution ? std::nullopt : moduleScope->varargPack;
std::unordered_map<Name, TypeFun>* exportedTypeBindings =
FFlag::DebugLuauDeferredConstraintResolution ? nullptr : &moduleScope->exportedTypeBindings;
returnType = clone(returnType, interfaceTypes, cloneState);
if (moduleScope)
moduleScope->returnType = returnType;
if (varargPack)
{
moduleScope->returnType = returnType;
if (varargPack)
{
varargPack = clone(*varargPack, interfaceTypes, cloneState);
moduleScope->varargPack = varargPack;
}
}
else
{
LUAU_ASSERT(moduleScope2);
moduleScope2->returnType = returnType; // TODO varargPack
varargPack = clone(*varargPack, interfaceTypes, cloneState);
moduleScope->varargPack = varargPack;
}
ForceNormal forceNormal{&interfaceTypes};
if (FFlag::LuauLowerBoundsCalculation)
{
normalize(returnType, interfaceTypes, ice);
if (FFlag::LuauForceExportSurfacesToBeNormal)
forceNormal.traverse(returnType);
if (varargPack)
{
normalize(*varargPack, interfaceTypes, ice);
if (FFlag::LuauForceExportSurfacesToBeNormal)
forceNormal.traverse(*varargPack);
}
}
ForceNormal forceNormal{&interfaceTypes};
if (exportedTypeBindings)
{
for (auto& [name, tf] : *exportedTypeBindings)
@ -142,11 +139,18 @@ void Module::clonePublicInterface(InternalErrorReporter& ice)
{
normalize(tf.type, interfaceTypes, ice);
if (FFlag::LuauNormalizeFlagIsConservative)
// We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables
// won't be marked normal. If the types aren't normal by now, they never will be.
forceNormal.traverse(tf.type);
for (GenericTypeDefinition param : tf.typeParams)
{
// We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables
// won't be marked normal. If the types aren't normal by now, they never will be.
forceNormal.traverse(tf.type);
forceNormal.traverse(param.ty);
if (param.defaultValue)
{
normalize(*param.defaultValue, interfaceTypes, ice);
forceNormal.traverse(*param.defaultValue);
}
}
}
}
@ -166,7 +170,12 @@ void Module::clonePublicInterface(InternalErrorReporter& ice)
{
ty = clone(ty, interfaceTypes, cloneState);
if (FFlag::LuauLowerBoundsCalculation)
{
normalize(ty, interfaceTypes, ice);
if (FFlag::LuauForceExportSurfacesToBeNormal)
forceNormal.traverse(ty);
}
}
freeze(internalTypes);
@ -179,10 +188,4 @@ ScopePtr Module::getModuleScope() const
return scopes.front().second;
}
Scope2* Module::getModuleScope2() const
{
LUAU_ASSERT(!scope2s.empty());
return scope2s.front().second.get();
}
} // namespace Luau

View file

@ -13,8 +13,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false)
// This could theoretically be 2000 on amd64, but x86 requires this.
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false);
LUAU_FASTFLAGVARIABLE(LuauFixNormalizationOfCyclicUnions, false);
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(LuauQuantifyConstrained)
namespace Luau
@ -88,13 +88,7 @@ static bool areNormal_(const T& t, const std::unordered_set<void*>& seen, Intern
if (count >= FInt::LuauNormalizeIterationLimit)
ice.ice("Luau::areNormal hit iteration limit");
if (FFlag::LuauNormalizeFlagIsConservative)
return ty->normal;
else
{
// The follow is here because a bound type may not be normal, but the bound type is normal.
return ty->normal || follow(ty)->normal || seen.find(asMutable(ty)) != seen.end();
}
return ty->normal;
};
return std::all_of(begin(t), end(t), isNormal);
@ -182,7 +176,6 @@ struct Normalize final : TypeVarVisitor
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
@ -193,6 +186,20 @@ struct Normalize final : TypeVarVisitor
return false;
}
bool visit(TypeId ty, const UnknownTypeVar&) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const NeverTypeVar&) override
{
if (!ty->normal)
asMutable(ty)->normal = true;
return false;
}
bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override
{
CHECK_ITERATION_LIMIT(false);
@ -327,13 +334,19 @@ struct Normalize final : TypeVarVisitor
return false;
UnionTypeVar* utv = &const_cast<UnionTypeVar&>(utvRef);
std::vector<TypeId> options = std::move(utv->options);
// TODO: Clip tempOptions and optionsRef when clipping FFlag::LuauFixNormalizationOfCyclicUnions
std::vector<TypeId> tempOptions;
if (!FFlag::LuauFixNormalizationOfCyclicUnions)
tempOptions = std::move(utv->options);
std::vector<TypeId>& optionsRef = FFlag::LuauFixNormalizationOfCyclicUnions ? utv->options : tempOptions;
// We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar
for (TypeId option : options)
for (TypeId option : optionsRef)
traverse(option);
std::vector<TypeId> newOptions = normalizeUnion(options);
std::vector<TypeId> newOptions = normalizeUnion(optionsRef);
const bool normal = areNormal(newOptions, seen, ice);
@ -358,51 +371,106 @@ struct Normalize final : TypeVarVisitor
IntersectionTypeVar* itv = &const_cast<IntersectionTypeVar&>(itvRef);
std::vector<TypeId> oldParts = std::move(itv->parts);
for (TypeId part : oldParts)
traverse(part);
std::vector<TypeId> tables;
for (TypeId part : oldParts)
if (FFlag::LuauFixNormalizationOfCyclicUnions)
{
part = follow(part);
if (get<TableTypeVar>(part))
tables.push_back(part);
else
{
Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD
combineIntoIntersection(replacer, itv, part);
}
}
std::vector<TypeId> oldParts = itv->parts;
IntersectionTypeVar newIntersection;
// Don't allocate a new table if there's just one in the intersection.
if (tables.size() == 1)
itv->parts.push_back(tables[0]);
else if (!tables.empty())
{
const TableTypeVar* first = get<TableTypeVar>(tables[0]);
LUAU_ASSERT(first);
for (TypeId part : oldParts)
traverse(part);
TypeId newTable = arena.addType(TableTypeVar{first->state, first->level});
TableTypeVar* ttv = getMutable<TableTypeVar>(newTable);
for (TypeId part : tables)
std::vector<TypeId> tables;
for (TypeId part : oldParts)
{
// Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need
// to be rewritten to point at 'newTable' in the clone.
Replacer replacer{&arena, part, newTable};
combineIntoTable(replacer, ttv, part);
part = follow(part);
if (get<TableTypeVar>(part))
tables.push_back(part);
else
{
Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD
combineIntoIntersection(replacer, &newIntersection, part);
}
}
itv->parts.push_back(newTable);
// Don't allocate a new table if there's just one in the intersection.
if (tables.size() == 1)
newIntersection.parts.push_back(tables[0]);
else if (!tables.empty())
{
const TableTypeVar* first = get<TableTypeVar>(tables[0]);
LUAU_ASSERT(first);
TypeId newTable = arena.addType(TableTypeVar{first->state, first->level});
TableTypeVar* ttv = getMutable<TableTypeVar>(newTable);
for (TypeId part : tables)
{
// Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need
// to be rewritten to point at 'newTable' in the clone.
Replacer replacer{&arena, part, newTable};
combineIntoTable(replacer, ttv, part);
}
newIntersection.parts.push_back(newTable);
}
itv->parts = std::move(newIntersection.parts);
asMutable(ty)->normal = areNormal(itv->parts, seen, ice);
if (itv->parts.size() == 1)
{
TypeId part = itv->parts[0];
*asMutable(ty) = BoundTypeVar{part};
}
}
asMutable(ty)->normal = areNormal(itv->parts, seen, ice);
if (itv->parts.size() == 1)
else
{
TypeId part = itv->parts[0];
*asMutable(ty) = BoundTypeVar{part};
std::vector<TypeId> oldParts = std::move(itv->parts);
for (TypeId part : oldParts)
traverse(part);
std::vector<TypeId> tables;
for (TypeId part : oldParts)
{
part = follow(part);
if (get<TableTypeVar>(part))
tables.push_back(part);
else
{
Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD
combineIntoIntersection(replacer, itv, part);
}
}
// Don't allocate a new table if there's just one in the intersection.
if (tables.size() == 1)
itv->parts.push_back(tables[0]);
else if (!tables.empty())
{
const TableTypeVar* first = get<TableTypeVar>(tables[0]);
LUAU_ASSERT(first);
TypeId newTable = arena.addType(TableTypeVar{first->state, first->level});
TableTypeVar* ttv = getMutable<TableTypeVar>(newTable);
for (TypeId part : tables)
{
// Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need
// to be rewritten to point at 'newTable' in the clone.
Replacer replacer{&arena, part, newTable};
combineIntoTable(replacer, ttv, part);
}
itv->parts.push_back(newTable);
}
asMutable(ty)->normal = areNormal(itv->parts, seen, ice);
if (itv->parts.size() == 1)
{
TypeId part = itv->parts[0];
*asMutable(ty) = BoundTypeVar{part};
}
}
return false;
@ -416,7 +484,13 @@ struct Normalize final : TypeVarVisitor
std::vector<TypeId> result;
for (TypeId part : options)
{
// AnyTypeVar always win the battle no matter what we do, so we're done.
if (FFlag::LuauUnknownAndNeverType && get<AnyTypeVar>(follow(part)))
return {part};
combineIntoUnion(result, part);
}
return result;
}
@ -427,7 +501,17 @@ struct Normalize final : TypeVarVisitor
if (auto utv = get<UnionTypeVar>(ty))
{
for (TypeId t : utv)
{
// AnyTypeVar always win the battle no matter what we do, so we're done.
if (FFlag::LuauUnknownAndNeverType && get<AnyTypeVar>(t))
{
result = {t};
return;
}
combineIntoUnion(result, t);
}
return;
}
@ -561,6 +645,24 @@ struct Normalize final : TypeVarVisitor
table->props.insert({propName, prop});
}
if (FFlag::LuauFixNormalizationOfCyclicUnions)
{
if (tyTable->indexer)
{
if (table->indexer)
{
table->indexer->indexType = combine(replacer, replacer.smartClone(tyTable->indexer->indexType), table->indexer->indexType);
table->indexer->indexResultType =
combine(replacer, replacer.smartClone(tyTable->indexer->indexResultType), table->indexer->indexResultType);
}
else
{
table->indexer =
TableIndexer{replacer.smartClone(tyTable->indexer->indexType), replacer.smartClone(tyTable->indexer->indexResultType)};
}
}
}
table->state = combineTableStates(table->state, tyTable->state);
table->level = max(table->level, tyTable->level);
}
@ -571,8 +673,7 @@ struct Normalize final : TypeVarVisitor
*/
TypeId combine(Replacer& replacer, TypeId a, TypeId b)
{
if (FFlag::LuauNormalizeCombineEqFix)
b = follow(b);
b = follow(b);
if (FFlag::LuauNormalizeCombineTableFix && a == b)
return a;
@ -592,7 +693,7 @@ struct Normalize final : TypeVarVisitor
}
else if (auto ttv = getMutable<TableTypeVar>(a))
{
if (FFlag::LuauNormalizeCombineTableFix && !get<TableTypeVar>(FFlag::LuauNormalizeCombineEqFix ? b : follow(b)))
if (FFlag::LuauNormalizeCombineTableFix && !get<TableTypeVar>(b))
return arena.addType(IntersectionTypeVar{{a, b}});
combineIntoTable(replacer, ttv, b);
return a;

View file

@ -7,7 +7,6 @@
#include "Luau/TxnLog.h"
#include "Luau/VisitTypeVar.h"
LUAU_FASTFLAG(LuauAlwaysQuantify);
LUAU_FASTFLAG(DebugLuauSharedSelf)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false)
@ -16,13 +15,13 @@ namespace Luau
{
/// @return true if outer encloses inner
static bool subsumes(Scope2* outer, Scope2* inner)
static bool subsumes(Scope* outer, Scope* inner)
{
while (inner)
{
if (inner == outer)
return true;
inner = inner->parent;
inner = inner->parent.get();
}
return false;
@ -33,7 +32,7 @@ struct Quantifier final : TypeVarOnceVisitor
TypeLevel level;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
Scope2* scope = nullptr;
Scope* scope = nullptr;
bool seenGenericType = false;
bool seenMutableType = false;
@ -43,20 +42,20 @@ struct Quantifier final : TypeVarOnceVisitor
LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution);
}
explicit Quantifier(Scope2* scope)
explicit Quantifier(Scope* scope)
: scope(scope)
{
LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution);
}
/// @return true if outer encloses inner
bool subsumes(Scope2* outer, Scope2* inner)
bool subsumes(Scope* outer, Scope* inner)
{
while (inner)
{
if (inner == outer)
return true;
inner = inner->parent;
inner = inner->parent.get();
}
return false;
@ -203,36 +202,20 @@ void quantify(TypeId ty, TypeLevel level)
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv);
if (FFlag::LuauAlwaysQuantify)
{
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
}
else
{
ftv->generics = q.generics;
ftv->genericPacks = q.genericPacks;
}
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
}
}
void quantify(TypeId ty, Scope2* scope)
void quantify(TypeId ty, Scope* scope)
{
Quantifier q{scope};
q.traverse(ty);
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv);
if (FFlag::LuauAlwaysQuantify)
{
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
}
else
{
ftv->generics = q.generics;
ftv->genericPacks = q.genericPacks;
}
ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end());
ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end());
if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType)
ftv->hasNoGenerics = true;
@ -240,11 +223,11 @@ void quantify(TypeId ty, Scope2* scope)
struct PureQuantifier : Substitution
{
Scope2* scope;
Scope* scope;
std::vector<TypeId> insertedGenerics;
std::vector<TypePackId> insertedGenericPacks;
PureQuantifier(TypeArena* arena, Scope2* scope)
PureQuantifier(TypeArena* arena, Scope* scope)
: Substitution(TxnLog::empty(), arena)
, scope(scope)
{
@ -322,7 +305,7 @@ struct PureQuantifier : Substitution
}
};
TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope)
TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope)
{
PureQuantifier quantifier{arena, scope};
std::optional<TypeId> result = quantifier.substitute(ty);

View file

@ -21,22 +21,6 @@ Scope::Scope(const ScopePtr& parent, int subLevel)
level.subLevel = subLevel;
}
std::optional<TypeId> Scope::lookup(const Symbol& name)
{
Scope* scope = this;
while (scope)
{
auto it = scope->bindings.find(name);
if (it != scope->bindings.end())
return it->second.typeId;
scope = scope->parent.get();
}
return std::nullopt;
}
std::optional<TypeFun> Scope::lookupType(const Name& name)
{
const Scope* scope = this;
@ -121,48 +105,48 @@ std::optional<Binding> Scope::linearSearchForBinding(const std::string& name, bo
return std::nullopt;
}
std::optional<TypeId> Scope2::lookup(Symbol sym)
std::optional<TypeId> Scope::lookup(Symbol sym)
{
Scope2* s = this;
Scope* s = this;
while (true)
{
auto it = s->bindings.find(sym);
if (it != s->bindings.end())
return it->second;
return it->second.typeId;
if (s->parent)
s = s->parent;
s = s->parent.get();
else
return std::nullopt;
}
}
std::optional<TypeId> Scope2::lookupTypeBinding(const Name& name)
std::optional<TypeFun> Scope::lookupTypeBinding(const Name& name)
{
Scope2* s = this;
Scope* s = this;
while (s)
{
auto it = s->typeBindings.find(name);
if (it != s->typeBindings.end())
return it->second;
s = s->parent;
s = s->parent.get();
}
return std::nullopt;
}
std::optional<TypePackId> Scope2::lookupTypePackBinding(const Name& name)
std::optional<TypePackId> Scope::lookupTypePackBinding(const Name& name)
{
Scope2* s = this;
Scope* s = this;
while (s)
{
auto it = s->typePackBindings.find(name);
if (it != s->typePackBindings.end())
return it->second;
s = s->parent;
s = s->parent.get();
}
return std::nullopt;

View file

@ -8,8 +8,13 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAGVARIABLE(LuauAnyificationMustClone, false)
LUAU_FASTFLAGVARIABLE(LuauSubstitutionFixMissingFields, false)
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000)
LUAU_FASTFLAGVARIABLE(LuauClassTypeVarsInSubstitution, false)
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauSubstitutionReentrant, false)
namespace Luau
{
@ -26,6 +31,14 @@ void Tarjan::visitChildren(TypeId ty, int index)
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{
if (FFlag::LuauSubstitutionFixMissingFields)
{
for (TypeId generic : ftv->generics)
visitChild(generic);
for (TypePackId genericPack : ftv->genericPacks)
visitChild(genericPack);
}
visitChild(ftv->argTypes);
visitChild(ftv->retTypes);
}
@ -66,6 +79,25 @@ void Tarjan::visitChildren(TypeId ty, int index)
for (TypeId part : ctv->parts)
visitChild(part);
}
else if (const PendingExpansionTypeVar* petv = get<PendingExpansionTypeVar>(ty))
{
for (TypeId a : petv->typeArguments)
visitChild(a);
for (TypePackId a : petv->packArguments)
visitChild(a);
}
else if (const ClassTypeVar* ctv = get<ClassTypeVar>(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv)
{
for (auto [name, prop] : ctv->props)
visitChild(prop.type);
if (ctv->parent)
visitChild(*ctv->parent);
if (ctv->metatable)
visitChild(*ctv->metatable);
}
}
void Tarjan::visitChildren(TypePackId tp, int index)
@ -154,7 +186,7 @@ TarjanResult Tarjan::loop()
if (currEdge == -1)
{
++childCount;
if (childLimit > 0 && childLimit < childCount)
if (childLimit > 0 && (FFlag::LuauUnknownAndNeverType ? childLimit <= childCount : childLimit < childCount))
return TarjanResult::TooManyChildren;
stack.push_back(index);
@ -265,6 +297,24 @@ TarjanResult Tarjan::visitRoot(TypePackId tp)
return loop();
}
void FindDirty::clearTarjan()
{
dirty.clear();
typeToIndex.clear();
packToIndex.clear();
indexToType.clear();
indexToPack.clear();
stack.clear();
onStack.clear();
lowlink.clear();
edgesTy.clear();
edgesTp.clear();
worklist.clear();
}
bool FindDirty::getDirty(int index)
{
if (dirty.size() <= size_t(index))
@ -328,16 +378,46 @@ std::optional<TypeId> Substitution::substitute(TypeId ty)
{
ty = log->follow(ty);
// clear algorithm state for reentrancy
if (FFlag::LuauSubstitutionReentrant)
clearTarjan();
auto result = findDirty(ty);
if (result != TarjanResult::Ok)
return std::nullopt;
for (auto [oldTy, newTy] : newTypes)
if (!ignoreChildren(oldTy))
replaceChildren(newTy);
{
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy))
{
replaceChildren(newTy);
replacedTypes.insert(newTy);
}
}
else
{
if (!ignoreChildren(oldTy))
replaceChildren(newTy);
}
}
for (auto [oldTp, newTp] : newPacks)
if (!ignoreChildren(oldTp))
replaceChildren(newTp);
{
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp))
{
replaceChildren(newTp);
replacedTypePacks.insert(newTp);
}
}
else
{
if (!ignoreChildren(oldTp))
replaceChildren(newTp);
}
}
TypeId newTy = replace(ty);
return newTy;
}
@ -346,16 +426,46 @@ std::optional<TypePackId> Substitution::substitute(TypePackId tp)
{
tp = log->follow(tp);
// clear algorithm state for reentrancy
if (FFlag::LuauSubstitutionReentrant)
clearTarjan();
auto result = findDirty(tp);
if (result != TarjanResult::Ok)
return std::nullopt;
for (auto [oldTy, newTy] : newTypes)
if (!ignoreChildren(oldTy))
replaceChildren(newTy);
{
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy))
{
replaceChildren(newTy);
replacedTypes.insert(newTy);
}
}
else
{
if (!ignoreChildren(oldTy))
replaceChildren(newTy);
}
}
for (auto [oldTp, newTp] : newPacks)
if (!ignoreChildren(oldTp))
replaceChildren(newTp);
{
if (FFlag::LuauSubstitutionReentrant)
{
if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp))
{
replaceChildren(newTp);
replacedTypePacks.insert(newTp);
}
}
else
{
if (!ignoreChildren(oldTp))
replaceChildren(newTp);
}
}
TypePackId newTp = replace(tp);
return newTp;
}
@ -383,6 +493,8 @@ TypePackId Substitution::clone(TypePackId tp)
{
VariadicTypePack clone;
clone.ty = vtp->ty;
if (FFlag::LuauSubstitutionFixMissingFields)
clone.hidden = vtp->hidden;
return addTypePack(std::move(clone));
}
else
@ -393,6 +505,9 @@ void Substitution::foundDirty(TypeId ty)
{
ty = log->follow(ty);
if (FFlag::LuauSubstitutionReentrant && newTypes.contains(ty))
return;
if (isDirty(ty))
newTypes[ty] = follow(clean(ty));
else
@ -403,6 +518,9 @@ void Substitution::foundDirty(TypePackId tp)
{
tp = log->follow(tp);
if (FFlag::LuauSubstitutionReentrant && newPacks.contains(tp))
return;
if (isDirty(tp))
newPacks[tp] = follow(clean(tp));
else
@ -439,8 +557,19 @@ void Substitution::replaceChildren(TypeId ty)
if (ignoreChildren(ty))
return;
if (FFlag::LuauAnyificationMustClone && ty->owningArena != arena)
return;
if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty))
{
if (FFlag::LuauSubstitutionFixMissingFields)
{
for (TypeId& generic : ftv->generics)
generic = replace(generic);
for (TypePackId& genericPack : ftv->genericPacks)
genericPack = replace(genericPack);
}
ftv->argTypes = replace(ftv->argTypes);
ftv->retTypes = replace(ftv->retTypes);
}
@ -481,6 +610,25 @@ void Substitution::replaceChildren(TypeId ty)
for (TypeId& part : ctv->parts)
part = replace(part);
}
else if (PendingExpansionTypeVar* petv = getMutable<PendingExpansionTypeVar>(ty))
{
for (TypeId& a : petv->typeArguments)
a = replace(a);
for (TypePackId& a : petv->packArguments)
a = replace(a);
}
else if (ClassTypeVar* ctv = getMutable<ClassTypeVar>(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv)
{
for (auto& [name, prop] : ctv->props)
prop.type = replace(prop.type);
if (ctv->parent)
ctv->parent = replace(*ctv->parent);
if (ctv->metatable)
ctv->metatable = replace(*ctv->metatable);
}
}
void Substitution::replaceChildren(TypePackId tp)
@ -490,6 +638,9 @@ void Substitution::replaceChildren(TypePackId tp)
if (ignoreChildren(tp))
return;
if (FFlag::LuauAnyificationMustClone && tp->owningArena != arena)
return;
if (TypePack* tpp = getMutable<TypePack>(tp))
{
for (TypeId& tv : tpp->head)

View file

@ -11,6 +11,8 @@
#include <stdexcept>
LUAU_FASTFLAG(LuauLowerBoundsCalculation)
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false)
/*
* Prefix generic typenames with gen-
@ -18,7 +20,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation)
* Fair warning: Setting this will break a lot of Luau unit tests.
*/
LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false)
LUAU_FASTFLAGVARIABLE(LuauToStringTableBracesNewlines, false)
namespace Luau
{
@ -231,6 +232,11 @@ struct StringifierState
emit(std::to_string(i).c_str());
}
void emit(size_t i)
{
emit(std::to_string(i).c_str());
}
void indent()
{
indentation += 4;
@ -277,7 +283,10 @@ struct TypeVarStringifier
if (tv->ty.valueless_by_exception())
{
state.result.error = true;
state.emit("< VALUELESS BY EXCEPTION >");
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("* VALUELESS BY EXCEPTION *");
else
state.emit("< VALUELESS BY EXCEPTION >");
return;
}
@ -406,6 +415,13 @@ struct TypeVarStringifier
state.emit("*");
}
void operator()(TypeId ty, const PendingExpansionTypeVar& petv)
{
state.emit("*pending-expansion-");
state.emit(petv.index);
state.emit("*");
}
void operator()(TypeId, const PrimitiveTypeVar& ptv)
{
switch (ptv.type)
@ -453,7 +469,10 @@ struct TypeVarStringifier
if (state.hasSeen(&ftv))
{
state.result.cycle = true;
state.emit("<CYCLE>");
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return;
}
@ -561,7 +580,10 @@ struct TypeVarStringifier
if (state.hasSeen(&ttv))
{
state.result.cycle = true;
state.emit("<CYCLE>");
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return;
}
@ -571,54 +593,22 @@ struct TypeVarStringifier
{
case TableState::Sealed:
state.result.invalid = true;
if (FFlag::LuauToStringTableBracesNewlines)
{
openbrace = "{|";
closedbrace = "|}";
}
else
{
openbrace = "{| ";
closedbrace = " |}";
}
openbrace = "{|";
closedbrace = "|}";
break;
case TableState::Unsealed:
if (FFlag::LuauToStringTableBracesNewlines)
{
openbrace = "{";
closedbrace = "}";
}
else
{
openbrace = "{ ";
closedbrace = " }";
}
openbrace = "{";
closedbrace = "}";
break;
case TableState::Free:
state.result.invalid = true;
if (FFlag::LuauToStringTableBracesNewlines)
{
openbrace = "{-";
closedbrace = "-}";
}
else
{
openbrace = "{- ";
closedbrace = " -}";
}
openbrace = "{-";
closedbrace = "-}";
break;
case TableState::Generic:
state.result.invalid = true;
if (FFlag::LuauToStringTableBracesNewlines)
{
openbrace = "{+";
closedbrace = "+}";
}
else
{
openbrace = "{+ ";
closedbrace = " +}";
}
openbrace = "{+";
closedbrace = "+}";
break;
}
@ -637,8 +627,7 @@ struct TypeVarStringifier
bool comma = false;
if (ttv.indexer)
{
if (FFlag::LuauToStringTableBracesNewlines)
state.newline();
state.newline();
state.emit("[");
stringify(ttv.indexer->indexType);
state.emit("]: ");
@ -655,10 +644,8 @@ struct TypeVarStringifier
state.emit(",");
state.newline();
}
else if (FFlag::LuauToStringTableBracesNewlines)
{
else
state.newline();
}
size_t length = state.result.name.length() - oldLength;
@ -685,13 +672,10 @@ struct TypeVarStringifier
}
state.dedent();
if (FFlag::LuauToStringTableBracesNewlines)
{
if (comma)
state.newline();
else
state.emit(" ");
}
if (comma)
state.newline();
else
state.emit(" ");
state.emit(closedbrace);
state.unsee(&ttv);
@ -700,6 +684,12 @@ struct TypeVarStringifier
void operator()(TypeId, const MetatableTypeVar& mtv)
{
state.result.invalid = true;
if (!state.exhaustive && mtv.syntheticName)
{
state.emit(*mtv.syntheticName);
return;
}
state.emit("{ @metatable ");
stringify(mtv.metatable);
state.emit(",");
@ -723,7 +713,10 @@ struct TypeVarStringifier
if (state.hasSeen(&uv))
{
state.result.cycle = true;
state.emit("<CYCLE>");
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return;
}
@ -790,7 +783,10 @@ struct TypeVarStringifier
if (state.hasSeen(&uv))
{
state.result.cycle = true;
state.emit("<CYCLE>");
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLE*");
else
state.emit("<CYCLE>");
return;
}
@ -835,7 +831,10 @@ struct TypeVarStringifier
void operator()(TypeId, const ErrorTypeVar& tv)
{
state.result.error = true;
state.emit("*unknown*");
if (FFlag::LuauSpecialTypesAsterisked)
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
else
state.emit(FFlag::LuauUnknownAndNeverType ? "<error-type>" : "*unknown*");
}
void operator()(TypeId, const LazyTypeVar& ltv)
@ -844,7 +843,16 @@ struct TypeVarStringifier
state.emit("lazy?");
}
}; // namespace
void operator()(TypeId, const UnknownTypeVar& ttv)
{
state.emit("unknown");
}
void operator()(TypeId, const NeverTypeVar& ttv)
{
state.emit("never");
}
};
struct TypePackStringifier
{
@ -880,7 +888,10 @@ struct TypePackStringifier
if (tp->ty.valueless_by_exception())
{
state.result.error = true;
state.emit("< VALUELESS TP BY EXCEPTION >");
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("* VALUELESS TP BY EXCEPTION *");
else
state.emit("< VALUELESS TP BY EXCEPTION >");
return;
}
@ -904,7 +915,10 @@ struct TypePackStringifier
if (state.hasSeen(&tp))
{
state.result.cycle = true;
state.emit("<CYCLETP>");
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*CYCLETP*");
else
state.emit("<CYCLETP>");
return;
}
@ -949,14 +963,22 @@ struct TypePackStringifier
void operator()(TypePackId, const Unifiable::Error& error)
{
state.result.error = true;
state.emit("*unknown*");
if (FFlag::LuauSpecialTypesAsterisked)
state.emit(FFlag::LuauUnknownAndNeverType ? "*error-type*" : "*unknown*");
else
state.emit(FFlag::LuauUnknownAndNeverType ? "<error-type>" : "*unknown*");
}
void operator()(TypePackId, const VariadicTypePack& pack)
{
state.emit("...");
if (FFlag::DebugLuauVerboseTypeNames && pack.hidden)
state.emit("<hidden>");
{
if (FFlag::LuauSpecialTypesAsterisked)
state.emit("*hidden*");
else
state.emit("<hidden>");
}
stringify(pack.ty);
}
@ -1151,7 +1173,11 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts)
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
{
result.truncated = true;
result.name += "... <TRUNCATED>";
if (FFlag::LuauSpecialTypesAsterisked)
result.name += "... *TRUNCATED*";
else
result.name += "... <TRUNCATED>";
}
return result;
@ -1222,7 +1248,12 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts)
}
if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength)
result.name += "... <TRUNCATED>";
{
if (FFlag::LuauSpecialTypesAsterisked)
result.name += "... *TRUNCATED*";
else
result.name += "... <TRUNCATED>";
}
return result;
}
@ -1440,6 +1471,12 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
opts.nameMap = std::move(namedStr.nameMap);
return "@name(" + namedStr.name + ") = " + c.name;
}
else if constexpr (std::is_same_v<T, TypeAliasExpansionConstraint>)
{
ToStringResult targetStr = toStringDetailed(c.target, opts);
opts.nameMap = std::move(targetStr.nameMap);
return "expand " + targetStr.name;
}
else
static_assert(always_false_v<T>, "Non-exhaustive constraint switch");
};

View file

@ -205,20 +205,6 @@ struct Printer
}
}
void visualizeWithSelf(AstExpr& expr, bool self)
{
if (!self)
return visualize(expr);
AstExprIndexName* func = expr.as<AstExprIndexName>();
LUAU_ASSERT(func);
visualize(*func->expr);
writer.symbol(":");
advance(func->indexLocation.begin);
writer.identifier(func->index.value);
}
void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg)
{
advance(annotation.location.begin);
@ -366,7 +352,7 @@ struct Printer
}
else if (const auto& a = expr.as<AstExprCall>())
{
visualizeWithSelf(*a->func, a->self);
visualize(*a->func);
writer.symbol("(");
bool first = true;
@ -385,7 +371,7 @@ struct Printer
else if (const auto& a = expr.as<AstExprIndexName>())
{
visualize(*a->expr);
writer.symbol(".");
writer.symbol(std::string(1, a->op));
writer.write(a->index.value);
}
else if (const auto& a = expr.as<AstExprIndexExpr>())
@ -766,7 +752,7 @@ struct Printer
else if (const auto& a = program.as<AstStatFunction>())
{
writer.keyword("function");
visualizeWithSelf(*a->name, a->func->self != nullptr);
visualize(*a->name);
visualizeFunctionBody(*a->func);
}
else if (const auto& a = program.as<AstStatLocalFunction>())

View file

@ -7,7 +7,7 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauNonCopyableTypeVarFields)
LUAU_FASTFLAG(LuauUnknownAndNeverType)
namespace Luau
{
@ -81,34 +81,10 @@ void TxnLog::concat(TxnLog rhs)
void TxnLog::commit()
{
for (auto& [ty, rep] : typeVarChanges)
{
if (FFlag::LuauNonCopyableTypeVarFields)
{
asMutable(ty)->reassign(rep.get()->pending);
}
else
{
TypeArena* owningArena = ty->owningArena;
TypeVar* mtv = asMutable(ty);
*mtv = rep.get()->pending;
mtv->owningArena = owningArena;
}
}
asMutable(ty)->reassign(rep.get()->pending);
for (auto& [tp, rep] : typePackChanges)
{
if (FFlag::LuauNonCopyableTypeVarFields)
{
asMutable(tp)->reassign(rep.get()->pending);
}
else
{
TypeArena* owningArena = tp->owningArena;
TypePackVar* mpv = asMutable(tp);
*mpv = rep.get()->pending;
mpv->owningArena = owningArena;
}
}
asMutable(tp)->reassign(rep.get()->pending);
clear();
}
@ -196,9 +172,7 @@ PendingType* TxnLog::queue(TypeId ty)
if (!pending)
{
pending = std::make_unique<PendingType>(*ty);
if (FFlag::LuauNonCopyableTypeVarFields)
pending->pending.owningArena = nullptr;
pending->pending.owningArena = nullptr;
}
return pending.get();
@ -214,9 +188,7 @@ PendingTypePack* TxnLog::queue(TypePackId tp)
if (!pending)
{
pending = std::make_unique<PendingTypePack>(*tp);
if (FFlag::LuauNonCopyableTypeVarFields)
pending->pending.owningArena = nullptr;
pending->pending.owningArena = nullptr;
}
return pending.get();
@ -255,24 +227,14 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const
PendingType* TxnLog::replace(TypeId ty, TypeVar replacement)
{
PendingType* newTy = queue(ty);
if (FFlag::LuauNonCopyableTypeVarFields)
newTy->pending.reassign(replacement);
else
newTy->pending = replacement;
newTy->pending.reassign(replacement);
return newTy;
}
PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement)
{
PendingTypePack* newTp = queue(tp);
if (FFlag::LuauNonCopyableTypeVarFields)
newTp->pending.reassign(replacement);
else
newTp->pending = replacement;
newTp->pending.reassign(replacement);
return newTp;
}
@ -289,7 +251,7 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional<TypeId> newBoundTo)
PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel)
{
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty));
LUAU_ASSERT(get<FreeTypeVar>(ty) || get<TableTypeVar>(ty) || get<FunctionTypeVar>(ty) || get<ConstrainedTypeVar>(ty));
PendingType* newTy = queue(ty);
if (FreeTypeVar* ftv = Luau::getMutable<FreeTypeVar>(newTy))
@ -305,6 +267,11 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel)
{
ftv->level = newLevel;
}
else if (ConstrainedTypeVar* ctv = Luau::getMutable<ConstrainedTypeVar>(newTy))
{
if (FFlag::LuauUnknownAndNeverType)
ctv->level = newLevel;
}
return newTy;
}

View file

@ -99,6 +99,11 @@ public:
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*blocked*"));
}
AstType* operator()(const PendingExpansionTypeVar& petv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("*pending-expansion*"));
}
AstType* operator()(const ConstrainedTypeVar& ctv)
{
AstArray<AstType*> types;
@ -335,6 +340,14 @@ public:
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Lazy?>"));
}
AstType* operator()(const UnknownTypeVar& ttv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"unknown"});
}
AstType* operator()(const NeverTypeVar& ttv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName{"never"});
}
private:
Allocator* allocator;

View file

@ -67,8 +67,18 @@ struct TypeChecker2 : public AstVisitor
return follow(*ty);
}
TypePackId lookupPackAnnotation(AstTypePack* annotation)
{
TypePackId* tp = module->astResolvedTypePacks.find(annotation);
LUAU_ASSERT(tp);
return follow(*tp);
}
TypePackId reconstructPack(AstArray<AstExpr*> exprs, TypeArena& arena)
{
if (exprs.size == 0)
return arena.addTypePack(TypePack{{}, std::nullopt});
std::vector<TypeId> head;
for (size_t i = 0; i < exprs.size - 1; ++i)
@ -80,14 +90,14 @@ struct TypeChecker2 : public AstVisitor
return arena.addTypePack(TypePack{head, tail});
}
Scope2* findInnermostScope(Location location)
Scope* findInnermostScope(Location location)
{
Scope2* bestScope = module->getModuleScope2();
Location bestLocation = module->scope2s[0].first;
Scope* bestScope = module->getModuleScope().get();
Location bestLocation = module->scopes[0].first;
for (size_t i = 0; i < module->scope2s.size(); ++i)
for (size_t i = 0; i < module->scopes.size(); ++i)
{
auto& [scopeBounds, scope] = module->scope2s[i];
auto& [scopeBounds, scope] = module->scopes[i];
if (scopeBounds.encloses(location))
{
if (scopeBounds.begin > bestLocation.begin || scopeBounds.end < bestLocation.end)
@ -181,7 +191,7 @@ struct TypeChecker2 : public AstVisitor
bool visit(AstStatReturn* ret) override
{
Scope2* scope = findInnermostScope(ret->location);
Scope* scope = findInnermostScope(ret->location);
TypePackId expectedRetType = scope->returnType;
TypeArena arena;
@ -322,10 +332,13 @@ struct TypeChecker2 : public AstVisitor
{
pack = follow(pack);
while (auto tp = get<TypePack>(pack))
while (true)
{
if (tp->head.empty() && tp->tail)
auto tp = get<TypePack>(pack);
if (tp && tp->head.empty() && tp->tail)
pack = *tp->tail;
else
break;
}
if (auto ty = first(pack))
@ -356,13 +369,154 @@ struct TypeChecker2 : public AstVisitor
bool visit(AstTypeReference* ty) override
{
Scope2* scope = findInnermostScope(ty->location);
Scope* scope = findInnermostScope(ty->location);
LUAU_ASSERT(scope);
// TODO: Imported types
// TODO: Generic types
if (!scope->lookupTypeBinding(ty->name.value))
std::optional<TypeFun> alias = scope->lookupTypeBinding(ty->name.value);
if (alias.has_value())
{
reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location);
size_t typesRequired = alias->typeParams.size();
size_t packsRequired = alias->typePackParams.size();
bool hasDefaultTypes = std::any_of(alias->typeParams.begin(), alias->typeParams.end(), [](auto&& el) {
return el.defaultValue.has_value();
});
bool hasDefaultPacks = std::any_of(alias->typePackParams.begin(), alias->typePackParams.end(), [](auto&& el) {
return el.defaultValue.has_value();
});
if (!ty->hasParameterList)
{
if ((!alias->typeParams.empty() && !hasDefaultTypes) || (!alias->typePackParams.empty() && !hasDefaultPacks))
{
reportError(GenericError{"Type parameter list is required"}, ty->location);
}
}
size_t typesProvided = 0;
size_t extraTypes = 0;
size_t packsProvided = 0;
for (const AstTypeOrPack& p : ty->parameters)
{
if (p.type)
{
if (packsProvided != 0)
{
reportError(GenericError{"Type parameters must come before type pack parameters"}, ty->location);
}
if (typesProvided < typesRequired)
{
typesProvided += 1;
}
else
{
extraTypes += 1;
}
}
else if (p.typePack)
{
TypePackId tp = lookupPackAnnotation(p.typePack);
if (typesProvided < typesRequired && size(tp) == 1 && finite(tp) && first(tp))
{
typesProvided += 1;
}
else
{
packsProvided += 1;
}
}
}
if (extraTypes != 0 && packsProvided == 0)
{
packsProvided += 1;
}
for (size_t i = typesProvided; i < typesRequired; ++i)
{
if (alias->typeParams[i].defaultValue)
{
typesProvided += 1;
}
}
for (size_t i = packsProvided; i < packsProvided; ++i)
{
if (alias->typePackParams[i].defaultValue)
{
packsProvided += 1;
}
}
if (extraTypes == 0 && packsProvided + 1 == packsRequired)
{
packsProvided += 1;
}
if (typesProvided != typesRequired || packsProvided != packsRequired)
{
reportError(IncorrectGenericParameterCount{
/* name */ ty->name.value,
/* typeFun */ *alias,
/* actualParameters */ typesProvided,
/* actualPackParameters */ packsProvided,
},
ty->location);
}
}
else
{
if (scope->lookupTypePackBinding(ty->name.value))
{
reportError(
SwappedGenericTypeParameter{
ty->name.value,
SwappedGenericTypeParameter::Kind::Type,
},
ty->location);
}
else
{
reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location);
}
}
return true;
}
bool visit(AstTypePack*) override
{
return true;
}
bool visit(AstTypePackGeneric* tp) override
{
Scope* scope = findInnermostScope(tp->location);
LUAU_ASSERT(scope);
std::optional<TypePackId> alias = scope->lookupTypePackBinding(tp->genericName.value);
if (!alias.has_value())
{
if (scope->lookupTypeBinding(tp->genericName.value))
{
reportError(
SwappedGenericTypeParameter{
tp->genericName.value,
SwappedGenericTypeParameter::Kind::Pack,
},
tp->location);
}
else
{
reportError(UnknownSymbol{tp->genericName.value, UnknownSymbol::Context::Type}, tp->location);
}
}
return true;

File diff suppressed because it is too large Load diff

View file

@ -5,8 +5,6 @@
#include <stdexcept>
LUAU_FASTFLAG(LuauNonCopyableTypeVarFields)
namespace Luau
{
@ -40,19 +38,10 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp)
TypePackVar& TypePackVar::operator=(const TypePackVar& rhs)
{
if (FFlag::LuauNonCopyableTypeVarFields)
{
LUAU_ASSERT(owningArena == rhs.owningArena);
LUAU_ASSERT(!rhs.persistent);
LUAU_ASSERT(owningArena == rhs.owningArena);
LUAU_ASSERT(!rhs.persistent);
reassign(rhs);
}
else
{
ty = rhs.ty;
persistent = rhs.persistent;
owningArena = rhs.owningArena;
}
reassign(rhs);
return *this;
}
@ -294,6 +283,16 @@ std::optional<TypeId> first(TypePackId tp, bool ignoreHiddenVariadics)
return std::nullopt;
}
TypePackVar* asMutable(TypePackId tp)
{
return const_cast<TypePackVar*>(tp);
}
TypePack* asMutable(const TypePack* tp)
{
return const_cast<TypePack*>(tp);
}
bool isEmpty(TypePackId tp)
{
tp = follow(tp);
@ -360,13 +359,25 @@ bool isVariadic(TypePackId tp, const TxnLog& log)
return false;
}
TypePackVar* asMutable(TypePackId tp)
bool containsNever(TypePackId tp)
{
return const_cast<TypePackVar*>(tp);
auto it = begin(tp);
auto endIt = end(tp);
while (it != endIt)
{
if (get<NeverTypeVar>(follow(*it)))
return true;
++it;
}
if (auto tail = it.tail())
{
if (auto vtp = get<VariadicTypePack>(*tail); vtp && get<NeverTypeVar>(follow(vtp->ty)))
return true;
}
return false;
}
TypePack* asMutable(const TypePack* tp)
{
return const_cast<TypePack*>(tp);
}
} // namespace Luau

View file

@ -24,7 +24,7 @@ std::optional<TypeId> findMetatableEntry(ErrorVec& errors, TypeId type, std::str
const TableTypeVar* mtt = getTableType(unwrapped);
if (!mtt)
{
errors.push_back(TypeError{location, GenericError{"Metatable was not a table."}});
errors.push_back(TypeError{location, GenericError{"Metatable was not a table"}});
return std::nullopt;
}

View file

@ -23,7 +23,10 @@ LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauNonCopyableTypeVarFields)
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauDeduceGmatchReturnTypes, false)
LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false)
LUAU_FASTFLAGVARIABLE(LuauDeduceFindMatchReturnTypes, false)
namespace Luau
{
@ -31,6 +34,15 @@ namespace Luau
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
TypeId follow(TypeId t)
{
return follow(t, [](TypeId t) {
@ -159,10 +171,12 @@ bool isNumber(TypeId ty)
// Returns true when ty is a subtype of string
bool isString(TypeId ty)
{
if (isPrim(ty, PrimitiveTypeVar::String) || get<StringSingleton>(get<SingletonTypeVar>(follow(ty))))
ty = follow(ty);
if (isPrim(ty, PrimitiveTypeVar::String) || get<StringSingleton>(get<SingletonTypeVar>(ty)))
return true;
if (auto utv = get<UnionTypeVar>(follow(ty)))
if (auto utv = get<UnionTypeVar>(ty))
return std::all_of(begin(utv), end(utv), isString);
return false;
@ -194,7 +208,7 @@ bool isOptional(TypeId ty)
ty = follow(ty);
if (get<AnyTypeVar>(ty))
if (get<AnyTypeVar>(ty) || (FFlag::LuauUnknownAndNeverType && get<UnknownTypeVar>(ty)))
return true;
auto utv = get<UnionTypeVar>(ty);
@ -228,6 +242,8 @@ bool isOverloadedFunction(TypeId ty)
std::optional<TypeId> getMetatable(TypeId type)
{
type = follow(type);
if (const MetatableTypeVar* mtType = get<MetatableTypeVar>(type))
return mtType->metatable;
else if (const ClassTypeVar* classType = get<ClassTypeVar>(type))
@ -334,6 +350,28 @@ bool isGeneric(TypeId ty)
bool maybeGeneric(TypeId ty)
{
if (FFlag::LuauMaybeGenericIntersectionTypes)
{
ty = follow(ty);
if (get<FreeTypeVar>(ty))
return true;
if (auto ttv = get<TableTypeVar>(ty))
{
// TODO: recurse on table types CLI-39914
(void)ttv;
return true;
}
if (auto itv = get<IntersectionTypeVar>(ty))
{
return std::any_of(begin(itv), end(itv), maybeGeneric);
}
return isGeneric(ty);
}
ty = follow(ty);
if (get<FreeTypeVar>(ty))
return true;
@ -407,6 +445,16 @@ BlockedTypeVar::BlockedTypeVar()
int BlockedTypeVar::nextIndex = 0;
PendingExpansionTypeVar::PendingExpansionTypeVar(TypeFun fn, std::vector<TypeId> typeArguments, std::vector<TypePackId> packArguments)
: fn(fn)
, typeArguments(typeArguments)
, packArguments(packArguments)
, index(++nextIndex)
{
}
size_t PendingExpansionTypeVar::nextIndex = 0;
FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional<FunctionDefinition> defn, bool hasSelf)
: argTypes(argTypes)
, retTypes(retTypes)
@ -646,20 +694,10 @@ TypeVar& TypeVar::operator=(TypeVariant&& rhs)
TypeVar& TypeVar::operator=(const TypeVar& rhs)
{
if (FFlag::LuauNonCopyableTypeVarFields)
{
LUAU_ASSERT(owningArena == rhs.owningArena);
LUAU_ASSERT(!rhs.persistent);
LUAU_ASSERT(owningArena == rhs.owningArena);
LUAU_ASSERT(!rhs.persistent);
reassign(rhs);
}
else
{
ty = rhs.ty;
persistent = rhs.persistent;
normal = rhs.normal;
owningArena = rhs.owningArena;
}
reassign(rhs);
return *this;
}
@ -676,10 +714,14 @@ static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persist
static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true};
static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true};
static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true};
static TypeVar unknownType_{UnknownTypeVar{}, /*persistent*/ true};
static TypeVar neverType_{NeverTypeVar{}, /*persistent*/ true};
static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true};
static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true};
static TypePackVar errorTypePack_{Unifiable::Error{}};
static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, /*persistent*/ true};
static TypePackVar errorTypePack_{Unifiable::Error{}, /*persistent*/ true};
static TypePackVar neverTypePack_{VariadicTypePack{&neverType_}, /*persistent*/ true};
static TypePackVar uninhabitableTypePack_{TypePack{{&neverType_}, &neverTypePack_}, /*persistent*/ true};
SingletonTypes::SingletonTypes()
: nilType(&nilType_)
@ -690,7 +732,11 @@ SingletonTypes::SingletonTypes()
, trueType(&trueType_)
, falseType(&falseType_)
, anyType(&anyType_)
, unknownType(&unknownType_)
, neverType(&neverType_)
, anyTypePack(&anyTypePack_)
, neverTypePack(&neverTypePack_)
, uninhabitableTypePack(&uninhabitableTypePack_)
, arena(new TypeArena)
{
TypeId stringMetatable = makeStringMetatable();
@ -738,19 +784,26 @@ TypeId SingletonTypes::makeStringMetatable()
const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType});
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})});
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
const TypeId matchFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}),
arena->addTypePack(TypePackVar{VariadicTypePack{FFlag::LuauDeduceFindMatchReturnTypes ? stringType : optionalString}})});
attachMagicFunction(matchFunc, magicFunctionMatch);
const TypeId findFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})});
attachMagicFunction(findFunc, magicFunctionFind);
TableTypeVar::Props stringLib = {
{"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}},
{"char", {arena->addType(FunctionTypeVar{numberVariadicList, arena->addTypePack({stringType})})}},
{"find", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})})}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}),
arena->addTypePack(TypePackVar{VariadicTypePack{optionalString}})})}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
@ -911,6 +964,8 @@ const TypeLevel* getLevel(TypeId ty)
return &ttv->level;
else if (auto ftv = get<FunctionTypeVar>(ty))
return &ftv->level;
else if (auto ctv = get<ConstrainedTypeVar>(ty))
return &ctv->level;
else
return nullptr;
}
@ -965,94 +1020,19 @@ bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent)
return false;
}
UnionTypeVarIterator::UnionTypeVarIterator(const UnionTypeVar* utv)
const std::vector<TypeId>& getTypes(const UnionTypeVar* utv)
{
LUAU_ASSERT(utv);
if (!utv->options.empty())
stack.push_front({utv, 0});
seen.insert(utv);
return utv->options;
}
UnionTypeVarIterator& UnionTypeVarIterator::operator++()
const std::vector<TypeId>& getTypes(const IntersectionTypeVar* itv)
{
advance();
descend();
return *this;
return itv->parts;
}
UnionTypeVarIterator UnionTypeVarIterator::operator++(int)
const std::vector<TypeId>& getTypes(const ConstrainedTypeVar* ctv)
{
UnionTypeVarIterator copy = *this;
++copy;
return copy;
}
bool UnionTypeVarIterator::operator!=(const UnionTypeVarIterator& rhs)
{
return !(*this == rhs);
}
bool UnionTypeVarIterator::operator==(const UnionTypeVarIterator& rhs)
{
if (!stack.empty() && !rhs.stack.empty())
return stack.front() == rhs.stack.front();
return stack.empty() && rhs.stack.empty();
}
const TypeId& UnionTypeVarIterator::operator*()
{
LUAU_ASSERT(!stack.empty());
descend();
auto [utv, currentIndex] = stack.front();
LUAU_ASSERT(utv);
LUAU_ASSERT(currentIndex < utv->options.size());
const TypeId& ty = utv->options[currentIndex];
LUAU_ASSERT(!get<UnionTypeVar>(follow(ty)));
return ty;
}
void UnionTypeVarIterator::advance()
{
while (!stack.empty())
{
auto& [utv, currentIndex] = stack.front();
++currentIndex;
if (currentIndex >= utv->options.size())
stack.pop_front();
else
break;
}
}
void UnionTypeVarIterator::descend()
{
while (!stack.empty())
{
auto [utv, currentIndex] = stack.front();
if (auto innerUnion = get<UnionTypeVar>(follow(utv->options[currentIndex])))
{
// If we're about to descend into a cyclic UnionTypeVar, we should skip over this.
// Ideally this should never happen, but alas it does from time to time. :(
if (seen.find(innerUnion) != seen.end())
advance();
else
{
seen.insert(innerUnion);
stack.push_front({innerUnion, 0});
}
continue;
}
break;
}
return ctv->parts;
}
UnionTypeVarIterator begin(const UnionTypeVar* utv)
@ -1065,9 +1045,30 @@ UnionTypeVarIterator end(const UnionTypeVar* utv)
return UnionTypeVarIterator{};
}
IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv)
{
return IntersectionTypeVarIterator{itv};
}
IntersectionTypeVarIterator end(const IntersectionTypeVar* itv)
{
return IntersectionTypeVarIterator{};
}
ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv)
{
return ConstrainedTypeVarIterator{ctv};
}
ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv)
{
return ConstrainedTypeVarIterator{};
}
static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const char* data, size_t size)
{
const char* options = "cdiouxXeEfgGqs";
const char* options = "cdiouxXeEfgGqs*";
std::vector<TypeId> result;
@ -1081,7 +1082,7 @@ static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const cha
continue;
// we just ignore all characters (including flags/precision) up until first alphabetic character
while (i < size && !(data[i] > 0 && isalpha(data[i])))
while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*')))
i++;
if (i == size)
@ -1089,6 +1090,8 @@ static std::vector<TypeId> parseFormatString(TypeChecker& typechecker, const cha
if (data[i] == 'q' || data[i] == 's')
result.push_back(typechecker.stringType);
else if (data[i] == '*')
result.push_back(typechecker.unknownType);
else if (strchr(options, data[i]))
result.push_back(typechecker.numberType);
else
@ -1144,6 +1147,197 @@ std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
}
static std::vector<TypeId> parsePatternString(TypeChecker& typechecker, const char* data, size_t size)
{
std::vector<TypeId> result;
int depth = 0;
bool parsingSet = false;
for (size_t i = 0; i < size; ++i)
{
if (data[i] == '%')
{
++i;
if (!parsingSet && i < size && data[i] == 'b')
i += 2;
}
else if (!parsingSet && data[i] == '[')
{
parsingSet = true;
if (i + 1 < size && data[i + 1] == ']')
i += 1;
}
else if (parsingSet && data[i] == ']')
{
parsingSet = false;
}
else if (data[i] == '(')
{
if (parsingSet)
continue;
if (i + 1 < size && data[i + 1] == ')')
{
i++;
result.push_back(typechecker.numberType);
continue;
}
++depth;
result.push_back(typechecker.stringType);
}
else if (data[i] == ')')
{
if (parsingSet)
continue;
--depth;
if (depth < 0)
break;
}
}
if (depth != 0 || parsingSet)
return std::vector<TypeId>();
if (result.empty())
result.push_back(typechecker.stringType);
return result;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
if (!FFlag::LuauDeduceGmatchReturnTypes)
return std::nullopt;
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() != 2)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t index = expr.self ? 0 : 1;
if (expr.args.size > index)
pattern = expr.args.data[index]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location);
const TypePackId emptyPack = arena.addTypePack({});
const TypePackId returnList = arena.addTypePack(returnTypes);
const TypeId iteratorType = arena.addType(FunctionTypeVar{emptyPack, returnList});
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
}
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
if (!FFlag::LuauDeduceFindMatchReturnTypes)
return std::nullopt;
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 3)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() == 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, expr.args.data[initIndex]->location);
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
if (!FFlag::LuauDeduceFindMatchReturnTypes)
return std::nullopt;
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 4)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
bool plain = false;
size_t plainIndex = expr.self ? 2 : 3;
if (expr.args.size > plainIndex)
{
AstExprConstantBool* p = expr.args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
}
typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}});
const TypeId optionalBoolean = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.booleanType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() >= 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, expr.args.data[initIndex]->location);
if (params.size() == 4 && expr.args.size > plainIndex)
typechecker.unify(params[3], optionalBoolean, expr.args.data[plainIndex]->location);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate)
{
type = follow(type);
@ -1228,4 +1422,19 @@ bool hasTag(const Property& prop, const std::string& tagName)
return hasTag(prop.tags, tagName);
}
bool TypeFun::operator==(const TypeFun& rhs) const
{
return type == rhs.type && typeParams == rhs.typeParams && typePackParams == rhs.typePackParams;
}
bool GenericTypeDefinition::operator==(const GenericTypeDefinition& rhs) const
{
return ty == rhs.ty && defaultValue == rhs.defaultValue;
}
bool GenericTypePackDefinition::operator==(const GenericTypePackDefinition& rhs) const
{
return tp == rhs.tp && defaultValue == rhs.defaultValue;
}
} // namespace Luau

View file

@ -12,7 +12,7 @@ Free::Free(TypeLevel level)
{
}
Free::Free(Scope2* scope)
Free::Free(Scope* scope)
: scope(scope)
{
}
@ -39,7 +39,7 @@ Generic::Generic(const Name& name)
{
}
Generic::Generic(Scope2* scope)
Generic::Generic(Scope* scope)
: index(++nextIndex)
, scope(scope)
{
@ -53,7 +53,7 @@ Generic::Generic(TypeLevel level, const Name& name)
{
}
Generic::Generic(Scope2* scope, const Name& name)
Generic::Generic(Scope* scope, const Name& name)
: index(++nextIndex)
, scope(scope)
, name(name)

View file

@ -19,7 +19,9 @@ LUAU_FASTFLAG(LuauAutocompleteDynamicLimits)
LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000);
LUAU_FASTFLAG(LuauLowerBoundsCalculation);
LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(LuauQuantifyConstrained)
LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false)
namespace Luau
{
@ -47,33 +49,6 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor
}
}
// TODO cycle and operator() need to be clipped when FFlagLuauUseVisitRecursionLimit is clipped
template<typename TID>
void cycle(TID)
{
}
template<typename TID, typename T>
bool operator()(TID ty, const T&)
{
return visit(ty);
}
bool operator()(TypeId ty, const FreeTypeVar& ftv)
{
return visit(ty, ftv);
}
bool operator()(TypeId ty, const FunctionTypeVar& ftv)
{
return visit(ty, ftv);
}
bool operator()(TypeId ty, const TableTypeVar& ttv)
{
return visit(ty, ttv);
}
bool operator()(TypePackId tp, const FreeTypePack& ftp)
{
return visit(tp, ftp);
}
bool visit(TypeId ty) override
{
// Type levels of types from other modules are already global, so we don't need to promote anything inside
@ -103,6 +78,15 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor
return true;
}
bool visit(TypeId ty, const ConstrainedTypeVar&) override
{
if (!FFlag::LuauUnknownAndNeverType)
return visit(ty);
promote(ty, log.getMutable<ConstrainedTypeVar>(ty));
return true;
}
bool visit(TypeId ty, const FunctionTypeVar&) override
{
// Type levels of types from other modules are already global, so we don't need to promote anything inside
@ -445,6 +429,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
}
else if (subFree)
{
if (FFlag::LuauUnknownAndNeverType)
{
// Normally, if the subtype is free, it should not be bound to any, unknown, or error types.
// But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors.
if (log.get<UnknownTypeVar>(superTy))
return;
}
TypeLevel subLevel = subFree->level;
occursCheck(subTy, superTy);
@ -468,7 +460,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
return;
}
if (get<ErrorTypeVar>(superTy) || get<AnyTypeVar>(superTy))
if (get<ErrorTypeVar>(superTy) || get<AnyTypeVar>(superTy) || get<UnknownTypeVar>(superTy))
return tryUnifyWithAny(subTy, superTy);
if (get<AnyTypeVar>(subTy))
@ -482,7 +474,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
return tryUnifyWithAny(superTy, subTy);
}
if (get<ErrorTypeVar>(subTy))
if (log.get<ErrorTypeVar>(subTy))
return tryUnifyWithAny(superTy, subTy);
if (log.get<NeverTypeVar>(subTy))
return tryUnifyWithAny(superTy, subTy);
auto& cache = sharedState.cachedUnify;
@ -544,6 +539,16 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
{
tryUnifyTables(subTy, superTy, isIntersection);
}
else if (FFlag::LuauScalarShapeSubtyping && log.get<TableTypeVar>(superTy) &&
(log.get<PrimitiveTypeVar>(subTy) || log.get<SingletonTypeVar>(subTy)))
{
tryUnifyScalarShape(subTy, superTy, /*reversed*/ false);
}
else if (FFlag::LuauScalarShapeSubtyping && log.get<TableTypeVar>(subTy) &&
(log.get<PrimitiveTypeVar>(superTy) || log.get<SingletonTypeVar>(superTy)))
{
tryUnifyScalarShape(subTy, superTy, /*reversed*/ true);
}
// tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical.
else if (log.getMutable<MetatableTypeVar>(superTy))
@ -1606,6 +1611,60 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
}
}
void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
{
LUAU_ASSERT(FFlag::LuauScalarShapeSubtyping);
TypeId osubTy = subTy;
TypeId osuperTy = superTy;
if (reversed)
std::swap(subTy, superTy);
if (auto ttv = log.get<TableTypeVar>(superTy); !ttv || ttv->state != TableState::Free)
return reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}});
auto fail = [&](std::optional<TypeError> e) {
std::string reason = "The former's metatable does not satisfy the requirements.";
if (e)
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason, *e}});
else
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason}});
};
// Given t1 where t1 = { lower: (t1) -> (a, b...) }
// It should be the case that `string <: t1` iff `(subtype's metatable).__index <: t1`
if (auto metatable = getMetatable(subTy))
{
auto mttv = log.get<TableTypeVar>(*metatable);
if (!mttv)
fail(std::nullopt);
if (auto it = mttv->props.find("__index"); it != mttv->props.end())
{
TypeId ty = it->second.type;
Unifier child = makeChildUnifier();
child.tryUnify_(ty, superTy);
if (auto e = hasUnificationTooComplex(child.errors))
reportError(*e);
else if (!child.errors.empty())
fail(child.errors.front());
log.concat(std::move(child.log));
return;
}
else
{
return fail(std::nullopt);
}
}
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}});
return;
}
TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> seen)
{
ty = follow(ty);
@ -1862,6 +1921,7 @@ static void tryUnifyWithAny(std::vector<TypeId>& queue, Unifier& state, DenseHas
if (state.log.getMutable<FreeTypeVar>(ty))
{
// TODO: Only bind if the anyType isn't any, unknown, or error (?)
state.log.replace(ty, BoundTypeVar{anyType});
}
else if (auto fun = state.log.getMutable<FunctionTypeVar>(ty))
@ -1901,22 +1961,28 @@ static void tryUnifyWithAny(std::vector<TypeId>& queue, Unifier& state, DenseHas
void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy)
{
LUAU_ASSERT(get<AnyTypeVar>(anyTy) || get<ErrorTypeVar>(anyTy));
LUAU_ASSERT(get<AnyTypeVar>(anyTy) || get<ErrorTypeVar>(anyTy) || get<UnknownTypeVar>(anyTy) || get<NeverTypeVar>(anyTy));
// These types are not visited in general loop below
if (get<PrimitiveTypeVar>(subTy) || get<AnyTypeVar>(subTy) || get<ClassTypeVar>(subTy))
return;
const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}});
const TypePackId anyTP = get<AnyTypeVar>(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}});
TypePackId anyTp;
if (FFlag::LuauUnknownAndNeverType)
anyTp = types->addTypePack(TypePackVar{VariadicTypePack{anyTy}});
else
{
const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}});
anyTp = get<AnyTypeVar>(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}});
}
std::vector<TypeId> queue = {subTy};
sharedState.tempSeenTy.clear();
sharedState.tempSeenTp.clear();
Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, getSingletonTypes().anyType, anyTP);
Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types,
FFlag::LuauUnknownAndNeverType ? anyTy : getSingletonTypes().anyType, anyTp);
}
void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp)

View file

@ -474,16 +474,26 @@ public:
bool value;
};
enum class ConstantNumberParseResult
{
Ok,
Malformed,
BinOverflow,
HexOverflow,
DoublePrefix,
};
class AstExprConstantNumber : public AstExpr
{
public:
LUAU_RTTI(AstExprConstantNumber)
AstExprConstantNumber(const Location& location, double value);
AstExprConstantNumber(const Location& location, double value, ConstantNumberParseResult parseResult = ConstantNumberParseResult::Ok);
void visit(AstVisitor* visitor) override;
double value;
ConstantNumberParseResult parseResult;
};
class AstExprConstantString : public AstExpr

View file

@ -50,9 +50,10 @@ void AstExprConstantBool::visit(AstVisitor* visitor)
visitor->visit(this);
}
AstExprConstantNumber::AstExprConstantNumber(const Location& location, double value)
AstExprConstantNumber::AstExprConstantNumber(const Location& location, double value, ConstantNumberParseResult parseResult)
: AstExpr(ClassIndex(), location)
, value(value)
, parseResult(parseResult)
{
}

View file

@ -15,14 +15,14 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000)
LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false)
LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false)
LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false)
bool lua_telemetry_parsed_named_non_function_type = false;
LUAU_FASTFLAGVARIABLE(LuauErrorParseIntegerIssues, false)
LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false)
LUAU_FASTFLAGVARIABLE(LuauLintParseIntegerIssues, false)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false)
bool lua_telemetry_parsed_out_of_range_bin_integer = false;
@ -1134,10 +1134,9 @@ AstTypePack* Parser::parseTypeList(TempVector<AstType*>& result, TempVector<std:
std::optional<AstTypeList> Parser::parseOptionalReturnTypeAnnotation()
{
if (options.allowTypeAnnotations &&
(lexer.current().type == ':' || (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow)))
if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow))
{
if (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow)
if (lexer.current().type == Lexeme::SkinnyArrow)
report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'");
nextLexeme();
@ -1373,12 +1372,10 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
if (FFlag::LuauFixNamedFunctionParse && !names.empty())
forceFunctionType = true;
bool returnTypeIntroducer =
FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false;
bool returnTypeIntroducer = lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':';
// Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element
if (params.size() == 1 && !varargAnnotation && !forceFunctionType &&
(FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow))
if (params.size() == 1 && !varargAnnotation && !forceFunctionType && !returnTypeIntroducer)
{
if (DFFlag::LuaReportParseWrongNamedType && !names.empty())
lua_telemetry_parsed_named_non_function_type = true;
@ -1389,8 +1386,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack)
return {params[0], {}};
}
if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && !forceFunctionType &&
allowPack)
if (!forceFunctionType && !returnTypeIntroducer && allowPack)
{
if (DFFlag::LuaReportParseWrongNamedType && !names.empty())
lua_telemetry_parsed_named_non_function_type = true;
@ -1409,7 +1405,7 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray<A
{
incrementRecursionCounter("type annotation");
if (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == ':')
if (lexer.current().type == ':')
{
report(lexer.current().location, "Return types in function type annotations are written after '->' instead of ':'");
lexer.next();
@ -2037,8 +2033,10 @@ AstExpr* Parser::parseAssertionExpr()
return expr;
}
static const char* parseInteger(double& result, const char* data, int base)
static const char* parseInteger_DEPRECATED(double& result, const char* data, int base)
{
LUAU_ASSERT(!FFlag::LuauLintParseIntegerIssues);
char* end = nullptr;
unsigned long long value = strtoull(data, &end, base);
@ -2058,9 +2056,6 @@ static const char* parseInteger(double& result, const char* data, int base)
else
lua_telemetry_parsed_out_of_range_hex_integer = true;
}
if (FFlag::LuauErrorParseIntegerIssues)
return "Integer number value is out of range";
}
}
@ -2068,11 +2063,13 @@ static const char* parseInteger(double& result, const char* data, int base)
return *end == 0 ? nullptr : "Malformed number";
}
static const char* parseNumber(double& result, const char* data)
static const char* parseNumber_DEPRECATED2(double& result, const char* data)
{
LUAU_ASSERT(!FFlag::LuauLintParseIntegerIssues);
// binary literal
if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2])
return parseInteger(result, data + 2, 2);
return parseInteger_DEPRECATED(result, data + 2, 2);
// hexadecimal literal
if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2])
@ -2080,10 +2077,7 @@ static const char* parseNumber(double& result, const char* data)
if (DFFlag::LuaReportParseIntegerIssues && data[2] == '0' && (data[3] == 'x' || data[3] == 'X'))
lua_telemetry_parsed_double_prefix_hex_integer = true;
if (FFlag::LuauErrorParseIntegerIssues)
return parseInteger(result, data, 16); // keep prefix, it's handled by 'strtoull'
else
return parseInteger(result, data + 2, 16);
return parseInteger_DEPRECATED(result, data + 2, 16);
}
char* end = nullptr;
@ -2095,6 +2089,8 @@ static const char* parseNumber(double& result, const char* data)
static bool parseNumber_DEPRECATED(double& result, const char* data)
{
LUAU_ASSERT(!FFlag::LuauLintParseIntegerIssues);
// binary literal
if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2])
{
@ -2123,6 +2119,73 @@ static bool parseNumber_DEPRECATED(double& result, const char* data)
}
}
static ConstantNumberParseResult parseInteger(double& result, const char* data, int base)
{
LUAU_ASSERT(FFlag::LuauLintParseIntegerIssues);
LUAU_ASSERT(base == 2 || base == 16);
char* end = nullptr;
unsigned long long value = strtoull(data, &end, base);
if (*end != 0)
return ConstantNumberParseResult::Malformed;
result = double(value);
if (value == ULLONG_MAX && errno == ERANGE)
{
// 'errno' might have been set before we called 'strtoull', but we don't want the overhead of resetting a TLS variable on each call
// so we only reset it when we get a result that might be an out-of-range error and parse again to make sure
errno = 0;
value = strtoull(data, &end, base);
if (errno == ERANGE)
{
if (DFFlag::LuaReportParseIntegerIssues)
{
if (base == 2)
lua_telemetry_parsed_out_of_range_bin_integer = true;
else
lua_telemetry_parsed_out_of_range_hex_integer = true;
}
return base == 2 ? ConstantNumberParseResult::BinOverflow : ConstantNumberParseResult::HexOverflow;
}
}
return ConstantNumberParseResult::Ok;
}
static ConstantNumberParseResult parseNumber(double& result, const char* data)
{
LUAU_ASSERT(FFlag::LuauLintParseIntegerIssues);
// binary literal
if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2])
return parseInteger(result, data + 2, 2);
// hexadecimal literal
if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2])
{
if (!FFlag::LuauErrorDoubleHexPrefix && data[2] == '0' && (data[3] == 'x' || data[3] == 'X'))
{
if (DFFlag::LuaReportParseIntegerIssues)
lua_telemetry_parsed_double_prefix_hex_integer = true;
ConstantNumberParseResult parseResult = parseInteger(result, data + 2, 16);
return parseResult == ConstantNumberParseResult::Malformed ? parseResult : ConstantNumberParseResult::DoublePrefix;
}
return parseInteger(result, data, 16); // pass in '0x' prefix, it's handled by 'strtoull'
}
char* end = nullptr;
double value = strtod(data, &end);
result = value;
return *end == 0 ? ConstantNumberParseResult::Ok : ConstantNumberParseResult::Malformed;
}
// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp
AstExpr* Parser::parseSimpleExpr()
{
@ -2163,10 +2226,21 @@ AstExpr* Parser::parseSimpleExpr()
scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end());
}
if (DFFlag::LuaReportParseIntegerIssues || FFlag::LuauErrorParseIntegerIssues)
if (FFlag::LuauLintParseIntegerIssues)
{
double value = 0;
if (const char* error = parseNumber(value, scratchData.c_str()))
ConstantNumberParseResult result = parseNumber(value, scratchData.c_str());
nextLexeme();
if (result == ConstantNumberParseResult::Malformed)
return reportExprError(start, {}, "Malformed number");
return allocator.alloc<AstExprConstantNumber>(start, value, result);
}
else if (DFFlag::LuaReportParseIntegerIssues)
{
double value = 0;
if (const char* error = parseNumber_DEPRECATED2(value, scratchData.c_str()))
{
nextLexeme();
@ -2923,37 +2997,34 @@ AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const
void Parser::nextLexeme()
{
if (options.captureComments)
{
Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type;
Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type;
while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment)
{
const Lexeme& lexeme = lexer.current();
while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment)
{
const Lexeme& lexeme = lexer.current();
if (options.captureComments)
commentLocations.push_back(Comment{lexeme.type, lexeme.location});
// Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme.
// The parser will turn this into a proper syntax error.
if (lexeme.type == Lexeme::BrokenComment)
return;
// Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme.
// The parser will turn this into a proper syntax error.
if (lexeme.type == Lexeme::BrokenComment)
return;
// Comments starting with ! are called "hot comments" and contain directives for type checking / linting
if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!')
{
const char* text = lexeme.data;
// Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling
if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!')
{
const char* text = lexeme.data;
unsigned int end = lexeme.length;
while (end > 0 && isSpace(text[end - 1]))
--end;
unsigned int end = lexeme.length;
while (end > 0 && isSpace(text[end - 1]))
--end;
hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)});
}
type = lexer.next(/* skipComments= */ false, /* updatePrevLocation= */ false).type;
hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)});
}
type = lexer.next(/* skipComments= */ false, /* updatePrevLocation= */ false).type;
}
else
lexer.next();
}
} // namespace Luau

View file

@ -7,6 +7,11 @@
#include "Luau/Transpiler.h"
#include "FileUtils.h"
#include "Flags.h"
#ifdef CALLGRIND
#include <valgrind/callgrind.h>
#endif
LUAU_FASTFLAG(DebugLuauTimeTracing)
LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution)
@ -112,6 +117,7 @@ static void displayHelp(const char* argv0)
printf("Available options:\n");
printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n");
printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n");
printf(" --mode=strict: default to strict mode when typechecking\n");
printf(" --timetrace: record compiler time tracing information into trace.json\n");
}
@ -178,9 +184,9 @@ struct CliConfigResolver : Luau::ConfigResolver
mutable std::unordered_map<std::string, Luau::Config> configCache;
mutable std::vector<std::pair<std::string, std::string>> configErrors;
CliConfigResolver()
CliConfigResolver(Luau::Mode mode)
{
defaultConfig.mode = Luau::Mode::Nonstrict;
defaultConfig.mode = mode;
}
const Luau::Config& getConfig(const Luau::ModuleName& name) const override
@ -218,9 +224,7 @@ int main(int argc, char** argv)
{
Luau::assertHandler() = assertionHandler;
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
if (strncmp(flag->name, "Luau", 4) == 0)
flag->value = true;
setLuauFlagsDefault();
if (argc >= 2 && strcmp(argv[1], "--help") == 0)
{
@ -229,6 +233,7 @@ int main(int argc, char** argv)
}
ReportFormat format = ReportFormat::Default;
Luau::Mode mode = Luau::Mode::Nonstrict;
bool annotate = false;
for (int i = 1; i < argc; ++i)
@ -240,16 +245,20 @@ int main(int argc, char** argv)
format = ReportFormat::Luacheck;
else if (strcmp(argv[i], "--formatter=gnu") == 0)
format = ReportFormat::Gnu;
else if (strcmp(argv[i], "--mode=strict") == 0)
mode = Luau::Mode::Strict;
else if (strcmp(argv[i], "--annotate") == 0)
annotate = true;
else if (strcmp(argv[i], "--timetrace") == 0)
FFlag::DebugLuauTimeTracing.value = true;
else if (strncmp(argv[i], "--fflags=", 9) == 0)
setLuauFlags(argv[i] + 9);
}
#if !defined(LUAU_ENABLE_TIME_TRACE)
if (FFlag::DebugLuauTimeTracing)
{
printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n");
fprintf(stderr, "To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n");
return 1;
}
#endif
@ -258,12 +267,16 @@ int main(int argc, char** argv)
frontendOptions.retainFullTypeGraphs = annotate;
CliFileResolver fileResolver;
CliConfigResolver configResolver;
CliConfigResolver configResolver(mode);
Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions);
Luau::registerBuiltinTypes(frontend.typeChecker);
Luau::freeze(frontend.typeChecker.globalTypes);
#ifdef CALLGRIND
CALLGRIND_ZERO_STATS;
#endif
std::vector<std::string> files = getSourceFiles(argc, argv);
int failed = 0;

View file

@ -3,7 +3,7 @@
#include "Luau/Common.h"
#include "Luau/Ast.h"
#include "Luau/JsonEncoder.h"
#include "Luau/AstJsonEncoder.h"
#include "Luau/Parser.h"
#include "Luau/ParseOptions.h"
@ -62,6 +62,7 @@ int main(int argc, char** argv)
Luau::AstNameTable names(allocator);
Luau::ParseOptions options;
options.captureComments = true;
options.supportContinueStatement = true;
options.allowTypeAnnotations = true;
options.allowDeclarationSyntax = true;
@ -78,7 +79,7 @@ int main(int argc, char** argv)
fprintf(stderr, "\n");
}
printf("%s", Luau::toJson(parseResult.root).c_str());
printf("%s", Luau::toJson(parseResult.root, parseResult.commentLocations).c_str());
return parseResult.errors.size() > 0 ? 1 : 0;
}

75
CLI/Flags.cpp Normal file
View file

@ -0,0 +1,75 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Common.h"
#include "Luau/ExperimentalFlags.h"
#include <string_view>
#include <stdio.h>
#include <string.h>
static void setLuauFlag(std::string_view name, bool state)
{
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
{
if (name == flag->name)
{
flag->value = state;
return;
}
}
fprintf(stderr, "Warning: unrecognized flag '%.*s'.\n", int(name.length()), name.data());
}
static void setLuauFlags(bool state)
{
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
if (strncmp(flag->name, "Luau", 4) == 0)
flag->value = state;
}
void setLuauFlagsDefault()
{
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
if (strncmp(flag->name, "Luau", 4) == 0 && !Luau::isFlagExperimental(flag->name))
flag->value = true;
}
void setLuauFlags(const char* list)
{
std::string_view rest = list;
while (!rest.empty())
{
size_t ending = rest.find(",");
std::string_view element = rest.substr(0, ending);
if (size_t separator = element.find('='); separator != std::string_view::npos)
{
std::string_view key = element.substr(0, separator);
std::string_view value = element.substr(separator + 1);
if (value == "true" || value == "True")
setLuauFlag(key, true);
else if (value == "false" || value == "False")
setLuauFlag(key, false);
else
fprintf(stderr, "Warning: unrecognized value '%.*s' for flag '%.*s'.\n", int(value.length()), value.data(), int(key.length()),
key.data());
}
else
{
if (element == "true" || element == "True")
setLuauFlags(true);
else if (element == "false" || element == "False")
setLuauFlags(false);
else
setLuauFlag(element, true);
}
if (ending != std::string_view::npos)
rest.remove_prefix(ending + 1);
else
break;
}
}

5
CLI/Flags.h Normal file
View file

@ -0,0 +1,5 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
void setLuauFlagsDefault();
void setLuauFlags(const char* list);

View file

@ -8,9 +8,10 @@
#include "Luau/BytecodeBuilder.h"
#include "Luau/Parser.h"
#include "FileUtils.h"
#include "Profiler.h"
#include "Coverage.h"
#include "FileUtils.h"
#include "Flags.h"
#include "Profiler.h"
#include "isocline.h"
@ -19,6 +20,9 @@
#ifdef _WIN32
#include <io.h>
#include <fcntl.h>
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#endif
#ifdef CALLGRIND
@ -26,6 +30,7 @@
#endif
#include <locale.h>
#include <signal.h>
LUAU_FASTFLAG(DebugLuauTimeTracing)
@ -46,6 +51,35 @@ enum class CompileFormat
constexpr int MaxTraversalLimit = 50;
// Ctrl-C handling
static void sigintCallback(lua_State* L, int gc)
{
if (gc >= 0)
return;
lua_callbacks(L)->interrupt = NULL;
lua_rawcheckstack(L, 1); // reserve space for error string
luaL_error(L, "Execution interrupted");
}
static lua_State* replState = NULL;
#ifdef _WIN32
BOOL WINAPI sigintHandler(DWORD signal)
{
if (signal == CTRL_C_EVENT && replState)
lua_callbacks(replState)->interrupt = &sigintCallback;
return TRUE;
}
#else
static void sigintHandler(int signum)
{
if (signum == SIGINT && replState)
lua_callbacks(replState)->interrupt = &sigintCallback;
}
#endif
struct GlobalOptions
{
int optimizationLevel = 1;
@ -75,8 +109,8 @@ static int lua_loadstring(lua_State* L)
return 1;
lua_pushnil(L);
lua_insert(L, -2); /* put before error message */
return 2; /* return nil plus error message */
lua_insert(L, -2); // put before error message
return 2; // return nil plus error message
}
static int finishrequire(lua_State* L)
@ -97,7 +131,11 @@ static int lua_require(lua_State* L)
// return the module from the cache
lua_getfield(L, -1, name.c_str());
if (!lua_isnil(L, -1))
{
// L stack: _MODULES result
return finishrequire(L);
}
lua_pop(L, 1);
std::optional<std::string> source = readFile(name + ".luau");
@ -109,6 +147,7 @@ static int lua_require(lua_State* L)
}
// module needs to run in a new thread, isolated from the rest
// note: we create ML on main thread so that it doesn't inherit environment of L
lua_State* GL = lua_mainthread(L);
lua_State* ML = lua_newthread(GL);
lua_xmove(GL, L, 1);
@ -142,11 +181,12 @@ static int lua_require(lua_State* L)
}
}
// there's now a return value on top of ML; stack of L is MODULES thread
// there's now a return value on top of ML; L stack: _MODULES ML
lua_xmove(ML, L, 1);
lua_pushvalue(L, -1);
lua_setfield(L, -4, name.c_str());
// L stack: _MODULES ML result
return finishrequire(L);
}
@ -528,6 +568,15 @@ static void runRepl()
lua_State* L = globalState.get();
setupState(L);
// setup Ctrl+C handling
replState = L;
#ifdef _WIN32
SetConsoleCtrlHandler(sigintHandler, TRUE);
#else
signal(SIGINT, sigintHandler);
#endif
luaL_sandboxthread(L);
runReplImpl(L);
}
@ -682,60 +731,11 @@ static int assertionHandler(const char* expr, const char* file, int line, const
return 1;
}
static void setLuauFlags(bool state)
{
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
{
if (strncmp(flag->name, "Luau", 4) == 0)
flag->value = state;
}
}
static void setFlag(std::string_view name, bool state)
{
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
{
if (name == flag->name)
{
flag->value = state;
return;
}
}
fprintf(stderr, "Warning: --fflag unrecognized flag '%.*s'.\n\n", int(name.length()), name.data());
}
static void applyFlagKeyValue(std::string_view element)
{
if (size_t separator = element.find('='); separator != std::string_view::npos)
{
std::string_view key = element.substr(0, separator);
std::string_view value = element.substr(separator + 1);
if (value == "true")
setFlag(key, true);
else if (value == "false")
setFlag(key, false);
else
fprintf(stderr, "Warning: --fflag unrecognized value '%.*s' for flag '%.*s'.\n\n", int(value.length()), value.data(), int(key.length()),
key.data());
}
else
{
if (element == "true")
setLuauFlags(true);
else if (element == "false")
setLuauFlags(false);
else
setFlag(element, true);
}
}
int replMain(int argc, char** argv)
{
Luau::assertHandler() = assertionHandler;
setLuauFlags(true);
setLuauFlagsDefault();
CliMode mode = CliMode::Unknown;
CompileFormat compileFormat{};
@ -818,27 +818,10 @@ int replMain(int argc, char** argv)
else if (strcmp(argv[i], "--timetrace") == 0)
{
FFlag::DebugLuauTimeTracing.value = true;
#if !defined(LUAU_ENABLE_TIME_TRACE)
printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n");
return 1;
#endif
}
else if (strncmp(argv[i], "--fflags=", 9) == 0)
{
std::string_view list = argv[i] + 9;
while (!list.empty())
{
size_t ending = list.find(",");
applyFlagKeyValue(list.substr(0, ending));
if (ending != std::string_view::npos)
list.remove_prefix(ending + 1);
else
break;
}
setLuauFlags(argv[i] + 9);
}
else if (argv[i][0] == '-')
{
@ -848,6 +831,14 @@ int replMain(int argc, char** argv)
}
}
#if !defined(LUAU_ENABLE_TIME_TRACE)
if (FFlag::DebugLuauTimeTracing)
{
fprintf(stderr, "To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n");
return 1;
}
#endif
const std::vector<std::string> files = getSourceFiles(argc, argv);
if (mode == CliMode::Unknown)
{

View file

@ -1,8 +1,6 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Repl.h"
int main(int argc, char** argv)
{
return replMain(argc, argv);

View file

@ -175,6 +175,7 @@ endif()
if(LUAU_BUILD_TESTS)
target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS})
target_compile_definitions(Luau.UnitTest PRIVATE DOCTEST_CONFIG_DOUBLE_STRINGIFY)
target_include_directories(Luau.UnitTest PRIVATE extern)
target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen)

View file

@ -58,12 +58,25 @@ public:
void jmp(Label& label);
void jmp(OperandX64 op);
void call(Label& label);
void call(OperandX64 op);
void int3();
// AVX
void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vaddsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vaddss(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vsubsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vmulsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vdivsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vxorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2);
void vcomisd(OperandX64 src1, OperandX64 src2);
void vsqrtpd(OperandX64 dst, OperandX64 src);
void vsqrtps(OperandX64 dst, OperandX64 src);
void vsqrtsd(OperandX64 dst, OperandX64 src1, OperandX64 src2);

View file

@ -37,8 +37,6 @@ enum class Condition
Zero,
NotZero,
// TODO: ordered and unordered floating-point conditions
Count
};

View file

@ -231,6 +231,7 @@ void AssemblyBuilderX64::lea(OperandX64 lhs, OperandX64 rhs)
if (logText)
log("lea", lhs, rhs);
LUAU_ASSERT(rhs.cat == CategoryX64::mem);
placeBinaryRegAndRegMem(lhs, rhs, 0x8d, 0x8d);
}
@ -286,11 +287,42 @@ void AssemblyBuilderX64::jmp(OperandX64 op)
if (logText)
log("jmp", op);
placeRex(op);
place(0xff);
placeModRegMem(op, 4);
commit();
}
void AssemblyBuilderX64::call(Label& label)
{
place(0xe8);
placeLabel(label);
if (logText)
log("call", label);
commit();
}
void AssemblyBuilderX64::call(OperandX64 op)
{
if (logText)
log("call", op);
placeRex(op);
place(0xff);
placeModRegMem(op, 2);
commit();
}
void AssemblyBuilderX64::int3()
{
if (logText)
log("int3");
place(0xcc);
}
void AssemblyBuilderX64::vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{
placeAvx("vaddpd", dst, src1, src2, 0x58, false, AVX_0F, AVX_66);
@ -311,6 +343,31 @@ void AssemblyBuilderX64::vaddss(OperandX64 dst, OperandX64 src1, OperandX64 src2
placeAvx("vaddss", dst, src1, src2, 0x58, false, AVX_0F, AVX_F3);
}
void AssemblyBuilderX64::vsubsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{
placeAvx("vsubsd", dst, src1, src2, 0x5c, false, AVX_0F, AVX_F2);
}
void AssemblyBuilderX64::vmulsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{
placeAvx("vmulsd", dst, src1, src2, 0x59, false, AVX_0F, AVX_F2);
}
void AssemblyBuilderX64::vdivsd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{
placeAvx("vdivsd", dst, src1, src2, 0x5e, false, AVX_0F, AVX_F2);
}
void AssemblyBuilderX64::vxorpd(OperandX64 dst, OperandX64 src1, OperandX64 src2)
{
placeAvx("vxorpd", dst, src1, src2, 0x57, false, AVX_0F, AVX_66);
}
void AssemblyBuilderX64::vcomisd(OperandX64 src1, OperandX64 src2)
{
placeAvx("vcomisd", src1, src2, 0x2f, false, AVX_0F, AVX_66);
}
void AssemblyBuilderX64::vsqrtpd(OperandX64 dst, OperandX64 src)
{
placeAvx("vsqrtpd", dst, src, 0x51, false, AVX_0F, AVX_66);
@ -471,9 +528,10 @@ void AssemblyBuilderX64::placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs,
LUAU_ASSERT(lhs.cat == CategoryX64::reg || lhs.cat == CategoryX64::mem);
LUAU_ASSERT(rhs.cat == CategoryX64::imm);
SizeX64 size = lhs.base.size;
SizeX64 size = lhs.cat == CategoryX64::reg ? lhs.base.size : lhs.memSize;
LUAU_ASSERT(size == SizeX64::byte || size == SizeX64::dword || size == SizeX64::qword);
placeRex(lhs.base);
placeRex(lhs);
if (size == SizeX64::byte)
{

View file

@ -37,6 +37,14 @@
// Note that Luau runtime doesn't provide indefinite bytecode compatibility: support for older versions gets removed over time. As such, bytecode isn't a durable storage format and it's expected
// that Luau users can recompile bytecode from source on Luau version upgrades if necessary.
// # Bytecode version history
//
// Note: due to limitations of the versioning scheme, some bytecode blobs that carry version 2 are using features from version 3. Starting from version 3, version should be sufficient to indicate bytecode compatibility.
//
// Version 1: Baseline version for the open-source release. Supported until 0.521.
// Version 2: Adds Proto::linedefined. Currently supported.
// Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported.
// Bytecode opcode, part of the instruction header
enum LuauOpcode
{
@ -367,6 +375,20 @@ enum LuauOpcode
// D: jump offset (-32768..32767)
LOP_FORGPREP,
// JUMPXEQKNIL, JUMPXEQKB: jumps to target offset if the comparison with constant is true (or false, see AUX)
// A: source register 1
// D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump")
// AUX: constant value (for boolean) in low bit, NOT flag (that flips comparison result) in high bit
LOP_JUMPXEQKNIL,
LOP_JUMPXEQKB,
// JUMPXEQKN, JUMPXEQKS: jumps to target offset if the comparison with constant is true (or false, see AUX)
// A: source register 1
// D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump")
// AUX: constant table index in low 24 bits, NOT flag (that flips comparison result) in high bit
LOP_JUMPXEQKN,
LOP_JUMPXEQKS,
// Enum entry for number of opcodes, not a valid opcode by itself!
LOP__COUNT
};
@ -391,7 +413,7 @@ enum LuauBytecodeTag
{
// Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled
LBC_VERSION_MIN = 2,
LBC_VERSION_MAX = 2,
LBC_VERSION_MAX = 3,
LBC_VERSION_TARGET = 2,
// Types of constant table entries
LBC_CONSTANT_NIL = 0,

View file

@ -20,12 +20,6 @@
#define LUAU_DEBUGBREAK() __builtin_trap()
#endif
namespace Luau
{
@ -67,16 +61,13 @@ struct FValue
const char* name;
FValue* next;
FValue(const char* name, T def, bool dynamic, void (*reg)(const char*, T*, bool) = nullptr)
FValue(const char* name, T def, bool dynamic)
: value(def)
, dynamic(dynamic)
, name(name)
, next(list)
{
list = this;
if (reg)
reg(name, &value, dynamic);
}
operator T() const
@ -98,7 +89,7 @@ FValue<T>* FValue<T>::list = nullptr;
#define LUAU_FASTFLAGVARIABLE(flag, def) \
namespace FFlag \
{ \
Luau::FValue<bool> flag(#flag, def, false, nullptr); \
Luau::FValue<bool> flag(#flag, def, false); \
}
#define LUAU_FASTINT(flag) \
namespace FInt \
@ -108,7 +99,7 @@ FValue<T>* FValue<T>::list = nullptr;
#define LUAU_FASTINTVARIABLE(flag, def) \
namespace FInt \
{ \
Luau::FValue<int> flag(#flag, def, false, nullptr); \
Luau::FValue<int> flag(#flag, def, false); \
}
#define LUAU_DYNAMIC_FASTFLAG(flag) \
@ -119,7 +110,7 @@ FValue<T>* FValue<T>::list = nullptr;
#define LUAU_DYNAMIC_FASTFLAGVARIABLE(flag, def) \
namespace DFFlag \
{ \
Luau::FValue<bool> flag(#flag, def, true, nullptr); \
Luau::FValue<bool> flag(#flag, def, true); \
}
#define LUAU_DYNAMIC_FASTINT(flag) \
namespace DFInt \
@ -129,5 +120,5 @@ FValue<T>* FValue<T>::list = nullptr;
#define LUAU_DYNAMIC_FASTINTVARIABLE(flag, def) \
namespace DFInt \
{ \
Luau::FValue<int> flag(#flag, def, true, nullptr); \
Luau::FValue<int> flag(#flag, def, true); \
}

View file

@ -0,0 +1,25 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include <string.h>
namespace Luau
{
inline bool isFlagExperimental(const char* flag)
{
// Flags in this list are disabled by default in various command-line tools. They may have behavior that is not fully final,
// or critical bugs that are found after the code has been submitted.
static const char* kList[] = {
"LuauLowerBoundsCalculation",
nullptr, // makes sure we always have at least one entry
};
for (const char* item : kList)
if (item && strcmp(item, flag) == 0)
return true;
return false;
}
} // namespace Luau

View file

@ -8,8 +8,8 @@
namespace Luau
{
class AstStatBlock;
class AstNameTable;
struct ParseResult;
class BytecodeBuilder;
class BytecodeEncoder;
@ -58,7 +58,7 @@ private:
};
// compiles bytecode into bytecode builder using either a pre-parsed AST or parsing it from source; throws on errors
void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options = {});
void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& options = {});
void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const CompileOptions& options = {}, const ParseOptions& parseOptions = {});
// compiles bytecode into a bytecode blob, that either contains the valid bytecode or an encoded error that luau_load can decode

View file

@ -3,7 +3,7 @@
#include <stddef.h>
/* Can be used to reconfigure visibility/exports for public APIs */
// Can be used to reconfigure visibility/exports for public APIs
#ifndef LUACODE_API
#define LUACODE_API extern
#endif
@ -35,5 +35,5 @@ struct lua_CompileOptions
const char** mutableGlobals;
};
/* compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy */
// compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy
LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize);

View file

@ -0,0 +1,463 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "BuiltinFolding.h"
#include "Luau/Bytecode.h"
#include <math.h>
namespace Luau
{
namespace Compile
{
const double kRadDeg = 3.14159265358979323846 / 180.0;
static Constant cvar()
{
return Constant();
}
static Constant cbool(bool v)
{
Constant res = {Constant::Type_Boolean};
res.valueBoolean = v;
return res;
}
static Constant cnum(double v)
{
Constant res = {Constant::Type_Number};
res.valueNumber = v;
return res;
}
static Constant cstring(const char* v)
{
Constant res = {Constant::Type_String};
res.stringLength = unsigned(strlen(v));
res.valueString = v;
return res;
}
static Constant ctype(const Constant& c)
{
LUAU_ASSERT(c.type != Constant::Type_Unknown);
switch (c.type)
{
case Constant::Type_Nil:
return cstring("nil");
case Constant::Type_Boolean:
return cstring("boolean");
case Constant::Type_Number:
return cstring("number");
case Constant::Type_String:
return cstring("string");
default:
LUAU_ASSERT(!"Unsupported constant type");
return cvar();
}
}
static uint32_t bit32(double v)
{
// convert through signed 64-bit integer to match runtime behavior and gracefully truncate negative integers
return uint32_t(int64_t(v));
}
Constant foldBuiltin(int bfid, const Constant* args, size_t count)
{
switch (bfid)
{
case LBF_MATH_ABS:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(fabs(args[0].valueNumber));
break;
case LBF_MATH_ACOS:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(acos(args[0].valueNumber));
break;
case LBF_MATH_ASIN:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(asin(args[0].valueNumber));
break;
case LBF_MATH_ATAN2:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
return cnum(atan2(args[0].valueNumber, args[1].valueNumber));
break;
case LBF_MATH_ATAN:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(atan(args[0].valueNumber));
break;
case LBF_MATH_CEIL:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(ceil(args[0].valueNumber));
break;
case LBF_MATH_COSH:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(cosh(args[0].valueNumber));
break;
case LBF_MATH_COS:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(cos(args[0].valueNumber));
break;
case LBF_MATH_DEG:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(args[0].valueNumber / kRadDeg);
break;
case LBF_MATH_EXP:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(exp(args[0].valueNumber));
break;
case LBF_MATH_FLOOR:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(floor(args[0].valueNumber));
break;
case LBF_MATH_FMOD:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
return cnum(fmod(args[0].valueNumber, args[1].valueNumber));
break;
// Note: FREXP isn't folded since it returns multiple values
case LBF_MATH_LDEXP:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
return cnum(ldexp(args[0].valueNumber, int(args[1].valueNumber)));
break;
case LBF_MATH_LOG10:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(log10(args[0].valueNumber));
break;
case LBF_MATH_LOG:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(log(args[0].valueNumber));
else if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
if (args[1].valueNumber == 2.0)
return cnum(log2(args[0].valueNumber));
else if (args[1].valueNumber == 10.0)
return cnum(log10(args[0].valueNumber));
else
return cnum(log(args[0].valueNumber) / log(args[1].valueNumber));
}
break;
case LBF_MATH_MAX:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
double r = args[0].valueNumber;
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
double a = args[i].valueNumber;
r = (a > r) ? a : r;
}
return cnum(r);
}
break;
case LBF_MATH_MIN:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
double r = args[0].valueNumber;
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
double a = args[i].valueNumber;
r = (a < r) ? a : r;
}
return cnum(r);
}
break;
// Note: MODF isn't folded since it returns multiple values
case LBF_MATH_POW:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
return cnum(pow(args[0].valueNumber, args[1].valueNumber));
break;
case LBF_MATH_RAD:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(args[0].valueNumber * kRadDeg);
break;
case LBF_MATH_SINH:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(sinh(args[0].valueNumber));
break;
case LBF_MATH_SIN:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(sin(args[0].valueNumber));
break;
case LBF_MATH_SQRT:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(sqrt(args[0].valueNumber));
break;
case LBF_MATH_TANH:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(tanh(args[0].valueNumber));
break;
case LBF_MATH_TAN:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(tan(args[0].valueNumber));
break;
case LBF_BIT32_ARSHIFT:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int s = int(args[1].valueNumber);
if (unsigned(s) < 32)
return cnum(double(uint32_t(int32_t(u) >> s)));
}
break;
case LBF_BIT32_BAND:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
uint32_t r = bit32(args[0].valueNumber);
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
r &= bit32(args[i].valueNumber);
}
return cnum(double(r));
}
break;
case LBF_BIT32_BNOT:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(double(uint32_t(~bit32(args[0].valueNumber))));
break;
case LBF_BIT32_BOR:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
uint32_t r = bit32(args[0].valueNumber);
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
r |= bit32(args[i].valueNumber);
}
return cnum(double(r));
}
break;
case LBF_BIT32_BXOR:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
uint32_t r = bit32(args[0].valueNumber);
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
r ^= bit32(args[i].valueNumber);
}
return cnum(double(r));
}
break;
case LBF_BIT32_BTEST:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
uint32_t r = bit32(args[0].valueNumber);
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
r &= bit32(args[i].valueNumber);
}
return cbool(r != 0);
}
break;
case LBF_BIT32_EXTRACT:
if (count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int f = int(args[1].valueNumber);
int w = int(args[2].valueNumber);
if (f >= 0 && w > 0 && f + w <= 32)
{
uint32_t m = ~(0xfffffffeu << (w - 1));
return cnum(double((u >> f) & m));
}
}
break;
case LBF_BIT32_LROTATE:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int s = int(args[1].valueNumber);
return cnum(double((u << (s & 31)) | (u >> ((32 - s) & 31))));
}
break;
case LBF_BIT32_LSHIFT:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int s = int(args[1].valueNumber);
if (unsigned(s) < 32)
return cnum(double(u << s));
}
break;
case LBF_BIT32_REPLACE:
if (count == 4 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number &&
args[3].type == Constant::Type_Number)
{
uint32_t n = bit32(args[0].valueNumber);
uint32_t v = bit32(args[1].valueNumber);
int f = int(args[2].valueNumber);
int w = int(args[3].valueNumber);
if (f >= 0 && w > 0 && f + w <= 32)
{
uint32_t m = ~(0xfffffffeu << (w - 1));
return cnum(double((n & ~(m << f)) | ((v & m) << f)));
}
}
break;
case LBF_BIT32_RROTATE:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int s = int(args[1].valueNumber);
return cnum(double((u >> (s & 31)) | (u << ((32 - s) & 31))));
}
break;
case LBF_BIT32_RSHIFT:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int s = int(args[1].valueNumber);
if (unsigned(s) < 32)
return cnum(double(u >> s));
}
break;
case LBF_TYPE:
if (count == 1 && args[0].type != Constant::Type_Unknown)
return ctype(args[0]);
break;
case LBF_STRING_BYTE:
if (count == 1 && args[0].type == Constant::Type_String)
{
if (args[0].stringLength > 0)
return cnum(double(uint8_t(args[0].valueString[0])));
}
else if (count == 2 && args[0].type == Constant::Type_String && args[1].type == Constant::Type_Number)
{
int i = int(args[1].valueNumber);
if (i > 0 && unsigned(i) <= args[0].stringLength)
return cnum(double(uint8_t(args[0].valueString[i - 1])));
}
break;
case LBF_STRING_LEN:
if (count == 1 && args[0].type == Constant::Type_String)
return cnum(double(args[0].stringLength));
break;
case LBF_TYPEOF:
if (count == 1 && args[0].type != Constant::Type_Unknown)
return ctype(args[0]);
break;
case LBF_MATH_CLAMP:
if (count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number)
{
double min = args[1].valueNumber;
double max = args[2].valueNumber;
if (min <= max)
{
double v = args[0].valueNumber;
v = v < min ? min : v;
v = v > max ? max : v;
return cnum(v);
}
}
break;
case LBF_MATH_SIGN:
if (count == 1 && args[0].type == Constant::Type_Number)
{
double v = args[0].valueNumber;
return cnum(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0);
}
break;
case LBF_MATH_ROUND:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(round(args[0].valueNumber));
break;
}
return cvar();
}
} // namespace Compile
} // namespace Luau

View file

@ -0,0 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "ConstantFolding.h"
namespace Luau
{
namespace Compile
{
Constant foldBuiltin(int bfid, const Constant* args, size_t count);
} // namespace Compile
} // namespace Luau

View file

@ -4,8 +4,6 @@
#include "Luau/Bytecode.h"
#include "Luau/Compiler.h"
LUAU_FASTFLAGVARIABLE(LuauCompileRawlen, false)
namespace Luau
{
namespace Compile
@ -40,11 +38,8 @@ Builtin getBuiltin(AstExpr* node, const DenseHashMap<AstName, Global>& globals,
}
}
int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options)
static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options)
{
if (builtin.empty())
return -1;
if (builtin.isGlobal("assert"))
return LBF_ASSERT;
@ -60,7 +55,7 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options)
return LBF_RAWGET;
if (builtin.isGlobal("rawequal"))
return LBF_RAWEQUAL;
if (FFlag::LuauCompileRawlen && builtin.isGlobal("rawlen"))
if (builtin.isGlobal("rawlen"))
return LBF_RAWLEN;
if (builtin.isGlobal("unpack"))
@ -200,5 +195,49 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options)
return -1;
}
struct BuiltinVisitor : AstVisitor
{
DenseHashMap<AstExprCall*, int>& result;
const DenseHashMap<AstName, Global>& globals;
const DenseHashMap<AstLocal*, Variable>& variables;
const CompileOptions& options;
BuiltinVisitor(DenseHashMap<AstExprCall*, int>& result, const DenseHashMap<AstName, Global>& globals,
const DenseHashMap<AstLocal*, Variable>& variables, const CompileOptions& options)
: result(result)
, globals(globals)
, variables(variables)
, options(options)
{
}
bool visit(AstExprCall* node) override
{
Builtin builtin = node->self ? Builtin() : getBuiltin(node->func, globals, variables);
if (builtin.empty())
return true;
int bfid = getBuiltinFunctionId(builtin, options);
// getBuiltinFunctionId optimistically assumes all select() calls are builtin but actually the second argument must be a vararg
if (bfid == LBF_SELECT_VARARG && !(node->args.size == 2 && node->args.data[1]->is<AstExprVarargs>()))
bfid = -1;
if (bfid >= 0)
result[node] = bfid;
return true; // propagate to nested calls
}
};
void analyzeBuiltins(DenseHashMap<AstExprCall*, int>& result, const DenseHashMap<AstName, Global>& globals,
const DenseHashMap<AstLocal*, Variable>& variables, const CompileOptions& options, AstNode* root)
{
BuiltinVisitor visitor{result, globals, variables, options};
root->visit(&visitor);
}
} // namespace Compile
} // namespace Luau

View file

@ -35,7 +35,9 @@ struct Builtin
};
Builtin getBuiltin(AstExpr* node, const DenseHashMap<AstName, Global>& globals, const DenseHashMap<AstLocal*, Variable>& variables);
int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options);
void analyzeBuiltins(DenseHashMap<AstExprCall*, int>& result, const DenseHashMap<AstName, Global>& globals,
const DenseHashMap<AstLocal*, Variable>& variables, const CompileOptions& options, AstNode* root);
} // namespace Compile
} // namespace Luau

View file

@ -6,6 +6,8 @@
#include <algorithm>
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauCompileBytecodeV3, false)
namespace Luau
{
@ -77,6 +79,10 @@ static int getOpLength(LuauOpcode op)
case LOP_JUMPIFNOTEQK:
case LOP_FASTCALL2:
case LOP_FASTCALL2K:
case LOP_JUMPXEQKNIL:
case LOP_JUMPXEQKB:
case LOP_JUMPXEQKN:
case LOP_JUMPXEQKS:
return 2;
default:
@ -108,6 +114,10 @@ inline bool isJumpD(LuauOpcode op)
case LOP_JUMPBACK:
case LOP_JUMPIFEQK:
case LOP_JUMPIFNOTEQK:
case LOP_JUMPXEQKNIL:
case LOP_JUMPXEQKB:
case LOP_JUMPXEQKN:
case LOP_JUMPXEQKS:
return true;
default:
@ -120,6 +130,17 @@ inline bool isSkipC(LuauOpcode op)
switch (op)
{
case LOP_LOADB:
return true;
default:
return false;
}
}
inline bool isFastCall(LuauOpcode op)
{
switch (op)
{
case LOP_FASTCALL:
case LOP_FASTCALL1:
case LOP_FASTCALL2:
@ -137,6 +158,8 @@ static int getJumpTarget(uint32_t insn, uint32_t pc)
if (isJumpD(op))
return int(pc + LUAU_INSN_D(insn) + 1);
else if (isFastCall(op))
return int(pc + LUAU_INSN_C(insn) + 2);
else if (isSkipC(op) && LUAU_INSN_C(insn))
return int(pc + LUAU_INSN_C(insn) + 1);
else if (op == LOP_JUMPX)
@ -479,7 +502,7 @@ bool BytecodeBuilder::patchSkipC(size_t jumpLabel, size_t targetLabel)
unsigned int jumpInsn = insns[jumpLabel];
(void)jumpInsn;
LUAU_ASSERT(isSkipC(LuauOpcode(LUAU_INSN_OP(jumpInsn))));
LUAU_ASSERT(isSkipC(LuauOpcode(LUAU_INSN_OP(jumpInsn))) || isFastCall(LuauOpcode(LUAU_INSN_OP(jumpInsn))));
LUAU_ASSERT(LUAU_INSN_C(jumpInsn) == 0);
int offset = int(targetLabel) - int(jumpLabel) - 1;
@ -1056,6 +1079,9 @@ std::string BytecodeBuilder::getError(const std::string& message)
uint8_t BytecodeBuilder::getVersion()
{
if (FFlag::LuauCompileBytecodeV3)
return 3;
// This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags
return LBC_VERSION_TARGET;
}
@ -1233,6 +1259,24 @@ void BytecodeBuilder::validate() const
VJUMP(LUAU_INSN_D(insn));
break;
case LOP_JUMPXEQKNIL:
case LOP_JUMPXEQKB:
VREG(LUAU_INSN_A(insn));
VJUMP(LUAU_INSN_D(insn));
break;
case LOP_JUMPXEQKN:
VREG(LUAU_INSN_A(insn));
VCONST(insns[i + 1] & 0xffffff, Number);
VJUMP(LUAU_INSN_D(insn));
break;
case LOP_JUMPXEQKS:
VREG(LUAU_INSN_A(insn));
VCONST(insns[i + 1] & 0xffffff, String);
VJUMP(LUAU_INSN_D(insn));
break;
case LOP_ADD:
case LOP_SUB:
case LOP_MUL:
@ -1766,6 +1810,26 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result,
formatAppend(result, "JUMPIFNOTEQK R%d K%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel);
break;
case LOP_JUMPXEQKNIL:
formatAppend(result, "JUMPXEQKNIL R%d L%d%s\n", LUAU_INSN_A(insn), targetLabel, *code >> 31 ? " NOT" : "");
code++;
break;
case LOP_JUMPXEQKB:
formatAppend(result, "JUMPXEQKB R%d %d L%d%s\n", LUAU_INSN_A(insn), *code & 1, targetLabel, *code >> 31 ? " NOT" : "");
code++;
break;
case LOP_JUMPXEQKN:
formatAppend(result, "JUMPXEQKN R%d K%d L%d%s\n", LUAU_INSN_A(insn), *code & 0xffffff, targetLabel, *code >> 31 ? " NOT" : "");
code++;
break;
case LOP_JUMPXEQKS:
formatAppend(result, "JUMPXEQKS R%d K%d L%d%s\n", LUAU_INSN_A(insn), *code & 0xffffff, targetLabel, *code >> 31 ? " NOT" : "");
code++;
break;
default:
LUAU_ASSERT(!"Unsupported opcode");
}

View file

@ -25,6 +25,9 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5)
LUAU_FASTFLAGVARIABLE(LuauCompileNoIpairs, false)
LUAU_FASTFLAGVARIABLE(LuauCompileFreeReassign, false)
LUAU_FASTFLAGVARIABLE(LuauCompileXEQ, false)
namespace Luau
{
@ -75,6 +78,12 @@ static BytecodeBuilder::StringRef sref(AstArray<char> data)
return {data.data, data.size};
}
static BytecodeBuilder::StringRef sref(AstArray<const char> data)
{
LUAU_ASSERT(data.data);
return {data.data, data.size};
}
struct Compiler
{
struct RegScope;
@ -89,6 +98,7 @@ struct Compiler
, constants(nullptr)
, locstants(nullptr)
, tableShapes(nullptr)
, builtins(nullptr)
{
// preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays
localStack.reserve(16);
@ -245,7 +255,7 @@ struct Compiler
{
f.canInline = true;
f.stackSize = stackSize;
f.costModel = modelCost(func->body, func->args.data, func->args.size);
f.costModel = modelCost(func->body, func->args.data, func->args.size, builtins);
// track functions that only ever return a single value so that we can convert multret calls to fixedret calls
if (allPathsEndWithReturn(func->body))
@ -262,23 +272,44 @@ struct Compiler
return fid;
}
// returns true if node can return multiple values; may conservatively return true even if expr is known to return just a single value
bool isExprMultRet(AstExpr* node)
{
AstExprCall* expr = node->as<AstExprCall>();
if (!expr)
return node->is<AstExprVarargs>();
// conservative version, optimized for compilation throughput
if (options.optimizationLevel <= 1)
return true;
// handles builtin calls that can be constant-folded
// without this we may omit some optimizations eg compiling fast calls without use of FASTCALL2K
if (isConstant(expr))
return false;
// handles local function calls where we know only one argument is returned
AstExprFunction* func = getFunctionExpr(expr->func);
Function* fi = func ? functions.find(func) : nullptr;
if (fi && fi->returnsOne)
return false;
// unrecognized call, so we conservatively assume multret
return true;
}
// note: this doesn't just clobber target (assuming it's temp), but also clobbers *all* allocated registers >= target!
// this is important to be able to support "multret" semantics due to Lua call frame structure
bool compileExprTempMultRet(AstExpr* node, uint8_t target)
{
if (AstExprCall* expr = node->as<AstExprCall>())
{
// Optimization: convert multret calls to functions that always return one value to fixedret calls; this facilitates inlining
if (options.optimizationLevel >= 2)
// Optimization: convert multret calls that always return one value to fixedret calls; this facilitates inlining/constant folding
if (options.optimizationLevel >= 2 && !isExprMultRet(node))
{
AstExprFunction* func = getFunctionExpr(expr->func);
Function* fi = func ? functions.find(func) : nullptr;
if (fi && fi->returnsOne)
{
compileExprTemp(node, target);
return false;
}
compileExprTemp(node, target);
return false;
}
// We temporarily swap out regTop to have targetTop work correctly...
@ -483,8 +514,7 @@ struct Compiler
varc[i] = isConstant(expr->args.data[i]);
// if the last argument only returns a single value, all following arguments are nil
if (expr->args.size != 0 &&
!(expr->args.data[expr->args.size - 1]->is<AstExprCall>() || expr->args.data[expr->args.size - 1]->is<AstExprVarargs>()))
if (expr->args.size != 0 && !isExprMultRet(expr->args.data[expr->args.size - 1]))
for (size_t i = expr->args.size; i < func->args.size && i < 8; ++i)
varc[i] = true;
@ -523,7 +553,7 @@ struct Compiler
AstLocal* var = func->args.data[i];
AstExpr* arg = i < expr->args.size ? expr->args.data[i] : nullptr;
if (i + 1 == expr->args.size && func->args.size > expr->args.size && (arg->is<AstExprCall>() || arg->is<AstExprVarargs>()))
if (i + 1 == expr->args.size && func->args.size > expr->args.size && isExprMultRet(arg))
{
// if the last argument can return multiple values, we need to compute all of them into the remaining arguments
unsigned int tail = unsigned(func->args.size - expr->args.size) + 1;
@ -566,7 +596,7 @@ struct Compiler
}
else
{
AstExprLocal* le = arg->as<AstExprLocal>();
AstExprLocal* le = FFlag::LuauCompileFreeReassign ? getExprLocal(arg) : arg->as<AstExprLocal>();
Variable* lv = le ? variables.find(le->local) : nullptr;
// if the argument is a local that isn't mutated, we will simply reuse the existing register
@ -591,7 +621,7 @@ struct Compiler
}
// fold constant values updated above into expressions in the function body
foldConstants(constants, variables, locstants, func->body);
foldConstants(constants, variables, locstants, builtinsFold, func->body);
bool usedFallthrough = false;
@ -632,7 +662,7 @@ struct Compiler
if (Constant* var = locstants.find(func->args.data[i]))
var->type = Constant::Type_Unknown;
foldConstants(constants, variables, locstants, func->body);
foldConstants(constants, variables, locstants, builtinsFold, func->body);
}
void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false)
@ -675,29 +705,23 @@ struct Compiler
int bfid = -1;
if (options.optimizationLevel >= 1)
{
Builtin builtin = getBuiltin(expr->func, globals, variables);
bfid = getBuiltinFunctionId(builtin, options);
}
if (options.optimizationLevel >= 1 && !expr->self)
if (const int* id = builtins.find(expr))
bfid = *id;
if (bfid == LBF_SELECT_VARARG)
{
// Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly
// note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases
if (multRet == false && targetCount == 1 && expr->args.size == 2 && expr->args.data[1]->is<AstExprVarargs>())
if (multRet == false && targetCount == 1)
return compileExprSelectVararg(expr, target, targetCount, targetTop, multRet, regs);
else
bfid = -1;
}
// Optimization: for 1/2 argument fast calls use specialized opcodes
if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2)
{
AstExpr* last = expr->args.data[expr->args.size - 1];
if (!last->is<AstExprCall>() && !last->is<AstExprVarargs>())
return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid);
}
if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2 && !isExprMultRet(expr->args.data[expr->args.size - 1]))
return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid);
if (expr->self)
{
@ -985,9 +1009,8 @@ struct Compiler
size_t compileCompareJump(AstExprBinary* expr, bool not_ = false)
{
RegScope rs(this);
LuauOpcode opc = getJumpOpCompare(expr->op, not_);
bool isEq = (opc == LOP_JUMPIFEQ || opc == LOP_JUMPIFNOTEQ);
bool isEq = (expr->op == AstExprBinary::CompareEq || expr->op == AstExprBinary::CompareNe);
AstExpr* left = expr->left;
AstExpr* right = expr->right;
@ -999,36 +1022,112 @@ struct Compiler
std::swap(left, right);
}
uint8_t rl = compileExprAuto(left, rs);
int32_t rr = -1;
if (isEq && operandIsConstant)
if (FFlag::LuauCompileXEQ)
{
if (opc == LOP_JUMPIFEQ)
opc = LOP_JUMPIFEQK;
else if (opc == LOP_JUMPIFNOTEQ)
opc = LOP_JUMPIFNOTEQK;
uint8_t rl = compileExprAuto(left, rs);
rr = getConstantIndex(right);
LUAU_ASSERT(rr >= 0);
}
else
rr = compileExprAuto(right, rs);
if (isEq && operandIsConstant)
{
const Constant* cv = constants.find(right);
LUAU_ASSERT(cv && cv->type != Constant::Type_Unknown);
size_t jumpLabel = bytecode.emitLabel();
LuauOpcode opc = LOP_NOP;
int32_t cid = -1;
uint32_t flip = (expr->op == AstExprBinary::CompareEq) == not_ ? 0x80000000 : 0;
if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe)
{
bytecode.emitAD(opc, uint8_t(rr), 0);
bytecode.emitAux(rl);
switch (cv->type)
{
case Constant::Type_Nil:
opc = LOP_JUMPXEQKNIL;
cid = 0;
break;
case Constant::Type_Boolean:
opc = LOP_JUMPXEQKB;
cid = cv->valueBoolean;
break;
case Constant::Type_Number:
opc = LOP_JUMPXEQKN;
cid = getConstantIndex(right);
break;
case Constant::Type_String:
opc = LOP_JUMPXEQKS;
cid = getConstantIndex(right);
break;
default:
LUAU_ASSERT(!"Unexpected constant type");
}
if (cid < 0)
CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile");
size_t jumpLabel = bytecode.emitLabel();
bytecode.emitAD(opc, rl, 0);
bytecode.emitAux(cid | flip);
return jumpLabel;
}
else
{
LuauOpcode opc = getJumpOpCompare(expr->op, not_);
uint8_t rr = compileExprAuto(right, rs);
size_t jumpLabel = bytecode.emitLabel();
if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe)
{
bytecode.emitAD(opc, rr, 0);
bytecode.emitAux(rl);
}
else
{
bytecode.emitAD(opc, rl, 0);
bytecode.emitAux(rr);
}
return jumpLabel;
}
}
else
{
bytecode.emitAD(opc, rl, 0);
bytecode.emitAux(rr);
}
LuauOpcode opc = getJumpOpCompare(expr->op, not_);
return jumpLabel;
uint8_t rl = compileExprAuto(left, rs);
int32_t rr = -1;
if (isEq && operandIsConstant)
{
if (opc == LOP_JUMPIFEQ)
opc = LOP_JUMPIFEQK;
else if (opc == LOP_JUMPIFNOTEQ)
opc = LOP_JUMPIFNOTEQK;
rr = getConstantIndex(right);
LUAU_ASSERT(rr >= 0);
}
else
rr = compileExprAuto(right, rs);
size_t jumpLabel = bytecode.emitLabel();
if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe)
{
bytecode.emitAD(opc, uint8_t(rr), 0);
bytecode.emitAux(rl);
}
else
{
bytecode.emitAD(opc, rl, 0);
bytecode.emitAux(rr);
}
return jumpLabel;
}
}
int32_t getConstantNumber(AstExpr* node)
@ -2156,19 +2255,27 @@ struct Compiler
compileLValueUse(lv, source, /* set= */ true);
}
int getExprLocalReg(AstExpr* node)
AstExprLocal* getExprLocal(AstExpr* node)
{
if (AstExprLocal* expr = node->as<AstExprLocal>())
return expr;
else if (AstExprGroup* expr = node->as<AstExprGroup>())
return getExprLocal(expr->expr);
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
return getExprLocal(expr->expr);
else
return nullptr;
}
int getExprLocalReg(AstExpr* node)
{
if (AstExprLocal* expr = getExprLocal(node))
{
// note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining
Local* l = locals.find(expr->local);
return l && l->allocated ? l->reg : -1;
}
else if (AstExprGroup* expr = node->as<AstExprGroup>())
return getExprLocalReg(expr->expr);
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
return getExprLocalReg(expr->expr);
else
return -1;
}
@ -2454,6 +2561,22 @@ struct Compiler
if (options.optimizationLevel >= 1 && options.debugLevel <= 1 && areLocalsRedundant(stat))
return;
// Optimization: for 1-1 local assignments, we can reuse the register *if* neither local is mutated
if (FFlag::LuauCompileFreeReassign && options.optimizationLevel >= 1 && stat->vars.size == 1 && stat->values.size == 1)
{
if (AstExprLocal* re = getExprLocal(stat->values.data[0]))
{
Variable* lv = variables.find(stat->vars.data[0]);
Variable* rv = variables.find(re->local);
if (int reg = getExprLocalReg(re); reg >= 0 && (!lv || !lv->written) && (!rv || !rv->written))
{
pushLocal(stat->vars.data[0], uint8_t(reg));
return;
}
}
}
// note: allocReg in this case allocates into parent block register - note that we don't have RegScope here
uint8_t vars = allocReg(stat, unsigned(stat->vars.size));
@ -2495,7 +2618,7 @@ struct Compiler
}
AstLocal* var = stat->var;
uint64_t costModel = modelCost(stat->body, &var, 1);
uint64_t costModel = modelCost(stat->body, &var, 1, builtins);
// we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to unrolling
bool varc = true;
@ -2533,7 +2656,7 @@ struct Compiler
locstants[var].type = Constant::Type_Number;
locstants[var].valueNumber = from + iv * step;
foldConstants(constants, variables, locstants, stat);
foldConstants(constants, variables, locstants, builtinsFold, stat);
size_t iterJumps = loopJumps.size();
@ -2561,7 +2684,7 @@ struct Compiler
// clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again
locstants[var].type = Constant::Type_Unknown;
foldConstants(constants, variables, locstants, stat);
foldConstants(constants, variables, locstants, builtinsFold, stat);
}
void compileStatFor(AstStatFor* stat)
@ -3368,26 +3491,7 @@ struct Compiler
bool visit(AstStatReturn* stat) override
{
if (stat->list.size == 1)
{
AstExpr* value = stat->list.data[0];
if (AstExprCall* expr = value->as<AstExprCall>())
{
AstExprFunction* func = self->getFunctionExpr(expr->func);
Function* fi = func ? self->functions.find(func) : nullptr;
returnsOne &= fi && fi->returnsOne;
}
else if (value->is<AstExprVarargs>())
{
returnsOne = false;
}
}
else
{
returnsOne = false;
}
returnsOne &= stat->list.size == 1 && !self->isExprMultRet(stat->list.data[0]);
return false;
}
@ -3487,6 +3591,8 @@ struct Compiler
DenseHashMap<AstExpr*, Constant> constants;
DenseHashMap<AstLocal*, Constant> locstants;
DenseHashMap<AstExprTable*, TableShape> tableShapes;
DenseHashMap<AstExprCall*, int> builtins;
const DenseHashMap<AstExprCall*, int>* builtinsFold = nullptr;
unsigned int regTop = 0;
unsigned int stackSize = 0;
@ -3502,10 +3608,21 @@ struct Compiler
std::vector<Capture> captures;
};
void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options)
void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& inputOptions)
{
LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler");
LUAU_ASSERT(parseResult.root);
LUAU_ASSERT(parseResult.errors.empty());
CompileOptions options = inputOptions;
for (const HotComment& hc : parseResult.hotcomments)
if (hc.header && hc.content.compare(0, 9, "optimize ") == 0)
options.optimizationLevel = std::max(0, std::min(2, atoi(hc.content.c_str() + 9)));
AstStatBlock* root = parseResult.root;
Compiler compiler(bytecode, options);
// since access to some global objects may result in values that change over time, we block imports from non-readonly tables
@ -3514,10 +3631,17 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName
// this pass analyzes mutability of locals/globals and associates locals with their initial values
trackValues(compiler.globals, compiler.variables, root);
// builtin folding is enabled on optimization level 2 since we can't deoptimize folding at runtime
if (options.optimizationLevel >= 2)
compiler.builtinsFold = &compiler.builtins;
if (options.optimizationLevel >= 1)
{
// this pass tracks which calls are builtins and can be compiled more efficiently
analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root);
// this pass analyzes constantness of expressions
foldConstants(compiler.constants, compiler.variables, compiler.locstants, root);
foldConstants(compiler.constants, compiler.variables, compiler.locstants, compiler.builtinsFold, root);
// this pass analyzes table assignments to estimate table shapes for initially empty tables
predictTableShapes(compiler.tableShapes, root);
@ -3559,9 +3683,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const
if (!result.errors.empty())
throw ParseErrors(result.errors);
AstStatBlock* root = result.root;
compileOrThrow(bytecode, root, names, options);
compileOrThrow(bytecode, result, names, options);
}
std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder)
@ -3584,7 +3706,7 @@ std::string compile(const std::string& source, const CompileOptions& options, co
try
{
BytecodeBuilder bcb(encoder);
compileOrThrow(bcb, result.root, names, options);
compileOrThrow(bcb, result, names, options);
return bcb.getBytecode();
}

View file

@ -1,6 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "ConstantFolding.h"
#include "BuiltinFolding.h"
#include <math.h>
namespace Luau
@ -193,13 +195,18 @@ struct ConstantVisitor : AstVisitor
DenseHashMap<AstLocal*, Variable>& variables;
DenseHashMap<AstLocal*, Constant>& locals;
const DenseHashMap<AstExprCall*, int>* builtins;
bool wasEmpty = false;
ConstantVisitor(
DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables, DenseHashMap<AstLocal*, Constant>& locals)
std::vector<Constant> builtinArgs;
ConstantVisitor(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, const DenseHashMap<AstExprCall*, int>* builtins)
: constants(constants)
, variables(variables)
, locals(locals)
, builtins(builtins)
{
// since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries
wasEmpty = constants.empty() && locals.empty();
@ -253,8 +260,37 @@ struct ConstantVisitor : AstVisitor
{
analyze(expr->func);
for (size_t i = 0; i < expr->args.size; ++i)
analyze(expr->args.data[i]);
if (const int* bfid = builtins ? builtins->find(expr) : nullptr)
{
// since recursive calls to analyze() may reuse the vector we need to be careful and preserve existing contents
size_t offset = builtinArgs.size();
bool canFold = true;
builtinArgs.reserve(offset + expr->args.size);
for (size_t i = 0; i < expr->args.size; ++i)
{
Constant ac = analyze(expr->args.data[i]);
if (ac.type == Constant::Type_Unknown)
canFold = false;
else
builtinArgs.push_back(ac);
}
if (canFold)
{
LUAU_ASSERT(builtinArgs.size() == offset + expr->args.size);
result = foldBuiltin(*bfid, builtinArgs.data() + offset, expr->args.size);
}
builtinArgs.resize(offset);
}
else
{
for (size_t i = 0; i < expr->args.size; ++i)
analyze(expr->args.data[i]);
}
}
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
{
@ -395,9 +431,9 @@ struct ConstantVisitor : AstVisitor
};
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, AstNode* root)
DenseHashMap<AstLocal*, Constant>& locals, const DenseHashMap<AstExprCall*, int>* builtins, AstNode* root)
{
ConstantVisitor visitor{constants, variables, locals};
ConstantVisitor visitor{constants, variables, locals, builtins};
root->visit(&visitor);
}

View file

@ -26,7 +26,7 @@ struct Constant
{
bool valueBoolean;
double valueNumber;
char* valueString = nullptr; // length stored in stringLength
const char* valueString = nullptr; // length stored in stringLength
};
bool isTruthful() const
@ -35,7 +35,7 @@ struct Constant
return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false);
}
AstArray<char> getString() const
AstArray<const char> getString() const
{
LUAU_ASSERT(type == Type_String);
return {valueString, stringLength};
@ -43,7 +43,7 @@ struct Constant
};
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, AstNode* root);
DenseHashMap<AstLocal*, Constant>& locals, const DenseHashMap<AstExprCall*, int>* builtins, AstNode* root);
} // namespace Compile
} // namespace Luau

View file

@ -113,11 +113,14 @@ struct Cost
struct CostVisitor : AstVisitor
{
const DenseHashMap<AstExprCall*, int>& builtins;
DenseHashMap<AstLocal*, uint64_t> vars;
Cost result;
CostVisitor()
: vars(nullptr)
CostVisitor(const DenseHashMap<AstExprCall*, int>& builtins)
: builtins(builtins)
, vars(nullptr)
{
}
@ -148,14 +151,21 @@ struct CostVisitor : AstVisitor
}
else if (AstExprCall* expr = node->as<AstExprCall>())
{
Cost cost = 3;
cost += model(expr->func);
// builtin cost modeling is different from regular calls because we use FASTCALL to compile these
// thus we use a cheaper baseline, don't account for function, and assume constant/local copy is free
bool builtin = builtins.find(expr) != nullptr;
bool builtinShort = builtin && expr->args.size <= 2; // FASTCALL1/2
Cost cost = builtin ? 2 : 3;
if (!builtin)
cost += model(expr->func);
for (size_t i = 0; i < expr->args.size; ++i)
{
Cost ac = model(expr->args.data[i]);
// for constants/locals we still need to copy them to the argument list
cost += ac.model == 0 ? Cost(1) : ac;
cost += ac.model == 0 && !builtinShort ? Cost(1) : ac;
}
return cost;
@ -327,9 +337,9 @@ struct CostVisitor : AstVisitor
}
};
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount)
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount, const DenseHashMap<AstExprCall*, int>& builtins)
{
CostVisitor visitor;
CostVisitor visitor{builtins};
for (size_t i = 0; i < varCount && i < 7; ++i)
visitor.vars[vars[i]] = 0xffull << (i * 8 + 8);

View file

@ -2,6 +2,7 @@
#pragma once
#include "Luau/Ast.h"
#include "Luau/DenseHash.h"
namespace Luau
{
@ -9,7 +10,7 @@ namespace Compile
{
// cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount);
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount, const DenseHashMap<AstExprCall*, int>& builtins);
// cost is computed as B - sum(Di * Ci), where B is baseline cost, Di is the discount for each variable and Ci is 1 when variable #i is constant
int computeCost(uint64_t model, const bool* varsConst, size_t varCount);

View file

@ -31,15 +31,15 @@ ISOCLINE_SOURCES=extern/isocline/src/isocline.c
ISOCLINE_OBJECTS=$(ISOCLINE_SOURCES:%=$(BUILD)/%.o)
ISOCLINE_TARGET=$(BUILD)/libisocline.a
TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp
TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Flags.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp
TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o)
TESTS_TARGET=$(BUILD)/luau-tests
REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/ReplEntry.cpp
REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/ReplEntry.cpp
REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o)
REPL_CLI_TARGET=$(BUILD)/luau
ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Analyze.cpp
ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Analyze.cpp
ANALYZE_CLI_OBJECTS=$(ANALYZE_CLI_SOURCES:%=$(BUILD)/%.o)
ANALYZE_CLI_TARGET=$(BUILD)/luau-analyze
@ -50,8 +50,12 @@ TESTS_ARGS=
ifneq ($(flags),)
TESTS_ARGS+=--fflags=$(flags)
endif
ifneq ($(opt),)
TESTS_ARGS+=-O$(opt)
endif
OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS)
EXECUTABLE_ALIASES = luau luau-analyze luau-tests
# common flags
CXXFLAGS=-g -Wall
@ -104,7 +108,7 @@ $(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnaly
$(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include
$(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include
$(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include
$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern
$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY
$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -Iextern -Iextern/isocline/include
$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -Iextern
$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include
@ -114,15 +118,18 @@ $(REPL_CLI_TARGET): LDFLAGS+=-lpthread
fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a build/libprotobuf-mutator/external.protobuf/lib/libprotobuf.a
# pseudo targets
.PHONY: all test clean coverage format luau-size
.PHONY: all test clean coverage format luau-size aliases
all: $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET) $(TESTS_TARGET)
all: $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET) $(TESTS_TARGET) aliases
aliases: $(EXECUTABLE_ALIASES)
test: $(TESTS_TARGET)
$(TESTS_TARGET) $(TESTS_ARGS)
clean:
rm -rf $(BUILD)
rm -rf $(EXECUTABLE_ALIASES)
coverage: $(TESTS_TARGET)
$(TESTS_TARGET) --fflags=true
@ -148,6 +155,9 @@ luau: $(REPL_CLI_TARGET)
luau-analyze: $(ANALYZE_CLI_TARGET)
ln -fs $^ $@
luau-tests: $(TESTS_TARGET)
ln -fs $^ $@
# executable targets
$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)
$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET)

View file

@ -1,4 +1,4 @@
Luau ![CI](https://github.com/Roblox/luau/workflows/build/badge.svg) [![Coverage](https://coveralls.io/repos/github/Roblox/luau/badge.svg?branch=master&t=2PXMow)](https://coveralls.io/github/Roblox/luau?branch=master)
Luau ![CI](https://github.com/Roblox/luau/workflows/build/badge.svg) [![codecov](https://codecov.io/gh/Roblox/luau/branch/master/graph/badge.svg?token=S3U44WN416)](https://codecov.io/gh/Roblox/luau)
====
Luau (lowercase u, /ˈlu.aʊ/) is a fast, small, safe, gradually typed embeddable scripting language derived from [Lua](https://lua.org).

View file

@ -4,6 +4,7 @@ if(NOT ${CMAKE_VERSION} VERSION_LESS "3.19")
target_sources(Luau.Common PRIVATE
Common/include/Luau/Common.h
Common/include/Luau/Bytecode.h
Common/include/Luau/ExperimentalFlags.h
)
endif()
@ -38,12 +39,14 @@ target_sources(Luau.Compiler PRIVATE
Compiler/src/BytecodeBuilder.cpp
Compiler/src/Compiler.cpp
Compiler/src/Builtins.cpp
Compiler/src/BuiltinFolding.cpp
Compiler/src/ConstantFolding.cpp
Compiler/src/CostModel.cpp
Compiler/src/TableShape.cpp
Compiler/src/ValueTracking.cpp
Compiler/src/lcode.cpp
Compiler/src/Builtins.h
Compiler/src/BuiltinFolding.h
Compiler/src/ConstantFolding.h
Compiler/src/CostModel.h
Compiler/src/TableShape.h
@ -63,6 +66,8 @@ target_sources(Luau.CodeGen PRIVATE
# Luau.Analysis Sources
target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/ApplyTypeFunction.h
Analysis/include/Luau/AstJsonEncoder.h
Analysis/include/Luau/AstQuery.h
Analysis/include/Luau/Autocomplete.h
Analysis/include/Luau/BuiltinDefinitions.h
@ -78,7 +83,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Frontend.h
Analysis/include/Luau/Instantiation.h
Analysis/include/Luau/IostreamHelpers.h
Analysis/include/Luau/JsonEncoder.h
Analysis/include/Luau/JsonEmitter.h
Analysis/include/Luau/Linter.h
Analysis/include/Luau/LValue.h
Analysis/include/Luau/Module.h
@ -110,6 +115,8 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Variant.h
Analysis/include/Luau/VisitTypeVar.h
Analysis/src/ApplyTypeFunction.cpp
Analysis/src/AstJsonEncoder.cpp
Analysis/src/AstQuery.cpp
Analysis/src/Autocomplete.cpp
Analysis/src/BuiltinDefinitions.cpp
@ -123,7 +130,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Frontend.cpp
Analysis/src/Instantiation.cpp
Analysis/src/IostreamHelpers.cpp
Analysis/src/JsonEncoder.cpp
Analysis/src/JsonEmitter.cpp
Analysis/src/Linter.cpp
Analysis/src/LValue.cpp
Analysis/src/Module.cpp
@ -218,6 +225,8 @@ if(TARGET Luau.Repl.CLI)
CLI/Coverage.cpp
CLI/FileUtils.h
CLI/FileUtils.cpp
CLI/Flags.h
CLI/Flags.cpp
CLI/Profiler.h
CLI/Profiler.cpp
CLI/Repl.cpp
@ -229,6 +238,8 @@ if(TARGET Luau.Analyze.CLI)
target_sources(Luau.Analyze.CLI PRIVATE
CLI/FileUtils.h
CLI/FileUtils.cpp
CLI/Flags.h
CLI/Flags.cpp
CLI/Analyze.cpp)
endif()
@ -247,23 +258,27 @@ if(TARGET Luau.UnitTest)
tests/IostreamOptional.h
tests/ScopedFlags.h
tests/Fixture.cpp
tests/AssemblyBuilderX64.test.cpp
tests/AstJsonEncoder.test.cpp
tests/AstQuery.test.cpp
tests/AstVisitor.test.cpp
tests/Autocomplete.test.cpp
tests/BuiltinDefinitions.test.cpp
tests/Compiler.test.cpp
tests/Config.test.cpp
tests/ConstraintGraphBuilder.test.cpp
tests/ConstraintSolver.test.cpp
tests/CostModel.test.cpp
tests/Error.test.cpp
tests/Frontend.test.cpp
tests/JsonEncoder.test.cpp
tests/JsonEmitter.test.cpp
tests/Lexer.test.cpp
tests/Linter.test.cpp
tests/LValue.test.cpp
tests/Module.test.cpp
tests/NonstrictMode.test.cpp
tests/Normalize.test.cpp
tests/ConstraintGraphBuilder.test.cpp
tests/ConstraintSolver.test.cpp
tests/NotNull.test.cpp
tests/Parser.test.cpp
tests/RequireTracer.test.cpp
tests/RuntimeLimits.test.cpp
@ -295,11 +310,11 @@ if(TARGET Luau.UnitTest)
tests/TypeInfer.tryUnify.test.cpp
tests/TypeInfer.typePacks.cpp
tests/TypeInfer.unionTypes.test.cpp
tests/TypeInfer.unknownnever.test.cpp
tests/TypePack.test.cpp
tests/TypeVar.test.cpp
tests/Variant.test.cpp
tests/VisitTypeVar.test.cpp
tests/AssemblyBuilderX64.test.cpp
tests/main.cpp)
endif()
@ -317,6 +332,8 @@ if(TARGET Luau.CLI.Test)
CLI/Coverage.cpp
CLI/FileUtils.h
CLI/FileUtils.cpp
CLI/Flags.h
CLI/Flags.cpp
CLI/Profiler.h
CLI/Profiler.cpp
CLI/Repl.cpp

View file

@ -11,7 +11,7 @@
/* option for multiple returns in `lua_pcall' and `lua_call' */
// option for multiple returns in `lua_pcall' and `lua_call'
#define LUA_MULTRET (-1)
/*
@ -23,7 +23,7 @@
#define lua_upvalueindex(i) (LUA_GLOBALSINDEX - (i))
#define lua_ispseudo(i) ((i) <= LUA_REGISTRYINDEX)
/* thread status; 0 is OK */
// thread status; 0 is OK
enum lua_Status
{
LUA_OK = 0,
@ -32,7 +32,7 @@ enum lua_Status
LUA_ERRSYNTAX,
LUA_ERRMEM,
LUA_ERRERR,
LUA_BREAK, /* yielded for a debug breakpoint */
LUA_BREAK, // yielded for a debug breakpoint
};
typedef struct lua_State lua_State;
@ -46,7 +46,7 @@ typedef int (*lua_Continuation)(lua_State* L, int status);
typedef void* (*lua_Alloc)(void* ud, void* ptr, size_t osize, size_t nsize);
/* non-return type */
// non-return type
#define l_noret void LUA_NORETURN
/*
@ -61,39 +61,39 @@ typedef void* (*lua_Alloc)(void* ud, void* ptr, size_t osize, size_t nsize);
// clang-format off
enum lua_Type
{
LUA_TNIL = 0, /* must be 0 due to lua_isnoneornil */
LUA_TBOOLEAN = 1, /* must be 1 due to l_isfalse */
LUA_TNIL = 0, // must be 0 due to lua_isnoneornil
LUA_TBOOLEAN = 1, // must be 1 due to l_isfalse
LUA_TLIGHTUSERDATA,
LUA_TNUMBER,
LUA_TVECTOR,
LUA_TSTRING, /* all types above this must be value types, all types below this must be GC types - see iscollectable */
LUA_TSTRING, // all types above this must be value types, all types below this must be GC types - see iscollectable
LUA_TTABLE,
LUA_TFUNCTION,
LUA_TUSERDATA,
LUA_TTHREAD,
/* values below this line are used in GCObject tags but may never show up in TValue type tags */
// values below this line are used in GCObject tags but may never show up in TValue type tags
LUA_TPROTO,
LUA_TUPVAL,
LUA_TDEADKEY,
/* the count of TValue type tags */
// the count of TValue type tags
LUA_T_COUNT = LUA_TPROTO
};
// clang-format on
/* type of numbers in Luau */
// type of numbers in Luau
typedef double lua_Number;
/* type for integer functions */
// type for integer functions
typedef int lua_Integer;
/* unsigned integer type */
// unsigned integer type
typedef unsigned lua_Unsigned;
/*
@ -117,7 +117,7 @@ LUA_API void lua_remove(lua_State* L, int idx);
LUA_API void lua_insert(lua_State* L, int idx);
LUA_API void lua_replace(lua_State* L, int idx);
LUA_API int lua_checkstack(lua_State* L, int sz);
LUA_API void lua_rawcheckstack(lua_State* L, int sz); /* allows for unlimited stack frames */
LUA_API void lua_rawcheckstack(lua_State* L, int sz); // allows for unlimited stack frames
LUA_API void lua_xmove(lua_State* from, lua_State* to, int n);
LUA_API void lua_xpush(lua_State* from, lua_State* to, int idx);
@ -231,18 +231,18 @@ LUA_API void lua_setthreaddata(lua_State* L, void* data);
enum lua_GCOp
{
/* stop and resume incremental garbage collection */
// stop and resume incremental garbage collection
LUA_GCSTOP,
LUA_GCRESTART,
/* run a full GC cycle; not recommended for latency sensitive applications */
// run a full GC cycle; not recommended for latency sensitive applications
LUA_GCCOLLECT,
/* return the heap size in KB and the remainder in bytes */
// return the heap size in KB and the remainder in bytes
LUA_GCCOUNT,
LUA_GCCOUNTB,
/* return 1 if GC is active (not stopped); note that GC may not be actively collecting even if it's running */
// return 1 if GC is active (not stopped); note that GC may not be actively collecting even if it's running
LUA_GCISRUNNING,
/*
@ -300,6 +300,7 @@ LUA_API uintptr_t lua_encodepointer(lua_State* L, uintptr_t p);
LUA_API double lua_clock();
LUA_API void lua_setuserdatatag(lua_State* L, int idx, int tag);
LUA_API void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*));
LUA_API void lua_clonefunction(lua_State* L, int idx);
@ -358,9 +359,9 @@ LUA_API void lua_unref(lua_State* L, int ref);
** =======================================================================
*/
typedef struct lua_Debug lua_Debug; /* activation record */
typedef struct lua_Debug lua_Debug; // activation record
/* Functions to be called by the debugger in specific events */
// Functions to be called by the debugger in specific events
typedef void (*lua_Hook)(lua_State* L, lua_Debug* ar);
LUA_API int lua_stackdepth(lua_State* L);
@ -372,30 +373,30 @@ LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n);
LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n);
LUA_API void lua_singlestep(lua_State* L, int enabled);
LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled);
LUA_API int lua_breakpoint(lua_State* L, int funcindex, int line, int enabled);
typedef void (*lua_Coverage)(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size);
LUA_API void lua_getcoverage(lua_State* L, int funcindex, void* context, lua_Coverage callback);
/* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */
// Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging.
LUA_API const char* lua_debugtrace(lua_State* L);
struct lua_Debug
{
const char* name; /* (n) */
const char* what; /* (s) `Lua', `C', `main', `tail' */
const char* source; /* (s) */
int linedefined; /* (s) */
int currentline; /* (l) */
unsigned char nupvals; /* (u) number of upvalues */
unsigned char nparams; /* (a) number of parameters */
char isvararg; /* (a) */
char short_src[LUA_IDSIZE]; /* (s) */
void* userdata; /* only valid in luau_callhook */
const char* name; // (n)
const char* what; // (s) `Lua', `C', `main', `tail'
const char* source; // (s)
int linedefined; // (s)
int currentline; // (l)
unsigned char nupvals; // (u) number of upvalues
unsigned char nparams; // (a) number of parameters
char isvararg; // (a)
char short_src[LUA_IDSIZE]; // (s)
void* userdata; // only valid in luau_callhook
};
/* }====================================================================== */
// }======================================================================
/* Callbacks that can be used to reconfigure behavior of the VM dynamically.
* These are shared between all coroutines.
@ -404,18 +405,18 @@ struct lua_Debug
* can only be changed when the VM is not running any code */
struct lua_Callbacks
{
void* userdata; /* arbitrary userdata pointer that is never overwritten by Luau */
void* userdata; // arbitrary userdata pointer that is never overwritten by Luau
void (*interrupt)(lua_State* L, int gc); /* gets called at safepoints (loop back edges, call/ret, gc) if set */
void (*panic)(lua_State* L, int errcode); /* gets called when an unprotected error is raised (if longjmp is used) */
void (*interrupt)(lua_State* L, int gc); // gets called at safepoints (loop back edges, call/ret, gc) if set
void (*panic)(lua_State* L, int errcode); // gets called when an unprotected error is raised (if longjmp is used)
void (*userthread)(lua_State* LP, lua_State* L); /* gets called when L is created (LP == parent) or destroyed (LP == NULL) */
int16_t (*useratom)(const char* s, size_t l); /* gets called when a string is created; returned atom can be retrieved via tostringatom */
void (*userthread)(lua_State* LP, lua_State* L); // gets called when L is created (LP == parent) or destroyed (LP == NULL)
int16_t (*useratom)(const char* s, size_t l); // gets called when a string is created; returned atom can be retrieved via tostringatom
void (*debugbreak)(lua_State* L, lua_Debug* ar); /* gets called when BREAK instruction is encountered */
void (*debugstep)(lua_State* L, lua_Debug* ar); /* gets called after each instruction in single step mode */
void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */
void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */
void (*debugbreak)(lua_State* L, lua_Debug* ar); // gets called when BREAK instruction is encountered
void (*debugstep)(lua_State* L, lua_Debug* ar); // gets called after each instruction in single step mode
void (*debuginterrupt)(lua_State* L, lua_Debug* ar); // gets called when thread execution is interrupted by break in another thread
void (*debugprotectederror)(lua_State* L); // gets called when protected call results in an error
};
typedef struct lua_Callbacks lua_Callbacks;

View file

@ -33,14 +33,14 @@
#define LUA_NORETURN __attribute__((__noreturn__))
#endif
/* Can be used to reconfigure visibility/exports for public APIs */
// Can be used to reconfigure visibility/exports for public APIs
#ifndef LUA_API
#define LUA_API extern
#endif
#define LUALIB_API LUA_API
/* Can be used to reconfigure visibility for internal APIs */
// Can be used to reconfigure visibility for internal APIs
#if defined(__GNUC__)
#define LUAI_FUNC __attribute__((visibility("hidden"))) extern
#define LUAI_DATA LUAI_FUNC
@ -49,67 +49,67 @@
#define LUAI_DATA extern
#endif
/* Can be used to reconfigure internal error handling to use longjmp instead of C++ EH */
// Can be used to reconfigure internal error handling to use longjmp instead of C++ EH
#ifndef LUA_USE_LONGJMP
#define LUA_USE_LONGJMP 0
#endif
/* LUA_IDSIZE gives the maximum size for the description of the source */
// LUA_IDSIZE gives the maximum size for the description of the source
#ifndef LUA_IDSIZE
#define LUA_IDSIZE 256
#endif
/* LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function */
// LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function
#ifndef LUA_MINSTACK
#define LUA_MINSTACK 20
#endif
/* LUAI_MAXCSTACK limits the number of Lua stack slots that a C function can use */
// LUAI_MAXCSTACK limits the number of Lua stack slots that a C function can use
#ifndef LUAI_MAXCSTACK
#define LUAI_MAXCSTACK 8000
#endif
/* LUAI_MAXCALLS limits the number of nested calls */
// LUAI_MAXCALLS limits the number of nested calls
#ifndef LUAI_MAXCALLS
#define LUAI_MAXCALLS 20000
#endif
/* LUAI_MAXCCALLS is the maximum depth for nested C calls; this limit depends on native stack size */
// LUAI_MAXCCALLS is the maximum depth for nested C calls; this limit depends on native stack size
#ifndef LUAI_MAXCCALLS
#define LUAI_MAXCCALLS 200
#endif
/* buffer size used for on-stack string operations; this limit depends on native stack size */
// buffer size used for on-stack string operations; this limit depends on native stack size
#ifndef LUA_BUFFERSIZE
#define LUA_BUFFERSIZE 512
#endif
/* number of valid Lua userdata tags */
// number of valid Lua userdata tags
#ifndef LUA_UTAG_LIMIT
#define LUA_UTAG_LIMIT 128
#endif
/* upper bound for number of size classes used by page allocator */
// upper bound for number of size classes used by page allocator
#ifndef LUA_SIZECLASSES
#define LUA_SIZECLASSES 32
#endif
/* available number of separate memory categories */
// available number of separate memory categories
#ifndef LUA_MEMORY_CATEGORIES
#define LUA_MEMORY_CATEGORIES 256
#endif
/* minimum size for the string table (must be power of 2) */
// minimum size for the string table (must be power of 2)
#ifndef LUA_MINSTRTABSIZE
#define LUA_MINSTRTABSIZE 32
#endif
/* maximum number of captures supported by pattern matching */
// maximum number of captures supported by pattern matching
#ifndef LUA_MAXCAPTURES
#define LUA_MAXCAPTURES 32
#endif
/* }================================================================== */
// }==================================================================
/*
@@ LUAI_USER_ALIGNMENT_T is a type that requires maximum alignment.
@ -126,6 +126,6 @@
long l; \
}
#define LUA_VECTOR_SIZE 3 /* must be 3 or 4 */
#define LUA_VECTOR_SIZE 3 // must be 3 or 4
#define LUA_EXTRA_SIZE LUA_VECTOR_SIZE - 2

View file

@ -72,7 +72,7 @@ LUALIB_API const char* luaL_typename(lua_State* L, int idx);
#define luaL_opt(L, f, n, d) (lua_isnoneornil(L, (n)) ? (d) : f(L, (n)))
/* generic buffer manipulation */
// generic buffer manipulation
struct luaL_Buffer
{
@ -102,7 +102,7 @@ LUALIB_API void luaL_addvalue(luaL_Buffer* B);
LUALIB_API void luaL_pushresult(luaL_Buffer* B);
LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size);
/* builtin libraries */
// builtin libraries
LUALIB_API int luaopen_base(lua_State* L);
#define LUA_COLIBNAME "coroutine"
@ -129,9 +129,9 @@ LUALIB_API int luaopen_math(lua_State* L);
#define LUA_DBLIBNAME "debug"
LUALIB_API int luaopen_debug(lua_State* L);
/* open all builtin libraries */
// open all builtin libraries
LUALIB_API void luaL_openlibs(lua_State* L);
/* sandbox libraries and globals */
// sandbox libraries and globals
LUALIB_API void luaL_sandbox(lua_State* L);
LUALIB_API void luaL_sandboxthread(lua_State* L);

View file

@ -51,10 +51,16 @@ const char* luau_ident = "$Luau: Copyright (C) 2019-2022 Roblox Corporation $\n"
L->top++; \
}
#define updateatom(L, ts) \
{ \
if (ts->atom == ATOM_UNDEF) \
ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; \
}
static Table* getcurrenv(lua_State* L)
{
if (L->ci == L->base_ci) /* no enclosing function? */
return L->gt; /* use global table as environment */
if (L->ci == L->base_ci) // no enclosing function?
return L->gt; // use global table as environment
else
return curr_func(L)->env;
}
@ -63,7 +69,7 @@ static LUAU_NOINLINE TValue* pseudo2addr(lua_State* L, int idx)
{
api_check(L, lua_ispseudo(idx));
switch (idx)
{ /* pseudo-indices */
{ // pseudo-indices
case LUA_REGISTRYINDEX:
return registry(L);
case LUA_ENVIRONINDEX:
@ -123,7 +129,7 @@ int lua_checkstack(lua_State* L, int size)
{
int res = 1;
if (size > LUAI_MAXCSTACK || (L->top - L->base + size) > LUAI_MAXCSTACK)
res = 0; /* stack overflow */
res = 0; // stack overflow
else if (size > 0)
{
luaD_checkstack(L, size);
@ -213,7 +219,7 @@ void lua_settop(lua_State* L, int idx)
else
{
api_check(L, -(idx + 1) <= (L->top - L->base));
L->top += idx + 1; /* `subtract' index (index is negative) */
L->top += idx + 1; // `subtract' index (index is negative)
}
return;
}
@ -261,7 +267,7 @@ void lua_replace(lua_State* L, int idx)
else
{
setobj(L, o, L->top - 1);
if (idx < LUA_GLOBALSINDEX) /* function upvalue? */
if (idx < LUA_GLOBALSINDEX) // function upvalue?
luaC_barrier(L, curr_func(L), L->top - 1);
}
L->top--;
@ -423,13 +429,13 @@ const char* lua_tolstring(lua_State* L, int idx, size_t* len)
{
luaC_checkthreadsleep(L);
if (!luaV_tostring(L, o))
{ /* conversion failed? */
{ // conversion failed?
if (len != NULL)
*len = 0;
return NULL;
}
luaC_checkGC(L);
o = index2addr(L, idx); /* previous call may reallocate the stack */
o = index2addr(L, idx); // previous call may reallocate the stack
}
if (len != NULL)
*len = tsvalue(o)->len;
@ -441,19 +447,25 @@ const char* lua_tostringatom(lua_State* L, int idx, int* atom)
StkId o = index2addr(L, idx);
if (!ttisstring(o))
return NULL;
const TString* s = tsvalue(o);
TString* s = tsvalue(o);
if (atom)
{
updateatom(L, s);
*atom = s->atom;
}
return getstr(s);
}
const char* lua_namecallatom(lua_State* L, int* atom)
{
const TString* s = L->namecall;
TString* s = L->namecall;
if (!s)
return NULL;
if (atom)
{
updateatom(L, s);
*atom = s->atom;
}
return getstr(s);
}
@ -648,7 +660,7 @@ void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, in
void lua_pushboolean(lua_State* L, int b)
{
setbvalue(L->top, (b != 0)); /* ensure that true is 1 */
setbvalue(L->top, (b != 0)); // ensure that true is 1
api_incr_top(L);
return;
}
@ -817,7 +829,7 @@ void lua_settable(lua_State* L, int idx)
StkId t = index2addr(L, idx);
api_checkvalidindex(L, t);
luaV_settable(L, t, L->top - 2, L->top - 1);
L->top -= 2; /* pop index and value */
L->top -= 2; // pop index and value
return;
}
@ -839,7 +851,7 @@ void lua_rawset(lua_State* L, int idx)
StkId t = index2addr(L, idx);
api_check(L, ttistable(t));
if (hvalue(t)->readonly)
luaG_runerror(L, "Attempt to modify a readonly table");
luaG_readonlyerror(L);
setobj2t(L, luaH_set(L, hvalue(t), L->top - 2), L->top - 1);
luaC_barriert(L, hvalue(t), L->top - 1);
L->top -= 2;
@ -852,7 +864,7 @@ void lua_rawseti(lua_State* L, int idx, int n)
StkId o = index2addr(L, idx);
api_check(L, ttistable(o));
if (hvalue(o)->readonly)
luaG_runerror(L, "Attempt to modify a readonly table");
luaG_readonlyerror(L);
setobj2t(L, luaH_setnum(L, hvalue(o), n), L->top - 1);
luaC_barriert(L, hvalue(o), L->top - 1);
L->top--;
@ -875,7 +887,7 @@ int lua_setmetatable(lua_State* L, int objindex)
case LUA_TTABLE:
{
if (hvalue(obj)->readonly)
luaG_runerror(L, "Attempt to modify a readonly table");
luaG_readonlyerror(L);
hvalue(obj)->metatable = mt;
if (mt)
luaC_objbarrier(L, hvalue(obj), mt);
@ -955,7 +967,7 @@ void lua_call(lua_State* L, int nargs, int nresults)
** Execute a protected call.
*/
struct CallS
{ /* data to `f_call' */
{ // data to `f_call'
StkId func;
int nresults;
};
@ -980,7 +992,7 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc)
func = savestack(L, o);
}
struct CallS c;
c.func = L->top - (nargs + 1); /* function to be called */
c.func = L->top - (nargs + 1); // function to be called
c.nresults = nresults;
int status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func);
@ -1032,7 +1044,7 @@ int lua_gc(lua_State* L, int what, int data)
}
case LUA_GCCOUNT:
{
/* GC values are expressed in Kbytes: #bytes/2^10 */
// GC values are expressed in Kbytes: #bytes/2^10
res = cast_int(g->totalbytes >> 10);
break;
}
@ -1072,8 +1084,8 @@ int lua_gc(lua_State* L, int what, int data)
actualwork += stepsize;
if (g->gcstate == GCSpause)
{ /* end of cycle? */
res = 1; /* signal it */
{ // end of cycle?
res = 1; // signal it
break;
}
}
@ -1125,13 +1137,13 @@ int lua_gc(lua_State* L, int what, int data)
}
case LUA_GCSETSTEPSIZE:
{
/* GC values are expressed in Kbytes: #bytes/2^10 */
// GC values are expressed in Kbytes: #bytes/2^10
res = g->gcstepsize >> 10;
g->gcstepsize = data << 10;
break;
}
default:
res = -1; /* invalid option */
res = -1; // invalid option
}
return res;
}
@ -1157,8 +1169,8 @@ int lua_next(lua_State* L, int idx)
{
api_incr_top(L);
}
else /* no more elements */
L->top -= 1; /* remove key */
else // no more elements
L->top -= 1; // remove key
return more;
}
@ -1173,12 +1185,12 @@ void lua_concat(lua_State* L, int n)
L->top -= (n - 1);
}
else if (n == 0)
{ /* push empty string */
{ // push empty string
luaC_checkthreadsleep(L);
setsvalue2s(L, L->top, luaS_newlstr(L, "", 0));
api_incr_top(L);
}
/* else n == 1; nothing to do */
// else n == 1; nothing to do
return;
}
@ -1265,7 +1277,7 @@ uintptr_t lua_encodepointer(lua_State* L, uintptr_t p)
int lua_ref(lua_State* L, int idx)
{
api_check(L, idx != LUA_REGISTRYINDEX); /* idx is a stack index for value */
api_check(L, idx != LUA_REGISTRYINDEX); // idx is a stack index for value
int ref = LUA_REFNIL;
global_State* g = L->global;
StkId p = index2addr(L, idx);
@ -1274,13 +1286,13 @@ int lua_ref(lua_State* L, int idx)
Table* reg = hvalue(registry(L));
if (g->registryfree != 0)
{ /* reuse existing slot */
{ // reuse existing slot
ref = g->registryfree;
}
else
{ /* no free elements */
{ // no free elements
ref = luaH_getn(reg);
ref++; /* create new reference */
ref++; // create new reference
}
TValue* slot = luaH_setnum(L, reg, ref);
@ -1300,11 +1312,19 @@ void lua_unref(lua_State* L, int ref)
global_State* g = L->global;
Table* reg = hvalue(registry(L));
TValue* slot = luaH_setnum(L, reg, ref);
setnvalue(slot, g->registryfree); /* NB: no barrier needed because value isn't collectable */
setnvalue(slot, g->registryfree); // NB: no barrier needed because value isn't collectable
g->registryfree = ref;
return;
}
void lua_setuserdatatag(lua_State* L, int idx, int tag)
{
api_check(L, unsigned(tag) < LUA_UTAG_LIMIT);
StkId o = index2addr(L, idx);
api_check(L, ttisuserdata(o));
uvalue(o)->tag = uint8_t(tag);
}
void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*))
{
api_check(L, unsigned(tag) < LUA_UTAG_LIMIT);

View file

@ -11,7 +11,7 @@
#include <string.h>
/* convert a stack index to positive */
// convert a stack index to positive
#define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1)
/*
@ -75,7 +75,7 @@ void luaL_where(lua_State* L, int level)
lua_pushfstring(L, "%s:%d: ", ar.short_src, ar.currentline);
return;
}
lua_pushliteral(L, ""); /* else, no information available... */
lua_pushliteral(L, ""); // else, no information available...
}
l_noret luaL_errorL(lua_State* L, const char* fmt, ...)
@ -89,7 +89,7 @@ l_noret luaL_errorL(lua_State* L, const char* fmt, ...)
lua_error(L);
}
/* }====================================================== */
// }======================================================
int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const lst[])
{
@ -104,13 +104,13 @@ int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const
int luaL_newmetatable(lua_State* L, const char* tname)
{
lua_getfield(L, LUA_REGISTRYINDEX, tname); /* get registry.name */
if (!lua_isnil(L, -1)) /* name already in use? */
return 0; /* leave previous value on top, but return 0 */
lua_getfield(L, LUA_REGISTRYINDEX, tname); // get registry.name
if (!lua_isnil(L, -1)) // name already in use?
return 0; // leave previous value on top, but return 0
lua_pop(L, 1);
lua_newtable(L); /* create metatable */
lua_newtable(L); // create metatable
lua_pushvalue(L, -1);
lua_setfield(L, LUA_REGISTRYINDEX, tname); /* registry.name = metatable */
lua_setfield(L, LUA_REGISTRYINDEX, tname); // registry.name = metatable
return 1;
}
@ -118,18 +118,18 @@ void* luaL_checkudata(lua_State* L, int ud, const char* tname)
{
void* p = lua_touserdata(L, ud);
if (p != NULL)
{ /* value is a userdata? */
{ // value is a userdata?
if (lua_getmetatable(L, ud))
{ /* does it have a metatable? */
lua_getfield(L, LUA_REGISTRYINDEX, tname); /* get correct metatable */
{ // does it have a metatable?
lua_getfield(L, LUA_REGISTRYINDEX, tname); // get correct metatable
if (lua_rawequal(L, -1, -2))
{ /* does it have the correct mt? */
lua_pop(L, 2); /* remove both metatables */
{ // does it have the correct mt?
lua_pop(L, 2); // remove both metatables
return p;
}
}
}
luaL_typeerrorL(L, ud, tname); /* else error */
luaL_typeerrorL(L, ud, tname); // else error
}
void luaL_checkstack(lua_State* L, int space, const char* mes)
@ -243,18 +243,18 @@ const float* luaL_optvector(lua_State* L, int narg, const float* def)
int luaL_getmetafield(lua_State* L, int obj, const char* event)
{
if (!lua_getmetatable(L, obj)) /* no metatable? */
if (!lua_getmetatable(L, obj)) // no metatable?
return 0;
lua_pushstring(L, event);
lua_rawget(L, -2);
if (lua_isnil(L, -1))
{
lua_pop(L, 2); /* remove metatable and metafield */
lua_pop(L, 2); // remove metatable and metafield
return 0;
}
else
{
lua_remove(L, -2); /* remove only metatable */
lua_remove(L, -2); // remove only metatable
return 1;
}
}
@ -262,7 +262,7 @@ int luaL_getmetafield(lua_State* L, int obj, const char* event)
int luaL_callmeta(lua_State* L, int obj, const char* event)
{
obj = abs_index(L, obj);
if (!luaL_getmetafield(L, obj, event)) /* no metafield? */
if (!luaL_getmetafield(L, obj, event)) // no metafield?
return 0;
lua_pushvalue(L, obj);
lua_call(L, 1, 1);
@ -282,19 +282,19 @@ void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l)
if (libname)
{
int size = libsize(l);
/* check whether lib already exists */
// check whether lib already exists
luaL_findtable(L, LUA_REGISTRYINDEX, "_LOADED", 1);
lua_getfield(L, -1, libname); /* get _LOADED[libname] */
lua_getfield(L, -1, libname); // get _LOADED[libname]
if (!lua_istable(L, -1))
{ /* not found? */
lua_pop(L, 1); /* remove previous result */
/* try global variable (and create one if it does not exist) */
{ // not found?
lua_pop(L, 1); // remove previous result
// try global variable (and create one if it does not exist)
if (luaL_findtable(L, LUA_GLOBALSINDEX, libname, size) != NULL)
luaL_error(L, "name conflict for module '%s'", libname);
lua_pushvalue(L, -1);
lua_setfield(L, -3, libname); /* _LOADED[libname] = new table */
lua_setfield(L, -3, libname); // _LOADED[libname] = new table
}
lua_remove(L, -2); /* remove _LOADED table */
lua_remove(L, -2); // remove _LOADED table
}
for (; l->name; l++)
{
@ -315,19 +315,19 @@ const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint)
lua_pushlstring(L, fname, e - fname);
lua_rawget(L, -2);
if (lua_isnil(L, -1))
{ /* no such field? */
lua_pop(L, 1); /* remove this nil */
lua_createtable(L, 0, (*e == '.' ? 1 : szhint)); /* new table for field */
{ // no such field?
lua_pop(L, 1); // remove this nil
lua_createtable(L, 0, (*e == '.' ? 1 : szhint)); // new table for field
lua_pushlstring(L, fname, e - fname);
lua_pushvalue(L, -2);
lua_settable(L, -4); /* set new table into field */
lua_settable(L, -4); // set new table into field
}
else if (!lua_istable(L, -1))
{ /* field has a non-table value? */
lua_pop(L, 2); /* remove table and value */
return fname; /* return problematic part of the name */
{ // field has a non-table value?
lua_pop(L, 2); // remove table and value
return fname; // return problematic part of the name
}
lua_remove(L, -2); /* remove previous table */
lua_remove(L, -2); // remove previous table
fname = e + 1;
} while (*e == '.');
return NULL;
@ -470,11 +470,11 @@ void luaL_pushresultsize(luaL_Buffer* B, size_t size)
luaL_pushresult(B);
}
/* }====================================================== */
// }======================================================
const char* luaL_tolstring(lua_State* L, int idx, size_t* len)
{
if (luaL_callmeta(L, idx, "__tostring")) /* is there a metafield? */
if (luaL_callmeta(L, idx, "__tostring")) // is there a metafield?
{
if (!lua_isstring(L, -1))
luaL_error(L, "'__tostring' must return a string");

View file

@ -11,8 +11,6 @@
#include <stdio.h>
#include <stdlib.h>
LUAU_FASTFLAG(LuauLenTM)
static void writestring(const char* s, size_t l)
{
fwrite(s, 1, l, stdout);
@ -20,15 +18,15 @@ static void writestring(const char* s, size_t l)
static int luaB_print(lua_State* L)
{
int n = lua_gettop(L); /* number of arguments */
int n = lua_gettop(L); // number of arguments
for (int i = 1; i <= n; i++)
{
size_t l;
const char* s = luaL_tolstring(L, i, &l); /* convert to string using __tostring et al */
const char* s = luaL_tolstring(L, i, &l); // convert to string using __tostring et al
if (i > 1)
writestring("\t", 1);
writestring(s, l);
lua_pop(L, 1); /* pop result */
lua_pop(L, 1); // pop result
}
writestring("\n", 1);
return 0;
@ -38,7 +36,7 @@ static int luaB_tonumber(lua_State* L)
{
int base = luaL_optinteger(L, 2, 10);
if (base == 10)
{ /* standard conversion */
{ // standard conversion
int isnum = 0;
double n = lua_tonumberx(L, 1, &isnum);
if (isnum)
@ -46,7 +44,7 @@ static int luaB_tonumber(lua_State* L)
lua_pushnumber(L, n);
return 1;
}
luaL_checkany(L, 1); /* error if we don't have any argument */
luaL_checkany(L, 1); // error if we don't have any argument
}
else
{
@ -56,17 +54,17 @@ static int luaB_tonumber(lua_State* L)
unsigned long long n;
n = strtoull(s1, &s2, base);
if (s1 != s2)
{ /* at least one valid digit? */
{ // at least one valid digit?
while (isspace((unsigned char)(*s2)))
s2++; /* skip trailing spaces */
s2++; // skip trailing spaces
if (*s2 == '\0')
{ /* no invalid trailing characters? */
{ // no invalid trailing characters?
lua_pushnumber(L, (double)n);
return 1;
}
}
}
lua_pushnil(L); /* else not a number */
lua_pushnil(L); // else not a number
return 1;
}
@ -75,7 +73,7 @@ static int luaB_error(lua_State* L)
int level = luaL_optinteger(L, 2, 1);
lua_settop(L, 1);
if (lua_isstring(L, 1) && level > 0)
{ /* add extra information? */
{ // add extra information?
luaL_where(L, level);
lua_pushvalue(L, 1);
lua_concat(L, 2);
@ -89,10 +87,10 @@ static int luaB_getmetatable(lua_State* L)
if (!lua_getmetatable(L, 1))
{
lua_pushnil(L);
return 1; /* no metatable */
return 1; // no metatable
}
luaL_getmetafield(L, 1, "__metatable");
return 1; /* returns either __metatable field (if present) or metatable */
return 1; // returns either __metatable field (if present) or metatable
}
static int luaB_setmetatable(lua_State* L)
@ -126,8 +124,8 @@ static void getfunc(lua_State* L, int opt)
static int luaB_getfenv(lua_State* L)
{
getfunc(L, 1);
if (lua_iscfunction(L, -1)) /* is a C function? */
lua_pushvalue(L, LUA_GLOBALSINDEX); /* return the thread's global env. */
if (lua_iscfunction(L, -1)) // is a C function?
lua_pushvalue(L, LUA_GLOBALSINDEX); // return the thread's global env.
else
lua_getfenv(L, -1);
lua_setsafeenv(L, -1, false);
@ -142,7 +140,7 @@ static int luaB_setfenv(lua_State* L)
lua_setsafeenv(L, -1, false);
if (lua_isnumber(L, 1) && lua_tonumber(L, 1) == 0)
{
/* change environment of current thread */
// change environment of current thread
lua_pushthread(L);
lua_insert(L, -2);
lua_setfenv(L, -2);
@ -182,9 +180,6 @@ static int luaB_rawset(lua_State* L)
static int luaB_rawlen(lua_State* L)
{
if (!FFlag::LuauLenTM)
luaL_error(L, "'rawlen' is not available");
int tt = lua_type(L, 1);
luaL_argcheck(L, tt == LUA_TTABLE || tt == LUA_TSTRING, 1, "table or string expected");
int len = lua_objlen(L, 1);
@ -201,7 +196,7 @@ static int luaB_gcinfo(lua_State* L)
static int luaB_type(lua_State* L)
{
luaL_checkany(L, 1);
/* resulting name doesn't differentiate between userdata types */
// resulting name doesn't differentiate between userdata types
lua_pushstring(L, lua_typename(L, lua_type(L, 1)));
return 1;
}
@ -209,7 +204,7 @@ static int luaB_type(lua_State* L)
static int luaB_typeof(lua_State* L)
{
luaL_checkany(L, 1);
/* resulting name returns __type if specified unless the input is a newproxy-created userdata */
// resulting name returns __type if specified unless the input is a newproxy-created userdata
lua_pushstring(L, luaL_typename(L, 1));
return 1;
}
@ -217,7 +212,7 @@ static int luaB_typeof(lua_State* L)
int luaB_next(lua_State* L)
{
luaL_checktype(L, 1, LUA_TTABLE);
lua_settop(L, 2); /* create a 2nd argument if there isn't one */
lua_settop(L, 2); // create a 2nd argument if there isn't one
if (lua_next(L, 1))
return 2;
else
@ -230,9 +225,9 @@ int luaB_next(lua_State* L)
static int luaB_pairs(lua_State* L)
{
luaL_checktype(L, 1, LUA_TTABLE);
lua_pushvalue(L, lua_upvalueindex(1)); /* return generator, */
lua_pushvalue(L, 1); /* state, */
lua_pushnil(L); /* and initial value */
lua_pushvalue(L, lua_upvalueindex(1)); // return generator,
lua_pushvalue(L, 1); // state,
lua_pushnil(L); // and initial value
return 3;
}
@ -240,7 +235,7 @@ int luaB_inext(lua_State* L)
{
int i = luaL_checkinteger(L, 2);
luaL_checktype(L, 1, LUA_TTABLE);
i++; /* next value */
i++; // next value
lua_pushinteger(L, i);
lua_rawgeti(L, 1, i);
return (lua_isnil(L, -1)) ? 0 : 2;
@ -249,9 +244,9 @@ int luaB_inext(lua_State* L)
static int luaB_ipairs(lua_State* L)
{
luaL_checktype(L, 1, LUA_TTABLE);
lua_pushvalue(L, lua_upvalueindex(1)); /* return generator, */
lua_pushvalue(L, 1); /* state, */
lua_pushinteger(L, 0); /* and initial value */
lua_pushvalue(L, lua_upvalueindex(1)); // return generator,
lua_pushvalue(L, 1); // state,
lua_pushinteger(L, 0); // and initial value
return 3;
}
@ -340,12 +335,12 @@ static int luaB_xpcally(lua_State* L)
{
luaL_checktype(L, 2, LUA_TFUNCTION);
/* swap function & error function */
// swap function & error function
lua_pushvalue(L, 1);
lua_pushvalue(L, 2);
lua_replace(L, 1);
lua_replace(L, 2);
/* at this point the stack looks like err, f, args */
// at this point the stack looks like err, f, args
// any errors from this point on are handled by continuation
L->ci->flags |= LUA_CALLINFO_HANDLE;
@ -386,7 +381,7 @@ static int luaB_xpcallcont(lua_State* L, int status)
lua_rawcheckstack(L, 1);
lua_pushboolean(L, true);
lua_replace(L, 1); // replace error function with status
return lua_gettop(L); /* return status + all results */
return lua_gettop(L); // return status + all results
}
else
{
@ -462,16 +457,16 @@ static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFuncti
int luaopen_base(lua_State* L)
{
/* set global _G */
// set global _G
lua_pushvalue(L, LUA_GLOBALSINDEX);
lua_setglobal(L, "_G");
/* open lib into global table */
// open lib into global table
luaL_register(L, "_G", base_funcs);
lua_pushliteral(L, "Luau");
lua_setglobal(L, "_VERSION"); /* set global _VERSION */
lua_setglobal(L, "_VERSION"); // set global _VERSION
/* `ipairs' and `pairs' need auxiliary functions as upvalues */
// `ipairs' and `pairs' need auxiliary functions as upvalues
auxopen(L, "ipairs", luaB_ipairs, luaB_inext);
auxopen(L, "pairs", luaB_pairs, luaB_next);

View file

@ -8,10 +8,10 @@
#define ALLONES ~0u
#define NBITS int(8 * sizeof(unsigned))
/* macro to trim extra bits */
// macro to trim extra bits
#define trim(x) ((x)&ALLONES)
/* builds a number with 'n' ones (1 <= n <= NBITS) */
// builds a number with 'n' ones (1 <= n <= NBITS)
#define mask(n) (~((ALLONES << 1) << ((n)-1)))
typedef unsigned b_uint;
@ -69,7 +69,7 @@ static int b_not(lua_State* L)
static int b_shift(lua_State* L, b_uint r, int i)
{
if (i < 0)
{ /* shift right? */
{ // shift right?
i = -i;
r = trim(r);
if (i >= NBITS)
@ -78,7 +78,7 @@ static int b_shift(lua_State* L, b_uint r, int i)
r >>= i;
}
else
{ /* shift left */
{ // shift left
if (i >= NBITS)
r = 0;
else
@ -106,11 +106,11 @@ static int b_arshift(lua_State* L)
if (i < 0 || !(r & ((b_uint)1 << (NBITS - 1))))
return b_shift(L, r, -i);
else
{ /* arithmetic shift for 'negative' number */
{ // arithmetic shift for 'negative' number
if (i >= NBITS)
r = ALLONES;
else
r = trim((r >> i) | ~(~(b_uint)0 >> i)); /* add signal bit */
r = trim((r >> i) | ~(~(b_uint)0 >> i)); // add signal bit
lua_pushunsigned(L, r);
return 1;
}
@ -119,9 +119,9 @@ static int b_arshift(lua_State* L)
static int b_rot(lua_State* L, int i)
{
b_uint r = luaL_checkunsigned(L, 1);
i &= (NBITS - 1); /* i = i % NBITS */
i &= (NBITS - 1); // i = i % NBITS
r = trim(r);
if (i != 0) /* avoid undefined shift of NBITS when i == 0 */
if (i != 0) // avoid undefined shift of NBITS when i == 0
r = (r << i) | (r >> (NBITS - i));
lua_pushunsigned(L, trim(r));
return 1;
@ -172,7 +172,7 @@ static int b_replace(lua_State* L)
b_uint v = luaL_checkunsigned(L, 2);
int f = fieldargs(L, 3, &w);
int m = mask(w);
v &= m; /* erase bits outside given width */
v &= m; // erase bits outside given width
r = (r & ~(m << f)) | (v << f);
lua_pushunsigned(L, r);
return 1;

View file

@ -11,7 +11,7 @@
typedef LUAI_USER_ALIGNMENT_T L_Umaxalign;
/* internal assertions for in-house debugging */
// internal assertions for in-house debugging
#define check_exp(c, e) (LUAU_ASSERT(c), (e))
#define api_check(l, e) LUAU_ASSERT(e)

View file

@ -5,9 +5,9 @@
#include "lstate.h"
#include "lvm.h"
#define CO_RUN 0 /* running */
#define CO_SUS 1 /* suspended */
#define CO_NOR 2 /* 'normal' (it resumed another coroutine) */
#define CO_RUN 0 // running
#define CO_SUS 1 // suspended
#define CO_NOR 2 // 'normal' (it resumed another coroutine)
#define CO_DEAD 3
#define CO_STATUS_ERROR -1
@ -23,13 +23,13 @@ static int auxstatus(lua_State* L, lua_State* co)
return CO_SUS;
if (co->status == LUA_BREAK)
return CO_NOR;
if (co->status != 0) /* some error occurred */
if (co->status != 0) // some error occurred
return CO_DEAD;
if (co->ci != co->base_ci) /* does it have frames? */
if (co->ci != co->base_ci) // does it have frames?
return CO_NOR;
if (co->top == co->base)
return CO_DEAD;
return CO_SUS; /* initial state */
return CO_SUS; // initial state
}
static int costatus(lua_State* L)
@ -68,10 +68,10 @@ static int auxresume(lua_State* L, lua_State* co, int narg)
int nres = cast_int(co->top - co->base);
if (nres)
{
/* +1 accounts for true/false status in resumefinish */
// +1 accounts for true/false status in resumefinish
if (nres + 1 > LUA_MINSTACK && !lua_checkstack(L, nres + 1))
luaL_error(L, "too many results to resume");
lua_xmove(co, L, nres); /* move yielded values */
lua_xmove(co, L, nres); // move yielded values
}
return nres;
}
@ -81,7 +81,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg)
}
else
{
lua_xmove(co, L, 1); /* move error message */
lua_xmove(co, L, 1); // move error message
return CO_STATUS_ERROR;
}
}
@ -102,13 +102,13 @@ static int auxresumecont(lua_State* L, lua_State* co)
int nres = cast_int(co->top - co->base);
if (!lua_checkstack(L, nres + 1))
luaL_error(L, "too many results to resume");
lua_xmove(co, L, nres); /* move yielded values */
lua_xmove(co, L, nres); // move yielded values
return nres;
}
else
{
lua_rawcheckstack(L, 2);
lua_xmove(co, L, 1); /* move error message */
lua_xmove(co, L, 1); // move error message
return CO_STATUS_ERROR;
}
}
@ -119,13 +119,13 @@ static int coresumefinish(lua_State* L, int r)
{
lua_pushboolean(L, 0);
lua_insert(L, -2);
return 2; /* return false + error message */
return 2; // return false + error message
}
else
{
lua_pushboolean(L, 1);
lua_insert(L, -(r + 1));
return r + 1; /* return true + `resume' returns */
return r + 1; // return true + `resume' returns
}
}
@ -161,12 +161,12 @@ static int auxwrapfinish(lua_State* L, int r)
if (r < 0)
{
if (lua_isstring(L, -1))
{ /* error object is a string? */
luaL_where(L, 1); /* add extra info */
{ // error object is a string?
luaL_where(L, 1); // add extra info
lua_insert(L, -2);
lua_concat(L, 2);
}
lua_error(L); /* propagate error */
lua_error(L); // propagate error
}
return r;
}
@ -221,7 +221,7 @@ static int coyield(lua_State* L)
static int corunning(lua_State* L)
{
if (lua_pushthread(L))
lua_pushnil(L); /* main thread is not a coroutine */
lua_pushnil(L); // main thread is not a coroutine
return 1;
}
@ -250,7 +250,7 @@ static int coclose(lua_State* L)
{
lua_pushboolean(L, false);
if (lua_gettop(co))
lua_xmove(co, L, 1); /* move error message */
lua_xmove(co, L, 1); // move error message
lua_resetthread(co);
return 2;
}

View file

@ -82,9 +82,9 @@ static int db_info(lua_State* L)
case 'f':
if (L1 == L)
lua_pushvalue(L, -1 - results); /* function is right before results */
lua_pushvalue(L, -1 - results); // function is right before results
else
lua_xmove(L1, L, 1); /* function is at top of L1 */
lua_xmove(L1, L, 1); // function is at top of L1
results++;
break;
@ -130,15 +130,14 @@ static int db_traceback(lua_State* L)
if (ar.currentline > 0)
{
char line[32];
#ifdef _MSC_VER
_itoa(ar.currentline, line, 10); // 5x faster than sprintf
#else
sprintf(line, "%d", ar.currentline);
#endif
char line[32]; // manual conversion for performance
char* lineend = line + sizeof(line);
char* lineptr = lineend;
for (unsigned int r = ar.currentline; r > 0; r /= 10)
*--lineptr = '0' + (r % 10);
luaL_addchar(&buf, ':');
luaL_addstring(&buf, line);
luaL_addlstring(&buf, lineptr, lineend - lineptr);
}
if (ar.name)

View file

@ -12,6 +12,8 @@
#include <string.h>
#include <stdio.h>
LUAU_FASTFLAGVARIABLE(LuauDebuggerBreakpointHitOnNextBestLine, false);
static const char* getfuncname(Closure* f);
static int currentpc(lua_State* L, CallInfo* ci)
@ -84,7 +86,7 @@ const char* lua_setlocal(lua_State* L, int level, int n)
const LocVar* var = fp ? luaF_getlocal(fp, n, currentpc(L, ci)) : NULL;
if (var)
setobjs2s(L, ci->base + var->reg, L->top - 1);
L->top--; /* pop value */
L->top--; // pop value
const char* name = var ? getstr(var->varname) : NULL;
return name;
}
@ -267,12 +269,24 @@ l_noret luaG_indexerror(lua_State* L, const TValue* p1, const TValue* p2)
luaG_runerror(L, "attempt to index %s with %s", t1, t2);
}
l_noret luaG_methoderror(lua_State* L, const TValue* p1, const TString* p2)
{
const char* t1 = luaT_objtypename(L, p1);
luaG_runerror(L, "attempt to call missing method '%s' of %s", getstr(p2), t1);
}
l_noret luaG_readonlyerror(lua_State* L)
{
luaG_runerror(L, "attempt to modify a readonly table");
}
static void pusherror(lua_State* L, const char* msg)
{
CallInfo* ci = L->ci;
if (isLua(ci))
{
char buff[LUA_IDSIZE]; /* add file:line information */
char buff[LUA_IDSIZE]; // add file:line information
luaO_chunkid(buff, getstr(getluaproto(ci)->source), LUA_IDSIZE);
int line = currentline(L, ci);
luaO_pushfstring(L, "%s:%d: %s", buff, line, msg);
@ -367,14 +381,6 @@ void lua_singlestep(lua_State* L, int enabled)
L->singlestep = bool(enabled);
}
void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled)
{
const TValue* func = luaA_toobject(L, funcindex);
api_check(L, ttisfunction(func) && !clvalue(func)->isC);
luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled));
}
static int getmaxline(Proto* p)
{
int result = -1;
@ -394,6 +400,71 @@ static int getmaxline(Proto* p)
return result;
}
// Find the line number with instructions. If the provided line doesn't have any instruction, it should return the next line number with
// instructions.
static int getnextline(Proto* p, int line)
{
int closest = -1;
if (p->lineinfo)
{
for (int i = 0; i < p->sizecode; ++i)
{
// note: we keep prologue as is, instead opting to break at the first meaningful instruction
if (LUAU_INSN_OP(p->code[i]) == LOP_PREPVARARGS)
continue;
int current = luaG_getline(p, i);
if (current >= line)
{
closest = current;
break;
}
}
}
for (int i = 0; i < p->sizep; ++i)
{
// Find the closest line number to the intended one.
int candidate = getnextline(p->p[i], line);
if (closest == -1 || (candidate >= line && candidate < closest))
{
closest = candidate;
}
}
return closest;
}
int lua_breakpoint(lua_State* L, int funcindex, int line, int enabled)
{
int target = -1;
if (FFlag::LuauDebuggerBreakpointHitOnNextBestLine)
{
const TValue* func = luaA_toobject(L, funcindex);
api_check(L, ttisfunction(func) && !clvalue(func)->isC);
Proto* p = clvalue(func)->l.p;
// Find line number to add the breakpoint to.
target = getnextline(p, line);
if (target != -1)
{
// Add breakpoint on the exact line
luaG_breakpoint(L, p, target, bool(enabled));
}
}
else
{
const TValue* func = luaA_toobject(L, funcindex);
api_check(L, ttisfunction(func) && !clvalue(func)->isC);
luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled));
}
return target;
}
static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* context, lua_Coverage callback)
{
memset(buffer, -1, size * sizeof(int));
@ -465,7 +536,7 @@ const char* lua_debugtrace(lua_State* L)
if (ar.currentline > 0)
{
char line[32];
sprintf(line, ":%d", ar.currentline);
snprintf(line, sizeof(line), ":%d", ar.currentline);
offset = append(buf, sizeof(buf), offset, line);
}
@ -481,7 +552,7 @@ const char* lua_debugtrace(lua_State* L)
if (depth > limit1 + limit2 && level == limit1 - 1)
{
char skip[32];
sprintf(skip, "... (+%d frames)\n", int(depth - limit1 - limit2));
snprintf(skip, sizeof(skip), "... (+%d frames)\n", int(depth - limit1 - limit2));
offset = append(buf, sizeof(buf), offset, skip);

View file

@ -19,6 +19,8 @@ LUAI_FUNC l_noret luaG_concaterror(lua_State* L, StkId p1, StkId p2);
LUAI_FUNC l_noret luaG_aritherror(lua_State* L, const TValue* p1, const TValue* p2, TMS op);
LUAI_FUNC l_noret luaG_ordererror(lua_State* L, const TValue* p1, const TValue* p2, TMS op);
LUAI_FUNC l_noret luaG_indexerror(lua_State* L, const TValue* p1, const TValue* p2);
LUAI_FUNC l_noret luaG_methoderror(lua_State* L, const TValue* p1, const TString* p2);
LUAI_FUNC l_noret luaG_readonlyerror(lua_State* L);
LUAI_FUNC LUA_PRINTF_ATTR(2, 3) l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...);
LUAI_FUNC void luaG_pusherror(lua_State* L, const char* error);

View file

@ -31,7 +31,7 @@ struct lua_jmpbuf
jmp_buf buf;
};
/* use POSIX versions of setjmp/longjmp if possible: they don't save/restore signal mask and are therefore faster */
// use POSIX versions of setjmp/longjmp if possible: they don't save/restore signal mask and are therefore faster
#if defined(__linux__) || defined(__APPLE__)
#define LUAU_SETJMP(buf) _setjmp(buf)
#define LUAU_LONGJMP(buf, code) _longjmp(buf, code)
@ -153,7 +153,7 @@ l_noret luaD_throw(lua_State* L, int errcode)
}
#endif
/* }====================================================== */
// }======================================================
static void correctstack(lua_State* L, TValue* oldstack)
{
@ -177,7 +177,7 @@ void luaD_reallocstack(lua_State* L, int newsize)
luaM_reallocarray(L, L->stack, L->stacksize, realsize, TValue, L->memcat);
TValue* newstack = L->stack;
for (int i = L->stacksize; i < realsize; i++)
setnilvalue(newstack + i); /* erase new segment */
setnilvalue(newstack + i); // erase new segment
L->stacksize = realsize;
L->stack_last = newstack + newsize;
correctstack(L, oldstack);
@ -194,7 +194,7 @@ void luaD_reallocCI(lua_State* L, int newsize)
void luaD_growstack(lua_State* L, int n)
{
if (n <= L->stacksize) /* double size is enough? */
if (n <= L->stacksize) // double size is enough?
luaD_reallocstack(L, 2 * L->stacksize);
else
luaD_reallocstack(L, L->stacksize + n);
@ -202,11 +202,11 @@ void luaD_growstack(lua_State* L, int n)
CallInfo* luaD_growCI(lua_State* L)
{
/* allow extra stack space to handle stack overflow in xpcall */
// allow extra stack space to handle stack overflow in xpcall
const int hardlimit = LUAI_MAXCALLS + (LUAI_MAXCALLS >> 3);
if (L->size_ci >= hardlimit)
luaD_throw(L, LUA_ERRERR); /* error while handling stack error */
luaD_throw(L, LUA_ERRERR); // error while handling stack error
int request = L->size_ci * 2;
luaD_reallocCI(L, L->size_ci >= LUAI_MAXCALLS ? hardlimit : request < LUAI_MAXCALLS ? request : LUAI_MAXCALLS);
@ -219,13 +219,13 @@ CallInfo* luaD_growCI(lua_State* L)
void luaD_checkCstack(lua_State* L)
{
/* allow extra stack space to handle stack overflow in xpcall */
// allow extra stack space to handle stack overflow in xpcall
const int hardlimit = LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3);
if (L->nCcalls == LUAI_MAXCCALLS)
luaG_runerror(L, "C stack overflow");
else if (L->nCcalls >= hardlimit)
luaD_throw(L, LUA_ERRERR); /* error while handling stack error */
luaD_throw(L, LUA_ERRERR); // error while handling stack error
}
/*
@ -240,14 +240,14 @@ void luaD_call(lua_State* L, StkId func, int nResults)
luaD_checkCstack(L);
if (luau_precall(L, func, nResults) == PCRLUA)
{ /* is a Lua function? */
L->ci->flags |= LUA_CALLINFO_RETURN; /* luau_execute will stop after returning from the stack frame */
{ // is a Lua function?
L->ci->flags |= LUA_CALLINFO_RETURN; // luau_execute will stop after returning from the stack frame
int oldactive = luaC_threadactive(L);
l_setbit(L->stackstate, THREAD_ACTIVEBIT);
luaC_checkthreadsleep(L);
luau_execute(L); /* call it */
luau_execute(L); // call it
if (!oldactive)
resetbit(L->stackstate, THREAD_ACTIVEBIT);
@ -263,18 +263,18 @@ static void seterrorobj(lua_State* L, int errcode, StkId oldtop)
{
case LUA_ERRMEM:
{
setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_MEMERRMSG)); /* can not fail because string is pinned in luaopen */
setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_MEMERRMSG)); // can not fail because string is pinned in luaopen
break;
}
case LUA_ERRERR:
{
setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_ERRERRMSG)); /* can not fail because string is pinned in luaopen */
setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_ERRERRMSG)); // can not fail because string is pinned in luaopen
break;
}
case LUA_ERRSYNTAX:
case LUA_ERRRUN:
{
setobjs2s(L, oldtop, L->top - 1); /* error message on current top */
setobjs2s(L, oldtop, L->top - 1); // error message on current top
break;
}
}
@ -430,8 +430,8 @@ static void resume_finish(lua_State* L, int status)
resetbit(L->stackstate, THREAD_ACTIVEBIT);
if (status != 0)
{ /* error? */
L->status = cast_byte(status); /* mark thread as `dead' */
{ // error?
L->status = cast_byte(status); // mark thread as `dead'
seterrorobj(L, status, L->top);
L->ci->top = L->top;
}
@ -503,7 +503,7 @@ int lua_yield(lua_State* L, int nresults)
{
if (L->nCcalls > L->baseCcalls)
luaG_runerror(L, "attempt to yield across metamethod/C-call boundary");
L->base = L->top - nresults; /* protect stack slots below */
L->base = L->top - nresults; // protect stack slots below
L->status = LUA_YIELD;
return -1;
}
@ -535,9 +535,9 @@ static void restore_stack_limit(lua_State* L)
{
LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK);
if (L->size_ci > LUAI_MAXCALLS)
{ /* there was an overflow? */
{ // there was an overflow?
int inuse = cast_int(L->ci - L->base_ci);
if (inuse + 1 < LUAI_MAXCALLS) /* can `undo' overflow? */
if (inuse + 1 < LUAI_MAXCALLS) // can `undo' overflow?
luaD_reallocCI(L, LUAI_MAXCALLS);
}
}
@ -576,7 +576,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
}
StkId oldtop = restorestack(L, old_top);
luaF_close(L, oldtop); /* close eventual pending closures */
luaF_close(L, oldtop); // close eventual pending closures
seterrorobj(L, status, oldtop);
L->ci = restoreci(L, old_ci);
L->base = L->ci->base;

Some files were not shown because too many files have changed in this diff Show more