Merge branch 'master' of https://github.com/Roblox/luau into pr1004

This commit is contained in:
Alexander McCord 2023-09-20 13:46:11 -07:00
commit 94dc9d40c1
218 changed files with 9879 additions and 3604 deletions

View file

@ -1,185 +0,0 @@
name: benchmark-dev
on:
push:
branches:
- master
paths-ignore:
- "docs/**"
- "papers/**"
- "rfcs/**"
- "*.md"
jobs:
windows:
name: windows-${{matrix.arch}}
strategy:
fail-fast: false
matrix:
os: [windows-latest]
arch: [Win32, x64]
bench:
- {
script: "run-benchmarks",
timeout: 12,
title: "Luau Benchmarks",
}
benchResultsRepo:
- { name: "luau-lang/benchmark-data", branch: "main" }
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Luau repository
uses: actions/checkout@v3
- name: Build Luau
shell: bash # necessary for fail-fast
run: |
mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
cmake --build . --target Luau.Repl.CLI --config Release
cmake --build . --target Luau.Analyze.CLI --config Release
- name: Move build files to root
run: |
move build/Release/* .
- uses: actions/setup-python@v3
with:
python-version: "3.9"
architecture: "x64"
- name: Install python dependencies
run: |
python -m pip install requests
python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose
- name: Run benchmark
run: |
python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt
- name: Push benchmark results
id: pushBenchmarkAttempt1
continue-on-error: true
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})"
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 2)
id: pushBenchmarkAttempt2
continue-on-error: true
if: steps.pushBenchmarkAttempt1.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})"
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 3)
id: pushBenchmarkAttempt3
continue-on-error: true
if: steps.pushBenchmarkAttempt2.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})"
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
unix:
name: ${{matrix.os}}
strategy:
fail-fast: false
matrix:
os: [ubuntu-20.04, macos-latest]
bench:
- {
script: "run-benchmarks",
timeout: 12,
title: "Luau Benchmarks",
}
benchResultsRepo:
- { name: "luau-lang/benchmark-data", branch: "main" }
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Luau repository
uses: actions/checkout@v3
- name: Build Luau
run: make config=release luau luau-analyze
- uses: actions/setup-python@v3
with:
python-version: "3.9"
architecture: "x64"
- name: Install python dependencies
run: |
python -m pip install requests
python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose
- name: Run benchmark
run: |
python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt
- name: Push benchmark results
id: pushBenchmarkAttempt1
continue-on-error: true
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: ${{ matrix.bench.title }}
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 2)
id: pushBenchmarkAttempt2
continue-on-error: true
if: steps.pushBenchmarkAttempt1.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: ${{ matrix.bench.title }}
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"
- name: Push benchmark results (Attempt 3)
id: pushBenchmarkAttempt3
continue-on-error: true
if: steps.pushBenchmarkAttempt2.outcome == 'failure'
uses: ./.github/workflows/push-results
with:
repository: ${{ matrix.benchResultsRepo.name }}
branch: ${{ matrix.benchResultsRepo.branch }}
token: ${{ secrets.BENCH_GITHUB_TOKEN }}
path: "./gh-pages"
bench_name: ${{ matrix.bench.title }}
bench_tool: "benchmarkluau"
bench_output_file_path: "./${{ matrix.bench.script }}-output.txt"
bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json"

View file

@ -25,6 +25,7 @@ jobs:
- name: Install valgrind - name: Install valgrind
run: | run: |
sudo apt-get update
sudo apt-get install valgrind sudo apt-get install valgrind
- name: Build Luau (gcc) - name: Build Luau (gcc)

View file

@ -1,63 +0,0 @@
name: Checkout & push results
description: Checkout a given repo and push results to GitHub
inputs:
repository:
required: true
type: string
description: The benchmark results repository to check out
branch:
required: true
type: string
description: The benchmark results repository's branch to check out
token:
required: true
type: string
description: The GitHub token to use for pushing results
path:
required: true
type: string
description: The path to check out the results repository to
bench_name:
required: true
type: string
bench_tool:
required: true
type: string
bench_output_file_path:
required: true
type: string
bench_external_data_json_path:
required: true
type: string
runs:
using: "composite"
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
repository: ${{ inputs.repository }}
ref: ${{ inputs.branch }}
token: ${{ inputs.token }}
path: ${{ inputs.path }}
- name: Store results
uses: Roblox/rhysd-github-action-benchmark@v-luau
with:
name: ${{ inputs.bench_name }}
tool: ${{ inputs.bench_tool }}
gh-pages-branch: ${{ inputs.branch }}
output-file-path: ${{ inputs.bench_output_file_path }}
external-data-json-path: ${{ inputs.bench_external_data_json_path }}
- name: Push benchmark results
shell: bash
run: |
echo "Pushing benchmark results..."
cd gh-pages
git config user.name github-actions
git config user.email github@users.noreply.github.com
git add *.json
git commit -m "Add benchmarks results for ${{ github.sha }}"
git push
cd ..

1
.gitignore vendored
View file

@ -10,4 +10,5 @@
/luau /luau
/luau-tests /luau-tests
/luau-analyze /luau-analyze
/luau-compile
__pycache__ __pycache__

View file

@ -14,8 +14,6 @@ struct GlobalTypes;
struct TypeChecker; struct TypeChecker;
struct TypeArena; struct TypeArena;
void registerBuiltinTypes(GlobalTypes& globals);
void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false); void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false);
TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types); TypeId makeUnion(TypeArena& arena, std::vector<TypeId>&& types);
TypeId makeIntersection(TypeArena& arena, std::vector<TypeId>&& types); TypeId makeIntersection(TypeArena& arena, std::vector<TypeId>&& types);

View file

@ -16,6 +16,8 @@ using SeenTypePacks = std::unordered_map<TypePackId, TypePackId>;
struct CloneState struct CloneState
{ {
NotNull<BuiltinTypes> builtinTypes;
SeenTypes seenTypes; SeenTypes seenTypes;
SeenTypePacks seenTypePacks; SeenTypePacks seenTypePacks;

View file

@ -101,9 +101,10 @@ struct ConstraintGraphBuilder
DcrLogger* logger; DcrLogger* logger;
ConstraintGraphBuilder(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver, NotNull<BuiltinTypes> builtinTypes, ConstraintGraphBuilder(ModulePtr module, NotNull<Normalizer> normalizer, NotNull<ModuleResolver> moduleResolver,
NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope, std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, NotNull<BuiltinTypes> builtinTypes, NotNull<InternalErrorReporter> ice, const ScopePtr& globalScope,
DcrLogger* logger, NotNull<DataFlowGraph> dfg, std::vector<RequireCycle> requireCycles); std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope, DcrLogger* logger, NotNull<DataFlowGraph> dfg,
std::vector<RequireCycle> requireCycles);
/** /**
* Fabricates a new free type belonging to a given scope. * Fabricates a new free type belonging to a given scope.

View file

@ -279,8 +279,6 @@ private:
TypeId errorRecoveryType() const; TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const; TypePackId errorRecoveryTypePack() const;
TypeId unionOfTypes(TypeId a, TypeId b, NotNull<Scope> scope, bool unifyFreeTypes);
TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp);
void throwTimeLimitError(); void throwTimeLimitError();

View file

@ -128,10 +128,10 @@ struct DiffError
checkValidInitialization(left, right); checkValidInitialization(left, right);
} }
std::string toString() const; std::string toString(bool multiLine = false) const;
private: private:
std::string toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf) const; std::string toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf, bool multiLine) const;
void checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right); void checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right);
void checkNonMissingPropertyLeavesHaveNulloptTableProperty() const; void checkNonMissingPropertyLeavesHaveNulloptTableProperty() const;
}; };
@ -152,12 +152,17 @@ struct DifferEnvironment
{ {
TypeId rootLeft; TypeId rootLeft;
TypeId rootRight; TypeId rootRight;
std::optional<std::string> externalSymbolLeft;
std::optional<std::string> externalSymbolRight;
DenseHashMap<TypeId, TypeId> genericMatchedPairs; DenseHashMap<TypeId, TypeId> genericMatchedPairs;
DenseHashMap<TypePackId, TypePackId> genericTpMatchedPairs; DenseHashMap<TypePackId, TypePackId> genericTpMatchedPairs;
DifferEnvironment(TypeId rootLeft, TypeId rootRight) DifferEnvironment(
TypeId rootLeft, TypeId rootRight, std::optional<std::string> externalSymbolLeft, std::optional<std::string> externalSymbolRight)
: rootLeft(rootLeft) : rootLeft(rootLeft)
, rootRight(rootRight) , rootRight(rootRight)
, externalSymbolLeft(externalSymbolLeft)
, externalSymbolRight(externalSymbolRight)
, genericMatchedPairs(nullptr) , genericMatchedPairs(nullptr)
, genericTpMatchedPairs(nullptr) , genericTpMatchedPairs(nullptr)
{ {
@ -170,6 +175,8 @@ struct DifferEnvironment
void popVisiting(); void popVisiting();
std::vector<std::pair<TypeId, TypeId>>::const_reverse_iterator visitingBegin() const; std::vector<std::pair<TypeId, TypeId>>::const_reverse_iterator visitingBegin() const;
std::vector<std::pair<TypeId, TypeId>>::const_reverse_iterator visitingEnd() const; std::vector<std::pair<TypeId, TypeId>>::const_reverse_iterator visitingEnd() const;
std::string getDevFixFriendlyNameLeft() const;
std::string getDevFixFriendlyNameRight() const;
private: private:
// TODO: consider using DenseHashSet // TODO: consider using DenseHashSet
@ -179,6 +186,7 @@ private:
std::vector<std::pair<TypeId, TypeId>> visitingStack; std::vector<std::pair<TypeId, TypeId>> visitingStack;
}; };
DifferResult diff(TypeId ty1, TypeId ty2); DifferResult diff(TypeId ty1, TypeId ty2);
DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional<std::string> symbol1, std::optional<std::string> symbol2);
/** /**
* True if ty is a "simple" type, i.e. cannot contain types. * True if ty is a "simple" type, i.e. cannot contain types.

View file

@ -2,6 +2,7 @@
#pragma once #pragma once
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/NotNull.h"
#include "Luau/Type.h" #include "Luau/Type.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
@ -432,7 +433,7 @@ std::string toString(const TypeError& error, TypeErrorToStringOptions options);
bool containsParseErrorName(const TypeError& error); bool containsParseErrorName(const TypeError& error);
// Copy any types named in the error into destArena. // Copy any types named in the error into destArena.
void copyErrors(ErrorVec& errors, struct TypeArena& destArena); void copyErrors(ErrorVec& errors, struct TypeArena& destArena, NotNull<BuiltinTypes> builtinTypes);
// Internal Compiler Error // Internal Compiler Error
struct InternalErrorReporter struct InternalErrorReporter

View file

@ -2,12 +2,12 @@
#pragma once #pragma once
#include "Luau/Config.h" #include "Luau/Config.h"
#include "Luau/GlobalTypes.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/ModuleResolver.h" #include "Luau/ModuleResolver.h"
#include "Luau/RequireTracer.h" #include "Luau/RequireTracer.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/TypeCheckLimits.h" #include "Luau/TypeCheckLimits.h"
#include "Luau/TypeInfer.h"
#include "Luau/Variant.h" #include "Luau/Variant.h"
#include <mutex> #include <mutex>
@ -100,6 +100,12 @@ struct FrontendOptions
std::optional<LintOptions> enabledLintWarnings; std::optional<LintOptions> enabledLintWarnings;
std::shared_ptr<FrontendCancellationToken> cancellationToken; std::shared_ptr<FrontendCancellationToken> cancellationToken;
// Time limit for typechecking a single module
std::optional<double> moduleTimeLimitSec;
// When true, some internal complexity limits will be scaled down for modules that miss the limit set by moduleTimeLimitSec
bool applyInternalLimitScaling = false;
}; };
struct CheckResult struct CheckResult

View file

@ -0,0 +1,26 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Module.h"
#include "Luau/NotNull.h"
#include "Luau/Scope.h"
#include "Luau/TypeArena.h"
namespace Luau
{
struct BuiltinTypes;
struct GlobalTypes
{
explicit GlobalTypes(NotNull<BuiltinTypes> builtinTypes);
NotNull<BuiltinTypes> builtinTypes; // Global types are based on builtin types
TypeArena globalTypes;
SourceModule globalNames; // names for symbols entered into globalScope
ScopePtr globalScope; // shared by all modules
};
} // namespace Luau

View file

@ -17,8 +17,8 @@ struct TypeCheckLimits;
// A substitution which replaces generic types in a given set by free types. // A substitution which replaces generic types in a given set by free types.
struct ReplaceGenerics : Substitution struct ReplaceGenerics : Substitution
{ {
ReplaceGenerics(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope, const std::vector<TypeId>& generics, ReplaceGenerics(const TxnLog* log, TypeArena* arena, NotNull<BuiltinTypes> builtinTypes, TypeLevel level, Scope* scope,
const std::vector<TypePackId>& genericPacks) const std::vector<TypeId>& generics, const std::vector<TypePackId>& genericPacks)
: Substitution(log, arena) : Substitution(log, arena)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, level(level) , level(level)
@ -77,6 +77,7 @@ struct Instantiation : Substitution
* Instantiation fails only when processing the type causes internal recursion * Instantiation fails only when processing the type causes internal recursion
* limits to be exceeded. * limits to be exceeded.
*/ */
std::optional<TypeId> instantiate(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, TypeId ty); std::optional<TypeId> instantiate(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, TypeId ty);
} // namespace Luau } // namespace Luau

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/LinterConfig.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include <memory> #include <memory>
@ -15,86 +16,15 @@ class AstStat;
class AstNameTable; class AstNameTable;
struct TypeChecker; struct TypeChecker;
struct Module; struct Module;
struct HotComment;
using ScopePtr = std::shared_ptr<struct Scope>; using ScopePtr = std::shared_ptr<struct Scope>;
struct LintWarning
{
// Make sure any new lint codes are documented here: https://luau-lang.org/lint
// Note that in Studio, the active set of lint warnings is determined by FStringStudioLuauLints
enum Code
{
Code_Unknown = 0,
Code_UnknownGlobal = 1, // superseded by type checker
Code_DeprecatedGlobal = 2,
Code_GlobalUsedAsLocal = 3,
Code_LocalShadow = 4, // disabled in Studio
Code_SameLineStatement = 5, // disabled in Studio
Code_MultiLineStatement = 6,
Code_LocalUnused = 7, // disabled in Studio
Code_FunctionUnused = 8, // disabled in Studio
Code_ImportUnused = 9, // disabled in Studio
Code_BuiltinGlobalWrite = 10,
Code_PlaceholderRead = 11,
Code_UnreachableCode = 12,
Code_UnknownType = 13,
Code_ForRange = 14,
Code_UnbalancedAssignment = 15,
Code_ImplicitReturn = 16, // disabled in Studio, superseded by type checker in strict mode
Code_DuplicateLocal = 17,
Code_FormatString = 18,
Code_TableLiteral = 19,
Code_UninitializedLocal = 20,
Code_DuplicateFunction = 21,
Code_DeprecatedApi = 22,
Code_TableOperations = 23,
Code_DuplicateCondition = 24,
Code_MisleadingAndOr = 25,
Code_CommentDirective = 26,
Code_IntegerParsing = 27,
Code_ComparisonPrecedence = 28,
Code__Count
};
Code code;
Location location;
std::string text;
static const char* getName(Code code);
static Code parseName(const char* name);
static uint64_t parseMask(const std::vector<HotComment>& hotcomments);
};
struct LintResult struct LintResult
{ {
std::vector<LintWarning> errors; std::vector<LintWarning> errors;
std::vector<LintWarning> warnings; std::vector<LintWarning> warnings;
}; };
struct LintOptions
{
uint64_t warningMask = 0;
void enableWarning(LintWarning::Code code)
{
warningMask |= 1ull << code;
}
void disableWarning(LintWarning::Code code)
{
warningMask &= ~(1ull << code);
}
bool isEnabled(LintWarning::Code code) const
{
return 0 != (warningMask & (1ull << code));
}
void setDefaults();
};
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module,
const std::vector<HotComment>& hotcomments, const LintOptions& options); const std::vector<HotComment>& hotcomments, const LintOptions& options);

View file

@ -19,6 +19,7 @@ static const std::unordered_map<AstExprBinary::Op, const char*> kBinaryOpMetamet
{AstExprBinary::Op::Sub, "__sub"}, {AstExprBinary::Op::Sub, "__sub"},
{AstExprBinary::Op::Mul, "__mul"}, {AstExprBinary::Op::Mul, "__mul"},
{AstExprBinary::Op::Div, "__div"}, {AstExprBinary::Op::Div, "__div"},
{AstExprBinary::Op::FloorDiv, "__idiv"},
{AstExprBinary::Op::Pow, "__pow"}, {AstExprBinary::Op::Pow, "__pow"},
{AstExprBinary::Op::Mod, "__mod"}, {AstExprBinary::Op::Mod, "__mod"},
{AstExprBinary::Op::Concat, "__concat"}, {AstExprBinary::Op::Concat, "__concat"},

View file

@ -111,6 +111,7 @@ struct Module
LintResult lintResult; LintResult lintResult;
Mode mode; Mode mode;
SourceCode::Type type; SourceCode::Type type;
double checkDurationSec = 0.0;
bool timeout = false; bool timeout = false;
bool cancelled = false; bool cancelled = false;

View file

@ -147,9 +147,6 @@ struct Tarjan
void visitEdge(int index, int parentIndex); void visitEdge(int index, int parentIndex);
void visitSCC(int index); void visitSCC(int index);
TarjanResult loop_DEPRECATED();
void visitSCC_DEPRECATED(int index);
// Each subclass can decide to ignore some nodes. // Each subclass can decide to ignore some nodes.
virtual bool ignoreChildren(TypeId ty) virtual bool ignoreChildren(TypeId ty)
{ {
@ -178,13 +175,6 @@ struct Tarjan
virtual bool isDirty(TypePackId tp) = 0; virtual bool isDirty(TypePackId tp) = 0;
virtual void foundDirty(TypeId ty) = 0; virtual void foundDirty(TypeId ty) = 0;
virtual void foundDirty(TypePackId tp) = 0; virtual void foundDirty(TypePackId tp) = 0;
// TODO: remove with FFlagLuauTarjanSingleArr
std::vector<TypeId> indexToType;
std::vector<TypePackId> indexToPack;
std::vector<bool> onStack;
std::vector<int> lowlink;
std::vector<bool> dirty;
}; };
// And finally substitution, which finds all the reachable dirty vertices // And finally substitution, which finds all the reachable dirty vertices

View file

@ -0,0 +1,146 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Type.h"
#include "Luau/TypePack.h"
#include "Luau/UnifierSharedState.h"
#include <vector>
#include <optional>
namespace Luau
{
template<typename A, typename B>
struct TryPair;
struct InternalErrorReporter;
class TypeIds;
class Normalizer;
struct NormalizedType;
struct NormalizedClassType;
struct NormalizedStringType;
struct NormalizedFunctionType;
struct SubtypingResult
{
bool isSubtype = false;
bool isErrorSuppressing = false;
bool normalizationTooComplex = false;
SubtypingResult& andAlso(const SubtypingResult& other);
SubtypingResult& orElse(const SubtypingResult& other);
// Only negates the `isSubtype`.
static SubtypingResult negate(const SubtypingResult& result);
static SubtypingResult all(const std::vector<SubtypingResult>& results);
static SubtypingResult any(const std::vector<SubtypingResult>& results);
};
struct Subtyping
{
NotNull<BuiltinTypes> builtinTypes;
NotNull<TypeArena> arena;
NotNull<Normalizer> normalizer;
NotNull<InternalErrorReporter> iceReporter;
NotNull<Scope> scope;
enum class Variance
{
Covariant,
Contravariant
};
Variance variance = Variance::Covariant;
struct GenericBounds
{
DenseHashSet<TypeId> lowerBound{nullptr};
DenseHashSet<TypeId> upperBound{nullptr};
};
/*
* When we encounter a generic over the course of a subtyping test, we need
* to tentatively map that generic onto a type on the other side.
*/
DenseHashMap<TypeId, GenericBounds> mappedGenerics{nullptr};
DenseHashMap<TypePackId, TypePackId> mappedGenericPacks{nullptr};
using SeenSet = std::unordered_set<std::pair<TypeId, TypeId>, TypeIdPairHash>;
SeenSet seenTypes;
Subtyping(const Subtyping&) = delete;
Subtyping& operator=(const Subtyping&) = delete;
Subtyping(Subtyping&&) = default;
Subtyping& operator=(Subtyping&&) = default;
// TODO cache
// TODO cyclic types
// TODO recursion limits
SubtypingResult isSubtype(TypeId subTy, TypeId superTy);
SubtypingResult isSubtype(TypePackId subTy, TypePackId superTy);
private:
SubtypingResult isCovariantWith(TypeId subTy, TypeId superTy);
SubtypingResult isCovariantWith(TypePackId subTy, TypePackId superTy);
template<typename SubTy, typename SuperTy>
SubtypingResult isContravariantWith(SubTy&& subTy, SuperTy&& superTy);
template<typename SubTy, typename SuperTy>
SubtypingResult isInvariantWith(SubTy&& subTy, SuperTy&& superTy);
template<typename SubTy, typename SuperTy>
SubtypingResult isCovariantWith(const TryPair<const SubTy*, const SuperTy*>& pair);
template<typename SubTy, typename SuperTy>
SubtypingResult isContravariantWith(const TryPair<const SubTy*, const SuperTy*>& pair);
template<typename SubTy, typename SuperTy>
SubtypingResult isInvariantWith(const TryPair<const SubTy*, const SuperTy*>& pair);
SubtypingResult isCovariantWith(TypeId subTy, const UnionType* superUnion);
SubtypingResult isCovariantWith(const UnionType* subUnion, TypeId superTy);
SubtypingResult isCovariantWith(TypeId subTy, const IntersectionType* superIntersection);
SubtypingResult isCovariantWith(const IntersectionType* subIntersection, TypeId superTy);
SubtypingResult isCovariantWith(const NegationType* subNegation, TypeId superTy);
SubtypingResult isCovariantWith(const TypeId subTy, const NegationType* superNegation);
SubtypingResult isCovariantWith(const PrimitiveType* subPrim, const PrimitiveType* superPrim);
SubtypingResult isCovariantWith(const SingletonType* subSingleton, const PrimitiveType* superPrim);
SubtypingResult isCovariantWith(const SingletonType* subSingleton, const SingletonType* superSingleton);
SubtypingResult isCovariantWith(const TableType* subTable, const TableType* superTable);
SubtypingResult isCovariantWith(const MetatableType* subMt, const MetatableType* superMt);
SubtypingResult isCovariantWith(const MetatableType* subMt, const TableType* superTable);
SubtypingResult isCovariantWith(const ClassType* subClass, const ClassType* superClass);
SubtypingResult isCovariantWith(const ClassType* subClass, const TableType* superTable);
SubtypingResult isCovariantWith(const FunctionType* subFunction, const FunctionType* superFunction);
SubtypingResult isCovariantWith(const PrimitiveType* subPrim, const TableType* superTable);
SubtypingResult isCovariantWith(const SingletonType* subSingleton, const TableType* superTable);
SubtypingResult isCovariantWith(const NormalizedType* subNorm, const NormalizedType* superNorm);
SubtypingResult isCovariantWith(const NormalizedClassType& subClass, const NormalizedClassType& superClass);
SubtypingResult isCovariantWith(const NormalizedClassType& subClass, const TypeIds& superTables);
SubtypingResult isCovariantWith(const NormalizedStringType& subString, const NormalizedStringType& superString);
SubtypingResult isCovariantWith(const NormalizedStringType& subString, const TypeIds& superTables);
SubtypingResult isCovariantWith(const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction);
SubtypingResult isCovariantWith(const TypeIds& subTypes, const TypeIds& superTypes);
SubtypingResult isCovariantWith(const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic);
bool bindGeneric(TypeId subTp, TypeId superTp);
bool bindGeneric(TypePackId subTp, TypePackId superTp);
template<typename T, typename Container>
TypeId makeAggregateType(const Container& container, TypeId orElse);
[[noreturn]] void unexpected(TypePackId tp);
};
} // namespace Luau

View file

@ -733,9 +733,17 @@ struct Type final
using SeenSet = std::set<std::pair<const void*, const void*>>; using SeenSet = std::set<std::pair<const void*, const void*>>;
bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs); bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs);
enum class FollowOption
{
Normal,
DisableLazyTypeThunks,
};
// Follow BoundTypes until we get to something real // Follow BoundTypes until we get to something real
TypeId follow(TypeId t); TypeId follow(TypeId t);
TypeId follow(TypeId t, FollowOption followOption);
TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)); TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId));
TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId));
std::vector<TypeId> flattenIntersection(TypeId ty); std::vector<TypeId> flattenIntersection(TypeId ty);
@ -790,12 +798,13 @@ struct BuiltinTypes
TypeId errorRecoveryType() const; TypeId errorRecoveryType() const;
TypePackId errorRecoveryTypePack() const; TypePackId errorRecoveryTypePack() const;
friend TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes);
friend struct GlobalTypes;
private: private:
std::unique_ptr<struct TypeArena> arena; std::unique_ptr<struct TypeArena> arena;
bool debugFreezeArena = false; bool debugFreezeArena = false;
TypeId makeStringMetatable();
public: public:
const TypeId nilType; const TypeId nilType;
const TypeId numberType; const TypeId numberType;
@ -818,6 +827,7 @@ public:
const TypeId optionalNumberType; const TypeId optionalNumberType;
const TypeId optionalStringType; const TypeId optionalStringType;
const TypePackId emptyTypePack;
const TypePackId anyTypePack; const TypePackId anyTypePack;
const TypePackId neverTypePack; const TypePackId neverTypePack;
const TypePackId uninhabitableTypePack; const TypePackId uninhabitableTypePack;
@ -839,6 +849,18 @@ bool isSubclass(const ClassType* cls, const ClassType* parent);
Type* asMutable(TypeId ty); Type* asMutable(TypeId ty);
template<typename... Ts, typename T>
bool is(T&& tv)
{
if (!tv)
return false;
if constexpr (std::is_same_v<TypeId, T> && !(std::is_same_v<BoundType, Ts> || ...))
LUAU_ASSERT(get_if<BoundType>(&tv->ty) == nullptr);
return (get<Ts>(tv) || ...);
}
template<typename T> template<typename T>
const T* get(TypeId tv) const T* get(TypeId tv)
{ {

View file

@ -14,7 +14,7 @@ struct DcrLogger;
struct TypeCheckLimits; struct TypeCheckLimits;
struct UnifierSharedState; struct UnifierSharedState;
void check(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> sharedState, NotNull<TypeCheckLimits> limits, DcrLogger* logger, const SourceModule& sourceModule, void check(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> sharedState, NotNull<TypeCheckLimits> limits, DcrLogger* logger,
Module* module); const SourceModule& sourceModule, Module* module);
} // namespace Luau } // namespace Luau

View file

@ -57,17 +57,6 @@ struct HashBoolNamePair
size_t operator()(const std::pair<bool, Name>& pair) const; size_t operator()(const std::pair<bool, Name>& pair) const;
}; };
struct GlobalTypes
{
GlobalTypes(NotNull<BuiltinTypes> builtinTypes);
NotNull<BuiltinTypes> builtinTypes; // Global types are based on builtin types
TypeArena globalTypes;
SourceModule globalNames; // names for symbols entered into globalScope
ScopePtr globalScope; // shared by all modules
};
// All Types are retained via Environment::types. All TypeIds // All Types are retained via Environment::types. All TypeIds
// within a program are borrowed pointers into this set. // within a program are borrowed pointers into this set.
struct TypeChecker struct TypeChecker

View file

@ -101,6 +101,31 @@ ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypeId ty1
*/ */
ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypePackId tp1, TypePackId tp2); ErrorSuppression shouldSuppressErrors(NotNull<Normalizer> normalizer, TypePackId tp1, TypePackId tp2);
// Similar to `std::optional<std::pair<A, B>>`, but whose `sizeof()` is the same as `std::pair<A, B>`
// and cooperates with C++'s `if (auto p = ...)` syntax without the extra fatness of `std::optional`.
template<typename A, typename B>
struct TryPair
{
A first;
B second;
explicit operator bool() const
{
return bool(first) && bool(second);
}
};
template<typename A, typename B, typename Ty>
TryPair<const A*, const B*> get2(Ty one, Ty two)
{
const A* a = get<A>(one);
const B* b = get<B>(two);
if (a && b)
return {a, b};
else
return {nullptr, nullptr};
}
template<typename T, typename Ty> template<typename T, typename Ty>
const T* get(std::optional<Ty> ty) const T* get(std::optional<Ty> ty)
{ {

View file

@ -105,10 +105,12 @@ struct Unifier
* Populate the vector errors with any type errors that may arise. * Populate the vector errors with any type errors that may arise.
* Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt.
*/ */
void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); void tryUnify(
TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr);
private: private:
void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); void tryUnify_(
TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr);
void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy); void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy);
// Traverse the two types provided and block on any BlockedTypes we find. // Traverse the two types provided and block on any BlockedTypes we find.

View file

@ -55,8 +55,8 @@ struct Unifier2
bool unify(TypePackId subTp, TypePackId superTp); bool unify(TypePackId subTp, TypePackId superTp);
std::optional<TypeId> generalize(NotNull<Scope> scope, TypeId ty); std::optional<TypeId> generalize(NotNull<Scope> scope, TypeId ty);
private:
private:
/** /**
* @returns simplify(left | right) * @returns simplify(left | right)
*/ */
@ -72,4 +72,4 @@ private:
OccursCheckResult occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack); OccursCheckResult occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack);
}; };
} } // namespace Luau

View file

@ -10,6 +10,7 @@
LUAU_FASTINT(LuauVisitRecursionLimit) LUAU_FASTINT(LuauVisitRecursionLimit)
LUAU_FASTFLAG(LuauBoundLazyTypes2) LUAU_FASTFLAG(LuauBoundLazyTypes2)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTFLAG(DebugLuauReadWriteProperties)
namespace Luau namespace Luau
@ -220,7 +221,21 @@ struct GenericTypeVisitor
traverse(btv->boundTo); traverse(btv->boundTo);
} }
else if (auto ftv = get<FreeType>(ty)) else if (auto ftv = get<FreeType>(ty))
{
if (FFlag::DebugLuauDeferredConstraintResolution)
{
if (visit(ty, *ftv))
{
LUAU_ASSERT(ftv->lowerBound);
traverse(ftv->lowerBound);
LUAU_ASSERT(ftv->upperBound);
traverse(ftv->upperBound);
}
}
else
visit(ty, *ftv); visit(ty, *ftv);
}
else if (auto gtv = get<GenericType>(ty)) else if (auto gtv = get<GenericType>(ty))
visit(ty, *gtv); visit(ty, *gtv);
else if (auto etv = get<ErrorType>(ty)) else if (auto etv = get<ErrorType>(ty))

View file

@ -8,6 +8,8 @@
#include <math.h> #include <math.h>
LUAU_FASTFLAG(LuauFloorDivision)
namespace Luau namespace Luau
{ {
@ -514,6 +516,9 @@ struct AstJsonEncoder : public AstVisitor
return writeString("Mul"); return writeString("Mul");
case AstExprBinary::Div: case AstExprBinary::Div:
return writeString("Div"); return writeString("Div");
case AstExprBinary::FloorDiv:
LUAU_ASSERT(FFlag::LuauFloorDivision);
return writeString("FloorDiv");
case AstExprBinary::Mod: case AstExprBinary::Mod:
return writeString("Mod"); return writeString("Mod");
case AstExprBinary::Pow: case AstExprBinary::Pow:
@ -536,6 +541,8 @@ struct AstJsonEncoder : public AstVisitor
return writeString("And"); return writeString("And");
case AstExprBinary::Or: case AstExprBinary::Or:
return writeString("Or"); return writeString("Or");
default:
LUAU_ASSERT(!"Unknown Op");
} }
} }

View file

@ -12,6 +12,7 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTFLAG(DebugLuauReadWriteProperties)
LUAU_FASTFLAGVARIABLE(FixFindBindingAtFunctionName, false);
namespace Luau namespace Luau
{ {
@ -148,6 +149,23 @@ struct FindNode : public AstVisitor
return false; return false;
} }
bool visit(AstStatFunction* node) override
{
if (FFlag::FixFindBindingAtFunctionName)
{
visit(static_cast<AstNode*>(node));
if (node->name->location.contains(pos))
node->name->visit(this);
else if (node->func->location.contains(pos))
node->func->visit(this);
return false;
}
else
{
return AstVisitor::visit(node);
}
}
bool visit(AstStatBlock* block) override bool visit(AstStatBlock* block) override
{ {
visit(static_cast<AstNode*>(block)); visit(static_cast<AstNode*>(block));
@ -188,6 +206,23 @@ struct FindFullAncestry final : public AstVisitor
return false; return false;
} }
bool visit(AstStatFunction* node) override
{
if (FFlag::FixFindBindingAtFunctionName)
{
visit(static_cast<AstNode*>(node));
if (node->name->location.contains(pos))
node->name->visit(this);
else if (node->func->location.contains(pos))
node->func->visit(this);
return false;
}
else
{
return AstVisitor::visit(node);
}
}
bool visit(AstNode* node) override bool visit(AstNode* node) override
{ {
if (node->location.contains(pos)) if (node->location.contains(pos))

View file

@ -13,10 +13,8 @@
#include <utility> #include <utility>
LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTFLAG(DebugLuauReadWriteProperties)
LUAU_FASTFLAGVARIABLE(LuauDisableCompletionOutsideQuotes, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteDoEnd, false)
LUAU_FASTFLAGVARIABLE(LuauAnonymousAutofilled1, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteStringLiteralBounds, false);
LUAU_FASTFLAGVARIABLE(LuauAutocompleteLastTypecheck, false)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteHideSelfArg, false)
static const std::unordered_set<std::string> kStatementStartingKeywords = { static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -284,38 +282,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul
ParenthesesRecommendation parens = ParenthesesRecommendation parens =
indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect);
if (FFlag::LuauAutocompleteHideSelfArg) result[name] = AutocompleteEntry{AutocompleteEntryKind::Property, type, prop.deprecated, isWrongIndexer(type), typeCorrect,
{ containingClass, &prop, prop.documentationSymbol, {}, parens, {}, indexType == PropIndexType::Colon};
result[name] = AutocompleteEntry{
AutocompleteEntryKind::Property,
type,
prop.deprecated,
isWrongIndexer(type),
typeCorrect,
containingClass,
&prop,
prop.documentationSymbol,
{},
parens,
{},
indexType == PropIndexType::Colon
};
}
else
{
result[name] = AutocompleteEntry{
AutocompleteEntryKind::Property,
type,
prop.deprecated,
isWrongIndexer(type),
typeCorrect,
containingClass,
&prop,
prop.documentationSymbol,
{},
parens
};
}
} }
} }
}; };
@ -485,8 +453,19 @@ AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position posi
return result; return result;
} }
static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteEntryMap& result) static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result)
{ {
if (FFlag::LuauAutocompleteStringLiteralBounds)
{
if (position == node->location.begin || position == node->location.end)
{
if (auto str = node->as<AstExprConstantString>(); str && str->quoteStyle == AstExprConstantString::Quoted)
return;
else if (node->is<AstExprInterpString>())
return;
}
}
auto formatKey = [addQuotes](const std::string& key) { auto formatKey = [addQuotes](const std::string& key) {
if (addQuotes) if (addQuotes)
return "\"" + escape(key) + "\""; return "\"" + escape(key) + "\"";
@ -618,7 +597,6 @@ std::optional<TypeId> getLocalTypeInScopeAt(const Module& module, Position posit
template<typename T> template<typename T>
static std::optional<std::string> tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) static std::optional<std::string> tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments)
{ {
LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1);
ToStringOptions opts; ToStringOptions opts;
opts.useLineBreaks = false; opts.useLineBreaks = false;
opts.hideTableKind = true; opts.hideTableKind = true;
@ -637,24 +615,8 @@ static std::optional<Name> tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool
if (!canSuggestInferredType(scope, ty)) if (!canSuggestInferredType(scope, ty))
return std::nullopt; return std::nullopt;
if (FFlag::LuauAnonymousAutofilled1)
{
return tryToStringDetailed(scope, ty, functionTypeArguments); return tryToStringDetailed(scope, ty, functionTypeArguments);
} }
else
{
ToStringOptions opts;
opts.useLineBreaks = false;
opts.hideTableKind = true;
opts.scope = scope;
ToStringResult name = toStringDetailed(ty, opts);
if (name.error || name.invalid || name.cycle || name.truncated)
return std::nullopt;
return name.name;
}
}
static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position)
{ {
@ -1097,14 +1059,19 @@ static AutocompleteEntryMap autocompleteStatement(
{ {
if (AstStatForIn* statForIn = (*it)->as<AstStatForIn>(); statForIn && !statForIn->hasEnd) if (AstStatForIn* statForIn = (*it)->as<AstStatForIn>(); statForIn && !statForIn->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (AstStatFor* statFor = (*it)->as<AstStatFor>(); statFor && !statFor->hasEnd) else if (AstStatFor* statFor = (*it)->as<AstStatFor>(); statFor && !statFor->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (AstStatIf* statIf = (*it)->as<AstStatIf>(); statIf && !statIf->hasEnd) else if (AstStatIf* statIf = (*it)->as<AstStatIf>(); statIf && !statIf->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (AstStatWhile* statWhile = (*it)->as<AstStatWhile>(); statWhile && !statWhile->hasEnd) else if (AstStatWhile* statWhile = (*it)->as<AstStatWhile>(); statWhile && !statWhile->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (AstExprFunction* exprFunction = (*it)->as<AstExprFunction>(); exprFunction && !exprFunction->hasEnd) else if (AstExprFunction* exprFunction = (*it)->as<AstExprFunction>(); exprFunction && !exprFunction->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
if (FFlag::LuauAutocompleteDoEnd)
{
if (AstStatBlock* exprBlock = (*it)->as<AstStatBlock>(); exprBlock && !exprBlock->hasEnd)
result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword});
}
} }
if (ancestry.size() >= 2) if (ancestry.size() >= 2)
@ -1239,7 +1206,7 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu
result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction};
if (auto ty = findExpectedTypeAt(module, node, position)) if (auto ty = findExpectedTypeAt(module, node, position))
autocompleteStringSingleton(*ty, true, result); autocompleteStringSingleton(*ty, true, node, position, result);
} }
return AutocompleteContext::Expression; return AutocompleteContext::Expression;
@ -1345,7 +1312,7 @@ static std::optional<AutocompleteEntryMap> autocompleteStringParams(const Source
return std::nullopt; return std::nullopt;
} }
if (FFlag::LuauDisableCompletionOutsideQuotes && !nodes.back()->is<AstExprError>()) if (!nodes.back()->is<AstExprError>())
{ {
if (nodes.back()->location.end == position || nodes.back()->location.begin == position) if (nodes.back()->location.end == position || nodes.back()->location.begin == position)
{ {
@ -1419,7 +1386,6 @@ static AutocompleteResult autocompleteWhileLoopKeywords(std::vector<AstNode*> an
static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy)
{ {
LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1);
std::string result = "function("; std::string result = "function(";
auto [args, tail] = Luau::flatten(funcTy.argTypes); auto [args, tail] = Luau::flatten(funcTy.argTypes);
@ -1483,9 +1449,9 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func
return result; return result;
} }
static std::optional<AutocompleteEntry> makeAnonymousAutofilled(const ModulePtr& module, Position position, const AstNode* node, const std::vector<AstNode*>& ancestry) static std::optional<AutocompleteEntry> makeAnonymousAutofilled(
const ModulePtr& module, Position position, const AstNode* node, const std::vector<AstNode*>& ancestry)
{ {
LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1);
const AstExprCall* call = node->as<AstExprCall>(); const AstExprCall* call = node->as<AstExprCall>();
if (!call && ancestry.size() > 1) if (!call && ancestry.size() > 1)
call = ancestry[ancestry.size() - 2]->as<AstExprCall>(); call = ancestry[ancestry.size() - 2]->as<AstExprCall>();
@ -1720,7 +1686,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry);
if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) if (auto nodeIt = module->astExpectedTypes.find(node->asExpr()))
autocompleteStringSingleton(*nodeIt, !node->is<AstExprConstantString>(), result); autocompleteStringSingleton(*nodeIt, !node->is<AstExprConstantString>(), node, position, result);
if (!key) if (!key)
{ {
@ -1732,7 +1698,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
// suggest those too. // suggest those too.
if (auto ttv = get<TableType>(follow(*it)); ttv && ttv->indexer) if (auto ttv = get<TableType>(follow(*it)); ttv && ttv->indexer)
{ {
autocompleteStringSingleton(ttv->indexer->indexType, false, result); autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result);
} }
} }
@ -1769,7 +1735,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
AutocompleteEntryMap result; AutocompleteEntryMap result;
if (auto it = module->astExpectedTypes.find(node->asExpr())) if (auto it = module->astExpectedTypes.find(node->asExpr()))
autocompleteStringSingleton(*it, false, result); autocompleteStringSingleton(*it, false, node, position, result);
if (ancestry.size() >= 2) if (ancestry.size() >= 2)
{ {
@ -1783,7 +1749,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe)
{ {
if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left))
autocompleteStringSingleton(*it, false, result); autocompleteStringSingleton(*it, false, node, position, result);
} }
} }
} }
@ -1802,19 +1768,12 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
return {}; return {};
if (node->asExpr()) if (node->asExpr())
{
if (FFlag::LuauAnonymousAutofilled1)
{ {
AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position);
if (std::optional<AutocompleteEntry> generated = makeAnonymousAutofilled(module, position, node, ancestry)) if (std::optional<AutocompleteEntry> generated = makeAnonymousAutofilled(module, position, node, ancestry))
ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated);
return ret; return ret;
} }
else
{
return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position);
}
}
else if (node->asStat()) else if (node->asStat())
return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement};
@ -1823,15 +1782,6 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback)
{ {
if (!FFlag::LuauAutocompleteLastTypecheck)
{
// FIXME: We can improve performance here by parsing without checking.
// The old type graph is probably fine. (famous last words!)
FrontendOptions opts;
opts.forAutocomplete = true;
frontend.check(moduleName, opts);
}
const SourceModule* sourceModule = frontend.getSourceModule(moduleName); const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
if (!sourceModule) if (!sourceModule)
return {}; return {};

View file

@ -201,18 +201,6 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string&
} }
} }
void registerBuiltinTypes(GlobalTypes& globals)
{
globals.globalScope->addBuiltinTypeBinding("any", TypeFun{{}, globals.builtinTypes->anyType});
globals.globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, globals.builtinTypes->nilType});
globals.globalScope->addBuiltinTypeBinding("number", TypeFun{{}, globals.builtinTypes->numberType});
globals.globalScope->addBuiltinTypeBinding("string", TypeFun{{}, globals.builtinTypes->stringType});
globals.globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, globals.builtinTypes->booleanType});
globals.globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, globals.builtinTypes->threadType});
globals.globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, globals.builtinTypes->unknownType});
globals.globalScope->addBuiltinTypeBinding("never", TypeFun{{}, globals.builtinTypes->neverType});
}
void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete) void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete)
{ {
LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); LUAU_ASSERT(!globals.globalTypes.types.isFrozen());
@ -310,6 +298,520 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC
attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire);
} }
static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
{
const char* options = "cdiouxXeEfgGqs*";
std::vector<TypeId> result;
for (size_t i = 0; i < size; ++i)
{
if (data[i] == '%')
{
i++;
if (i < size && data[i] == '%')
continue;
// we just ignore all characters (including flags/precision) up until first alphabetic character
while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*')))
i++;
if (i == size)
break;
if (data[i] == 'q' || data[i] == 's')
result.push_back(builtinTypes->stringType);
else if (data[i] == '*')
result.push_back(builtinTypes->unknownType);
else if (strchr(options, data[i]))
result.push_back(builtinTypes->numberType);
else
result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType));
}
}
return result;
}
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* fmt = nullptr;
if (auto index = expr.func->as<AstExprIndexName>(); index && expr.self)
{
if (auto group = index->expr->as<AstExprGroup>())
fmt = group->expr->as<AstExprConstantString>();
else
fmt = index->expr->as<AstExprConstantString>();
}
if (!expr.self && expr.args.size > 0)
fmt = expr.args.data[0]->as<AstExprConstantString>();
if (!fmt)
return std::nullopt;
std::vector<TypeId> expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(paramPack);
size_t paramOffset = 1;
size_t dataOffset = expr.self ? 0 : 1;
// unify the prefix one argument at a time
for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i)
{
Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location;
typechecker.unify(params[i + paramOffset], expected[i], scope, location);
}
// if we know the argument count or if we have too many arguments for sure, we can issue an error
size_t numActualParams = params.size();
size_t numExpectedParams = expected.size() + 1; // + 1 for the format string
if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams))
typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}});
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
}
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
{
TypeArena* arena = context.solver->arena;
AstExprConstantString* fmt = nullptr;
if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self)
{
if (auto group = index->expr->as<AstExprGroup>())
fmt = group->expr->as<AstExprConstantString>();
else
fmt = index->expr->as<AstExprConstantString>();
}
if (!context.callSite->self && context.callSite->args.size > 0)
fmt = context.callSite->args.data[0]->as<AstExprConstantString>();
if (!fmt)
return false;
std::vector<TypeId> expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(context.arguments);
size_t paramOffset = 1;
// unify the prefix one argument at a time
for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i)
{
context.solver->unify(context.solver->rootScope, context.callSite->location, params[i + paramOffset], expected[i]);
}
// if we know the argument count or if we have too many arguments for sure, we can issue an error
size_t numActualParams = params.size();
size_t numExpectedParams = expected.size() + 1; // + 1 for the format string
if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams))
context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}});
TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType});
asMutable(context.result)->ty.emplace<BoundTypePack>(resultPack);
return true;
}
static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
{
std::vector<TypeId> result;
int depth = 0;
bool parsingSet = false;
for (size_t i = 0; i < size; ++i)
{
if (data[i] == '%')
{
++i;
if (!parsingSet && i < size && data[i] == 'b')
i += 2;
}
else if (!parsingSet && data[i] == '[')
{
parsingSet = true;
if (i + 1 < size && data[i + 1] == ']')
i += 1;
}
else if (parsingSet && data[i] == ']')
{
parsingSet = false;
}
else if (data[i] == '(')
{
if (parsingSet)
continue;
if (i + 1 < size && data[i + 1] == ')')
{
i++;
result.push_back(builtinTypes->optionalNumberType);
continue;
}
++depth;
result.push_back(builtinTypes->optionalStringType);
}
else if (data[i] == ')')
{
if (parsingSet)
continue;
--depth;
if (depth < 0)
break;
}
}
if (depth != 0 || parsingSet)
return std::vector<TypeId>();
if (result.empty())
result.push_back(builtinTypes->optionalStringType);
return result;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() != 2)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t index = expr.self ? 0 : 1;
if (expr.args.size > index)
pattern = expr.args.data[index]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypePackId emptyPack = arena.addTypePack({});
const TypePackId returnList = arena.addTypePack(returnTypes);
const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList});
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
}
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() != 2)
return false;
TypeArena* arena = context.solver->arena;
AstExprConstantString* pattern = nullptr;
size_t index = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > index)
pattern = context.callSite->args.data[index]->as<AstExprConstantString>();
if (!pattern)
return false;
std::vector<TypeId> returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType);
const TypePackId emptyPack = arena->addTypePack({});
const TypePackId returnList = arena->addTypePack(returnTypes);
const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList});
const TypePackId resTypePack = arena->addTypePack({iteratorType});
asMutable(context.result)->ty.emplace<BoundTypePack>(resTypePack);
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 3)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() == 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() < 2 || params.size() > 3)
return false;
TypeArena* arena = context.solver->arena;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > patternIndex)
pattern = context.callSite->args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return false;
std::vector<TypeId> returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType);
const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}});
size_t initIndex = context.callSite->self ? 1 : 2;
if (params.size() == 3 && context.callSite->args.size > initIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber);
const TypePackId returnList = arena->addTypePack(returnTypes);
asMutable(context.result)->ty.emplace<BoundTypePack>(returnList);
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 4)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
bool plain = false;
size_t plainIndex = expr.self ? 2 : 3;
if (expr.args.size > plainIndex)
{
AstExprConstantBool* p = expr.args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
}
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}});
const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() >= 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
if (params.size() == 4 && expr.args.size > plainIndex)
typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionFind(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() < 2 || params.size() > 4)
return false;
TypeArena* arena = context.solver->arena;
NotNull<BuiltinTypes> builtinTypes = context.solver->builtinTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > patternIndex)
pattern = context.callSite->args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return false;
bool plain = false;
size_t plainIndex = context.callSite->self ? 2 : 3;
if (context.callSite->args.size > plainIndex)
{
AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
}
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], builtinTypes->stringType);
const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}});
const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}});
size_t initIndex = context.callSite->self ? 1 : 2;
if (params.size() >= 3 && context.callSite->args.size > initIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber);
if (params.size() == 4 && context.callSite->args.size > plainIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[3], optionalBoolean);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena->addTypePack(returnTypes);
asMutable(context.result)->ty.emplace<BoundTypePack>(returnList);
return true;
}
TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes)
{
NotNull<TypeArena> arena{builtinTypes->arena.get()};
const TypeId nilType = builtinTypes->nilType;
const TypeId numberType = builtinTypes->numberType;
const TypeId booleanType = builtinTypes->booleanType;
const TypeId stringType = builtinTypes->stringType;
const TypeId anyType = builtinTypes->anyType;
const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}});
const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}});
const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}});
const TypePackId oneStringPack = arena->addTypePack({stringType});
const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true});
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
const TypePackId emptyPack = arena->addTypePack({});
const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}});
const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}});
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType});
const TypeId replArgType =
arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}});
const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType});
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})});
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
const TypeId matchFunc = arena->addType(
FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})});
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})});
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
TableType::Props stringLib = {
{"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}},
{"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, anyTypePack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"unpack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
anyTypePack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
if (TableType* ttv = getMutable<TableType>(tableType))
ttv->name = "typeof(string)";
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
static std::optional<WithPredicate<TypePackId>> magicFunctionSelect( static std::optional<WithPredicate<TypePackId>> magicFunctionSelect(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate) TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{ {

View file

@ -1,8 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/NotNull.h"
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/Type.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/Unifiable.h" #include "Luau/Unifiable.h"
@ -13,12 +15,416 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300)
LUAU_FASTFLAGVARIABLE(LuauCloneCyclicUnions, false) LUAU_FASTFLAGVARIABLE(LuauCloneCyclicUnions, false)
LUAU_FASTFLAGVARIABLE(LuauStacklessTypeClone2, false)
LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000)
namespace Luau namespace Luau
{ {
namespace namespace
{ {
using Kind = Variant<TypeId, TypePackId>;
template<typename T>
const T* get(const Kind& kind)
{
return get_if<T>(&kind);
}
class TypeCloner2
{
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
// A queue of kinds where we cloned it, but whose interior types hasn't
// been updated to point to new clones. Once all of its interior types
// has been updated, it gets removed from the queue.
std::vector<Kind> queue;
NotNull<SeenTypes> types;
NotNull<SeenTypePacks> packs;
int steps = 0;
public:
TypeCloner2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, NotNull<SeenTypes> types, NotNull<SeenTypePacks> packs)
: arena(arena)
, builtinTypes(builtinTypes)
, types(types)
, packs(packs)
{
}
TypeId clone(TypeId ty)
{
shallowClone(ty);
run();
if (hasExceededIterationLimit())
{
TypeId error = builtinTypes->errorRecoveryType();
(*types)[ty] = error;
return error;
}
return find(ty).value_or(builtinTypes->errorRecoveryType());
}
TypePackId clone(TypePackId tp)
{
shallowClone(tp);
run();
if (hasExceededIterationLimit())
{
TypePackId error = builtinTypes->errorRecoveryTypePack();
(*packs)[tp] = error;
return error;
}
return find(tp).value_or(builtinTypes->errorRecoveryTypePack());
}
private:
bool hasExceededIterationLimit() const
{
if (FInt::LuauTypeCloneIterationLimit == 0)
return false;
return steps + queue.size() >= size_t(FInt::LuauTypeCloneIterationLimit);
}
void run()
{
while (!queue.empty())
{
++steps;
if (hasExceededIterationLimit())
break;
Kind kind = queue.back();
queue.pop_back();
if (find(kind))
continue;
cloneChildren(kind);
}
}
std::optional<TypeId> find(TypeId ty) const
{
ty = follow(ty, FollowOption::DisableLazyTypeThunks);
if (auto it = types->find(ty); it != types->end())
return it->second;
return std::nullopt;
}
std::optional<TypePackId> find(TypePackId tp) const
{
tp = follow(tp);
if (auto it = packs->find(tp); it != packs->end())
return it->second;
return std::nullopt;
}
std::optional<Kind> find(Kind kind) const
{
if (auto ty = get<TypeId>(kind))
return find(*ty);
else if (auto tp = get<TypePackId>(kind))
return find(*tp);
else
{
LUAU_ASSERT(!"Unknown kind?");
return std::nullopt;
}
}
private:
TypeId shallowClone(TypeId ty)
{
// We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s.
ty = follow(ty, FollowOption::DisableLazyTypeThunks);
if (auto clone = find(ty))
return *clone;
else if (ty->persistent)
return ty;
TypeId target = arena->addType(ty->ty);
asMutable(target)->documentationSymbol = ty->documentationSymbol;
(*types)[ty] = target;
queue.push_back(target);
return target;
}
TypePackId shallowClone(TypePackId tp)
{
tp = follow(tp);
if (auto clone = find(tp))
return *clone;
else if (tp->persistent)
return tp;
TypePackId target = arena->addTypePack(tp->ty);
(*packs)[tp] = target;
queue.push_back(target);
return target;
}
Property shallowClone(const Property& p)
{
if (FFlag::DebugLuauReadWriteProperties)
{
std::optional<TypeId> cloneReadTy;
if (auto ty = p.readType())
cloneReadTy = shallowClone(*ty);
std::optional<TypeId> cloneWriteTy;
if (auto ty = p.writeType())
cloneWriteTy = shallowClone(*ty);
std::optional<Property> cloned = Property::create(cloneReadTy, cloneWriteTy);
LUAU_ASSERT(cloned);
cloned->deprecated = p.deprecated;
cloned->deprecatedSuggestion = p.deprecatedSuggestion;
cloned->location = p.location;
cloned->tags = p.tags;
cloned->documentationSymbol = p.documentationSymbol;
return *cloned;
}
else
{
return Property{
shallowClone(p.type()),
p.deprecated,
p.deprecatedSuggestion,
p.location,
p.tags,
p.documentationSymbol,
};
}
}
void cloneChildren(TypeId ty)
{
return visit(
[&](auto&& t) {
return cloneChildren(&t);
},
asMutable(ty)->ty);
}
void cloneChildren(TypePackId tp)
{
return visit(
[&](auto&& t) {
return cloneChildren(&t);
},
asMutable(tp)->ty);
}
void cloneChildren(Kind kind)
{
if (auto ty = get<TypeId>(kind))
return cloneChildren(*ty);
else if (auto tp = get<TypePackId>(kind))
return cloneChildren(*tp);
else
LUAU_ASSERT(!"Item holds neither TypeId nor TypePackId when enqueuing its children?");
}
// ErrorType and ErrorTypePack is an alias to this type.
void cloneChildren(Unifiable::Error* t)
{
// noop.
}
void cloneChildren(BoundType* t)
{
t->boundTo = shallowClone(t->boundTo);
}
void cloneChildren(FreeType* t)
{
// TODO: clone lower and upper bounds.
// TODO: In the new solver, we should ice.
}
void cloneChildren(GenericType* t)
{
// TOOD: clone upper bounds.
}
void cloneChildren(PrimitiveType* t)
{
// noop.
}
void cloneChildren(BlockedType* t)
{
// TODO: In the new solver, we should ice.
}
void cloneChildren(PendingExpansionType* t)
{
// TODO: In the new solver, we should ice.
}
void cloneChildren(SingletonType* t)
{
// noop.
}
void cloneChildren(FunctionType* t)
{
for (TypeId& g : t->generics)
g = shallowClone(g);
for (TypePackId& gp : t->genericPacks)
gp = shallowClone(gp);
t->argTypes = shallowClone(t->argTypes);
t->retTypes = shallowClone(t->retTypes);
}
void cloneChildren(TableType* t)
{
if (t->indexer)
{
t->indexer->indexType = shallowClone(t->indexer->indexType);
t->indexer->indexResultType = shallowClone(t->indexer->indexResultType);
}
for (auto& [_, p] : t->props)
p = shallowClone(p);
for (TypeId& ty : t->instantiatedTypeParams)
ty = shallowClone(ty);
for (TypePackId& tp : t->instantiatedTypePackParams)
tp = shallowClone(tp);
}
void cloneChildren(MetatableType* t)
{
t->table = shallowClone(t->table);
t->metatable = shallowClone(t->metatable);
}
void cloneChildren(ClassType* t)
{
for (auto& [_, p] : t->props)
p = shallowClone(p);
if (t->parent)
t->parent = shallowClone(*t->parent);
if (t->metatable)
t->metatable = shallowClone(*t->metatable);
if (t->indexer)
{
t->indexer->indexType = shallowClone(t->indexer->indexType);
t->indexer->indexResultType = shallowClone(t->indexer->indexResultType);
}
}
void cloneChildren(AnyType* t)
{
// noop.
}
void cloneChildren(UnionType* t)
{
for (TypeId& ty : t->options)
ty = shallowClone(ty);
}
void cloneChildren(IntersectionType* t)
{
for (TypeId& ty : t->parts)
ty = shallowClone(ty);
}
void cloneChildren(LazyType* t)
{
if (auto unwrapped = t->unwrapped.load())
t->unwrapped.store(shallowClone(unwrapped));
}
void cloneChildren(UnknownType* t)
{
// noop.
}
void cloneChildren(NeverType* t)
{
// noop.
}
void cloneChildren(NegationType* t)
{
t->ty = shallowClone(t->ty);
}
void cloneChildren(TypeFamilyInstanceType* t)
{
// TODO: In the new solver, we should ice.
}
void cloneChildren(FreeTypePack* t)
{
// TODO: clone lower and upper bounds.
// TODO: In the new solver, we should ice.
}
void cloneChildren(GenericTypePack* t)
{
// TOOD: clone upper bounds.
}
void cloneChildren(BlockedTypePack* t)
{
// TODO: In the new solver, we should ice.
}
void cloneChildren(BoundTypePack* t)
{
t->boundTo = shallowClone(t->boundTo);
}
void cloneChildren(VariadicTypePack* t)
{
t->ty = shallowClone(t->ty);
}
void cloneChildren(TypePack* t)
{
for (TypeId& ty : t->head)
ty = shallowClone(ty);
if (t->tail)
t->tail = shallowClone(*t->tail);
}
void cloneChildren(TypeFamilyInstanceTypePack* t)
{
// TODO: In the new solver, we should ice.
}
};
} // namespace
namespace
{
Property clone(const Property& prop, TypeArena& dest, CloneState& cloneState) Property clone(const Property& prop, TypeArena& dest, CloneState& cloneState)
{ {
if (FFlag::DebugLuauReadWriteProperties) if (FFlag::DebugLuauReadWriteProperties)
@ -470,6 +876,13 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
if (tp->persistent) if (tp->persistent)
return tp; return tp;
if (FFlag::LuauStacklessTypeClone2)
{
TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
return cloner.clone(tp);
}
else
{
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
TypePackId& res = cloneState.seenTypePacks[tp]; TypePackId& res = cloneState.seenTypePacks[tp];
@ -482,12 +895,20 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState)
return res; return res;
} }
}
TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
{ {
if (typeId->persistent) if (typeId->persistent)
return typeId; return typeId;
if (FFlag::LuauStacklessTypeClone2)
{
TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
return cloner.clone(typeId);
}
else
{
RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit);
TypeId& res = cloneState.seenTypes[typeId]; TypeId& res = cloneState.seenTypes[typeId];
@ -506,8 +927,37 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState)
return res; return res;
} }
}
TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
{
if (FFlag::LuauStacklessTypeClone2)
{
TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}};
TypeFun copy = typeFun;
for (auto& param : copy.typeParams)
{
param.ty = cloner.clone(param.ty);
if (param.defaultValue)
param.defaultValue = cloner.clone(*param.defaultValue);
}
for (auto& param : copy.typePackParams)
{
param.tp = cloner.clone(param.tp);
if (param.defaultValue)
param.defaultValue = cloner.clone(*param.defaultValue);
}
copy.type = cloner.clone(copy.type);
return copy;
}
else
{ {
TypeFun result; TypeFun result;
@ -537,5 +987,6 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState)
return result; return result;
} }
}
} // namespace Luau } // namespace Luau

View file

@ -25,6 +25,7 @@ LUAU_FASTFLAG(DebugLuauLogSolverToJson);
LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(DebugLuauMagicTypes);
LUAU_FASTFLAG(LuauParseDeclareClassIndexer); LUAU_FASTFLAG(LuauParseDeclareClassIndexer);
LUAU_FASTFLAG(LuauLoopControlFlowAnalysis); LUAU_FASTFLAG(LuauLoopControlFlowAnalysis);
LUAU_FASTFLAG(LuauFloorDivision);
namespace Luau namespace Luau
{ {
@ -1188,7 +1189,8 @@ static bool isMetamethod(const Name& name)
{ {
return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" ||
name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" ||
name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len" ||
(FFlag::LuauFloorDivision && name == "__idiv");
} }
ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass)

View file

@ -22,6 +22,7 @@
#include "Luau/VisitType.h" #include "Luau/VisitType.h"
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false);
LUAU_FASTFLAG(LuauFloorDivision);
namespace Luau namespace Luau
{ {
@ -429,6 +430,35 @@ bool ConstraintSolver::isDone()
return unsolvedConstraints.empty(); return unsolvedConstraints.empty();
} }
namespace
{
struct TypeAndLocation
{
TypeId typeId;
Location location;
};
struct FreeTypeSearcher : TypeOnceVisitor
{
std::deque<TypeAndLocation>* result;
Location location;
FreeTypeSearcher(std::deque<TypeAndLocation>* result, Location location)
: result(result)
, location(location)
{
}
bool visit(TypeId ty, const FreeType&) override
{
result->push_back({ty, location});
return false;
}
};
} // namespace
void ConstraintSolver::finalizeModule() void ConstraintSolver::finalizeModule()
{ {
Anyification a{arena, rootScope, builtinTypes, &iceReporter, builtinTypes->anyType, builtinTypes->anyTypePack}; Anyification a{arena, rootScope, builtinTypes, &iceReporter, builtinTypes->anyType, builtinTypes->anyTypePack};
@ -445,12 +475,28 @@ void ConstraintSolver::finalizeModule()
Unifier2 u2{NotNull{arena}, builtinTypes, NotNull{&iceReporter}}; Unifier2 u2{NotNull{arena}, builtinTypes, NotNull{&iceReporter}};
std::deque<TypeAndLocation> queue;
for (auto& [name, binding] : rootScope->bindings) for (auto& [name, binding] : rootScope->bindings)
queue.push_back({binding.typeId, binding.location});
DenseHashSet<TypeId> seen{nullptr};
while (!queue.empty())
{ {
auto generalizedTy = u2.generalize(rootScope, binding.typeId); TypeAndLocation binding = queue.front();
if (generalizedTy) queue.pop_front();
binding.typeId = *generalizedTy;
else TypeId ty = follow(binding.typeId);
if (seen.find(ty))
continue;
seen.insert(ty);
FreeTypeSearcher fts{&queue, binding.location};
fts.traverse(ty);
auto result = u2.generalize(rootScope, ty);
if (!result)
reportError(CodeTooComplex{}, binding.location); reportError(CodeTooComplex{}, binding.location);
} }
} }
@ -719,6 +765,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull<const Cons
// Metatables go first, even if there is primitive behavior. // Metatables go first, even if there is primitive behavior.
if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end()) if (auto it = kBinaryOpMetamethods.find(c.op); it != kBinaryOpMetamethods.end())
{ {
LUAU_ASSERT(FFlag::LuauFloorDivision || c.op != AstExprBinary::Op::FloorDiv);
// Metatables are not the same. The metamethod will not be invoked. // Metatables are not the same. The metamethod will not be invoked.
if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) && if ((c.op == AstExprBinary::Op::CompareEq || c.op == AstExprBinary::Op::CompareNe) &&
getMetatable(leftType, builtinTypes) != getMetatable(rightType, builtinTypes)) getMetatable(leftType, builtinTypes) != getMetatable(rightType, builtinTypes))
@ -806,9 +854,12 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull<const Cons
case AstExprBinary::Op::Sub: case AstExprBinary::Op::Sub:
case AstExprBinary::Op::Mul: case AstExprBinary::Op::Mul:
case AstExprBinary::Op::Div: case AstExprBinary::Op::Div:
case AstExprBinary::Op::FloorDiv:
case AstExprBinary::Op::Pow: case AstExprBinary::Op::Pow:
case AstExprBinary::Op::Mod: case AstExprBinary::Op::Mod:
{ {
LUAU_ASSERT(FFlag::LuauFloorDivision || c.op != AstExprBinary::Op::FloorDiv);
const NormalizedType* normLeftTy = normalizer->normalize(leftType); const NormalizedType* normLeftTy = normalizer->normalize(leftType);
if (hasTypeInIntersection<FreeType>(leftType) && force) if (hasTypeInIntersection<FreeType>(leftType) && force)
asMutable(leftType)->ty.emplace<BoundType>(anyPresent ? builtinTypes->anyType : builtinTypes->numberType); asMutable(leftType)->ty.emplace<BoundType>(anyPresent ? builtinTypes->anyType : builtinTypes->numberType);
@ -2636,20 +2687,14 @@ ErrorVec ConstraintSolver::unify(NotNull<Scope> scope, Location location, TypeId
ErrorVec ConstraintSolver::unify(NotNull<Scope> scope, Location location, TypePackId subPack, TypePackId superPack) ErrorVec ConstraintSolver::unify(NotNull<Scope> scope, Location location, TypePackId subPack, TypePackId superPack)
{ {
UnifierSharedState sharedState{&iceReporter}; Unifier2 u{arena, builtinTypes, NotNull{&iceReporter}};
Unifier u{normalizer, scope, Location{}, Covariant};
u.enableNewSolver();
u.tryUnify(subPack, superPack); u.unify(subPack, superPack);
const auto [changedTypes, changedPacks] = u.log.getChanges(); unblock(subPack, Location{});
unblock(superPack, Location{});
u.log.commit(); return {};
unblock(changedTypes, Location{});
unblock(changedPacks, Location{});
return std::move(u.errors);
} }
NotNull<Constraint> ConstraintSolver::pushConstraint(NotNull<Scope> scope, const Location& location, ConstraintV cv) NotNull<Constraint> ConstraintSolver::pushConstraint(NotNull<Scope> scope, const Location& location, ConstraintV cv)
@ -2727,41 +2772,6 @@ TypePackId ConstraintSolver::errorRecoveryTypePack() const
return builtinTypes->errorRecoveryTypePack(); return builtinTypes->errorRecoveryTypePack();
} }
TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull<Scope> scope, bool unifyFreeTypes)
{
a = follow(a);
b = follow(b);
if (unifyFreeTypes && (get<FreeType>(a) || get<FreeType>(b)))
{
Unifier u{normalizer, scope, Location{}, Covariant};
u.enableNewSolver();
u.tryUnify(b, a);
if (u.errors.empty())
{
u.log.commit();
return a;
}
else
{
return builtinTypes->errorRecoveryType(builtinTypes->anyType);
}
}
if (*a == *b)
return a;
std::vector<TypeId> types = reduceUnion({a, b});
if (types.empty())
return builtinTypes->neverType;
if (types.size() == 1)
return types[0];
return arena->addType(UnionType{types});
}
TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp) TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp)
{ {
tp = follow(tp); tp = follow(tp);

View file

@ -107,15 +107,17 @@ std::string DiffPath::toString(bool prependDot) const
} }
return pathStr; return pathStr;
} }
std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf) const std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf, bool multiLine) const
{ {
std::string conditionalNewline = multiLine ? "\n" : " ";
std::string conditionalIndent = multiLine ? " " : "";
std::string pathStr{rootName + diffPath.toString(true)}; std::string pathStr{rootName + diffPath.toString(true)};
switch (kind) switch (kind)
{ {
case DiffError::Kind::Normal: case DiffError::Kind::Normal:
{ {
checkNonMissingPropertyLeavesHaveNulloptTableProperty(); checkNonMissingPropertyLeavesHaveNulloptTableProperty();
return pathStr + " has type " + Luau::toString(*leaf.ty); return pathStr + conditionalNewline + "has type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty);
} }
case DiffError::Kind::MissingTableProperty: case DiffError::Kind::MissingTableProperty:
{ {
@ -123,13 +125,14 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea
{ {
if (!leaf.tableProperty.has_value()) if (!leaf.tableProperty.has_value())
throw InternalCompilerError{"leaf.tableProperty is nullopt"}; throw InternalCompilerError{"leaf.tableProperty is nullopt"};
return pathStr + "." + *leaf.tableProperty + " has type " + Luau::toString(*leaf.ty); return pathStr + "." + *leaf.tableProperty + conditionalNewline + "has type" + conditionalNewline + conditionalIndent +
Luau::toString(*leaf.ty);
} }
else if (otherLeaf.ty.has_value()) else if (otherLeaf.ty.has_value())
{ {
if (!otherLeaf.tableProperty.has_value()) if (!otherLeaf.tableProperty.has_value())
throw InternalCompilerError{"otherLeaf.tableProperty is nullopt"}; throw InternalCompilerError{"otherLeaf.tableProperty is nullopt"};
return pathStr + " is missing the property " + *otherLeaf.tableProperty; return pathStr + conditionalNewline + "is missing the property" + conditionalNewline + conditionalIndent + *otherLeaf.tableProperty;
} }
throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"};
} }
@ -140,11 +143,11 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea
{ {
if (!leaf.unionIndex.has_value()) if (!leaf.unionIndex.has_value())
throw InternalCompilerError{"leaf.unionIndex is nullopt"}; throw InternalCompilerError{"leaf.unionIndex is nullopt"};
return pathStr + " is a union containing type " + Luau::toString(*leaf.ty); return pathStr + conditionalNewline + "is a union containing type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty);
} }
else if (otherLeaf.ty.has_value()) else if (otherLeaf.ty.has_value())
{ {
return pathStr + " is a union missing type " + Luau::toString(*otherLeaf.ty); return pathStr + conditionalNewline + "is a union missing type" + conditionalNewline + conditionalIndent + Luau::toString(*otherLeaf.ty);
} }
throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"};
} }
@ -157,11 +160,13 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea
{ {
if (!leaf.unionIndex.has_value()) if (!leaf.unionIndex.has_value())
throw InternalCompilerError{"leaf.unionIndex is nullopt"}; throw InternalCompilerError{"leaf.unionIndex is nullopt"};
return pathStr + " is an intersection containing type " + Luau::toString(*leaf.ty); return pathStr + conditionalNewline + "is an intersection containing type" + conditionalNewline + conditionalIndent +
Luau::toString(*leaf.ty);
} }
else if (otherLeaf.ty.has_value()) else if (otherLeaf.ty.has_value())
{ {
return pathStr + " is an intersection missing type " + Luau::toString(*otherLeaf.ty); return pathStr + conditionalNewline + "is an intersection missing type" + conditionalNewline + conditionalIndent +
Luau::toString(*otherLeaf.ty);
} }
throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"};
} }
@ -169,13 +174,13 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea
{ {
if (!leaf.minLength.has_value()) if (!leaf.minLength.has_value())
throw InternalCompilerError{"leaf.minLength is nullopt"}; throw InternalCompilerError{"leaf.minLength is nullopt"};
return pathStr + " takes " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " arguments"; return pathStr + conditionalNewline + "takes " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " arguments";
} }
case DiffError::Kind::LengthMismatchInFnRets: case DiffError::Kind::LengthMismatchInFnRets:
{ {
if (!leaf.minLength.has_value()) if (!leaf.minLength.has_value())
throw InternalCompilerError{"leaf.minLength is nullopt"}; throw InternalCompilerError{"leaf.minLength is nullopt"};
return pathStr + " returns " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " values"; return pathStr + conditionalNewline + "returns " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " values";
} }
default: default:
{ {
@ -190,8 +195,11 @@ void DiffError::checkNonMissingPropertyLeavesHaveNulloptTableProperty() const
throw InternalCompilerError{"Non-MissingProperty DiffError should have nullopt tableProperty in both leaves"}; throw InternalCompilerError{"Non-MissingProperty DiffError should have nullopt tableProperty in both leaves"};
} }
std::string getDevFixFriendlyName(TypeId ty) std::string getDevFixFriendlyName(const std::optional<std::string>& maybeSymbol, TypeId ty)
{ {
if (maybeSymbol.has_value())
return *maybeSymbol;
if (auto table = get<TableType>(ty)) if (auto table = get<TableType>(ty))
{ {
if (table->name.has_value()) if (table->name.has_value())
@ -206,27 +214,37 @@ std::string getDevFixFriendlyName(TypeId ty)
return *metatable->syntheticName; return *metatable->syntheticName;
} }
} }
// else if (auto primitive = get<PrimitiveType>(ty))
//{
// return "<unlabeled-symbol>";
//}
return "<unlabeled-symbol>"; return "<unlabeled-symbol>";
} }
std::string DiffError::toString() const std::string DifferEnvironment::getDevFixFriendlyNameLeft() const
{ {
return getDevFixFriendlyName(externalSymbolLeft, rootLeft);
}
std::string DifferEnvironment::getDevFixFriendlyNameRight() const
{
return getDevFixFriendlyName(externalSymbolRight, rootRight);
}
std::string DiffError::toString(bool multiLine) const
{
std::string conditionalNewline = multiLine ? "\n" : " ";
std::string conditionalIndent = multiLine ? " " : "";
switch (kind) switch (kind)
{ {
case DiffError::Kind::IncompatibleGeneric: case DiffError::Kind::IncompatibleGeneric:
{ {
std::string diffPathStr{diffPath.toString(true)}; std::string diffPathStr{diffPath.toString(true)};
return "DiffError: these two types are not equal because the left generic at " + leftRootName + diffPathStr + return "DiffError: these two types are not equal because the left generic at" + conditionalNewline + conditionalIndent + leftRootName +
" cannot be the same type parameter as the right generic at " + rightRootName + diffPathStr; diffPathStr + conditionalNewline + "cannot be the same type parameter as the right generic at" + conditionalNewline +
conditionalIndent + rightRootName + diffPathStr;
} }
default: default:
{ {
return "DiffError: these two types are not equal because the left type at " + toStringALeaf(leftRootName, left, right) + return "DiffError: these two types are not equal because the left type at" + conditionalNewline + conditionalIndent +
", while the right type at " + toStringALeaf(rightRootName, right, left); toStringALeaf(leftRootName, left, right, multiLine) + "," + conditionalNewline + "while the right type at" + conditionalNewline +
conditionalIndent + toStringALeaf(rightRootName, right, left, multiLine);
} }
} }
} }
@ -296,8 +314,8 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right)
DiffError::Kind::MissingTableProperty, DiffError::Kind::MissingTableProperty,
DiffPathNodeLeaf::detailsTableProperty(value.type(), field), DiffPathNodeLeaf::detailsTableProperty(value.type(), field),
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
} }
@ -307,8 +325,7 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right)
{ {
// right has a field the left doesn't // right has a field the left doesn't
return DifferResult{DiffError{DiffError::Kind::MissingTableProperty, DiffPathNodeLeaf::nullopts(), return DifferResult{DiffError{DiffError::Kind::MissingTableProperty, DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::detailsTableProperty(value.type(), field), getDevFixFriendlyName(env.rootLeft), DiffPathNodeLeaf::detailsTableProperty(value.type(), field), env.getDevFixFriendlyNameLeft(), env.getDevFixFriendlyNameRight()}};
getDevFixFriendlyName(env.rootRight)}};
} }
} }
// left and right have the same set of keys // left and right have the same set of keys
@ -360,8 +377,8 @@ static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId ri
DiffError::Kind::Normal, DiffError::Kind::Normal,
DiffPathNodeLeaf::detailsNormal(left), DiffPathNodeLeaf::detailsNormal(left),
DiffPathNodeLeaf::detailsNormal(right), DiffPathNodeLeaf::detailsNormal(right),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
return DifferResult{}; return DifferResult{};
@ -380,8 +397,8 @@ static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId ri
DiffError::Kind::Normal, DiffError::Kind::Normal,
DiffPathNodeLeaf::detailsNormal(left), DiffPathNodeLeaf::detailsNormal(left),
DiffPathNodeLeaf::detailsNormal(right), DiffPathNodeLeaf::detailsNormal(right),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
return DifferResult{}; return DifferResult{};
@ -419,8 +436,8 @@ static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId righ
DiffError::Kind::IncompatibleGeneric, DiffError::Kind::IncompatibleGeneric,
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
@ -432,8 +449,8 @@ static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId righ
DiffError::Kind::IncompatibleGeneric, DiffError::Kind::IncompatibleGeneric,
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
@ -468,8 +485,8 @@ static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right)
DiffError::Kind::Normal, DiffError::Kind::Normal,
DiffPathNodeLeaf::detailsNormal(left), DiffPathNodeLeaf::detailsNormal(left),
DiffPathNodeLeaf::detailsNormal(right), DiffPathNodeLeaf::detailsNormal(right),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
@ -521,16 +538,16 @@ static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right)
DiffError::Kind::MissingUnionMember, DiffError::Kind::MissingUnionMember,
DiffPathNodeLeaf::detailsUnionIndex(leftUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), DiffPathNodeLeaf::detailsUnionIndex(leftUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx),
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
else else
return DifferResult{DiffError{ return DifferResult{DiffError{
DiffError::Kind::MissingUnionMember, DiffError::Kind::MissingUnionMember,
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::detailsUnionIndex(rightUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), DiffPathNodeLeaf::detailsUnionIndex(rightUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
@ -554,16 +571,16 @@ static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId
DiffError::Kind::MissingIntersectionMember, DiffError::Kind::MissingIntersectionMember,
DiffPathNodeLeaf::detailsUnionIndex(leftIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), DiffPathNodeLeaf::detailsUnionIndex(leftIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx),
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
else else
return DifferResult{DiffError{ return DifferResult{DiffError{
DiffError::Kind::MissingIntersectionMember, DiffError::Kind::MissingIntersectionMember,
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::detailsUnionIndex(rightIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), DiffPathNodeLeaf::detailsUnionIndex(rightIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
@ -583,8 +600,8 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig
DiffError::Kind::Normal, DiffError::Kind::Normal,
DiffPathNodeLeaf::detailsNormal(left), DiffPathNodeLeaf::detailsNormal(left),
DiffPathNodeLeaf::detailsNormal(right), DiffPathNodeLeaf::detailsNormal(right),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
@ -753,8 +770,8 @@ static DifferResult diffCanonicalTpShape(DifferEnvironment& env, DiffError::Kind
possibleNonNormalErrorKind, possibleNonNormalErrorKind,
DiffPathNodeLeaf::detailsLength(int(left.first.size()), left.second.has_value()), DiffPathNodeLeaf::detailsLength(int(left.first.size()), left.second.has_value()),
DiffPathNodeLeaf::detailsLength(int(right.first.size()), right.second.has_value()), DiffPathNodeLeaf::detailsLength(int(right.first.size()), right.second.has_value()),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
@ -769,8 +786,8 @@ static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::K
DiffError::Kind::Normal, DiffError::Kind::Normal,
DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->first), DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->first),
DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->second), DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->second),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
@ -847,8 +864,8 @@ static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypeP
DiffError::Kind::IncompatibleGeneric, DiffError::Kind::IncompatibleGeneric,
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
@ -860,8 +877,8 @@ static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypeP
DiffError::Kind::IncompatibleGeneric, DiffError::Kind::IncompatibleGeneric,
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(),
getDevFixFriendlyName(env.rootLeft), env.getDevFixFriendlyNameLeft(),
getDevFixFriendlyName(env.rootRight), env.getDevFixFriendlyNameRight(),
}}; }};
} }
@ -910,7 +927,13 @@ std::vector<std::pair<TypeId, TypeId>>::const_reverse_iterator DifferEnvironment
DifferResult diff(TypeId ty1, TypeId ty2) DifferResult diff(TypeId ty1, TypeId ty2)
{ {
DifferEnvironment differEnv{ty1, ty2}; DifferEnvironment differEnv{ty1, ty2, std::nullopt, std::nullopt};
return diffUsingEnv(differEnv, ty1, ty2);
}
DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional<std::string> symbol1, std::optional<std::string> symbol2)
{
DifferEnvironment differEnv{ty1, ty2, symbol1, symbol2};
return diffUsingEnv(differEnv, ty1, ty2); return diffUsingEnv(differEnv, ty1, ty2);
} }

View file

@ -148,8 +148,7 @@ declare coroutine: {
resume: <A..., R...>(co: thread, A...) -> (boolean, R...), resume: <A..., R...>(co: thread, A...) -> (boolean, R...),
running: () -> thread, running: () -> thread,
status: (co: thread) -> "dead" | "running" | "normal" | "suspended", status: (co: thread) -> "dead" | "running" | "normal" | "suspended",
-- FIXME: This technically returns a function, but we can't represent this yet. wrap: <A..., R...>(f: (A...) -> R...) -> ((A...) -> R...),
wrap: <A..., R...>(f: (A...) -> R...) -> any,
yield: <A..., R...>(A...) -> R..., yield: <A..., R...>(A...) -> R...,
isyieldable: () -> boolean, isyieldable: () -> boolean,
close: (co: thread) -> (boolean, any) close: (co: thread) -> (boolean, any)

View file

@ -4,12 +4,15 @@
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/Common.h" #include "Luau/Common.h"
#include "Luau/FileResolver.h" #include "Luau/FileResolver.h"
#include "Luau/NotNull.h"
#include "Luau/StringUtils.h" #include "Luau/StringUtils.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include <optional>
#include <stdexcept> #include <stdexcept>
#include <type_traits> #include <type_traits>
LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10)
static std::string wrongNumberOfArgsString( static std::string wrongNumberOfArgsString(
size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) size_t expectedCount, std::optional<size_t> maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false)
@ -66,6 +69,20 @@ struct ErrorConverter
std::string result; std::string result;
auto quote = [&](std::string s) {
return "'" + s + "'";
};
auto constructErrorMessage = [&](std::string givenType, std::string wantedType, std::optional<std::string> givenModule,
std::optional<std::string> wantedModule) -> std::string {
std::string given = givenModule ? quote(givenType) + " from " + quote(*givenModule) : quote(givenType);
std::string wanted = wantedModule ? quote(wantedType) + " from " + quote(*wantedModule) : quote(wantedType);
size_t luauIndentTypeMismatchMaxTypeLength = size_t(FInt::LuauIndentTypeMismatchMaxTypeLength);
if (givenType.length() <= luauIndentTypeMismatchMaxTypeLength || wantedType.length() <= luauIndentTypeMismatchMaxTypeLength)
return "Type " + given + " could not be converted into " + wanted;
return "Type\n " + given + "\ncould not be converted into\n " + wanted;
};
if (givenTypeName == wantedTypeName) if (givenTypeName == wantedTypeName)
{ {
if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType))
@ -76,20 +93,18 @@ struct ErrorConverter
{ {
std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule);
std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule);
result = "Type '" + givenTypeName + "' from '" + givenModuleName + "' could not be converted into '" + wantedTypeName + result = constructErrorMessage(givenTypeName, wantedTypeName, givenModuleName, wantedModuleName);
"' from '" + wantedModuleName + "'";
} }
else else
{ {
result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + result = constructErrorMessage(givenTypeName, wantedTypeName, *givenDefinitionModule, *wantedDefinitionModule);
"' from '" + *wantedDefinitionModule + "'";
} }
} }
} }
} }
if (result.empty()) if (result.empty())
result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; result = constructErrorMessage(givenTypeName, wantedTypeName, std::nullopt, std::nullopt);
if (tm.error) if (tm.error)
@ -97,7 +112,7 @@ struct ErrorConverter
result += "\ncaused by:\n "; result += "\ncaused by:\n ";
if (!tm.reason.empty()) if (!tm.reason.empty())
result += tm.reason + " "; result += tm.reason + " \n";
result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver});
} }
@ -845,7 +860,7 @@ bool containsParseErrorName(const TypeError& error)
} }
template<typename T> template<typename T>
void copyError(T& e, TypeArena& destArena, CloneState cloneState) void copyError(T& e, TypeArena& destArena, CloneState& cloneState)
{ {
auto clone = [&](auto&& ty) { auto clone = [&](auto&& ty) {
return ::Luau::clone(ty, destArena, cloneState); return ::Luau::clone(ty, destArena, cloneState);
@ -998,9 +1013,9 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState)
static_assert(always_false_v<T>, "Non-exhaustive type switch"); static_assert(always_false_v<T>, "Non-exhaustive type switch");
} }
void copyErrors(ErrorVec& errors, TypeArena& destArena) void copyErrors(ErrorVec& errors, TypeArena& destArena, NotNull<BuiltinTypes> builtinTypes)
{ {
CloneState cloneState; CloneState cloneState{builtinTypes};
auto visitErrorData = [&](auto&& e) { auto visitErrorData = [&](auto&& e) {
copyError(e, destArena, cloneState); copyError(e, destArena, cloneState);

View file

@ -31,11 +31,12 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTINT(LuauTarjanChildLimit)
LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) // TODO: Remove with FFlagLuauTypecheckLimitControls
LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false)
LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false)
LUAU_FASTFLAGVARIABLE(LuauTypecheckCancellation, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckLimitControls, false)
LUAU_FASTFLAGVARIABLE(CorrectEarlyReturnInMarkDirty, false)
namespace Luau namespace Luau
{ {
@ -126,7 +127,7 @@ static ParseResult parseSourceForModule(std::string_view source, Luau::SourceMod
static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, ScopePtr targetScope, const std::string& packageName) static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, ScopePtr targetScope, const std::string& packageName)
{ {
CloneState cloneState; CloneState cloneState{globals.builtinTypes};
std::vector<TypeId> typesToPersist; std::vector<TypeId> typesToPersist;
typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size());
@ -462,7 +463,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional<FrontendOption
checkResult.timeoutHits.push_back(item.name); checkResult.timeoutHits.push_back(item.name);
// If check was manually cancelled, do not return partial results // If check was manually cancelled, do not return partial results
if (FFlag::LuauTypecheckCancellation && item.module->cancelled) if (item.module->cancelled)
return {}; return {};
checkResult.errors.insert(checkResult.errors.end(), item.module->errors.begin(), item.module->errors.end()); checkResult.errors.insert(checkResult.errors.end(), item.module->errors.begin(), item.module->errors.end());
@ -635,7 +636,7 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
if (item.exception) if (item.exception)
itemWithException = i; itemWithException = i;
if (FFlag::LuauTypecheckCancellation && item.module && item.module->cancelled) if (item.module && item.module->cancelled)
cancelled = true; cancelled = true;
if (itemWithException || cancelled) if (itemWithException || cancelled)
@ -677,7 +678,7 @@ std::vector<ModuleName> Frontend::checkQueuedModules(std::optional<FrontendOptio
if (remaining != 0 && processing == 0) if (remaining != 0 && processing == 0)
{ {
// Typechecking might have been cancelled by user, don't return partial results // Typechecking might have been cancelled by user, don't return partial results
if (FFlag::LuauTypecheckCancellation && cancelled) if (cancelled)
return {}; return {};
// We might have stopped because of a pending exception // We might have stopped because of a pending exception
@ -874,6 +875,14 @@ void Frontend::addBuildQueueItems(std::vector<BuildQueueItem>& items, std::vecto
} }
} }
static void applyInternalLimitScaling(SourceNode& sourceNode, const ModulePtr module, double limit)
{
if (module->timeout)
sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
else if (module->checkDurationSec < limit / 2.0)
sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0);
}
void Frontend::checkBuildQueueItem(BuildQueueItem& item) void Frontend::checkBuildQueueItem(BuildQueueItem& item)
{ {
SourceNode& sourceNode = *item.sourceNode; SourceNode& sourceNode = *item.sourceNode;
@ -884,13 +893,42 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
double timestamp = getTimestamp(); double timestamp = getTimestamp();
const std::vector<RequireCycle>& requireCycles = item.requireCycles; const std::vector<RequireCycle>& requireCycles = item.requireCycles;
TypeCheckLimits typeCheckLimits;
if (FFlag::LuauTypecheckLimitControls)
{
if (item.options.moduleTimeLimitSec)
typeCheckLimits.finishTime = TimeTrace::getClock() + *item.options.moduleTimeLimitSec;
else
typeCheckLimits.finishTime = std::nullopt;
// TODO: This is a dirty ad hoc solution for autocomplete timeouts
// We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit
// so that we'll have type information for the whole file at lower quality instead of a full abort in the middle
if (item.options.applyInternalLimitScaling)
{
if (FInt::LuauTarjanChildLimit > 0)
typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckLimits.instantiationChildLimit = std::nullopt;
if (FInt::LuauTypeInferIterationLimit > 0)
typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult));
else
typeCheckLimits.unifierIterationLimit = std::nullopt;
}
typeCheckLimits.cancellationToken = item.options.cancellationToken;
}
if (item.options.forAutocomplete) if (item.options.forAutocomplete)
{ {
double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0;
if (!FFlag::LuauTypecheckLimitControls)
{
// The autocomplete typecheck is always in strict mode with DM awareness // The autocomplete typecheck is always in strict mode with DM awareness
// to provide better type information for IDE features // to provide better type information for IDE features
TypeCheckLimits typeCheckLimits;
if (autocompleteTimeLimit != 0.0) if (autocompleteTimeLimit != 0.0)
typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit;
@ -910,18 +948,29 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
else else
typeCheckLimits.unifierIterationLimit = std::nullopt; typeCheckLimits.unifierIterationLimit = std::nullopt;
if (FFlag::LuauTypecheckCancellation)
typeCheckLimits.cancellationToken = item.options.cancellationToken; typeCheckLimits.cancellationToken = item.options.cancellationToken;
}
// The autocomplete typecheck is always in strict mode with DM awareness to provide better type information for IDE features
ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true,
/*recordJsonLog*/ false, typeCheckLimits); /*recordJsonLog*/ false, typeCheckLimits);
double duration = getTimestamp() - timestamp; double duration = getTimestamp() - timestamp;
if (FFlag::LuauTypecheckLimitControls)
{
moduleForAutocomplete->checkDurationSec = duration;
if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling)
applyInternalLimitScaling(sourceNode, moduleForAutocomplete, *item.options.moduleTimeLimitSec);
}
else
{
if (moduleForAutocomplete->timeout) if (moduleForAutocomplete->timeout)
sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0;
else if (duration < autocompleteTimeLimit / 2.0) else if (duration < autocompleteTimeLimit / 2.0)
sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0);
}
item.stats.timeCheck += duration; item.stats.timeCheck += duration;
item.stats.filesStrict += 1; item.stats.filesStrict += 1;
@ -930,14 +979,29 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
return; return;
} }
TypeCheckLimits typeCheckLimits; if (!FFlag::LuauTypecheckLimitControls)
{
if (FFlag::LuauTypecheckCancellation)
typeCheckLimits.cancellationToken = item.options.cancellationToken; typeCheckLimits.cancellationToken = item.options.cancellationToken;
}
ModulePtr module = check(sourceModule, mode, requireCycles, environmentScope, /*forAutocomplete*/ false, item.recordJsonLog, typeCheckLimits); ModulePtr module = check(sourceModule, mode, requireCycles, environmentScope, /*forAutocomplete*/ false, item.recordJsonLog, typeCheckLimits);
if (FFlag::LuauTypecheckLimitControls)
{
double duration = getTimestamp() - timestamp;
module->checkDurationSec = duration;
if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling)
applyInternalLimitScaling(sourceNode, module, *item.options.moduleTimeLimitSec);
item.stats.timeCheck += duration;
}
else
{
item.stats.timeCheck += getTimestamp() - timestamp; item.stats.timeCheck += getTimestamp() - timestamp;
}
item.stats.filesStrict += mode == Mode::Strict; item.stats.filesStrict += mode == Mode::Strict;
item.stats.filesNonstrict += mode == Mode::Nonstrict; item.stats.filesNonstrict += mode == Mode::Nonstrict;
@ -969,7 +1033,7 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item)
// copyErrors needs to allocate into interfaceTypes as it copies // copyErrors needs to allocate into interfaceTypes as it copies
// types out of internalTypes, so we unfreeze it here. // types out of internalTypes, so we unfreeze it here.
unfreeze(module->interfaceTypes); unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes); copyErrors(module->errors, module->interfaceTypes, builtinTypes);
freeze(module->interfaceTypes); freeze(module->interfaceTypes);
module->internalTypes.clear(); module->internalTypes.clear();
@ -1014,7 +1078,7 @@ void Frontend::checkBuildQueueItems(std::vector<BuildQueueItem>& items)
{ {
checkBuildQueueItem(item); checkBuildQueueItem(item);
if (FFlag::LuauTypecheckCancellation && item.module && item.module->cancelled) if (item.module && item.module->cancelled)
break; break;
recordItemResult(item); recordItemResult(item);
@ -1084,9 +1148,17 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const
* It would be nice for this function to be O(1) * It would be nice for this function to be O(1)
*/ */
void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty) void Frontend::markDirty(const ModuleName& name, std::vector<ModuleName>* markedDirty)
{
if (FFlag::CorrectEarlyReturnInMarkDirty)
{
if (sourceNodes.count(name) == 0)
return;
}
else
{ {
if (!moduleResolver.getModule(name) && !moduleResolverForAutocomplete.getModule(name)) if (!moduleResolver.getModule(name) && !moduleResolverForAutocomplete.getModule(name))
return; return;
}
std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps; std::unordered_map<ModuleName, std::vector<ModuleName>> reverseDeps;
for (const auto& module : sourceNodes) for (const auto& module : sourceNodes)
@ -1217,7 +1289,8 @@ ModulePtr check(const SourceModule& sourceModule, const std::vector<RequireCycle
if (result->timeout || result->cancelled) if (result->timeout || result->cancelled)
{ {
// If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending types // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending
// types
ScopePtr moduleScope = result->getModuleScope(); ScopePtr moduleScope = result->getModuleScope();
moduleScope->returnType = builtinTypes->errorRecoveryTypePack(); moduleScope->returnType = builtinTypes->errorRecoveryTypePack();
@ -1295,8 +1368,6 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect
typeChecker.finishTime = typeCheckLimits.finishTime; typeChecker.finishTime = typeCheckLimits.finishTime;
typeChecker.instantiationChildLimit = typeCheckLimits.instantiationChildLimit; typeChecker.instantiationChildLimit = typeCheckLimits.instantiationChildLimit;
typeChecker.unifierIterationLimit = typeCheckLimits.unifierIterationLimit; typeChecker.unifierIterationLimit = typeCheckLimits.unifierIterationLimit;
if (FFlag::LuauTypecheckCancellation)
typeChecker.cancellationToken = typeCheckLimits.cancellationToken; typeChecker.cancellationToken = typeCheckLimits.cancellationToken;
return typeChecker.check(sourceModule, mode, environmentScope); return typeChecker.check(sourceModule, mode, environmentScope);

View file

@ -0,0 +1,34 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/GlobalTypes.h"
LUAU_FASTFLAG(LuauInitializeStringMetatableInGlobalTypes)
namespace Luau
{
GlobalTypes::GlobalTypes(NotNull<BuiltinTypes> builtinTypes)
: builtinTypes(builtinTypes)
{
globalScope = std::make_shared<Scope>(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}));
globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType});
globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType});
globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType});
globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType});
globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType});
globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType});
globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType});
globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType});
if (FFlag::LuauInitializeStringMetatableInGlobalTypes)
{
unfreeze(*builtinTypes->arena);
TypeId stringMetatableTy = makeStringMetatable(builtinTypes);
asMutable(builtinTypes->stringType)->ty.emplace<PrimitiveType>(PrimitiveType::String, stringMetatableTy);
persist(stringMetatableTy);
freeze(*builtinTypes->arena);
}
}
} // namespace Luau

View file

@ -174,7 +174,8 @@ struct Replacer : Substitution
} }
}; };
std::optional<TypeId> instantiate(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, TypeId ty) std::optional<TypeId> instantiate(
NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, NotNull<TypeCheckLimits> limits, NotNull<Scope> scope, TypeId ty)
{ {
ty = follow(ty); ty = follow(ty);

View file

@ -14,48 +14,9 @@
LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4)
LUAU_FASTFLAGVARIABLE(LuauLintNativeComment, false)
namespace Luau namespace Luau
{ {
// clang-format off
static const char* kWarningNames[] = {
"Unknown",
"UnknownGlobal",
"DeprecatedGlobal",
"GlobalUsedAsLocal",
"LocalShadow",
"SameLineStatement",
"MultiLineStatement",
"LocalUnused",
"FunctionUnused",
"ImportUnused",
"BuiltinGlobalWrite",
"PlaceholderRead",
"UnreachableCode",
"UnknownType",
"ForRange",
"UnbalancedAssignment",
"ImplicitReturn",
"DuplicateLocal",
"FormatString",
"TableLiteral",
"UninitializedLocal",
"DuplicateFunction",
"DeprecatedApi",
"TableOperations",
"DuplicateCondition",
"MisleadingAndOr",
"CommentDirective",
"IntegerParsing",
"ComparisonPrecedence",
};
// clang-format on
static_assert(std::size(kWarningNames) == unsigned(LintWarning::Code__Count), "did you forget to add warning to the list?");
struct LintContext struct LintContext
{ {
struct Global struct Global
@ -2827,11 +2788,11 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
"optimize directive uses unknown optimization level '%s', 0..2 expected", level); "optimize directive uses unknown optimization level '%s', 0..2 expected", level);
} }
} }
else if (FFlag::LuauLintNativeComment && first == "native") else if (first == "native")
{ {
if (space != std::string::npos) if (space != std::string::npos)
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, emitWarning(
"native directive has extra symbols at the end of the line"); context, LintWarning::Code_CommentDirective, hc.location, "native directive has extra symbols at the end of the line");
} }
else else
{ {
@ -2855,12 +2816,6 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
} }
} }
void LintOptions::setDefaults()
{
// By default, we enable all warnings
warningMask = ~0ull;
}
std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module,
const std::vector<HotComment>& hotcomments, const LintOptions& options) const std::vector<HotComment>& hotcomments, const LintOptions& options)
{ {
@ -2952,54 +2907,6 @@ std::vector<LintWarning> lint(AstStat* root, const AstNameTable& names, const Sc
return context.result; return context.result;
} }
const char* LintWarning::getName(Code code)
{
LUAU_ASSERT(unsigned(code) < Code__Count);
return kWarningNames[code];
}
LintWarning::Code LintWarning::parseName(const char* name)
{
for (int code = Code_Unknown; code < Code__Count; ++code)
if (strcmp(name, getName(Code(code))) == 0)
return Code(code);
return Code_Unknown;
}
uint64_t LintWarning::parseMask(const std::vector<HotComment>& hotcomments)
{
uint64_t result = 0;
for (const HotComment& hc : hotcomments)
{
if (!hc.header)
continue;
if (hc.content.compare(0, 6, "nolint") != 0)
continue;
size_t name = hc.content.find_first_not_of(" \t", 6);
// --!nolint disables everything
if (name == std::string::npos)
return ~0ull;
// --!nolint needs to be followed by a whitespace character
if (name == 6)
continue;
// --!nolint name disables the specific lint
LintWarning::Code code = LintWarning::parseName(hc.content.c_str() + name);
if (code != LintWarning::Code_Unknown)
result |= 1ull << int(code);
}
return result;
}
std::vector<AstName> getDeprecatedGlobals(const AstNameTable& names) std::vector<AstName> getDeprecatedGlobals(const AstNameTable& names)
{ {
LintContext context; LintContext context;

View file

@ -199,7 +199,7 @@ void Module::clonePublicInterface(NotNull<BuiltinTypes> builtinTypes, InternalEr
LUAU_ASSERT(interfaceTypes.types.empty()); LUAU_ASSERT(interfaceTypes.types.empty());
LUAU_ASSERT(interfaceTypes.typePacks.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty());
CloneState cloneState; CloneState cloneState{builtinTypes};
ScopePtr moduleScope = getModuleScope(); ScopePtr moduleScope = getModuleScope();

View file

@ -176,7 +176,7 @@ const NormalizedStringType NormalizedStringType::never;
bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr) bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr)
{ {
if (subStr.isUnion() && superStr.isUnion()) if (subStr.isUnion() && (superStr.isUnion() && !superStr.isNever()))
{ {
for (auto [name, ty] : subStr.singletons) for (auto [name, ty] : subStr.singletons)
{ {
@ -1983,18 +1983,68 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th
void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there)
{ {
/* There are 9 cases to worry about here
Normalized Left | Normalized Right
C1 string | string ===> trivial
C2 string - {u_1,..} | string ===> trivial
C3 {u_1, ..} | string ===> trivial
C4 string | string - {v_1, ..} ===> string - {v_1, ..}
C5 string - {u_1,..} | string - {v_1, ..} ===> string - ({u_s} U {v_s})
C6 {u_1, ..} | string - {v_1, ..} ===> {u_s} - {v_s}
C7 string | {v_1, ..} ===> {v_s}
C8 string - {u_1,..} | {v_1, ..} ===> {v_s} - {u_s}
C9 {u_1, ..} | {v_1, ..} ===> {u_s} {v_s}
*/
// Case 1,2,3
if (there.isString()) if (there.isString())
return; return;
if (here.isString()) // Case 4, Case 7
here.resetToNever(); else if (here.isString())
for (auto it = here.singletons.begin(); it != here.singletons.end();)
{ {
if (there.singletons.count(it->first)) here.singletons.clear();
it++; for (const auto& [key, type] : there.singletons)
else here.singletons[key] = type;
it = here.singletons.erase(it); here.isCofinite = here.isCofinite && there.isCofinite;
} }
// Case 5
else if (here.isIntersection() && there.isIntersection())
{
here.isCofinite = true;
for (const auto& [key, type] : there.singletons)
here.singletons[key] = type;
}
// Case 6
else if (here.isUnion() && there.isIntersection())
{
here.isCofinite = false;
for (const auto& [key, _] : there.singletons)
here.singletons.erase(key);
}
// Case 8
else if (here.isIntersection() && there.isUnion())
{
here.isCofinite = false;
std::map<std::string, TypeId> result(there.singletons);
for (const auto& [key, _] : here.singletons)
result.erase(key);
here.singletons = result;
}
// Case 9
else if (here.isUnion() && there.isUnion())
{
here.isCofinite = false;
std::map<std::string, TypeId> result;
result.insert(here.singletons.begin(), here.singletons.end());
result.insert(there.singletons.begin(), there.singletons.end());
for (auto it = result.begin(); it != result.end();)
if (!here.singletons.count(it->first) || !there.singletons.count(it->first))
it = result.erase(it);
else
++it;
here.singletons = result;
}
else
LUAU_ASSERT(0 && "Internal Error - unrecognized case");
} }
std::optional<TypePackId> Normalizer::intersectionOfTypePacks(TypePackId here, TypePackId there) std::optional<TypePackId> Normalizer::intersectionOfTypePacks(TypePackId here, TypePackId there)

View file

@ -2,10 +2,12 @@
#include "Luau/Simplify.h" #include "Luau/Simplify.h"
#include "Luau/Normalize.h" // TypeIds
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/Normalize.h" // TypeIds #include "Luau/TypeUtils.h"
#include <algorithm> #include <algorithm>
LUAU_FASTINT(LuauTypeReductionRecursionLimit) LUAU_FASTINT(LuauTypeReductionRecursionLimit)
@ -47,14 +49,6 @@ struct TypeSimplifier
TypeId simplify(TypeId ty, DenseHashSet<TypeId>& seen); TypeId simplify(TypeId ty, DenseHashSet<TypeId>& seen);
}; };
template<typename A, typename B, typename TID>
static std::pair<const A*, const B*> get2(TID one, TID two)
{
const A* a = get<A>(one);
const B* b = get<B>(two);
return a && b ? std::make_pair(a, b) : std::make_pair(nullptr, nullptr);
}
// Match the exact type false|nil // Match the exact type false|nil
static bool isFalsyType(TypeId ty) static bool isFalsyType(TypeId ty)
{ {

View file

@ -10,7 +10,6 @@
LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000)
LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTFLAG(DebugLuauReadWriteProperties)
LUAU_FASTFLAGVARIABLE(LuauTarjanSingleArr, false)
namespace Luau namespace Luau
{ {
@ -269,8 +268,6 @@ std::pair<int, bool> Tarjan::indexify(TypeId ty)
{ {
ty = log->follow(ty); ty = log->follow(ty);
if (FFlag::LuauTarjanSingleArr)
{
auto [index, fresh] = typeToIndex.try_insert(ty, false); auto [index, fresh] = typeToIndex.try_insert(ty, false);
if (fresh) if (fresh)
@ -281,29 +278,11 @@ std::pair<int, bool> Tarjan::indexify(TypeId ty)
return {index, fresh}; return {index, fresh};
} }
else
{
bool fresh = !typeToIndex.contains(ty);
int& index = typeToIndex[ty];
if (fresh)
{
index = int(indexToType.size());
indexToType.push_back(ty);
indexToPack.push_back(nullptr);
onStack.push_back(false);
lowlink.push_back(index);
}
return {index, fresh};
}
}
std::pair<int, bool> Tarjan::indexify(TypePackId tp) std::pair<int, bool> Tarjan::indexify(TypePackId tp)
{ {
tp = log->follow(tp); tp = log->follow(tp);
if (FFlag::LuauTarjanSingleArr)
{
auto [index, fresh] = packToIndex.try_insert(tp, false); auto [index, fresh] = packToIndex.try_insert(tp, false);
if (fresh) if (fresh)
@ -314,23 +293,6 @@ std::pair<int, bool> Tarjan::indexify(TypePackId tp)
return {index, fresh}; return {index, fresh};
} }
else
{
bool fresh = !packToIndex.contains(tp);
int& index = packToIndex[tp];
if (fresh)
{
index = int(indexToPack.size());
indexToType.push_back(nullptr);
indexToPack.push_back(tp);
onStack.push_back(false);
lowlink.push_back(index);
}
return {index, fresh};
}
}
void Tarjan::visitChild(TypeId ty) void Tarjan::visitChild(TypeId ty)
{ {
@ -350,9 +312,6 @@ void Tarjan::visitChild(TypePackId tp)
TarjanResult Tarjan::loop() TarjanResult Tarjan::loop()
{ {
if (!FFlag::LuauTarjanSingleArr)
return loop_DEPRECATED();
// Normally Tarjan is presented recursively, but this is a hot loop, so worth optimizing // Normally Tarjan is presented recursively, but this is a hot loop, so worth optimizing
while (!worklist.empty()) while (!worklist.empty())
{ {
@ -475,28 +434,12 @@ TarjanResult Tarjan::visitRoot(TypePackId tp)
} }
void Tarjan::clearTarjan() void Tarjan::clearTarjan()
{
if (FFlag::LuauTarjanSingleArr)
{ {
typeToIndex.clear(); typeToIndex.clear();
packToIndex.clear(); packToIndex.clear();
nodes.clear(); nodes.clear();
stack.clear(); stack.clear();
}
else
{
dirty.clear();
typeToIndex.clear();
packToIndex.clear();
indexToType.clear();
indexToPack.clear();
stack.clear();
onStack.clear();
lowlink.clear();
}
edgesTy.clear(); edgesTy.clear();
edgesTp.clear(); edgesTp.clear();
@ -504,34 +447,16 @@ void Tarjan::clearTarjan()
} }
bool Tarjan::getDirty(int index) bool Tarjan::getDirty(int index)
{
if (FFlag::LuauTarjanSingleArr)
{ {
LUAU_ASSERT(size_t(index) < nodes.size()); LUAU_ASSERT(size_t(index) < nodes.size());
return nodes[index].dirty; return nodes[index].dirty;
} }
else
{
if (dirty.size() <= size_t(index))
dirty.resize(index + 1, false);
return dirty[index];
}
}
void Tarjan::setDirty(int index, bool d) void Tarjan::setDirty(int index, bool d)
{
if (FFlag::LuauTarjanSingleArr)
{ {
LUAU_ASSERT(size_t(index) < nodes.size()); LUAU_ASSERT(size_t(index) < nodes.size());
nodes[index].dirty = d; nodes[index].dirty = d;
} }
else
{
if (dirty.size() <= size_t(index))
dirty.resize(index + 1, false);
dirty[index] = d;
}
}
void Tarjan::visitEdge(int index, int parentIndex) void Tarjan::visitEdge(int index, int parentIndex)
{ {
@ -541,9 +466,6 @@ void Tarjan::visitEdge(int index, int parentIndex)
void Tarjan::visitSCC(int index) void Tarjan::visitSCC(int index)
{ {
if (!FFlag::LuauTarjanSingleArr)
return visitSCC_DEPRECATED(index);
bool d = getDirty(index); bool d = getDirty(index);
for (auto it = stack.rbegin(); !d && it != stack.rend(); it++) for (auto it = stack.rbegin(); !d && it != stack.rend(); it++)
@ -588,132 +510,6 @@ TarjanResult Tarjan::findDirty(TypePackId tp)
return visitRoot(tp); return visitRoot(tp);
} }
TarjanResult Tarjan::loop_DEPRECATED()
{
// Normally Tarjan is presented recursively, but this is a hot loop, so worth optimizing
while (!worklist.empty())
{
auto [index, currEdge, lastEdge] = worklist.back();
// First visit
if (currEdge == -1)
{
++childCount;
if (childLimit > 0 && childLimit <= childCount)
return TarjanResult::TooManyChildren;
stack.push_back(index);
onStack[index] = true;
currEdge = int(edgesTy.size());
// Fill in edge list of this vertex
if (TypeId ty = indexToType[index])
visitChildren(ty, index);
else if (TypePackId tp = indexToPack[index])
visitChildren(tp, index);
lastEdge = int(edgesTy.size());
}
// Visit children
bool foundFresh = false;
for (; currEdge < lastEdge; currEdge++)
{
int childIndex = -1;
bool fresh = false;
if (auto ty = edgesTy[currEdge])
std::tie(childIndex, fresh) = indexify(ty);
else if (auto tp = edgesTp[currEdge])
std::tie(childIndex, fresh) = indexify(tp);
else
LUAU_ASSERT(false);
if (fresh)
{
// Original recursion point, update the parent continuation point and start the new element
worklist.back() = {index, currEdge + 1, lastEdge};
worklist.push_back({childIndex, -1, -1});
// We need to continue the top-level loop from the start with the new worklist element
foundFresh = true;
break;
}
else if (onStack[childIndex])
{
lowlink[index] = std::min(lowlink[index], childIndex);
}
visitEdge(childIndex, index);
}
if (foundFresh)
continue;
if (lowlink[index] == index)
{
visitSCC(index);
while (!stack.empty())
{
int popped = stack.back();
stack.pop_back();
onStack[popped] = false;
if (popped == index)
break;
}
}
worklist.pop_back();
// Original return from recursion into a child
if (!worklist.empty())
{
auto [parentIndex, _, parentEndEdge] = worklist.back();
// No need to keep child edges around
edgesTy.resize(parentEndEdge);
edgesTp.resize(parentEndEdge);
lowlink[parentIndex] = std::min(lowlink[parentIndex], lowlink[index]);
visitEdge(index, parentIndex);
}
}
return TarjanResult::Ok;
}
void Tarjan::visitSCC_DEPRECATED(int index)
{
bool d = getDirty(index);
for (auto it = stack.rbegin(); !d && it != stack.rend(); it++)
{
if (TypeId ty = indexToType[*it])
d = isDirty(ty);
else if (TypePackId tp = indexToPack[*it])
d = isDirty(tp);
if (*it == index)
break;
}
if (!d)
return;
for (auto it = stack.rbegin(); it != stack.rend(); it++)
{
setDirty(*it, true);
if (TypeId ty = indexToType[*it])
foundDirty(ty);
else if (TypePackId tp = indexToPack[*it])
foundDirty(tp);
if (*it == index)
return;
}
}
std::optional<TypeId> Substitution::substitute(TypeId ty) std::optional<TypeId> Substitution::substitute(TypeId ty)
{ {
ty = log->follow(ty); ty = log->follow(ty);

1050
Analysis/src/Subtyping.cpp Normal file

File diff suppressed because it is too large Load diff

View file

@ -9,6 +9,8 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
namespace Luau namespace Luau
{ {
@ -52,7 +54,7 @@ bool StateDot::canDuplicatePrimitive(TypeId ty)
if (get<BoundType>(ty)) if (get<BoundType>(ty))
return false; return false;
return get<PrimitiveType>(ty) || get<AnyType>(ty); return get<PrimitiveType>(ty) || get<AnyType>(ty) || get<UnknownType>(ty) || get<NeverType>(ty);
} }
void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName)
@ -76,6 +78,10 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName)
formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str()); formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str());
else if (get<AnyType>(ty)) else if (get<AnyType>(ty))
formatAppend(result, "n%d [label=\"any\"];\n", index); formatAppend(result, "n%d [label=\"any\"];\n", index);
else if (get<UnknownType>(ty))
formatAppend(result, "n%d [label=\"unknown\"];\n", index);
else if (get<NeverType>(ty))
formatAppend(result, "n%d [label=\"never\"];\n", index);
} }
else else
{ {
@ -139,142 +145,184 @@ void StateDot::visitChildren(TypeId ty, int index)
startNode(index); startNode(index);
startNodeLabel(); startNodeLabel();
if (const BoundType* btv = get<BoundType>(ty)) auto go = [&](auto&& t) {
using T = std::decay_t<decltype(t)>;
if constexpr (std::is_same_v<T, BoundType>)
{ {
formatAppend(result, "BoundType %d", index); formatAppend(result, "BoundType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
visitChild(btv->boundTo, index); visitChild(t.boundTo, index);
} }
else if (const FunctionType* ftv = get<FunctionType>(ty)) else if constexpr (std::is_same_v<T, BlockedType>)
{
formatAppend(result, "BlockedType %d", index);
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, FunctionType>)
{ {
formatAppend(result, "FunctionType %d", index); formatAppend(result, "FunctionType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
visitChild(ftv->argTypes, index, "arg"); visitChild(t.argTypes, index, "arg");
visitChild(ftv->retTypes, index, "ret"); visitChild(t.retTypes, index, "ret");
} }
else if (const TableType* ttv = get<TableType>(ty)) else if constexpr (std::is_same_v<T, TableType>)
{ {
if (ttv->name) if (t.name)
formatAppend(result, "TableType %s", ttv->name->c_str()); formatAppend(result, "TableType %s", t.name->c_str());
else if (ttv->syntheticName) else if (t.syntheticName)
formatAppend(result, "TableType %s", ttv->syntheticName->c_str()); formatAppend(result, "TableType %s", t.syntheticName->c_str());
else else
formatAppend(result, "TableType %d", index); formatAppend(result, "TableType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
if (ttv->boundTo) if (t.boundTo)
return visitChild(*ttv->boundTo, index, "boundTo"); return visitChild(*t.boundTo, index, "boundTo");
for (const auto& [name, prop] : ttv->props) for (const auto& [name, prop] : t.props)
visitChild(prop.type(), index, name.c_str()); visitChild(prop.type(), index, name.c_str());
if (ttv->indexer) if (t.indexer)
{ {
visitChild(ttv->indexer->indexType, index, "[index]"); visitChild(t.indexer->indexType, index, "[index]");
visitChild(ttv->indexer->indexResultType, index, "[value]"); visitChild(t.indexer->indexResultType, index, "[value]");
} }
for (TypeId itp : ttv->instantiatedTypeParams) for (TypeId itp : t.instantiatedTypeParams)
visitChild(itp, index, "typeParam"); visitChild(itp, index, "typeParam");
for (TypePackId itp : ttv->instantiatedTypePackParams) for (TypePackId itp : t.instantiatedTypePackParams)
visitChild(itp, index, "typePackParam"); visitChild(itp, index, "typePackParam");
} }
else if (const MetatableType* mtv = get<MetatableType>(ty)) else if constexpr (std::is_same_v<T, MetatableType>)
{ {
formatAppend(result, "MetatableType %d", index); formatAppend(result, "MetatableType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
visitChild(mtv->table, index, "table"); visitChild(t.table, index, "table");
visitChild(mtv->metatable, index, "metatable"); visitChild(t.metatable, index, "metatable");
} }
else if (const UnionType* utv = get<UnionType>(ty)) else if constexpr (std::is_same_v<T, UnionType>)
{ {
formatAppend(result, "UnionType %d", index); formatAppend(result, "UnionType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
for (TypeId opt : utv->options) for (TypeId opt : t.options)
visitChild(opt, index); visitChild(opt, index);
} }
else if (const IntersectionType* itv = get<IntersectionType>(ty)) else if constexpr (std::is_same_v<T, IntersectionType>)
{ {
formatAppend(result, "IntersectionType %d", index); formatAppend(result, "IntersectionType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
for (TypeId part : itv->parts) for (TypeId part : t.parts)
visitChild(part, index); visitChild(part, index);
} }
else if (const GenericType* gtv = get<GenericType>(ty)) else if constexpr (std::is_same_v<T, LazyType>)
{ {
if (gtv->explicitName) formatAppend(result, "LazyType %d", index);
formatAppend(result, "GenericType %s", gtv->name.c_str()); finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, PendingExpansionType>)
{
formatAppend(result, "PendingExpansionType %d", index);
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, GenericType>)
{
if (t.explicitName)
formatAppend(result, "GenericType %s", t.name.c_str());
else else
formatAppend(result, "GenericType %d", index); formatAppend(result, "GenericType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if (const FreeType* ftv = get<FreeType>(ty)) else if constexpr (std::is_same_v<T, FreeType>)
{ {
formatAppend(result, "FreeType %d", index); formatAppend(result, "FreeType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
if (FFlag::DebugLuauDeferredConstraintResolution)
{
if (!get<NeverType>(t.lowerBound))
visitChild(t.lowerBound, index, "[lowerBound]");
if (!get<UnknownType>(t.upperBound))
visitChild(t.upperBound, index, "[upperBound]");
} }
else if (get<AnyType>(ty)) }
else if constexpr (std::is_same_v<T, AnyType>)
{ {
formatAppend(result, "AnyType %d", index); formatAppend(result, "AnyType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if (get<PrimitiveType>(ty)) else if constexpr (std::is_same_v<T, UnknownType>)
{
formatAppend(result, "UnknownType %d", index);
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, NeverType>)
{
formatAppend(result, "NeverType %d", index);
finishNodeLabel(ty);
finishNode();
}
else if constexpr (std::is_same_v<T, PrimitiveType>)
{ {
formatAppend(result, "PrimitiveType %s", toString(ty).c_str()); formatAppend(result, "PrimitiveType %s", toString(ty).c_str());
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if (get<ErrorType>(ty)) else if constexpr (std::is_same_v<T, ErrorType>)
{ {
formatAppend(result, "ErrorType %d", index); formatAppend(result, "ErrorType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else if (const ClassType* ctv = get<ClassType>(ty)) else if constexpr (std::is_same_v<T, ClassType>)
{ {
formatAppend(result, "ClassType %s", ctv->name.c_str()); formatAppend(result, "ClassType %s", t.name.c_str());
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
for (const auto& [name, prop] : ctv->props) for (const auto& [name, prop] : t.props)
visitChild(prop.type(), index, name.c_str()); visitChild(prop.type(), index, name.c_str());
if (ctv->parent) if (t.parent)
visitChild(*ctv->parent, index, "[parent]"); visitChild(*t.parent, index, "[parent]");
if (ctv->metatable) if (t.metatable)
visitChild(*ctv->metatable, index, "[metatable]"); visitChild(*t.metatable, index, "[metatable]");
if (ctv->indexer) if (t.indexer)
{ {
visitChild(ctv->indexer->indexType, index, "[index]"); visitChild(t.indexer->indexType, index, "[index]");
visitChild(ctv->indexer->indexResultType, index, "[value]"); visitChild(t.indexer->indexResultType, index, "[value]");
} }
} }
else if (const SingletonType* stv = get<SingletonType>(ty)) else if constexpr (std::is_same_v<T, SingletonType>)
{ {
std::string res; std::string res;
if (const StringSingleton* ss = get<StringSingleton>(stv)) if (const StringSingleton* ss = get<StringSingleton>(&t))
{ {
// Don't put in quotes anywhere. If it's outside of the call to escape, // Don't put in quotes anywhere. If it's outside of the call to escape,
// then it's invalid syntax. If it's inside, then escaping is super noisy. // then it's invalid syntax. If it's inside, then escaping is super noisy.
res = "string: " + escape(ss->value); res = "string: " + escape(ss->value);
} }
else if (const BooleanSingleton* bs = get<BooleanSingleton>(stv)) else if (const BooleanSingleton* bs = get<BooleanSingleton>(&t))
{ {
res = "boolean: "; res = "boolean: ";
res += bs->value ? "true" : "false"; res += bs->value ? "true" : "false";
@ -286,12 +334,25 @@ void StateDot::visitChildren(TypeId ty, int index)
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else else if constexpr (std::is_same_v<T, NegationType>)
{ {
LUAU_ASSERT(!"unknown type kind"); formatAppend(result, "NegationType %d", index);
finishNodeLabel(ty);
finishNode();
visitChild(t.ty, index, "[negated]");
}
else if constexpr (std::is_same_v<T, TypeFamilyInstanceType>)
{
formatAppend(result, "TypeFamilyInstanceType %d", index);
finishNodeLabel(ty); finishNodeLabel(ty);
finishNode(); finishNode();
} }
else
static_assert(always_false_v<T>, "unknown type kind");
};
visit(go, ty->ty);
} }
void StateDot::visitChildren(TypePackId tp, int index) void StateDot::visitChildren(TypePackId tp, int index)

View file

@ -10,6 +10,8 @@
#include <limits> #include <limits>
#include <math.h> #include <math.h>
LUAU_FASTFLAG(LuauFloorDivision)
namespace namespace
{ {
bool isIdentifierStartChar(char c) bool isIdentifierStartChar(char c)
@ -467,10 +469,13 @@ struct Printer
case AstExprBinary::Sub: case AstExprBinary::Sub:
case AstExprBinary::Mul: case AstExprBinary::Mul:
case AstExprBinary::Div: case AstExprBinary::Div:
case AstExprBinary::FloorDiv:
case AstExprBinary::Mod: case AstExprBinary::Mod:
case AstExprBinary::Pow: case AstExprBinary::Pow:
case AstExprBinary::CompareLt: case AstExprBinary::CompareLt:
case AstExprBinary::CompareGt: case AstExprBinary::CompareGt:
LUAU_ASSERT(FFlag::LuauFloorDivision || a->op != AstExprBinary::FloorDiv);
writer.maybeSpace(a->right->location.begin, 2); writer.maybeSpace(a->right->location.begin, 2);
writer.symbol(toString(a->op)); writer.symbol(toString(a->op));
break; break;
@ -487,6 +492,8 @@ struct Printer
writer.maybeSpace(a->right->location.begin, 4); writer.maybeSpace(a->right->location.begin, 4);
writer.keyword(toString(a->op)); writer.keyword(toString(a->op));
break; break;
default:
LUAU_ASSERT(!"Unknown Op");
} }
visualize(*a->right); visualize(*a->right);
@ -753,6 +760,12 @@ struct Printer
writer.maybeSpace(a->value->location.begin, 2); writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("/="); writer.symbol("/=");
break; break;
case AstExprBinary::FloorDiv:
LUAU_ASSERT(FFlag::LuauFloorDivision);
writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("//=");
break;
case AstExprBinary::Mod: case AstExprBinary::Mod:
writer.maybeSpace(a->value->location.begin, 2); writer.maybeSpace(a->value->location.begin, 2);
writer.symbol("%="); writer.symbol("%=");

View file

@ -27,26 +27,11 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauNormalizeBlockedTypes)
LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTFLAG(DebugLuauReadWriteProperties)
LUAU_FASTFLAGVARIABLE(LuauInitializeStringMetatableInGlobalTypes, false)
namespace Luau namespace Luau
{ {
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context);
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context);
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context);
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static bool dcrMagicFunctionFind(MagicFunctionCallContext context);
// LUAU_NOINLINE prevents unwrapLazy from being inlined into advance below; advance is important to keep inlineable // LUAU_NOINLINE prevents unwrapLazy from being inlined into advance below; advance is important to keep inlineable
static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv) static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv)
{ {
@ -69,14 +54,24 @@ static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv)
TypeId follow(TypeId t) TypeId follow(TypeId t)
{ {
return follow(t, nullptr, [](const void*, TypeId t) -> TypeId { return follow(t, FollowOption::Normal);
}
TypeId follow(TypeId t, FollowOption followOption)
{
return follow(t, followOption, nullptr, [](const void*, TypeId t) -> TypeId {
return t; return t;
}); });
} }
TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)) TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId))
{ {
auto advance = [context, mapper](TypeId ty) -> std::optional<TypeId> { return follow(t, FollowOption::Normal, context, mapper);
}
TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId))
{
auto advance = [followOption, context, mapper](TypeId ty) -> std::optional<TypeId> {
TypeId mapped = mapper(context, ty); TypeId mapped = mapper(context, ty);
if (auto btv = get<Unifiable::Bound<TypeId>>(mapped)) if (auto btv = get<Unifiable::Bound<TypeId>>(mapped))
@ -85,7 +80,7 @@ TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeI
if (auto ttv = get<TableType>(mapped)) if (auto ttv = get<TableType>(mapped))
return ttv->boundTo; return ttv->boundTo;
if (auto ltv = getMutable<LazyType>(mapped)) if (auto ltv = getMutable<LazyType>(mapped); ltv && followOption != FollowOption::DisableLazyTypeThunks)
return unwrapLazy(ltv); return unwrapLazy(ltv);
return std::nullopt; return std::nullopt;
@ -923,6 +918,8 @@ TypeId makeFunction(TypeArena& arena, std::optional<TypeId> selfType, std::initi
std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypePackId> genericPacks, std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames,
std::initializer_list<TypeId> retTypes); std::initializer_list<TypeId> retTypes);
TypeId makeStringMetatable(NotNull<BuiltinTypes> builtinTypes); // BuiltinDefinitions.cpp
BuiltinTypes::BuiltinTypes() BuiltinTypes::BuiltinTypes()
: arena(new TypeArena) : arena(new TypeArena)
, debugFreezeArena(FFlag::DebugLuauFreezeArena) , debugFreezeArena(FFlag::DebugLuauFreezeArena)
@ -945,14 +942,18 @@ BuiltinTypes::BuiltinTypes()
, truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true})) , truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true}))
, optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true})) , optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true}))
, optionalStringType(arena->addType(Type{UnionType{{stringType, nilType}}, /*persistent*/ true})) , optionalStringType(arena->addType(Type{UnionType{{stringType, nilType}}, /*persistent*/ true}))
, emptyTypePack(arena->addTypePack(TypePackVar{TypePack{{}}, /*persistent*/ true}))
, anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true}))
, neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true}))
, uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true})) , uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true}))
, errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) , errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true}))
{ {
TypeId stringMetatable = makeStringMetatable(); if (!FFlag::LuauInitializeStringMetatableInGlobalTypes)
{
TypeId stringMetatable = makeStringMetatable(NotNull{this});
asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable}; asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable};
persist(stringMetatable); persist(stringMetatable);
}
freeze(*arena); freeze(*arena);
} }
@ -969,82 +970,6 @@ BuiltinTypes::~BuiltinTypes()
FFlag::DebugLuauFreezeArena.value = prevFlag; FFlag::DebugLuauFreezeArena.value = prevFlag;
} }
TypeId BuiltinTypes::makeStringMetatable()
{
const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}});
const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}});
const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}});
const TypePackId oneStringPack = arena->addTypePack({stringType});
const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true});
FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack};
formatFTV.magicFunction = &magicFunctionFormat;
const TypeId formatFn = arena->addType(formatFTV);
attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat);
const TypePackId emptyPack = arena->addTypePack({});
const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}});
const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}});
const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType});
const TypeId replArgType =
arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)),
makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}});
const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType});
const TypeId gmatchFunc =
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})});
attachMagicFunction(gmatchFunc, magicFunctionGmatch);
attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch);
const TypeId matchFunc = arena->addType(
FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})});
attachMagicFunction(matchFunc, magicFunctionMatch);
attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch);
const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})});
attachMagicFunction(findFunc, magicFunctionFind);
attachDcrMagicFunction(findFunc, dcrMagicFunctionFind);
TableType::Props stringLib = {
{"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}},
{"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}},
{"find", {findFunc}},
{"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"lower", {stringToStringType}},
{"match", {matchFunc}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}},
{"pack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType}, anyTypePack}),
oneStringPack,
})}},
{"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"unpack", {arena->addType(FunctionType{
arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}),
anyTypePack,
})}},
};
assignPropDocumentationSymbols(stringLib, "@luau/global/string");
TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed});
if (TableType* ttv = getMutable<TableType>(tableType))
ttv->name = "typeof(string)";
return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed});
}
TypeId BuiltinTypes::errorRecoveryType() const TypeId BuiltinTypes::errorRecoveryType() const
{ {
return errorType; return errorType;
@ -1250,436 +1175,6 @@ IntersectionTypeIterator end(const IntersectionType* itv)
return IntersectionTypeIterator{}; return IntersectionTypeIterator{};
} }
static std::vector<TypeId> parseFormatString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
{
const char* options = "cdiouxXeEfgGqs*";
std::vector<TypeId> result;
for (size_t i = 0; i < size; ++i)
{
if (data[i] == '%')
{
i++;
if (i < size && data[i] == '%')
continue;
// we just ignore all characters (including flags/precision) up until first alphabetic character
while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*')))
i++;
if (i == size)
break;
if (data[i] == 'q' || data[i] == 's')
result.push_back(builtinTypes->stringType);
else if (data[i] == '*')
result.push_back(builtinTypes->unknownType);
else if (strchr(options, data[i]))
result.push_back(builtinTypes->numberType);
else
result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType));
}
}
return result;
}
std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* fmt = nullptr;
if (auto index = expr.func->as<AstExprIndexName>(); index && expr.self)
{
if (auto group = index->expr->as<AstExprGroup>())
fmt = group->expr->as<AstExprConstantString>();
else
fmt = index->expr->as<AstExprConstantString>();
}
if (!expr.self && expr.args.size > 0)
fmt = expr.args.data[0]->as<AstExprConstantString>();
if (!fmt)
return std::nullopt;
std::vector<TypeId> expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(paramPack);
size_t paramOffset = 1;
size_t dataOffset = expr.self ? 0 : 1;
// unify the prefix one argument at a time
for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i)
{
Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location;
typechecker.unify(params[i + paramOffset], expected[i], scope, location);
}
// if we know the argument count or if we have too many arguments for sure, we can issue an error
size_t numActualParams = params.size();
size_t numExpectedParams = expected.size() + 1; // + 1 for the format string
if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams))
typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}});
return WithPredicate<TypePackId>{arena.addTypePack({typechecker.stringType})};
}
static bool dcrMagicFunctionFormat(MagicFunctionCallContext context)
{
TypeArena* arena = context.solver->arena;
AstExprConstantString* fmt = nullptr;
if (auto index = context.callSite->func->as<AstExprIndexName>(); index && context.callSite->self)
{
if (auto group = index->expr->as<AstExprGroup>())
fmt = group->expr->as<AstExprConstantString>();
else
fmt = index->expr->as<AstExprConstantString>();
}
if (!context.callSite->self && context.callSite->args.size > 0)
fmt = context.callSite->args.data[0]->as<AstExprConstantString>();
if (!fmt)
return false;
std::vector<TypeId> expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size);
const auto& [params, tail] = flatten(context.arguments);
size_t paramOffset = 1;
// unify the prefix one argument at a time
for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i)
{
context.solver->unify(context.solver->rootScope, context.callSite->location, params[i + paramOffset], expected[i]);
}
// if we know the argument count or if we have too many arguments for sure, we can issue an error
size_t numActualParams = params.size();
size_t numExpectedParams = expected.size() + 1; // + 1 for the format string
if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams))
context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}});
TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType});
asMutable(context.result)->ty.emplace<BoundTypePack>(resultPack);
return true;
}
static std::vector<TypeId> parsePatternString(NotNull<BuiltinTypes> builtinTypes, const char* data, size_t size)
{
std::vector<TypeId> result;
int depth = 0;
bool parsingSet = false;
for (size_t i = 0; i < size; ++i)
{
if (data[i] == '%')
{
++i;
if (!parsingSet && i < size && data[i] == 'b')
i += 2;
}
else if (!parsingSet && data[i] == '[')
{
parsingSet = true;
if (i + 1 < size && data[i + 1] == ']')
i += 1;
}
else if (parsingSet && data[i] == ']')
{
parsingSet = false;
}
else if (data[i] == '(')
{
if (parsingSet)
continue;
if (i + 1 < size && data[i + 1] == ')')
{
i++;
result.push_back(builtinTypes->optionalNumberType);
continue;
}
++depth;
result.push_back(builtinTypes->optionalStringType);
}
else if (data[i] == ')')
{
if (parsingSet)
continue;
--depth;
if (depth < 0)
break;
}
}
if (depth != 0 || parsingSet)
return std::vector<TypeId>();
if (result.empty())
result.push_back(builtinTypes->optionalStringType);
return result;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() != 2)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t index = expr.self ? 0 : 1;
if (expr.args.size > index)
pattern = expr.args.data[index]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypePackId emptyPack = arena.addTypePack({});
const TypePackId returnList = arena.addTypePack(returnTypes);
const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList});
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
}
static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() != 2)
return false;
TypeArena* arena = context.solver->arena;
AstExprConstantString* pattern = nullptr;
size_t index = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > index)
pattern = context.callSite->args.data[index]->as<AstExprConstantString>();
if (!pattern)
return false;
std::vector<TypeId> returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType);
const TypePackId emptyPack = arena->addTypePack({});
const TypePackId returnList = arena->addTypePack(returnTypes);
const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList});
const TypePackId resTypePack = arena->addTypePack({iteratorType});
asMutable(context.result)->ty.emplace<BoundTypePack>(resTypePack);
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 3)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() == 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionMatch(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() < 2 || params.size() > 3)
return false;
TypeArena* arena = context.solver->arena;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > patternIndex)
pattern = context.callSite->args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return false;
std::vector<TypeId> returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType);
const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}});
size_t initIndex = context.callSite->self ? 1 : 2;
if (params.size() == 3 && context.callSite->args.size > initIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber);
const TypePackId returnList = arena->addTypePack(returnTypes);
asMutable(context.result)->ty.emplace<BoundTypePack>(returnList);
return true;
}
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 4)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
bool plain = false;
size_t plainIndex = expr.self ? 2 : 3;
if (expr.args.size > plainIndex)
{
AstExprConstantBool* p = expr.args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
}
typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}});
const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() >= 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
if (params.size() == 4 && expr.args.size > plainIndex)
typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static bool dcrMagicFunctionFind(MagicFunctionCallContext context)
{
const auto& [params, tail] = flatten(context.arguments);
if (params.size() < 2 || params.size() > 4)
return false;
TypeArena* arena = context.solver->arena;
NotNull<BuiltinTypes> builtinTypes = context.solver->builtinTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = context.callSite->self ? 0 : 1;
if (context.callSite->args.size > patternIndex)
pattern = context.callSite->args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return false;
bool plain = false;
size_t plainIndex = context.callSite->self ? 2 : 3;
if (context.callSite->args.size > plainIndex)
{
AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return false;
}
context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], builtinTypes->stringType);
const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}});
const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}});
size_t initIndex = context.callSite->self ? 1 : 2;
if (params.size() >= 3 && context.callSite->args.size > initIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber);
if (params.size() == 4 && context.callSite->args.size > plainIndex)
context.solver->unify(context.solver->rootScope, context.callSite->location, params[3], optionalBoolean);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena->addTypePack(returnTypes);
asMutable(context.result)->ty.emplace<BoundTypePack>(returnList);
return true;
}
TypeId freshType(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, Scope* scope) TypeId freshType(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, Scope* scope)
{ {
return arena->addType(FreeType{scope, builtinTypes->neverType, builtinTypes->unknownType}); return arena->addType(FreeType{scope, builtinTypes->neverType, builtinTypes->unknownType});

View file

@ -23,6 +23,7 @@
#include <algorithm> #include <algorithm>
LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(LuauFloorDivision);
namespace Luau namespace Luau
{ {
@ -241,8 +242,8 @@ struct TypeChecker2
Normalizer normalizer; Normalizer normalizer;
TypeChecker2(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> unifierState, NotNull<TypeCheckLimits> limits, DcrLogger* logger, const SourceModule* sourceModule, TypeChecker2(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> unifierState, NotNull<TypeCheckLimits> limits, DcrLogger* logger,
Module* module) const SourceModule* sourceModule, Module* module)
: builtinTypes(builtinTypes) : builtinTypes(builtinTypes)
, logger(logger) , logger(logger)
, limits(limits) , limits(limits)
@ -1294,13 +1295,8 @@ struct TypeChecker2
else if (auto assertion = expr->as<AstExprTypeAssertion>()) else if (auto assertion = expr->as<AstExprTypeAssertion>())
return isLiteral(assertion->expr); return isLiteral(assertion->expr);
return return expr->is<AstExprConstantNil>() || expr->is<AstExprConstantBool>() || expr->is<AstExprConstantNumber>() ||
expr->is<AstExprConstantNil>() || expr->is<AstExprConstantString>() || expr->is<AstExprFunction>() || expr->is<AstExprTable>();
expr->is<AstExprConstantBool>() ||
expr->is<AstExprConstantNumber>() ||
expr->is<AstExprConstantString>() ||
expr->is<AstExprFunction>() ||
expr->is<AstExprTable>();
} }
static std::unique_ptr<LiteralProperties> buildLiteralPropertiesSet(AstExpr* expr) static std::unique_ptr<LiteralProperties> buildLiteralPropertiesSet(AstExpr* expr)
@ -1817,6 +1813,8 @@ struct TypeChecker2
bool typesHaveIntersection = normalizer.isIntersectionInhabited(leftType, rightType); bool typesHaveIntersection = normalizer.isIntersectionInhabited(leftType, rightType);
if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end())
{ {
LUAU_ASSERT(FFlag::LuauFloorDivision || expr->op != AstExprBinary::Op::FloorDiv);
std::optional<TypeId> leftMt = getMetatable(leftType, builtinTypes); std::optional<TypeId> leftMt = getMetatable(leftType, builtinTypes);
std::optional<TypeId> rightMt = getMetatable(rightType, builtinTypes); std::optional<TypeId> rightMt = getMetatable(rightType, builtinTypes);
bool matches = leftMt == rightMt; bool matches = leftMt == rightMt;
@ -2002,8 +2000,11 @@ struct TypeChecker2
case AstExprBinary::Op::Sub: case AstExprBinary::Op::Sub:
case AstExprBinary::Op::Mul: case AstExprBinary::Op::Mul:
case AstExprBinary::Op::Div: case AstExprBinary::Op::Div:
case AstExprBinary::Op::FloorDiv:
case AstExprBinary::Op::Pow: case AstExprBinary::Op::Pow:
case AstExprBinary::Op::Mod: case AstExprBinary::Op::Mod:
LUAU_ASSERT(FFlag::LuauFloorDivision || expr->op != AstExprBinary::Op::FloorDiv);
reportErrors(tryUnify(scope, expr->left->location, leftType, builtinTypes->numberType)); reportErrors(tryUnify(scope, expr->left->location, leftType, builtinTypes->numberType));
reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType));
@ -2680,15 +2681,15 @@ struct TypeChecker2
} }
}; };
void check( void check(NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> unifierState, NotNull<TypeCheckLimits> limits, DcrLogger* logger,
NotNull<BuiltinTypes> builtinTypes, NotNull<UnifierSharedState> unifierState, NotNull<TypeCheckLimits> limits, DcrLogger* logger, const SourceModule& sourceModule, Module* module) const SourceModule& sourceModule, Module* module)
{ {
TypeChecker2 typeChecker{builtinTypes, unifierState, limits, logger, &sourceModule, module}; TypeChecker2 typeChecker{builtinTypes, unifierState, limits, logger, &sourceModule, module};
typeChecker.visit(sourceModule.root); typeChecker.visit(sourceModule.root);
unfreeze(module->interfaceTypes); unfreeze(module->interfaceTypes);
copyErrors(module->errors, module->interfaceTypes); copyErrors(module->errors, module->interfaceTypes, builtinTypes);
freeze(module->interfaceTypes); freeze(module->interfaceTypes);
} }

View file

@ -36,14 +36,13 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false)
LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false)
LUAU_FASTFLAGVARIABLE(LuauFixCyclicModuleExports, false)
LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure)
LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false)
LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false)
LUAU_FASTFLAGVARIABLE(LuauLoopControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauLoopControlFlowAnalysis, false)
LUAU_FASTFLAGVARIABLE(LuauVariadicOverloadFix, false)
LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false)
LUAU_FASTFLAG(LuauParseDeclareClassIndexer) LUAU_FASTFLAG(LuauParseDeclareClassIndexer)
LUAU_FASTFLAGVARIABLE(LuauIndexTableIntersectionStringExpr, false) LUAU_FASTFLAG(LuauFloorDivision);
namespace Luau namespace Luau
{ {
@ -204,7 +203,8 @@ static bool isMetamethod(const Name& name)
{ {
return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" ||
name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" ||
name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len" ||
(FFlag::LuauFloorDivision && name == "__idiv");
} }
size_t HashBoolNamePair::operator()(const std::pair<bool, Name>& pair) const size_t HashBoolNamePair::operator()(const std::pair<bool, Name>& pair) const
@ -212,21 +212,6 @@ size_t HashBoolNamePair::operator()(const std::pair<bool, Name>& pair) const
return std::hash<bool>()(pair.first) ^ std::hash<Name>()(pair.second); return std::hash<bool>()(pair.first) ^ std::hash<Name>()(pair.second);
} }
GlobalTypes::GlobalTypes(NotNull<BuiltinTypes> builtinTypes)
: builtinTypes(builtinTypes)
{
globalScope = std::make_shared<Scope>(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}));
globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType});
globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType});
globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType});
globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType});
globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType});
globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType});
globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType});
globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType});
}
TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler) TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, NotNull<BuiltinTypes> builtinTypes, InternalErrorReporter* iceHandler)
: globalScope(globalScope) : globalScope(globalScope)
, resolver(resolver) , resolver(resolver)
@ -1215,8 +1200,6 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local)
scope->importedTypeBindings[name] = module->exportedTypeBindings; scope->importedTypeBindings[name] = module->exportedTypeBindings;
scope->importedModules[name] = moduleInfo->name; scope->importedModules[name] = moduleInfo->name;
if (FFlag::LuauFixCyclicModuleExports)
{
// Imported types of requires that transitively refer to current module have to be replaced with 'any' // Imported types of requires that transitively refer to current module have to be replaced with 'any'
for (const auto& [location, path] : requireCycles) for (const auto& [location, path] : requireCycles)
{ {
@ -1227,7 +1210,6 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local)
} }
} }
} }
}
// In non-strict mode we force the module type on the variable, in strict mode it is already unified // In non-strict mode we force the module type on the variable, in strict mode it is already unified
if (isNonstrictMode()) if (isNonstrictMode())
@ -2595,6 +2577,9 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op)
return "__mul"; return "__mul";
case AstExprBinary::Div: case AstExprBinary::Div:
return "__div"; return "__div";
case AstExprBinary::FloorDiv:
LUAU_ASSERT(FFlag::LuauFloorDivision);
return "__idiv";
case AstExprBinary::Mod: case AstExprBinary::Mod:
return "__mod"; return "__mod";
case AstExprBinary::Pow: case AstExprBinary::Pow:
@ -3088,8 +3073,11 @@ TypeId TypeChecker::checkBinaryOperation(
case AstExprBinary::Sub: case AstExprBinary::Sub:
case AstExprBinary::Mul: case AstExprBinary::Mul:
case AstExprBinary::Div: case AstExprBinary::Div:
case AstExprBinary::FloorDiv:
case AstExprBinary::Mod: case AstExprBinary::Mod:
case AstExprBinary::Pow: case AstExprBinary::Pow:
LUAU_ASSERT(FFlag::LuauFloorDivision || expr.op != AstExprBinary::FloorDiv);
reportErrors(tryUnify(lhsType, numberType, scope, expr.left->location)); reportErrors(tryUnify(lhsType, numberType, scope, expr.left->location));
reportErrors(tryUnify(rhsType, numberType, scope, expr.right->location)); reportErrors(tryUnify(rhsType, numberType, scope, expr.right->location));
return numberType; return numberType;
@ -3128,22 +3116,13 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
} }
else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe)
{ {
if (!FFlag::LuauTypecheckTypeguards)
{
if (auto predicate = tryGetTypeGuardPredicate(expr))
return {booleanType, {std::move(*predicate)}};
}
// For these, passing expectedType is worse than simply forcing them, because their implementation // For these, passing expectedType is worse than simply forcing them, because their implementation
// may inadvertently check if expectedTypes exist first and use it, instead of forceSingleton first. // may inadvertently check if expectedTypes exist first and use it, instead of forceSingleton first.
WithPredicate<TypeId> lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); WithPredicate<TypeId> lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true);
WithPredicate<TypeId> rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); WithPredicate<TypeId> rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true);
if (FFlag::LuauTypecheckTypeguards)
{
if (auto predicate = tryGetTypeGuardPredicate(expr)) if (auto predicate = tryGetTypeGuardPredicate(expr))
return {booleanType, {std::move(*predicate)}}; return {booleanType, {std::move(*predicate)}};
}
PredicateVec predicates; PredicateVec predicates;
@ -3423,7 +3402,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}});
return errorRecoveryType(scope); return errorRecoveryType(scope);
} }
else if (FFlag::LuauIndexTableIntersectionStringExpr && get<IntersectionType>(exprType)) else if (get<IntersectionType>(exprType))
{ {
Name name = std::string(value->value.data, value->value.size); Name name = std::string(value->value.data, value->value.size);
@ -4063,6 +4042,12 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam
if (argIndex < argLocations.size()) if (argIndex < argLocations.size())
location = argLocations[argIndex]; location = argLocations[argIndex];
if (FFlag::LuauVariadicOverloadFix)
{
state.location = location;
state.tryUnify(*argIter, vtp->ty);
}
else
unify(*argIter, vtp->ty, scope, location); unify(*argIter, vtp->ty, scope, location);
++argIter; ++argIter;
++argIndex; ++argIndex;

View file

@ -25,7 +25,7 @@ LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false)
LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauNormalizeBlockedTypes)
LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
LUAU_FASTFLAGVARIABLE(LuauTableUnifyRecursionLimit, false) LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false)
namespace Luau namespace Luau
{ {
@ -605,6 +605,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
{ {
// TODO: there are probably cheaper ways to check if any <: T. // TODO: there are probably cheaper ways to check if any <: T.
const NormalizedType* superNorm = normalizer->normalize(superTy); const NormalizedType* superNorm = normalizer->normalize(superTy);
if (!superNorm)
return reportError(location, UnificationTooComplex{});
if (!log.get<AnyType>(superNorm->tops)) if (!log.get<AnyType>(superNorm->tops))
failure = true; failure = true;
} }
@ -2255,8 +2259,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
TableType* newSubTable = log.getMutable<TableType>(subTyNew); TableType* newSubTable = log.getMutable<TableType>(subTyNew);
if (superTable != newSuperTable || subTable != newSubTable) if (superTable != newSuperTable || subTable != newSubTable)
{
if (FFlag::LuauTableUnifyRecursionLimit)
{ {
if (errors.empty()) if (errors.empty())
{ {
@ -2266,14 +2268,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
return; return;
} }
else
{
if (errors.empty())
return tryUnifyTables(subTy, superTy, isIntersection);
else
return;
}
}
} }
for (const auto& [name, prop] : subTable->props) for (const auto& [name, prop] : subTable->props)
@ -2292,7 +2286,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
variance = Invariant; variance = Invariant;
Unifier innerState = makeChildUnifier(); Unifier innerState = makeChildUnifier();
if (useNewSolver) if (useNewSolver || FFlag::LuauFixIndexerSubtypingOrdering)
innerState.tryUnify_(prop.type(), superTable->indexer->indexResultType); innerState.tryUnify_(prop.type(), superTable->indexer->indexResultType);
else else
{ {
@ -2346,8 +2340,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
TableType* newSubTable = log.getMutable<TableType>(subTyNew); TableType* newSubTable = log.getMutable<TableType>(subTyNew);
if (superTable != newSuperTable || subTable != newSubTable) if (superTable != newSuperTable || subTable != newSubTable)
{
if (FFlag::LuauTableUnifyRecursionLimit)
{ {
if (errors.empty()) if (errors.empty())
{ {
@ -2357,14 +2349,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection,
return; return;
} }
else
{
if (errors.empty())
return tryUnifyTables(subTy, superTy, isIntersection);
else
return;
}
}
} }
// Unify indexers // Unify indexers

View file

@ -26,7 +26,6 @@ Unifier2::Unifier2(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes,
, ice(ice) , ice(ice)
, recursionLimit(FInt::LuauTypeInferRecursionLimit) , recursionLimit(FInt::LuauTypeInferRecursionLimit)
{ {
} }
bool Unifier2::unify(TypeId subTy, TypeId superTy) bool Unifier2::unify(TypeId subTy, TypeId superTy)
@ -99,10 +98,7 @@ bool Unifier2::unify(TypePackId subTp, TypePackId superTp)
return true; return true;
} }
size_t maxLength = std::max( size_t maxLength = std::max(flatten(subTp).first.size(), flatten(superTp).first.size());
flatten(subTp).first.size(),
flatten(superTp).first.size()
);
auto [subTypes, subTail] = extendTypePack(*arena, builtinTypes, subTp, maxLength); auto [subTypes, subTail] = extendTypePack(*arena, builtinTypes, subTp, maxLength);
auto [superTypes, superTail] = extendTypePack(*arena, builtinTypes, superTp, maxLength); auto [superTypes, superTail] = extendTypePack(*arena, builtinTypes, superTp, maxLength);
@ -123,16 +119,25 @@ struct FreeTypeSearcher : TypeVisitor
explicit FreeTypeSearcher(NotNull<Scope> scope) explicit FreeTypeSearcher(NotNull<Scope> scope)
: TypeVisitor(/*skipBoundTypes*/ true) : TypeVisitor(/*skipBoundTypes*/ true)
, scope(scope) , scope(scope)
{} {
}
enum { Positive, Negative } polarity = Positive; enum
{
Positive,
Negative
} polarity = Positive;
void flip() void flip()
{ {
switch (polarity) switch (polarity)
{ {
case Positive: polarity = Negative; break; case Positive:
case Negative: polarity = Positive; break; polarity = Negative;
break;
case Negative:
polarity = Positive;
break;
} }
} }
@ -152,8 +157,12 @@ struct FreeTypeSearcher : TypeVisitor
switch (polarity) switch (polarity)
{ {
case Positive: positiveTypes.insert(ty); break; case Positive:
case Negative: negativeTypes.insert(ty); break; positiveTypes.insert(ty);
break;
case Negative:
negativeTypes.insert(ty);
break;
} }
return true; return true;
@ -180,13 +189,17 @@ struct MutatingGeneralizer : TypeOnceVisitor
std::unordered_set<TypeId> negativeTypes; std::unordered_set<TypeId> negativeTypes;
std::vector<TypeId> generics; std::vector<TypeId> generics;
MutatingGeneralizer(NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, std::unordered_set<TypeId> positiveTypes, std::unordered_set<TypeId> negativeTypes) bool isWithinFunction = false;
MutatingGeneralizer(
NotNull<BuiltinTypes> builtinTypes, NotNull<Scope> scope, std::unordered_set<TypeId> positiveTypes, std::unordered_set<TypeId> negativeTypes)
: TypeOnceVisitor(/* skipBoundTypes */ true) : TypeOnceVisitor(/* skipBoundTypes */ true)
, builtinTypes(builtinTypes) , builtinTypes(builtinTypes)
, scope(scope) , scope(scope)
, positiveTypes(std::move(positiveTypes)) , positiveTypes(std::move(positiveTypes))
, negativeTypes(std::move(negativeTypes)) , negativeTypes(std::move(negativeTypes))
{} {
}
static void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement) static void replace(DenseHashSet<TypeId>& seen, TypeId haystack, TypeId needle, TypeId replacement)
{ {
@ -211,10 +224,7 @@ struct MutatingGeneralizer : TypeOnceVisitor
// FIXME: I bet this function has reentrancy problems // FIXME: I bet this function has reentrancy problems
option = follow(option); option = follow(option);
if (option == needle) if (option == needle)
{
LUAU_ASSERT(!seen.find(option));
option = replacement; option = replacement;
}
// TODO seen set // TODO seen set
else if (get<UnionType>(option)) else if (get<UnionType>(option))
@ -224,6 +234,20 @@ struct MutatingGeneralizer : TypeOnceVisitor
} }
} }
bool visit(TypeId ty, const FunctionType& ft) override
{
const bool oldValue = isWithinFunction;
isWithinFunction = true;
traverse(ft.argTypes);
traverse(ft.retTypes);
isWithinFunction = oldValue;
return false;
}
bool visit(TypeId ty, const FreeType&) override bool visit(TypeId ty, const FreeType&) override
{ {
const FreeType* ft = get<FreeType>(ty); const FreeType* ft = get<FreeType>(ty);
@ -232,7 +256,8 @@ struct MutatingGeneralizer : TypeOnceVisitor
traverse(ft->lowerBound); traverse(ft->lowerBound);
traverse(ft->upperBound); traverse(ft->upperBound);
// ft is potentially invalid now. // It is possible for the above traverse() calls to cause ty to be
// transmuted. We must reaquire ft if this happens.
ty = follow(ty); ty = follow(ty);
ft = get<FreeType>(ty); ft = get<FreeType>(ty);
if (!ft) if (!ft)
@ -251,10 +276,15 @@ struct MutatingGeneralizer : TypeOnceVisitor
seen.insert(ty); seen.insert(ty);
if (!hasLowerBound && !hasUpperBound) if (!hasLowerBound && !hasUpperBound)
{
if (isWithinFunction)
{ {
emplaceType<GenericType>(asMutable(ty), scope); emplaceType<GenericType>(asMutable(ty), scope);
generics.push_back(ty); generics.push_back(ty);
} }
else
emplaceType<BoundType>(asMutable(ty), builtinTypes->unknownType);
}
// It is possible that this free type has other free types in its upper // It is possible that this free type has other free types in its upper
// or lower bounds. If this is the case, we must replace those // or lower bounds. If this is the case, we must replace those
@ -264,19 +294,27 @@ struct MutatingGeneralizer : TypeOnceVisitor
// If we do not do this, we get tautological bounds like a <: a <: unknown. // If we do not do this, we get tautological bounds like a <: a <: unknown.
else if (isPositive && !hasUpperBound) else if (isPositive && !hasUpperBound)
{ {
if (FreeType* lowerFree = getMutable<FreeType>(ft->lowerBound); lowerFree && lowerFree->upperBound == ty) TypeId lb = follow(ft->lowerBound);
if (FreeType* lowerFree = getMutable<FreeType>(lb); lowerFree && lowerFree->upperBound == ty)
lowerFree->upperBound = builtinTypes->unknownType; lowerFree->upperBound = builtinTypes->unknownType;
else else
replace(seen, ft->lowerBound, ty, builtinTypes->unknownType); {
emplaceType<BoundType>(asMutable(ty), ft->lowerBound); DenseHashSet<TypeId> replaceSeen{nullptr};
replace(replaceSeen, lb, ty, builtinTypes->unknownType);
}
emplaceType<BoundType>(asMutable(ty), lb);
} }
else else
{ {
if (FreeType* upperFree = getMutable<FreeType>(ft->upperBound); upperFree && upperFree->lowerBound == ty) TypeId ub = follow(ft->upperBound);
if (FreeType* upperFree = getMutable<FreeType>(ub); upperFree && upperFree->lowerBound == ty)
upperFree->lowerBound = builtinTypes->neverType; upperFree->lowerBound = builtinTypes->neverType;
else else
replace(seen, ft->upperBound, ty, builtinTypes->neverType); {
emplaceType<BoundType>(asMutable(ty), ft->upperBound); DenseHashSet<TypeId> replaceSeen{nullptr};
replace(replaceSeen, ub, ty, builtinTypes->neverType);
}
emplaceType<BoundType>(asMutable(ty), ub);
} }
return false; return false;
@ -363,4 +401,4 @@ OccursCheckResult Unifier2::occursCheck(DenseHashSet<TypePackId>& seen, TypePack
return OccursCheckResult::Pass; return OccursCheckResult::Pass;
} }
} } // namespace Luau

View file

@ -272,11 +272,18 @@ class AstExprConstantString : public AstExpr
public: public:
LUAU_RTTI(AstExprConstantString) LUAU_RTTI(AstExprConstantString)
AstExprConstantString(const Location& location, const AstArray<char>& value); enum QuoteStyle
{
Quoted,
Unquoted
};
AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle = Quoted);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstArray<char> value; AstArray<char> value;
QuoteStyle quoteStyle = Quoted;
}; };
class AstExprLocal : public AstExpr class AstExprLocal : public AstExpr
@ -450,6 +457,7 @@ public:
Sub, Sub,
Mul, Mul,
Div, Div,
FloorDiv,
Mod, Mod,
Pow, Pow,
Concat, Concat,
@ -460,7 +468,9 @@ public:
CompareGt, CompareGt,
CompareGe, CompareGe,
And, And,
Or Or,
Op__Count
}; };
AstExprBinary(const Location& location, Op op, AstExpr* left, AstExpr* right); AstExprBinary(const Location& location, Op op, AstExpr* left, AstExpr* right);
@ -524,11 +534,12 @@ class AstStatBlock : public AstStat
public: public:
LUAU_RTTI(AstStatBlock) LUAU_RTTI(AstStatBlock)
AstStatBlock(const Location& location, const AstArray<AstStat*>& body); AstStatBlock(const Location& location, const AstArray<AstStat*>& body, bool hasEnd = true);
void visit(AstVisitor* visitor) override; void visit(AstVisitor* visitor) override;
AstArray<AstStat*> body; AstArray<AstStat*> body;
bool hasEnd = false;
}; };
class AstStatIf : public AstStat class AstStatIf : public AstStat

View file

@ -62,6 +62,7 @@ struct Lexeme
Dot3, Dot3,
SkinnyArrow, SkinnyArrow,
DoubleColon, DoubleColon,
FloorDiv,
InterpStringBegin, InterpStringBegin,
InterpStringMid, InterpStringMid,
@ -73,6 +74,7 @@ struct Lexeme
SubAssign, SubAssign,
MulAssign, MulAssign,
DivAssign, DivAssign,
FloorDivAssign,
ModAssign, ModAssign,
PowAssign, PowAssign,
ConcatAssign, ConcatAssign,
@ -204,7 +206,9 @@ private:
Position position() const; Position position() const;
// consume() assumes current character is not a newline for performance; when that is not known, consumeAny() should be used instead.
void consume(); void consume();
void consumeAny();
Lexeme readCommentBody(); Lexeme readCommentBody();

View file

@ -3,6 +3,8 @@
#include "Luau/Common.h" #include "Luau/Common.h"
LUAU_FASTFLAG(LuauFloorDivision)
namespace Luau namespace Luau
{ {
@ -62,9 +64,10 @@ void AstExprConstantNumber::visit(AstVisitor* visitor)
visitor->visit(this); visitor->visit(this);
} }
AstExprConstantString::AstExprConstantString(const Location& location, const AstArray<char>& value) AstExprConstantString::AstExprConstantString(const Location& location, const AstArray<char>& value, QuoteStyle quoteStyle)
: AstExpr(ClassIndex(), location) : AstExpr(ClassIndex(), location)
, value(value) , value(value)
, quoteStyle(quoteStyle)
{ {
} }
@ -278,6 +281,9 @@ std::string toString(AstExprBinary::Op op)
return "*"; return "*";
case AstExprBinary::Div: case AstExprBinary::Div:
return "/"; return "/";
case AstExprBinary::FloorDiv:
LUAU_ASSERT(FFlag::LuauFloorDivision);
return "//";
case AstExprBinary::Mod: case AstExprBinary::Mod:
return "%"; return "%";
case AstExprBinary::Pow: case AstExprBinary::Pow:
@ -374,9 +380,10 @@ void AstExprError::visit(AstVisitor* visitor)
} }
} }
AstStatBlock::AstStatBlock(const Location& location, const AstArray<AstStat*>& body) AstStatBlock::AstStatBlock(const Location& location, const AstArray<AstStat*>& body, bool hasEnd)
: AstStat(ClassIndex(), location) : AstStat(ClassIndex(), location)
, body(body) , body(body)
, hasEnd(hasEnd)
{ {
} }

View file

@ -6,6 +6,9 @@
#include <limits.h> #include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauFloorDivision, false)
LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false)
namespace Luau namespace Luau
{ {
@ -136,6 +139,9 @@ std::string Lexeme::toString() const
case DoubleColon: case DoubleColon:
return "'::'"; return "'::'";
case FloorDiv:
return FFlag::LuauFloorDivision ? "'//'" : "<unknown>";
case AddAssign: case AddAssign:
return "'+='"; return "'+='";
@ -148,6 +154,9 @@ std::string Lexeme::toString() const
case DivAssign: case DivAssign:
return "'/='"; return "'/='";
case FloorDivAssign:
return FFlag::LuauFloorDivision ? "'//='" : "<unknown>";
case ModAssign: case ModAssign:
return "'%='"; return "'%='";
@ -373,7 +382,7 @@ const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation)
{ {
// consume whitespace before the token // consume whitespace before the token
while (isSpace(peekch())) while (isSpace(peekch()))
consume(); consumeAny();
if (updatePrevLocation) if (updatePrevLocation)
prevLocation = lexeme.location; prevLocation = lexeme.location;
@ -400,6 +409,8 @@ Lexeme Lexer::lookahead()
unsigned int currentLineOffset = lineOffset; unsigned int currentLineOffset = lineOffset;
Lexeme currentLexeme = lexeme; Lexeme currentLexeme = lexeme;
Location currentPrevLocation = prevLocation; Location currentPrevLocation = prevLocation;
size_t currentBraceStackSize = braceStack.size();
BraceType currentBraceType = braceStack.empty() ? BraceType::Normal : braceStack.back();
Lexeme result = next(); Lexeme result = next();
@ -408,6 +419,13 @@ Lexeme Lexer::lookahead()
lineOffset = currentLineOffset; lineOffset = currentLineOffset;
lexeme = currentLexeme; lexeme = currentLexeme;
prevLocation = currentPrevLocation; prevLocation = currentPrevLocation;
if (FFlag::LuauLexerLookaheadRemembersBraceType)
{
if (braceStack.size() < currentBraceStackSize)
braceStack.push_back(currentBraceType);
else if (braceStack.size() > currentBraceStackSize)
braceStack.pop_back();
}
return result; return result;
} }
@ -438,7 +456,17 @@ Position Lexer::position() const
return Position(line, offset - lineOffset); return Position(line, offset - lineOffset);
} }
LUAU_FORCEINLINE
void Lexer::consume() void Lexer::consume()
{
// consume() assumes current character is known to not be a newline; use consumeAny if this is not guaranteed
LUAU_ASSERT(!isNewline(buffer[offset]));
offset++;
}
LUAU_FORCEINLINE
void Lexer::consumeAny()
{ {
if (isNewline(buffer[offset])) if (isNewline(buffer[offset]))
{ {
@ -524,7 +552,7 @@ Lexeme Lexer::readLongString(const Position& start, int sep, Lexeme::Type ok, Le
} }
else else
{ {
consume(); consumeAny();
} }
} }
@ -540,7 +568,7 @@ void Lexer::readBackslashInString()
case '\r': case '\r':
consume(); consume();
if (peekch() == '\n') if (peekch() == '\n')
consume(); consumeAny();
break; break;
case 0: case 0:
@ -549,11 +577,11 @@ void Lexer::readBackslashInString()
case 'z': case 'z':
consume(); consume();
while (isSpace(peekch())) while (isSpace(peekch()))
consume(); consumeAny();
break; break;
default: default:
consume(); consumeAny();
} }
} }
@ -878,6 +906,35 @@ Lexeme Lexer::readNext()
return Lexeme(Location(start, 1), '+'); return Lexeme(Location(start, 1), '+');
case '/': case '/':
{
if (FFlag::LuauFloorDivision)
{
consume();
char ch = peekch();
if (ch == '=')
{
consume();
return Lexeme(Location(start, 2), Lexeme::DivAssign);
}
else if (ch == '/')
{
consume();
if (peekch() == '=')
{
consume();
return Lexeme(Location(start, 3), Lexeme::FloorDivAssign);
}
else
return Lexeme(Location(start, 2), Lexeme::FloorDiv);
}
else
return Lexeme(Location(start, 1), '/');
}
else
{
consume(); consume();
if (peekch() == '=') if (peekch() == '=')
@ -887,6 +944,8 @@ Lexeme Lexer::readNext()
} }
else else
return Lexeme(Location(start, 1), '/'); return Lexeme(Location(start, 1), '/');
}
}
case '*': case '*':
consume(); consume();
@ -939,6 +998,9 @@ Lexeme Lexer::readNext()
case ';': case ';':
case ',': case ',':
case '#': case '#':
case '?':
case '&':
case '|':
{ {
char ch = peekch(); char ch = peekch();
consume(); consume();

View file

@ -14,8 +14,7 @@
LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000)
LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauParseDeclareClassIndexer, false) LUAU_FASTFLAGVARIABLE(LuauParseDeclareClassIndexer, false)
LUAU_FASTFLAG(LuauFloorDivision)
#define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?"
namespace Luau namespace Luau
{ {
@ -462,11 +461,11 @@ AstStat* Parser::parseDo()
Lexeme matchDo = lexer.current(); Lexeme matchDo = lexer.current();
nextLexeme(); // do nextLexeme(); // do
AstStat* body = parseBlock(); AstStatBlock* body = parseBlock();
body->location.begin = start.begin; body->location.begin = start.begin;
expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); body->hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo);
return body; return body;
} }
@ -899,13 +898,13 @@ AstStat* Parser::parseDeclaration(const Location& start)
expectAndConsume(':', "property type annotation"); expectAndConsume(':', "property type annotation");
AstType* type = parseType(); AstType* type = parseType();
// TODO: since AstName conains a char*, it can't contain null // since AstName contains a char*, it can't contain null
bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size);
if (chars && !containsNull) if (chars && !containsNull)
props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false}); props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false});
else else
report(begin.location, "String literal contains malformed escape sequence"); report(begin.location, "String literal contains malformed escape sequence or \\0");
} }
else if (lexer.current().type == '[' && FFlag::LuauParseDeclareClassIndexer) else if (lexer.current().type == '[' && FFlag::LuauParseDeclareClassIndexer)
{ {
@ -1328,13 +1327,13 @@ AstType* Parser::parseTableType()
AstType* type = parseType(); AstType* type = parseType();
// TODO: since AstName conains a char*, it can't contain null // since AstName contains a char*, it can't contain null
bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size);
if (chars && !containsNull) if (chars && !containsNull)
props.push_back({AstName(chars->data), begin.location, type}); props.push_back({AstName(chars->data), begin.location, type});
else else
report(begin.location, "String literal contains malformed escape sequence"); report(begin.location, "String literal contains malformed escape sequence or \\0");
} }
else if (lexer.current().type == '[') else if (lexer.current().type == '[')
{ {
@ -1622,7 +1621,7 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack)
else if (lexer.current().type == Lexeme::BrokenString) else if (lexer.current().type == Lexeme::BrokenString)
{ {
nextLexeme(); nextLexeme();
return {reportTypeError(start, {}, "Malformed string")}; return {reportTypeError(start, {}, "Malformed string; did you forget to finish it?")};
} }
else if (lexer.current().type == Lexeme::Name) else if (lexer.current().type == Lexeme::Name)
{ {
@ -1741,7 +1740,8 @@ AstTypePack* Parser::parseTypePack()
return allocator.alloc<AstTypePackGeneric>(Location(name.location, end), name.name); return allocator.alloc<AstTypePackGeneric>(Location(name.location, end), name.name);
} }
// No type pack annotation exists here. // TODO: shouldParseTypePack can be removed and parseTypePack can be called unconditionally instead
LUAU_ASSERT(!"parseTypePack can't be called if shouldParseTypePack() returned false");
return nullptr; return nullptr;
} }
@ -1767,6 +1767,12 @@ std::optional<AstExprBinary::Op> Parser::parseBinaryOp(const Lexeme& l)
return AstExprBinary::Mul; return AstExprBinary::Mul;
else if (l.type == '/') else if (l.type == '/')
return AstExprBinary::Div; return AstExprBinary::Div;
else if (l.type == Lexeme::FloorDiv)
{
LUAU_ASSERT(FFlag::LuauFloorDivision);
return AstExprBinary::FloorDiv;
}
else if (l.type == '%') else if (l.type == '%')
return AstExprBinary::Mod; return AstExprBinary::Mod;
else if (l.type == '^') else if (l.type == '^')
@ -1803,6 +1809,12 @@ std::optional<AstExprBinary::Op> Parser::parseCompoundOp(const Lexeme& l)
return AstExprBinary::Mul; return AstExprBinary::Mul;
else if (l.type == Lexeme::DivAssign) else if (l.type == Lexeme::DivAssign)
return AstExprBinary::Div; return AstExprBinary::Div;
else if (l.type == Lexeme::FloorDivAssign)
{
LUAU_ASSERT(FFlag::LuauFloorDivision);
return AstExprBinary::FloorDiv;
}
else if (l.type == Lexeme::ModAssign) else if (l.type == Lexeme::ModAssign)
return AstExprBinary::Mod; return AstExprBinary::Mod;
else if (l.type == Lexeme::PowAssign) else if (l.type == Lexeme::PowAssign)
@ -1826,7 +1838,7 @@ std::optional<AstExprUnary::Op> Parser::checkUnaryConfusables()
if (curr.type == '!') if (curr.type == '!')
{ {
report(start, "Unexpected '!', did you mean 'not'?"); report(start, "Unexpected '!'; did you mean 'not'?");
return AstExprUnary::Not; return AstExprUnary::Not;
} }
@ -1848,20 +1860,20 @@ std::optional<AstExprBinary::Op> Parser::checkBinaryConfusables(const BinaryOpPr
if (curr.type == '&' && next.type == '&' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::And].left > limit) if (curr.type == '&' && next.type == '&' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::And].left > limit)
{ {
nextLexeme(); nextLexeme();
report(Location(start, next.location), "Unexpected '&&', did you mean 'and'?"); report(Location(start, next.location), "Unexpected '&&'; did you mean 'and'?");
return AstExprBinary::And; return AstExprBinary::And;
} }
else if (curr.type == '|' && next.type == '|' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::Or].left > limit) else if (curr.type == '|' && next.type == '|' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::Or].left > limit)
{ {
nextLexeme(); nextLexeme();
report(Location(start, next.location), "Unexpected '||', did you mean 'or'?"); report(Location(start, next.location), "Unexpected '||'; did you mean 'or'?");
return AstExprBinary::Or; return AstExprBinary::Or;
} }
else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin &&
binaryPriority[AstExprBinary::CompareNe].left > limit) binaryPriority[AstExprBinary::CompareNe].left > limit)
{ {
nextLexeme(); nextLexeme();
report(Location(start, next.location), "Unexpected '!=', did you mean '~='?"); report(Location(start, next.location), "Unexpected '!='; did you mean '~='?");
return AstExprBinary::CompareNe; return AstExprBinary::CompareNe;
} }
@ -1873,12 +1885,13 @@ std::optional<AstExprBinary::Op> Parser::checkBinaryConfusables(const BinaryOpPr
AstExpr* Parser::parseExpr(unsigned int limit) AstExpr* Parser::parseExpr(unsigned int limit)
{ {
static const BinaryOpPriority binaryPriority[] = { static const BinaryOpPriority binaryPriority[] = {
{6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `%' {6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `//' `%'
{10, 9}, {5, 4}, // power and concat (right associative) {10, 9}, {5, 4}, // power and concat (right associative)
{3, 3}, {3, 3}, // equality and inequality {3, 3}, {3, 3}, // equality and inequality
{3, 3}, {3, 3}, {3, 3}, {3, 3}, // order {3, 3}, {3, 3}, {3, 3}, {3, 3}, // order
{2, 2}, {1, 1} // logical (and/or) {2, 2}, {1, 1} // logical (and/or)
}; };
static_assert(sizeof(binaryPriority) / sizeof(binaryPriority[0]) == size_t(AstExprBinary::Op__Count), "binaryPriority needs an entry per op");
unsigned int recursionCounterOld = recursionCounter; unsigned int recursionCounterOld = recursionCounter;
@ -2169,12 +2182,12 @@ AstExpr* Parser::parseSimpleExpr()
else if (lexer.current().type == Lexeme::BrokenString) else if (lexer.current().type == Lexeme::BrokenString)
{ {
nextLexeme(); nextLexeme();
return reportExprError(start, {}, "Malformed string"); return reportExprError(start, {}, "Malformed string; did you forget to finish it?");
} }
else if (lexer.current().type == Lexeme::BrokenInterpDoubleBrace) else if (lexer.current().type == Lexeme::BrokenInterpDoubleBrace)
{ {
nextLexeme(); nextLexeme();
return reportExprError(start, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); return reportExprError(start, {}, "Double braces are not permitted within interpolated strings; did you mean '\\{'?");
} }
else if (lexer.current().type == Lexeme::Dot3) else if (lexer.current().type == Lexeme::Dot3)
{ {
@ -2312,7 +2325,7 @@ AstExpr* Parser::parseTableConstructor()
nameString.data = const_cast<char*>(name.name.value); nameString.data = const_cast<char*>(name.name.value);
nameString.size = strlen(name.name.value); nameString.size = strlen(name.name.value);
AstExpr* key = allocator.alloc<AstExprConstantString>(name.location, nameString); AstExpr* key = allocator.alloc<AstExprConstantString>(name.location, nameString, AstExprConstantString::Unquoted);
AstExpr* value = parseExpr(); AstExpr* value = parseExpr();
if (AstExprFunction* func = value->as<AstExprFunction>()) if (AstExprFunction* func = value->as<AstExprFunction>())
@ -2661,7 +2674,7 @@ AstExpr* Parser::parseInterpString()
{ {
errorWhileChecking = true; errorWhileChecking = true;
nextLexeme(); nextLexeme();
expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '`'?")); expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string; did you forget to add a '`'?"));
break; break;
} }
default: default:
@ -2681,10 +2694,10 @@ AstExpr* Parser::parseInterpString()
break; break;
case Lexeme::BrokenInterpDoubleBrace: case Lexeme::BrokenInterpDoubleBrace:
nextLexeme(); nextLexeme();
return reportExprError(endLocation, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); return reportExprError(endLocation, {}, "Double braces are not permitted within interpolated strings; did you mean '\\{'?");
case Lexeme::BrokenString: case Lexeme::BrokenString:
nextLexeme(); nextLexeme();
return reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '}'?"); return reportExprError(endLocation, {}, "Malformed interpolated string; did you forget to add a '}'?");
default: default:
return reportExprError(endLocation, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); return reportExprError(endLocation, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str());
} }

View file

@ -89,13 +89,14 @@ static void reportError(const char* name, const Luau::CompileError& error)
report(name, error.getLocation(), "CompileError", error.what()); report(name, error.getLocation(), "CompileError", error.what());
} }
static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options) static std::string getCodegenAssembly(
const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options, Luau::CodeGen::LoweringStats* stats)
{ {
std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close); std::unique_ptr<lua_State, void (*)(lua_State*)> globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get(); lua_State* L = globalState.get();
if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0)
return Luau::CodeGen::getAssembly(L, -1, options); return Luau::CodeGen::getAssembly(L, -1, options, stats);
fprintf(stderr, "Error loading bytecode %s\n", name); fprintf(stderr, "Error loading bytecode %s\n", name);
return ""; return "";
@ -119,6 +120,8 @@ struct CompileStats
double parseTime; double parseTime;
double compileTime; double compileTime;
double codegenTime; double codegenTime;
Luau::CodeGen::LoweringStats lowerStats;
}; };
static double recordDeltaTime(double& timer) static double recordDeltaTime(double& timer)
@ -213,10 +216,10 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A
case CompileFormat::CodegenAsm: case CompileFormat::CodegenAsm:
case CompileFormat::CodegenIr: case CompileFormat::CodegenIr:
case CompileFormat::CodegenVerbose: case CompileFormat::CodegenVerbose:
printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options, &stats.lowerStats).c_str());
break; break;
case CompileFormat::CodegenNull: case CompileFormat::CodegenNull:
stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options, &stats.lowerStats).size();
stats.codegenTime += recordDeltaTime(currts); stats.codegenTime += recordDeltaTime(currts);
break; break;
case CompileFormat::Null: case CompileFormat::Null:
@ -355,13 +358,22 @@ int main(int argc, char** argv)
failed += !compileFile(path.c_str(), compileFormat, assemblyTarget, stats); failed += !compileFile(path.c_str(), compileFormat, assemblyTarget, stats);
if (compileFormat == CompileFormat::Null) if (compileFormat == CompileFormat::Null)
{
printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), int(stats.bytecode / 1024), printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), int(stats.bytecode / 1024),
stats.readTime, stats.parseTime, stats.compileTime); stats.readTime, stats.parseTime, stats.compileTime);
}
else if (compileFormat == CompileFormat::CodegenNull) else if (compileFormat == CompileFormat::CodegenNull)
{
printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n", printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n",
int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024), int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024),
stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), stats.readTime, stats.parseTime, stats.compileTime, stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), stats.readTime, stats.parseTime, stats.compileTime,
stats.codegenTime); stats.codegenTime);
printf("Lowering stats:\n");
printf("- spills to stack: %d, spills to restore: %d, max spill slot %u\n", stats.lowerStats.spillsToSlot, stats.lowerStats.spillsToRestore,
stats.lowerStats.maxSpillSlotsUsed);
printf("- regalloc failed: %d, lowering failed %d\n", stats.lowerStats.regAllocErrors, stats.lowerStats.loweringErrors);
}
return failed ? 1 : 0; return failed ? 1 : 0;
} }

View file

@ -15,7 +15,7 @@
#define VERBOSE 0 // 1 - print out commandline invocations. 2 - print out stdout #define VERBOSE 0 // 1 - print out commandline invocations. 2 - print out stdout
#ifdef _WIN32 #if defined(_WIN32) && !defined(__MINGW32__)
const auto popen = &_popen; const auto popen = &_popen;
const auto pclose = &_pclose; const auto pclose = &_pclose;

View file

@ -757,14 +757,6 @@ int replMain(int argc, char** argv)
} }
#endif #endif
#if !LUA_CUSTOM_EXECUTION
if (codegen)
{
fprintf(stderr, "To run with --codegen, Luau has to be built with LUA_CUSTOM_EXECUTION enabled\n");
return 1;
}
#endif
if (codegenPerf) if (codegenPerf)
{ {
#if __linux__ #if __linux__
@ -784,10 +776,7 @@ int replMain(int argc, char** argv)
} }
if (codegen && !Luau::CodeGen::isSupported()) if (codegen && !Luau::CodeGen::isSupported())
{ fprintf(stderr, "Warning: Native code generation is not supported in current configuration\n");
fprintf(stderr, "Cannot enable --codegen, native code generation is not supported in current configuration\n");
return 1;
}
const std::vector<std::string> files = getSourceFiles(argc, argv); const std::vector<std::string> files = getSourceFiles(argc, argv);

View file

@ -12,7 +12,6 @@ option(LUAU_BUILD_WEB "Build Web module" OFF)
option(LUAU_WERROR "Warnings as errors" OFF) option(LUAU_WERROR "Warnings as errors" OFF)
option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF)
option(LUAU_EXTERN_C "Use extern C for all APIs" OFF) option(LUAU_EXTERN_C "Use extern C for all APIs" OFF)
option(LUAU_NATIVE "Enable support for native code generation" OFF)
cmake_policy(SET CMP0054 NEW) cmake_policy(SET CMP0054 NEW)
cmake_policy(SET CMP0091 NEW) cmake_policy(SET CMP0091 NEW)
@ -26,6 +25,7 @@ project(Luau LANGUAGES CXX C)
add_library(Luau.Common INTERFACE) add_library(Luau.Common INTERFACE)
add_library(Luau.Ast STATIC) add_library(Luau.Ast STATIC)
add_library(Luau.Compiler STATIC) add_library(Luau.Compiler STATIC)
add_library(Luau.Config STATIC)
add_library(Luau.Analysis STATIC) add_library(Luau.Analysis STATIC)
add_library(Luau.CodeGen STATIC) add_library(Luau.CodeGen STATIC)
add_library(Luau.VM STATIC) add_library(Luau.VM STATIC)
@ -71,9 +71,13 @@ target_compile_features(Luau.Compiler PUBLIC cxx_std_17)
target_include_directories(Luau.Compiler PUBLIC Compiler/include) target_include_directories(Luau.Compiler PUBLIC Compiler/include)
target_link_libraries(Luau.Compiler PUBLIC Luau.Ast) target_link_libraries(Luau.Compiler PUBLIC Luau.Ast)
target_compile_features(Luau.Config PUBLIC cxx_std_17)
target_include_directories(Luau.Config PUBLIC Config/include)
target_link_libraries(Luau.Config PUBLIC Luau.Ast)
target_compile_features(Luau.Analysis PUBLIC cxx_std_17) target_compile_features(Luau.Analysis PUBLIC cxx_std_17)
target_include_directories(Luau.Analysis PUBLIC Analysis/include) target_include_directories(Luau.Analysis PUBLIC Analysis/include)
target_link_libraries(Luau.Analysis PUBLIC Luau.Ast) target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.Config)
target_compile_features(Luau.CodeGen PRIVATE cxx_std_17) target_compile_features(Luau.CodeGen PRIVATE cxx_std_17)
target_include_directories(Luau.CodeGen PUBLIC CodeGen/include) target_include_directories(Luau.CodeGen PUBLIC CodeGen/include)
@ -141,14 +145,8 @@ if(LUAU_EXTERN_C)
target_compile_definitions(Luau.VM PUBLIC LUA_USE_LONGJMP=1) target_compile_definitions(Luau.VM PUBLIC LUA_USE_LONGJMP=1)
target_compile_definitions(Luau.VM PUBLIC LUA_API=extern\"C\") target_compile_definitions(Luau.VM PUBLIC LUA_API=extern\"C\")
target_compile_definitions(Luau.Compiler PUBLIC LUACODE_API=extern\"C\") target_compile_definitions(Luau.Compiler PUBLIC LUACODE_API=extern\"C\")
endif()
if(LUAU_NATIVE)
target_compile_definitions(Luau.VM PUBLIC LUA_CUSTOM_EXECUTION=1)
if(LUAU_EXTERN_C)
target_compile_definitions(Luau.CodeGen PUBLIC LUACODEGEN_API=extern\"C\") target_compile_definitions(Luau.CodeGen PUBLIC LUACODEGEN_API=extern\"C\")
endif() endif()
endif()
if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" AND MSVC_VERSION GREATER_EQUAL 1924) if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" AND MSVC_VERSION GREATER_EQUAL 1924)
# disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022:

View file

@ -221,6 +221,7 @@ private:
void placeFMOV(const char* name, RegisterA64 dst, double src, uint32_t op); void placeFMOV(const char* name, RegisterA64 dst, double src, uint32_t op);
void placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op); void placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op);
void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op, int immr, int imms); void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op, int immr, int imms);
void placeER(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift);
void place(uint32_t word); void place(uint32_t word);

View file

@ -1,6 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/CodeGen.h"
#include <vector> #include <vector>
#include <stddef.h> #include <stddef.h>
@ -16,6 +18,7 @@ constexpr uint32_t kCodeAlignment = 32;
struct CodeAllocator struct CodeAllocator
{ {
CodeAllocator(size_t blockSize, size_t maxTotalSize); CodeAllocator(size_t blockSize, size_t maxTotalSize);
CodeAllocator(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext);
~CodeAllocator(); ~CodeAllocator();
// Places data and code into the executable page area // Places data and code into the executable page area
@ -24,7 +27,7 @@ struct CodeAllocator
bool allocate( bool allocate(
const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart); const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart);
// Provided to callbacks // Provided to unwind info callbacks
void* context = nullptr; void* context = nullptr;
// Called when new block is created to create and setup the unwinding information for all the code in the block // Called when new block is created to create and setup the unwinding information for all the code in the block
@ -34,12 +37,16 @@ struct CodeAllocator
// Called to destroy unwinding information returned by 'createBlockUnwindInfo' // Called to destroy unwinding information returned by 'createBlockUnwindInfo'
void (*destroyBlockUnwindInfo)(void* context, void* unwindData) = nullptr; void (*destroyBlockUnwindInfo)(void* context, void* unwindData) = nullptr;
private:
// Unwind information can be placed inside the block with some implementation-specific reservations at the beginning // Unwind information can be placed inside the block with some implementation-specific reservations at the beginning
// But to simplify block space checks, we limit the max size of all that data // But to simplify block space checks, we limit the max size of all that data
static const size_t kMaxReservedDataSize = 256; static const size_t kMaxReservedDataSize = 256;
bool allocateNewBlock(size_t& unwindInfoSize); bool allocateNewBlock(size_t& unwindInfoSize);
uint8_t* allocatePages(size_t size) const;
void freePages(uint8_t* mem, size_t size) const;
// Current block we use for allocations // Current block we use for allocations
uint8_t* blockPos = nullptr; uint8_t* blockPos = nullptr;
uint8_t* blockEnd = nullptr; uint8_t* blockEnd = nullptr;
@ -50,6 +57,9 @@ struct CodeAllocator
size_t blockSize = 0; size_t blockSize = 0;
size_t maxTotalSize = 0; size_t maxTotalSize = 0;
AllocationCallback* allocationCallback = nullptr;
void* allocationCallbackContext = nullptr;
}; };
} // namespace CodeGen } // namespace CodeGen

View file

@ -3,6 +3,7 @@
#include <string> #include <string>
#include <stddef.h>
#include <stdint.h> #include <stdint.h>
struct lua_State; struct lua_State;
@ -18,12 +19,35 @@ enum CodeGenFlags
CodeGen_OnlyNativeModules = 1 << 0, CodeGen_OnlyNativeModules = 1 << 0,
}; };
enum class CodeGenCompilationResult
{
Success, // Successfully generated code for at least one function
NothingToCompile, // There were no new functions to compile
CodeGenNotInitialized, // Native codegen system is not initialized
CodeGenFailed, // Native codegen failed due to an internal compiler error
AllocationFailed, // Native codegen failed due to an allocation error
};
struct CompilationStats
{
size_t bytecodeSizeBytes = 0;
size_t nativeCodeSizeBytes = 0;
size_t nativeDataSizeBytes = 0;
size_t nativeMetadataSizeBytes = 0;
uint32_t functionsCompiled = 0;
};
using AllocationCallback = void(void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize);
bool isSupported(); bool isSupported();
void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext);
void create(lua_State* L); void create(lua_State* L);
// Builds target function and all inner functions // Builds target function and all inner functions
void compile(lua_State* L, int idx, unsigned int flags = 0); CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr);
using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos);
@ -51,8 +75,18 @@ struct AssemblyOptions
void* annotatorContext = nullptr; void* annotatorContext = nullptr;
}; };
struct LoweringStats
{
int spillsToSlot = 0;
int spillsToRestore = 0;
unsigned maxSpillSlotsUsed = 0;
int regAllocErrors = 0;
int loweringErrors = 0;
};
// Generates assembly for target function and all inner functions // Generates assembly for target function and all inner functions
std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {}); std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {}, LoweringStats* stats = nullptr);
using PerfLogFn = void (*)(void* context, uintptr_t addr, unsigned size, const char* symbol); using PerfLogFn = void (*)(void* context, uintptr_t addr, unsigned size, const char* symbol);

View file

@ -1,6 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Common.h"
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -43,5 +45,68 @@ enum class ConditionX64 : uint8_t
Count Count
}; };
inline ConditionX64 getReverseCondition(ConditionX64 cond)
{
switch (cond)
{
case ConditionX64::Overflow:
return ConditionX64::NoOverflow;
case ConditionX64::NoOverflow:
return ConditionX64::Overflow;
case ConditionX64::Carry:
return ConditionX64::NoCarry;
case ConditionX64::NoCarry:
return ConditionX64::Carry;
case ConditionX64::Below:
return ConditionX64::NotBelow;
case ConditionX64::BelowEqual:
return ConditionX64::NotBelowEqual;
case ConditionX64::Above:
return ConditionX64::NotAbove;
case ConditionX64::AboveEqual:
return ConditionX64::NotAboveEqual;
case ConditionX64::Equal:
return ConditionX64::NotEqual;
case ConditionX64::Less:
return ConditionX64::NotLess;
case ConditionX64::LessEqual:
return ConditionX64::NotLessEqual;
case ConditionX64::Greater:
return ConditionX64::NotGreater;
case ConditionX64::GreaterEqual:
return ConditionX64::NotGreaterEqual;
case ConditionX64::NotBelow:
return ConditionX64::Below;
case ConditionX64::NotBelowEqual:
return ConditionX64::BelowEqual;
case ConditionX64::NotAbove:
return ConditionX64::Above;
case ConditionX64::NotAboveEqual:
return ConditionX64::AboveEqual;
case ConditionX64::NotEqual:
return ConditionX64::Equal;
case ConditionX64::NotLess:
return ConditionX64::Less;
case ConditionX64::NotLessEqual:
return ConditionX64::LessEqual;
case ConditionX64::NotGreater:
return ConditionX64::Greater;
case ConditionX64::NotGreaterEqual:
return ConditionX64::GreaterEqual;
case ConditionX64::Zero:
return ConditionX64::NotZero;
case ConditionX64::NotZero:
return ConditionX64::Zero;
case ConditionX64::Parity:
return ConditionX64::NotParity;
case ConditionX64::NotParity:
return ConditionX64::Parity;
case ConditionX64::Count:
LUAU_ASSERT(!"invalid ConditionX64 value");
}
return ConditionX64::Count;
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -4,6 +4,7 @@
#include "Luau/Common.h" #include "Luau/Common.h"
#include <bitset> #include <bitset>
#include <queue>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -96,6 +97,46 @@ struct CfgInfo
void computeCfgImmediateDominators(IrFunction& function); void computeCfgImmediateDominators(IrFunction& function);
void computeCfgDominanceTreeChildren(IrFunction& function); void computeCfgDominanceTreeChildren(IrFunction& function);
struct IdfContext
{
struct BlockAndOrdering
{
uint32_t blockIdx;
BlockOrdering ordering;
bool operator<(const BlockAndOrdering& rhs) const
{
if (ordering.depth != rhs.ordering.depth)
return ordering.depth < rhs.ordering.depth;
return ordering.preOrder < rhs.ordering.preOrder;
}
};
// Using priority queue to work on nodes in the order from the bottom of the dominator tree to the top
// If the depth of keys is equal, DFS order is used to provide strong ordering
std::priority_queue<BlockAndOrdering> queue;
std::vector<uint32_t> worklist;
struct IdfVisitMarks
{
bool seenInQueue = false;
bool seenInWorklist = false;
};
std::vector<IdfVisitMarks> visits;
std::vector<uint32_t> idf;
};
// Compute iterated dominance frontier (IDF or DF+) for a variable, given the set of blocks where that variable is defined
// Providing a set of blocks where the variable is a live-in at the entry helps produce a pruned SSA form (inserted phi nodes will not be dead)
//
// 'Iterated' comes from the definition where we recompute the IDFn+1 = DF(S) while adding IDFn to S until a fixed point is reached
// Iterated dominance frontier has been shown to be equal to the set of nodes where phi instructions have to be inserted
void computeIteratedDominanceFrontierForDefs(
IdfContext& ctx, const IrFunction& function, const std::vector<uint32_t>& defBlocks, const std::vector<uint32_t>& liveInBlocks);
// Function used to update all CFG data // Function used to update all CFG data
void computeCfgInfo(IrFunction& function); void computeCfgInfo(IrFunction& function);

View file

@ -66,6 +66,8 @@ struct IrBuilder
bool inTerminatedBlock = false; bool inTerminatedBlock = false;
bool interruptRequested = false;
bool activeFastcallFallback = false; bool activeFastcallFallback = false;
IrOp fastcallFallbackReturn; IrOp fastcallFallbackReturn;
int fastcallSkipTarget = -1; int fastcallSkipTarget = -1;
@ -76,6 +78,8 @@ struct IrBuilder
std::vector<uint32_t> instIndexToBlock; // Block index at the bytecode instruction std::vector<uint32_t> instIndexToBlock; // Block index at the bytecode instruction
std::vector<IrOp> loopStepStack;
// Similar to BytecodeBuilder, duplicate constants are removed used the same method // Similar to BytecodeBuilder, duplicate constants are removed used the same method
struct ConstantKey struct ConstantKey
{ {

View file

@ -53,12 +53,9 @@ enum class IrCmd : uint8_t
// Load a TValue from memory // Load a TValue from memory
// A: Rn or Kn or pointer (TValue) // A: Rn or Kn or pointer (TValue)
// B: int (optional 'A' pointer offset)
LOAD_TVALUE, LOAD_TVALUE,
// Load a TValue from table node value
// A: pointer (LuaNode)
LOAD_NODE_VALUE_TV, // TODO: we should find a way to generalize LOAD_TVALUE
// Load current environment table // Load current environment table
LOAD_ENV, LOAD_ENV,
@ -70,6 +67,7 @@ enum class IrCmd : uint8_t
// Get pointer (LuaNode) to table node element at the active cached slot index // Get pointer (LuaNode) to table node element at the active cached slot index
// A: pointer (Table) // A: pointer (Table)
// B: unsigned int (pcpos) // B: unsigned int (pcpos)
// C: Kn
GET_SLOT_NODE_ADDR, GET_SLOT_NODE_ADDR,
// Get pointer (LuaNode) to table node element at the main position of the specified key hash // Get pointer (LuaNode) to table node element at the main position of the specified key hash
@ -113,12 +111,15 @@ enum class IrCmd : uint8_t
// Store a TValue into memory // Store a TValue into memory
// A: Rn or pointer (TValue) // A: Rn or pointer (TValue)
// B: TValue // B: TValue
// C: int (optional 'A' pointer offset)
STORE_TVALUE, STORE_TVALUE,
// Store a TValue into table node value // Store a pair of tag and value into memory
// A: pointer (LuaNode) // A: Rn or pointer (TValue)
// B: TValue // B: tag (must be a constant)
STORE_NODE_VALUE_TV, // TODO: we should find a way to generalize STORE_TVALUE // C: int/double/pointer
// D: int (optional 'A' pointer offset)
STORE_SPLIT_TVALUE,
// Add/Sub two integers together // Add/Sub two integers together
// A, B: int // A, B: int
@ -132,6 +133,7 @@ enum class IrCmd : uint8_t
SUB_NUM, SUB_NUM,
MUL_NUM, MUL_NUM,
DIV_NUM, DIV_NUM,
IDIV_NUM,
MOD_NUM, MOD_NUM,
// Get the minimum/maximum of two numbers // Get the minimum/maximum of two numbers
@ -176,7 +178,7 @@ enum class IrCmd : uint8_t
CMP_ANY, CMP_ANY,
// Unconditional jump // Unconditional jump
// A: block/vmexit // A: block/vmexit/undef
JUMP, JUMP,
// Jump if TValue is truthy // Jump if TValue is truthy
@ -197,24 +199,12 @@ enum class IrCmd : uint8_t
// D: block (if false) // D: block (if false)
JUMP_EQ_TAG, JUMP_EQ_TAG,
// Jump if two int numbers are equal // Perform a conditional jump based on the result of integer comparison
// A, B: int
// C: block (if true)
// D: block (if false)
JUMP_EQ_INT,
// Jump if A < B
// A, B: int
// C: block (if true)
// D: block (if false)
JUMP_LT_INT,
// Jump if unsigned(A) >= unsigned(B)
// A, B: int // A, B: int
// C: condition // C: condition
// D: block (if true) // D: block (if true)
// E: block (if false) // E: block (if false)
JUMP_GE_UINT, JUMP_CMP_INT,
// Jump if pointers are equal // Jump if pointers are equal
// A, B: pointer (*) // A, B: pointer (*)
@ -245,14 +235,19 @@ enum class IrCmd : uint8_t
STRING_LEN, STRING_LEN,
// Allocate new table // Allocate new table
// A: int (array element count) // A: unsigned int (array element count)
// B: int (node element count) // B: unsigned int (node element count)
NEW_TABLE, NEW_TABLE,
// Duplicate a table // Duplicate a table
// A: pointer (Table) // A: pointer (Table)
DUP_TABLE, DUP_TABLE,
// Insert an integer key into a table
// A: pointer (Table)
// B: int (key)
TABLE_SETNUM,
// Try to convert a double number into a table index (int) or jump if it's not an integer // Try to convert a double number into a table index (int) or jump if it's not an integer
// A: double // A: double
// B: block // B: block
@ -356,23 +351,16 @@ enum class IrCmd : uint8_t
// Store TValue from stack slot into a function upvalue // Store TValue from stack slot into a function upvalue
// A: UPn // A: UPn
// B: Rn // B: Rn
// C: tag/undef (tag of the value that was written)
SET_UPVALUE, SET_UPVALUE,
// Convert TValues into numbers for a numerical for loop
// A: Rn (start)
// B: Rn (end)
// C: Rn (step)
PREPARE_FORN,
// Guards and checks (these instructions are not block terminators even though they jump to fallback) // Guards and checks (these instructions are not block terminators even though they jump to fallback)
// Guard against tag mismatch // Guard against tag mismatch
// A, B: tag // A, B: tag
// C: block/vmexit/undef // C: block/vmexit/undef
// D: bool (finish execution in VM on failure)
// In final x64 lowering, A can also be Rn // In final x64 lowering, A can also be Rn
// When undef is specified instead of a block, execution is aborted on check failure; if D is true, execution is continued in VM interpreter // When undef is specified instead of a block, execution is aborted on check failure
// instead.
CHECK_TAG, CHECK_TAG,
// Guard against a falsy tag+value // Guard against a falsy tag+value
@ -418,6 +406,12 @@ enum class IrCmd : uint8_t
// When undef is specified instead of a block, execution is aborted on check failure // When undef is specified instead of a block, execution is aborted on check failure
CHECK_NODE_NO_NEXT, CHECK_NODE_NO_NEXT,
// Guard against table node with 'nil' value
// A: pointer (LuaNode)
// B: block/vmexit/undef
// When undef is specified instead of a block, execution is aborted on check failure
CHECK_NODE_VALUE,
// Special operations // Special operations
// Check interrupt handler // Check interrupt handler
@ -464,6 +458,7 @@ enum class IrCmd : uint8_t
// C: Rn (source start) // C: Rn (source start)
// D: int (count or -1 to assign values up to stack top) // D: int (count or -1 to assign values up to stack top)
// E: unsigned int (table index to start from) // E: unsigned int (table index to start from)
// F: undef/unsigned int (target table known size)
SETLIST, SETLIST,
// Call specified function // Call specified function
@ -689,6 +684,10 @@ enum class IrOpKind : uint32_t
VmExit, VmExit,
}; };
// VmExit uses a special value to indicate that pcpos update should be skipped
// This is only used during type checking at function entry
constexpr uint32_t kVmExitEntryGuardPc = (1u << 28) - 1;
struct IrOp struct IrOp
{ {
IrOpKind kind : 4; IrOpKind kind : 4;
@ -834,6 +833,8 @@ struct IrBlock
uint32_t finish = ~0u; uint32_t finish = ~0u;
uint32_t sortkey = ~0u; uint32_t sortkey = ~0u;
uint32_t chainkey = 0;
uint32_t expectedNextBlock = ~0u;
Label label; Label label;
}; };
@ -851,6 +852,8 @@ struct IrFunction
std::vector<IrConst> constants; std::vector<IrConst> constants;
std::vector<BytecodeMapping> bcMapping; std::vector<BytecodeMapping> bcMapping;
uint32_t entryBlock = 0;
uint32_t entryLocation = 0;
// For each instruction, an operand that can be used to recompute the value // For each instruction, an operand that can be used to recompute the value
std::vector<IrOp> valueRestoreOps; std::vector<IrOp> valueRestoreOps;
@ -993,23 +996,26 @@ struct IrFunction
valueRestoreOps[instIdx] = location; valueRestoreOps[instIdx] = location;
} }
IrOp findRestoreOp(uint32_t instIdx) const IrOp findRestoreOp(uint32_t instIdx, bool limitToCurrentBlock) const
{ {
if (instIdx >= valueRestoreOps.size()) if (instIdx >= valueRestoreOps.size())
return {}; return {};
const IrBlock& block = blocks[validRestoreOpBlockIdx]; const IrBlock& block = blocks[validRestoreOpBlockIdx];
// Values can only reference restore operands in the current block // When spilled, values can only reference restore operands in the current block
if (limitToCurrentBlock)
{
if (instIdx < block.start || instIdx > block.finish) if (instIdx < block.start || instIdx > block.finish)
return {}; return {};
}
return valueRestoreOps[instIdx]; return valueRestoreOps[instIdx];
} }
IrOp findRestoreOp(const IrInst& inst) const IrOp findRestoreOp(const IrInst& inst, bool limitToCurrentBlock) const
{ {
return findRestoreOp(getInstIndex(inst)); return findRestoreOp(getInstIndex(inst), limitToCurrentBlock);
} }
}; };
@ -1037,5 +1043,11 @@ inline int vmUpvalueOp(IrOp op)
return op.index; return op.index;
} }
inline uint32_t vmExitOp(IrOp op)
{
LUAU_ASSERT(op.kind == IrOpKind::VmExit);
return op.index;
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -12,6 +12,9 @@ namespace Luau
{ {
namespace CodeGen namespace CodeGen
{ {
struct LoweringStats;
namespace X64 namespace X64
{ {
@ -33,7 +36,7 @@ struct IrSpillX64
struct IrRegAllocX64 struct IrRegAllocX64
{ {
IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function); IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function, LoweringStats* stats);
RegisterX64 allocReg(SizeX64 size, uint32_t instIdx); RegisterX64 allocReg(SizeX64 size, uint32_t instIdx);
RegisterX64 allocRegOrReuse(SizeX64 size, uint32_t instIdx, std::initializer_list<IrOp> oprefs); RegisterX64 allocRegOrReuse(SizeX64 size, uint32_t instIdx, std::initializer_list<IrOp> oprefs);
@ -70,6 +73,7 @@ struct IrRegAllocX64
AssemblyBuilderX64& build; AssemblyBuilderX64& build;
IrFunction& function; IrFunction& function;
LoweringStats* stats = nullptr;
uint32_t currInstIdx = ~0u; uint32_t currInstIdx = ~0u;
@ -77,6 +81,7 @@ struct IrRegAllocX64
std::array<uint32_t, 16> gprInstUsers; std::array<uint32_t, 16> gprInstUsers;
std::array<bool, 16> freeXmmMap; std::array<bool, 16> freeXmmMap;
std::array<uint32_t, 16> xmmInstUsers; std::array<uint32_t, 16> xmmInstUsers;
uint8_t usableXmmRegCount = 0;
std::bitset<256> usedSpillSlots; std::bitset<256> usedSpillSlots;
unsigned maxUsedSlot = 0; unsigned maxUsedSlot = 0;

View file

@ -94,9 +94,7 @@ inline bool isBlockTerminator(IrCmd cmd)
case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_TRUTHY:
case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_IF_FALSY:
case IrCmd::JUMP_EQ_TAG: case IrCmd::JUMP_EQ_TAG:
case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_CMP_INT:
case IrCmd::JUMP_LT_INT:
case IrCmd::JUMP_GE_UINT:
case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_EQ_POINTER:
case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_NUM:
case IrCmd::JUMP_SLOT_MATCH: case IrCmd::JUMP_SLOT_MATCH:
@ -128,6 +126,7 @@ inline bool isNonTerminatingJump(IrCmd cmd)
case IrCmd::CHECK_ARRAY_SIZE: case IrCmd::CHECK_ARRAY_SIZE:
case IrCmd::CHECK_SLOT_MATCH: case IrCmd::CHECK_SLOT_MATCH:
case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_NO_NEXT:
case IrCmd::CHECK_NODE_VALUE:
return true; return true;
default: default:
break; break;
@ -145,7 +144,6 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::LOAD_DOUBLE: case IrCmd::LOAD_DOUBLE:
case IrCmd::LOAD_INT: case IrCmd::LOAD_INT:
case IrCmd::LOAD_TVALUE: case IrCmd::LOAD_TVALUE:
case IrCmd::LOAD_NODE_VALUE_TV:
case IrCmd::LOAD_ENV: case IrCmd::LOAD_ENV:
case IrCmd::GET_ARR_ADDR: case IrCmd::GET_ARR_ADDR:
case IrCmd::GET_SLOT_NODE_ADDR: case IrCmd::GET_SLOT_NODE_ADDR:
@ -157,6 +155,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::SUB_NUM: case IrCmd::SUB_NUM:
case IrCmd::MUL_NUM: case IrCmd::MUL_NUM:
case IrCmd::DIV_NUM: case IrCmd::DIV_NUM:
case IrCmd::IDIV_NUM:
case IrCmd::MOD_NUM: case IrCmd::MOD_NUM:
case IrCmd::MIN_NUM: case IrCmd::MIN_NUM:
case IrCmd::MAX_NUM: case IrCmd::MAX_NUM:
@ -169,6 +168,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::NOT_ANY: case IrCmd::NOT_ANY:
case IrCmd::CMP_ANY: case IrCmd::CMP_ANY:
case IrCmd::TABLE_LEN: case IrCmd::TABLE_LEN:
case IrCmd::TABLE_SETNUM:
case IrCmd::STRING_LEN: case IrCmd::STRING_LEN:
case IrCmd::NEW_TABLE: case IrCmd::NEW_TABLE:
case IrCmd::DUP_TABLE: case IrCmd::DUP_TABLE:
@ -264,5 +264,14 @@ uint32_t getNativeContextOffset(int bfid);
// Cleans up blocks that were created with no users // Cleans up blocks that were created with no users
void killUnusedBlocks(IrFunction& function); void killUnusedBlocks(IrFunction& function);
// Get blocks in order that tries to maximize fallthrough between them during lowering
// We want to mostly preserve build order with fallbacks outlined
// But we also use hints from optimization passes that chain blocks together where there's only one out-in edge between them
std::vector<uint32_t> getSortedBlockOrder(IrFunction& function);
// Returns first non-dead block that comes after block at index 'i' in the sorted blocks array
// 'dummy' block is returned if the end of array was reached
IrBlock& getNextBlock(IrFunction& function, std::vector<uint32_t>& sortedBlocks, IrBlock& dummy, size_t i);
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -47,18 +47,6 @@ constexpr RegisterA64 castReg(KindA64 kind, RegisterA64 reg)
return RegisterA64{kind, reg.index}; return RegisterA64{kind, reg.index};
} }
// This is equivalent to castReg(KindA64::x), but is separate because it implies different semantics
// Specifically, there are cases when it's useful to treat a wN register as an xN register *after* it has been assigned a value
// Since all A64 instructions that write to wN implicitly zero the top half, this works when we need zero extension semantics
// Crucially, this is *not* safe on an ABI boundary - an int parameter in wN register may have anything in its top half in certain cases
// However, as long as our codegen doesn't use 32-bit truncation by using castReg x=>w, we can safely rely on this.
constexpr RegisterA64 zextReg(RegisterA64 reg)
{
LUAU_ASSERT(reg.kind == KindA64::w);
return RegisterA64{KindA64::x, reg.index};
}
constexpr RegisterA64 noreg{KindA64::none, 0}; constexpr RegisterA64 noreg{KindA64::none, 0};
constexpr RegisterA64 w0{KindA64::w, 0}; constexpr RegisterA64 w0{KindA64::w, 0};

View file

@ -5,6 +5,7 @@
#include "Luau/RegisterX64.h" #include "Luau/RegisterX64.h"
#include <initializer_list> #include <initializer_list>
#include <vector>
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
@ -48,7 +49,8 @@ public:
// mov rbp, rsp // mov rbp, rsp
// push reg in the order specified in regs // push reg in the order specified in regs
// sub rsp, stackSize // sub rsp, stackSize
virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> regs) = 0; virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) = 0;
virtual size_t getSize() const = 0; virtual size_t getSize() const = 0;
virtual size_t getFunctionCount() const = 0; virtual size_t getFunctionCount() const = 0;

View file

@ -30,7 +30,8 @@ public:
void finishInfo() override; void finishInfo() override;
void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override; void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override;
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> regs) override; void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) override;
size_t getSize() const override; size_t getSize() const override;
size_t getFunctionCount() const override; size_t getFunctionCount() const override;

View file

@ -50,7 +50,8 @@ public:
void finishInfo() override; void finishInfo() override;
void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override; void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list<A64::RegisterA64> regs) override;
void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> regs) override; void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list<X64::RegisterX64> gpr,
const std::vector<X64::RegisterX64>& simd) override;
size_t getSize() const override; size_t getSize() const override;
size_t getFunctionCount() const override; size_t getFunctionCount() const override;

View file

@ -5,6 +5,7 @@
#include "ByteUtils.h" #include "ByteUtils.h"
#include <stdarg.h> #include <stdarg.h>
#include <stdio.h>
namespace Luau namespace Luau
{ {
@ -104,6 +105,9 @@ void AssemblyBuilderA64::movk(RegisterA64 dst, uint16_t src, int shift)
void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift) void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{ {
if (src1.kind == KindA64::x && src2.kind == KindA64::w)
placeER("add", dst, src1, src2, 0b00'01011, shift);
else
placeSR3("add", dst, src1, src2, 0b00'01011, shift); placeSR3("add", dst, src1, src2, 0b00'01011, shift);
} }
@ -114,6 +118,9 @@ void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, uint16_t src2)
void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift) void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift)
{ {
if (src1.kind == KindA64::x && src2.kind == KindA64::w)
placeER("sub", dst, src1, src2, 0b10'01011, shift);
else
placeSR3("sub", dst, src1, src2, 0b10'01011, shift); placeSR3("sub", dst, src1, src2, 0b10'01011, shift);
} }
@ -1074,6 +1081,22 @@ void AssemblyBuilderA64::placeBFM(const char* name, RegisterA64 dst, RegisterA64
commit(); commit();
} }
void AssemblyBuilderA64::placeER(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift)
{
if (logText)
log(name, dst, src1, src2, shift);
LUAU_ASSERT(dst.kind == KindA64::x && src1.kind == KindA64::x);
LUAU_ASSERT(src2.kind == KindA64::w);
LUAU_ASSERT(shift >= 0 && shift <= 4);
uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; // could be useful in the future for byte->word extends
int option = 0b010; // UXTW
place(dst.index | (src1.index << 5) | (shift << 10) | (option << 13) | (src2.index << 16) | (1 << 21) | (op << 24) | sf);
commit();
}
void AssemblyBuilderA64::place(uint32_t word) void AssemblyBuilderA64::place(uint32_t word)
{ {
LUAU_ASSERT(codePos < codeEnd); LUAU_ASSERT(codePos < codeEnd);
@ -1166,7 +1189,9 @@ void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 sr
log(src1); log(src1);
text.append(","); text.append(",");
log(src2); log(src2);
if (shift > 0) if (src1.kind == KindA64::x && src2.kind == KindA64::w)
logAppend(" UXTW #%d", shift);
else if (shift > 0)
logAppend(" LSL #%d", shift); logAppend(" LSL #%d", shift);
else if (shift < 0) else if (shift < 0)
logAppend(" LSR #%d", -shift); logAppend(" LSR #%d", -shift);

View file

@ -5,7 +5,6 @@
#include <stdarg.h> #include <stdarg.h>
#include <stdio.h> #include <stdio.h>
#include <string.h>
namespace Luau namespace Luau
{ {

View file

@ -33,13 +33,17 @@ static size_t alignToPageSize(size_t size)
} }
#if defined(_WIN32) #if defined(_WIN32)
static uint8_t* allocatePages(size_t size) static uint8_t* allocatePagesImpl(size_t size)
{ {
return (uint8_t*)VirtualAlloc(nullptr, alignToPageSize(size), MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); LUAU_ASSERT(size == alignToPageSize(size));
return (uint8_t*)VirtualAlloc(nullptr, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE);
} }
static void freePages(uint8_t* mem, size_t size) static void freePagesImpl(uint8_t* mem, size_t size)
{ {
LUAU_ASSERT(size == alignToPageSize(size));
if (VirtualFree(mem, 0, MEM_RELEASE) == 0) if (VirtualFree(mem, 0, MEM_RELEASE) == 0)
LUAU_ASSERT(!"failed to deallocate block memory"); LUAU_ASSERT(!"failed to deallocate block memory");
} }
@ -62,14 +66,24 @@ static void flushInstructionCache(uint8_t* mem, size_t size)
#endif #endif
} }
#else #else
static uint8_t* allocatePages(size_t size) static uint8_t* allocatePagesImpl(size_t size)
{ {
return (uint8_t*)mmap(nullptr, alignToPageSize(size), PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); LUAU_ASSERT(size == alignToPageSize(size));
#ifdef __APPLE__
void* result = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON | MAP_JIT, -1, 0);
#else
void* result = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0);
#endif
return (result == MAP_FAILED) ? nullptr : static_cast<uint8_t*>(result);
} }
static void freePages(uint8_t* mem, size_t size) static void freePagesImpl(uint8_t* mem, size_t size)
{ {
if (munmap(mem, alignToPageSize(size)) != 0) LUAU_ASSERT(size == alignToPageSize(size));
if (munmap(mem, size) != 0)
LUAU_ASSERT(!"Failed to deallocate block memory"); LUAU_ASSERT(!"Failed to deallocate block memory");
} }
@ -94,8 +108,15 @@ namespace CodeGen
{ {
CodeAllocator::CodeAllocator(size_t blockSize, size_t maxTotalSize) CodeAllocator::CodeAllocator(size_t blockSize, size_t maxTotalSize)
: blockSize(blockSize) : CodeAllocator(blockSize, maxTotalSize, nullptr, nullptr)
, maxTotalSize(maxTotalSize) {
}
CodeAllocator::CodeAllocator(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext)
: blockSize{blockSize}
, maxTotalSize{maxTotalSize}
, allocationCallback{allocationCallback}
, allocationCallbackContext{allocationCallbackContext}
{ {
LUAU_ASSERT(blockSize > kMaxReservedDataSize); LUAU_ASSERT(blockSize > kMaxReservedDataSize);
LUAU_ASSERT(maxTotalSize >= blockSize); LUAU_ASSERT(maxTotalSize >= blockSize);
@ -207,5 +228,29 @@ bool CodeAllocator::allocateNewBlock(size_t& unwindInfoSize)
return true; return true;
} }
uint8_t* CodeAllocator::allocatePages(size_t size) const
{
const size_t pageAlignedSize = alignToPageSize(size);
uint8_t* const mem = allocatePagesImpl(pageAlignedSize);
if (mem == nullptr)
return nullptr;
if (allocationCallback)
allocationCallback(allocationCallbackContext, nullptr, 0, mem, pageAlignedSize);
return mem;
}
void CodeAllocator::freePages(uint8_t* mem, size_t size) const
{
const size_t pageAlignedSize = alignToPageSize(size);
if (allocationCallback)
allocationCallback(allocationCallbackContext, mem, pageAlignedSize, nullptr, 0);
freePagesImpl(mem, pageAlignedSize);
}
} // namespace CodeGen } // namespace CodeGen
} // namespace Luau } // namespace Luau

View file

@ -65,7 +65,7 @@ static NativeProto createNativeProto(Proto* proto, const IrBuilder& ir)
int sizecode = proto->sizecode; int sizecode = proto->sizecode;
uint32_t* instOffsets = new uint32_t[sizecode]; uint32_t* instOffsets = new uint32_t[sizecode];
uint32_t instTarget = ir.function.bcMapping[0].asmLocation; uint32_t instTarget = ir.function.entryLocation;
for (int i = 0; i < sizecode; i++) for (int i = 0; i < sizecode; i++)
{ {
@ -74,6 +74,9 @@ static NativeProto createNativeProto(Proto* proto, const IrBuilder& ir)
instOffsets[i] = ir.function.bcMapping[i].asmLocation - instTarget; instOffsets[i] = ir.function.bcMapping[i].asmLocation - instTarget;
} }
// Set first instruction offset to 0 so that entering this function still executes any generated entry code.
instOffsets[0] = 0;
// entry target will be relocated when assembly is finalized // entry target will be relocated when assembly is finalized
return {proto, instOffsets, instTarget}; return {proto, instOffsets, instTarget};
} }
@ -103,7 +106,7 @@ static std::optional<NativeProto> createNativeFunction(AssemblyBuilder& build, M
IrBuilder ir; IrBuilder ir;
ir.buildFunctionIr(proto); ir.buildFunctionIr(proto);
if (!lowerFunction(ir, build, helpers, proto, {})) if (!lowerFunction(ir, build, helpers, proto, {}, /* stats */ nullptr))
return std::nullopt; return std::nullopt;
return createNativeProto(proto, ir); return createNativeProto(proto, ir);
@ -159,9 +162,6 @@ unsigned int getCpuFeaturesA64()
bool isSupported() bool isSupported()
{ {
if (!LUA_CUSTOM_EXECUTION)
return false;
if (LUA_EXTRA_SIZE != 1) if (LUA_EXTRA_SIZE != 1)
return false; return false;
@ -202,11 +202,11 @@ bool isSupported()
#endif #endif
} }
void create(lua_State* L) void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext)
{ {
LUAU_ASSERT(isSupported()); LUAU_ASSERT(isSupported());
std::unique_ptr<NativeState> data = std::make_unique<NativeState>(); std::unique_ptr<NativeState> data = std::make_unique<NativeState>(allocationCallback, allocationCallbackContext);
#if defined(_WIN32) #if defined(_WIN32)
data->unwindBuilder = std::make_unique<UnwindBuilderWin>(); data->unwindBuilder = std::make_unique<UnwindBuilderWin>();
@ -239,23 +239,38 @@ void create(lua_State* L)
ecb->enter = onEnter; ecb->enter = onEnter;
} }
void compile(lua_State* L, int idx, unsigned int flags) void create(lua_State* L)
{
create(L, nullptr, nullptr);
}
CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats)
{ {
LUAU_ASSERT(lua_isLfunction(L, idx)); LUAU_ASSERT(lua_isLfunction(L, idx));
const TValue* func = luaA_toobject(L, idx); const TValue* func = luaA_toobject(L, idx);
Proto* root = clvalue(func)->l.p;
if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
return CodeGenCompilationResult::NothingToCompile;
// If initialization has failed, do not compile any functions // If initialization has failed, do not compile any functions
NativeState* data = getNativeState(L); NativeState* data = getNativeState(L);
if (!data) if (!data)
return; return CodeGenCompilationResult::CodeGenNotInitialized;
Proto* root = clvalue(func)->l.p;
if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0)
return;
std::vector<Proto*> protos; std::vector<Proto*> protos;
gatherFunctions(protos, root); gatherFunctions(protos, root);
// Skip protos that have been compiled during previous invocations of CodeGen::compile
protos.erase(std::remove_if(protos.begin(), protos.end(),
[](Proto* p) {
return p == nullptr || p->execdata != nullptr;
}),
protos.end());
if (protos.empty())
return CodeGenCompilationResult::NothingToCompile;
#if defined(__aarch64__) #if defined(__aarch64__)
static unsigned int cpuFeatures = getCpuFeaturesA64(); static unsigned int cpuFeatures = getCpuFeaturesA64();
A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures); A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures);
@ -273,9 +288,7 @@ void compile(lua_State* L, int idx, unsigned int flags)
std::vector<NativeProto> results; std::vector<NativeProto> results;
results.reserve(protos.size()); results.reserve(protos.size());
// Skip protos that have been compiled during previous invocations of CodeGen::compile
for (Proto* p : protos) for (Proto* p : protos)
if (p && p->execdata == nullptr)
if (std::optional<NativeProto> np = createNativeFunction(build, helpers, p)) if (std::optional<NativeProto> np = createNativeFunction(build, helpers, p))
results.push_back(*np); results.push_back(*np);
@ -285,12 +298,12 @@ void compile(lua_State* L, int idx, unsigned int flags)
for (NativeProto result : results) for (NativeProto result : results)
destroyExecData(result.execdata); destroyExecData(result.execdata);
return; return CodeGenCompilationResult::CodeGenFailed;
} }
// If no functions were assembled, we don't need to allocate/copy executable pages for helpers // If no functions were assembled, we don't need to allocate/copy executable pages for helpers
if (results.empty()) if (results.empty())
return; return CodeGenCompilationResult::CodeGenFailed;
uint8_t* nativeData = nullptr; uint8_t* nativeData = nullptr;
size_t sizeNativeData = 0; size_t sizeNativeData = 0;
@ -301,7 +314,7 @@ void compile(lua_State* L, int idx, unsigned int flags)
for (NativeProto result : results) for (NativeProto result : results)
destroyExecData(result.execdata); destroyExecData(result.execdata);
return; return CodeGenCompilationResult::AllocationFailed;
} }
if (gPerfLogFn && results.size() > 0) if (gPerfLogFn && results.size() > 0)
@ -318,13 +331,30 @@ void compile(lua_State* L, int idx, unsigned int flags)
} }
} }
for (NativeProto result : results) for (const NativeProto& result : results)
{ {
// the memory is now managed by VM and will be freed via onDestroyFunction // the memory is now managed by VM and will be freed via onDestroyFunction
result.p->execdata = result.execdata; result.p->execdata = result.execdata;
result.p->exectarget = uintptr_t(codeStart) + result.exectarget; result.p->exectarget = uintptr_t(codeStart) + result.exectarget;
result.p->codeentry = &kCodeEntryInsn; result.p->codeentry = &kCodeEntryInsn;
} }
if (stats != nullptr)
{
for (const NativeProto& result : results)
{
stats->bytecodeSizeBytes += result.p->sizecode * sizeof(Instruction);
// Account for the native -> bytecode instruction offsets mapping:
stats->nativeMetadataSizeBytes += result.p->sizecode * sizeof(uint32_t);
}
stats->functionsCompiled += uint32_t(results.size());
stats->nativeCodeSizeBytes += build.code.size();
stats->nativeDataSizeBytes += build.data.size();
}
return CodeGenCompilationResult::Success;
} }
void setPerfLog(void* context, PerfLogFn logFn) void setPerfLog(void* context, PerfLogFn logFn)

View file

@ -24,15 +24,6 @@ struct EntryLocations
Label epilogueStart; Label epilogueStart;
}; };
static void emitClearNativeFlag(AssemblyBuilderA64& build)
{
build.ldr(x0, mem(rState, offsetof(lua_State, ci)));
build.ldr(w1, mem(x0, offsetof(CallInfo, flags)));
build.mov(w2, ~LUA_CALLINFO_NATIVE);
build.and_(w1, w1, w2);
build.str(w1, mem(x0, offsetof(CallInfo, flags)));
}
static void emitExit(AssemblyBuilderA64& build, bool continueInVm) static void emitExit(AssemblyBuilderA64& build, bool continueInVm)
{ {
build.mov(x0, continueInVm); build.mov(x0, continueInVm);
@ -40,14 +31,21 @@ static void emitExit(AssemblyBuilderA64& build, bool continueInVm)
build.br(x1); build.br(x1);
} }
static void emitUpdatePcAndContinueInVm(AssemblyBuilderA64& build) static void emitUpdatePcForExit(AssemblyBuilderA64& build)
{ {
// x0 = pcpos * sizeof(Instruction) // x0 = pcpos * sizeof(Instruction)
build.add(x0, rCode, x0); build.add(x0, rCode, x0);
build.ldr(x1, mem(rState, offsetof(lua_State, ci))); build.ldr(x1, mem(rState, offsetof(lua_State, ci)));
build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); build.str(x0, mem(x1, offsetof(CallInfo, savedpc)));
}
emitExit(build, /* continueInVm */ true); static void emitClearNativeFlag(AssemblyBuilderA64& build)
{
build.ldr(x0, mem(rState, offsetof(lua_State, ci)));
build.ldr(w1, mem(x0, offsetof(CallInfo, flags)));
build.mov(w2, ~LUA_CALLINFO_NATIVE);
build.and_(w1, w1, w2);
build.str(w1, mem(x0, offsetof(CallInfo, flags)));
} }
static void emitInterrupt(AssemblyBuilderA64& build) static void emitInterrupt(AssemblyBuilderA64& build)
@ -227,6 +225,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde
build.stp(x19, x20, mem(sp, 16)); build.stp(x19, x20, mem(sp, 16));
build.stp(x21, x22, mem(sp, 32)); build.stp(x21, x22, mem(sp, 32));
build.stp(x23, x24, mem(sp, 48)); build.stp(x23, x24, mem(sp, 48));
build.str(x25, mem(sp, 64));
build.mov(x29, sp); // this is only necessary if we maintain frame pointers, which we do in the JIT for now build.mov(x29, sp); // this is only necessary if we maintain frame pointers, which we do in the JIT for now
@ -237,6 +236,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde
// Setup native execution environment // Setup native execution environment
build.mov(rState, x0); build.mov(rState, x0);
build.mov(rNativeContext, x3); build.mov(rNativeContext, x3);
build.ldr(rGlobalState, mem(x0, offsetof(lua_State, global)));
build.ldr(rBase, mem(x0, offsetof(lua_State, base))); // L->base build.ldr(rBase, mem(x0, offsetof(lua_State, base))); // L->base
@ -254,6 +254,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde
locations.epilogueStart = build.setLabel(); locations.epilogueStart = build.setLabel();
// Cleanup and exit // Cleanup and exit
build.ldr(x25, mem(sp, 64));
build.ldp(x23, x24, mem(sp, 48)); build.ldp(x23, x24, mem(sp, 48));
build.ldp(x21, x22, mem(sp, 32)); build.ldp(x21, x22, mem(sp, 32));
build.ldp(x19, x20, mem(sp, 16)); build.ldp(x19, x20, mem(sp, 16));
@ -264,7 +265,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde
// Our entry function is special, it spans the whole remaining code area // Our entry function is special, it spans the whole remaining code area
unwind.startFunction(); unwind.startFunction();
unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24}); unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24, x25});
unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton);
return locations; return locations;
@ -305,6 +306,11 @@ bool initHeaderFunctions(NativeState& data)
void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers) void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers)
{ {
if (build.logText)
build.logAppend("; updatePcAndContinueInVm\n");
build.setLabel(helpers.updatePcAndContinueInVm);
emitUpdatePcForExit(build);
if (build.logText) if (build.logText)
build.logAppend("; exitContinueVmClearNativeFlag\n"); build.logAppend("; exitContinueVmClearNativeFlag\n");
build.setLabel(helpers.exitContinueVmClearNativeFlag); build.setLabel(helpers.exitContinueVmClearNativeFlag);
@ -320,11 +326,6 @@ void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers)
build.setLabel(helpers.exitNoContinueVm); build.setLabel(helpers.exitNoContinueVm);
emitExit(build, /* continueInVm */ false); emitExit(build, /* continueInVm */ false);
if (build.logText)
build.logAppend("; updatePcAndContinueInVm\n");
build.setLabel(helpers.updatePcAndContinueInVm);
emitUpdatePcAndContinueInVm(build);
if (build.logText) if (build.logText)
build.logAppend("; reentry\n"); build.logAppend("; reentry\n");
build.setLabel(helpers.reentry); build.setLabel(helpers.reentry);

View file

@ -43,7 +43,7 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto)
} }
template<typename AssemblyBuilder> template<typename AssemblyBuilder>
static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, AssemblyOptions options) static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, AssemblyOptions options, LoweringStats* stats)
{ {
std::vector<Proto*> protos; std::vector<Proto*> protos;
gatherFunctions(protos, clvalue(func)->l.p); gatherFunctions(protos, clvalue(func)->l.p);
@ -66,7 +66,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
if (options.includeAssembly || options.includeIr) if (options.includeAssembly || options.includeIr)
logFunctionHeader(build, p); logFunctionHeader(build, p);
if (!lowerFunction(ir, build, helpers, p, options)) if (!lowerFunction(ir, build, helpers, p, options, stats))
{ {
if (build.logText) if (build.logText)
build.logAppend("; skipping (can't lower)\n"); build.logAppend("; skipping (can't lower)\n");
@ -90,7 +90,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A
unsigned int getCpuFeaturesA64(); unsigned int getCpuFeaturesA64();
#endif #endif
std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) std::string getAssembly(lua_State* L, int idx, AssemblyOptions options, LoweringStats* stats)
{ {
LUAU_ASSERT(lua_isLfunction(L, idx)); LUAU_ASSERT(lua_isLfunction(L, idx));
const TValue* func = luaA_toobject(L, idx); const TValue* func = luaA_toobject(L, idx);
@ -106,35 +106,35 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options)
X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly);
#endif #endif
return getAssemblyImpl(build, func, options); return getAssemblyImpl(build, func, options, stats);
} }
case AssemblyOptions::A64: case AssemblyOptions::A64:
{ {
A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, /* features= */ A64::Feature_JSCVT); A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, /* features= */ A64::Feature_JSCVT);
return getAssemblyImpl(build, func, options); return getAssemblyImpl(build, func, options, stats);
} }
case AssemblyOptions::A64_NoFeatures: case AssemblyOptions::A64_NoFeatures:
{ {
A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, /* features= */ 0); A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, /* features= */ 0);
return getAssemblyImpl(build, func, options); return getAssemblyImpl(build, func, options, stats);
} }
case AssemblyOptions::X64_Windows: case AssemblyOptions::X64_Windows:
{ {
X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly, X64::ABIX64::Windows); X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly, X64::ABIX64::Windows);
return getAssemblyImpl(build, func, options); return getAssemblyImpl(build, func, options, stats);
} }
case AssemblyOptions::X64_SystemV: case AssemblyOptions::X64_SystemV:
{ {
X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly, X64::ABIX64::SystemV); X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly, X64::ABIX64::SystemV);
return getAssemblyImpl(build, func, options); return getAssemblyImpl(build, func, options, stats);
} }
default: default:

View file

@ -44,38 +44,10 @@ inline void gatherFunctions(std::vector<Proto*>& results, Proto* proto)
gatherFunctions(results, proto->p[i]); gatherFunctions(results, proto->p[i]);
} }
inline IrBlock& getNextBlock(IrFunction& function, std::vector<uint32_t>& sortedBlocks, IrBlock& dummy, size_t i)
{
for (size_t j = i + 1; j < sortedBlocks.size(); ++j)
{
IrBlock& block = function.blocks[sortedBlocks[j]];
if (block.kind != IrBlockKind::Dead)
return block;
}
return dummy;
}
template<typename AssemblyBuilder, typename IrLowering> template<typename AssemblyBuilder, typename IrLowering>
inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options)
{ {
// While we will need a better block ordering in the future, right now we want to mostly preserve build order with fallbacks outlined std::vector<uint32_t> sortedBlocks = getSortedBlockOrder(function);
std::vector<uint32_t> sortedBlocks;
sortedBlocks.reserve(function.blocks.size());
for (uint32_t i = 0; i < function.blocks.size(); i++)
sortedBlocks.push_back(i);
std::sort(sortedBlocks.begin(), sortedBlocks.end(), [&](uint32_t idxA, uint32_t idxB) {
const IrBlock& a = function.blocks[idxA];
const IrBlock& b = function.blocks[idxB];
// Place fallback blocks at the end
if ((a.kind == IrBlockKind::Fallback) != (b.kind == IrBlockKind::Fallback))
return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback);
// Try to order by instruction order
return a.sortkey < b.sortkey;
});
// For each IR instruction that begins a bytecode instruction, which bytecode instruction is it? // For each IR instruction that begins a bytecode instruction, which bytecode instruction is it?
std::vector<uint32_t> bcLocations(function.instructions.size() + 1, ~0u); std::vector<uint32_t> bcLocations(function.instructions.size() + 1, ~0u);
@ -100,6 +72,9 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction&
IrBlock dummy; IrBlock dummy;
dummy.start = ~0u; dummy.start = ~0u;
// Make sure entry block is first
LUAU_ASSERT(sortedBlocks[0] == 0);
for (size_t i = 0; i < sortedBlocks.size(); ++i) for (size_t i = 0; i < sortedBlocks.size(); ++i)
{ {
uint32_t blockIndex = sortedBlocks[i]; uint32_t blockIndex = sortedBlocks[i];
@ -130,8 +105,18 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction&
build.setLabel(block.label); build.setLabel(block.label);
if (blockIndex == function.entryBlock)
{
function.entryLocation = build.getLabelOffset(block.label);
}
IrBlock& nextBlock = getNextBlock(function, sortedBlocks, dummy, i); IrBlock& nextBlock = getNextBlock(function, sortedBlocks, dummy, i);
// Optimizations often propagate information between blocks
// To make sure the register and spill state is correct when blocks are lowered, we check that sorted block order matches the expected one
if (block.expectedNextBlock != ~0u)
LUAU_ASSERT(function.getBlockIndex(nextBlock) == block.expectedNextBlock);
for (uint32_t index = block.start; index <= block.finish; index++) for (uint32_t index = block.start; index <= block.finish; index++)
{ {
LUAU_ASSERT(index < function.instructions.size()); LUAU_ASSERT(index < function.instructions.size());
@ -189,7 +174,7 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction&
} }
} }
lowering.finishBlock(); lowering.finishBlock(block, nextBlock);
if (options.includeIr) if (options.includeIr)
build.logAppend("#\n"); build.logAppend("#\n");
@ -214,24 +199,26 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction&
return true; return true;
} }
inline bool lowerIr(X64::AssemblyBuilderX64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) inline bool lowerIr(
X64::AssemblyBuilderX64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats)
{ {
optimizeMemoryOperandsX64(ir.function); optimizeMemoryOperandsX64(ir.function);
X64::IrLoweringX64 lowering(build, helpers, ir.function); X64::IrLoweringX64 lowering(build, helpers, ir.function, stats);
return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options);
} }
inline bool lowerIr(A64::AssemblyBuilderA64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) inline bool lowerIr(
A64::AssemblyBuilderA64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats)
{ {
A64::IrLoweringA64 lowering(build, helpers, ir.function); A64::IrLoweringA64 lowering(build, helpers, ir.function, stats);
return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options);
} }
template<typename AssemblyBuilder> template<typename AssemblyBuilder>
inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats)
{ {
killUnusedBlocks(ir.function); killUnusedBlocks(ir.function);
@ -247,7 +234,7 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers&
createLinearBlocks(ir, useValueNumbering); createLinearBlocks(ir, useValueNumbering);
} }
return lowerIr(build, ir, helpers, proto, options); return lowerIr(build, ir, helpers, proto, options, stats);
} }
} // namespace CodeGen } // namespace CodeGen

View file

@ -16,10 +16,24 @@
* | rdx home space | (unused) * | rdx home space | (unused)
* | rcx home space | (unused) * | rcx home space | (unused)
* | return address | * | return address |
* | ... saved non-volatile registers ... <-- rsp + kStackSize + kLocalsSize * | ... saved non-volatile registers ... <-- rsp + kStackSizeFull
* | unused | for 16 byte alignment of the stack * | alignment |
* | xmm9 non-vol |
* | xmm9 cont. |
* | xmm8 non-vol |
* | xmm8 cont. |
* | xmm7 non-vol |
* | xmm7 cont. |
* | xmm6 non-vol |
* | xmm6 cont. |
* | spill slot 5 |
* | spill slot 4 |
* | spill slot 3 |
* | spill slot 2 |
* | spill slot 1 | <-- rsp + kStackOffsetToSpillSlots
* | sTemporarySlot |
* | sCode | * | sCode |
* | sClosure | <-- rsp + kStackSize * | sClosure | <-- rsp + kStackOffsetToLocals
* | argument 6 | <-- rsp + 40 * | argument 6 | <-- rsp + 40
* | argument 5 | <-- rsp + 32 * | argument 5 | <-- rsp + 32
* | r9 home space | * | r9 home space |
@ -81,24 +95,43 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde
build.push(rdi); build.push(rdi);
build.push(rsi); build.push(rsi);
// On Windows, rbp is available as a general-purpose non-volatile register; we currently don't use it, but we need to push an even number // On Windows, rbp is available as a general-purpose non-volatile register and this might be freed up
// of registers for stack alignment...
build.push(rbp); build.push(rbp);
// TODO: once we start using non-volatile SIMD registers on Windows, we will save those here
} }
// Allocate stack space (reg home area + local data) // Allocate stack space
build.sub(rsp, kStackSize + kLocalsSize); uint8_t usableXmmRegCount = getXmmRegisterCount(build.abi);
unsigned xmmStorageSize = getNonVolXmmStorageSize(build.abi, usableXmmRegCount);
unsigned fullStackSize = getFullStackSize(build.abi, usableXmmRegCount);
build.sub(rsp, fullStackSize);
OperandX64 xmmStorageOffset = rsp + (fullStackSize - (kStackAlign + xmmStorageSize));
// On Windows, we have to save non-volatile xmm registers
std::vector<RegisterX64> savedXmmRegs;
if (build.abi == ABIX64::Windows)
{
if (usableXmmRegCount > kWindowsFirstNonVolXmmReg)
savedXmmRegs.reserve(usableXmmRegCount - kWindowsFirstNonVolXmmReg);
for (uint8_t i = kWindowsFirstNonVolXmmReg, offset = 0; i < usableXmmRegCount; i++, offset += 16)
{
RegisterX64 xmmReg = RegisterX64{SizeX64::xmmword, i};
build.vmovaps(xmmword[xmmStorageOffset + offset], xmmReg);
savedXmmRegs.push_back(xmmReg);
}
}
locations.prologueEnd = build.setLabel(); locations.prologueEnd = build.setLabel();
uint32_t prologueSize = build.getLabelOffset(locations.prologueEnd) - build.getLabelOffset(locations.start); uint32_t prologueSize = build.getLabelOffset(locations.prologueEnd) - build.getLabelOffset(locations.start);
if (build.abi == ABIX64::SystemV) if (build.abi == ABIX64::SystemV)
unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ true, {rbx, r12, r13, r14, r15}); unwind.prologueX64(prologueSize, fullStackSize, /* setupFrame= */ true, {rbx, r12, r13, r14, r15}, {});
else if (build.abi == ABIX64::Windows) else if (build.abi == ABIX64::Windows)
unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ false, {rbx, r12, r13, r14, r15, rdi, rsi, rbp}); unwind.prologueX64(prologueSize, fullStackSize, /* setupFrame= */ false, {rbx, r12, r13, r14, r15, rdi, rsi, rbp}, savedXmmRegs);
// Setup native execution environment // Setup native execution environment
build.mov(rState, rArg1); build.mov(rState, rArg1);
@ -118,8 +151,15 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde
// Even though we jumped away, we will return here in the end // Even though we jumped away, we will return here in the end
locations.epilogueStart = build.setLabel(); locations.epilogueStart = build.setLabel();
// Cleanup and exit // Epilogue and exit
build.add(rsp, kStackSize + kLocalsSize); if (build.abi == ABIX64::Windows)
{
// xmm registers are restored before the official epilogue that has to start with 'add rsp/lea rsp'
for (uint8_t i = kWindowsFirstNonVolXmmReg, offset = 0; i < usableXmmRegCount; i++, offset += 16)
build.vmovaps(RegisterX64{SizeX64::xmmword, i}, xmmword[xmmStorageOffset + offset]);
}
build.add(rsp, fullStackSize);
if (build.abi == ABIX64::Windows) if (build.abi == ABIX64::Windows)
{ {
@ -180,6 +220,11 @@ bool initHeaderFunctions(NativeState& data)
void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers)
{ {
if (build.logText)
build.logAppend("; updatePcAndContinueInVm\n");
build.setLabel(helpers.updatePcAndContinueInVm);
emitUpdatePcForExit(build);
if (build.logText) if (build.logText)
build.logAppend("; exitContinueVmClearNativeFlag\n"); build.logAppend("; exitContinueVmClearNativeFlag\n");
build.setLabel(helpers.exitContinueVmClearNativeFlag); build.setLabel(helpers.exitContinueVmClearNativeFlag);
@ -195,11 +240,6 @@ void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers)
build.setLabel(helpers.exitNoContinueVm); build.setLabel(helpers.exitNoContinueVm);
emitExit(build, /* continueInVm */ false); emitExit(build, /* continueInVm */ false);
if (build.logText)
build.logAppend("; updatePcAndContinueInVm\n");
build.setLabel(helpers.updatePcAndContinueInVm);
emitUpdatePcAndContinueInVm(build);
if (build.logText) if (build.logText)
build.logAppend("; continueCallInVm\n"); build.logAppend("; continueCallInVm\n");
build.setLabel(helpers.continueCallInVm); build.setLabel(helpers.continueCallInVm);

View file

@ -1,8 +1,9 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "EmitBuiltinsX64.h" #include "EmitBuiltinsX64.h"
#include "Luau/AssemblyBuilderX64.h"
#include "Luau/Bytecode.h" #include "Luau/Bytecode.h"
#include "Luau/AssemblyBuilderX64.h"
#include "Luau/IrCallWrapperX64.h" #include "Luau/IrCallWrapperX64.h"
#include "Luau/IrRegAllocX64.h" #include "Luau/IrRegAllocX64.h"

View file

@ -25,7 +25,7 @@ struct ModuleHelpers
Label exitContinueVm; Label exitContinueVm;
Label exitNoContinueVm; Label exitNoContinueVm;
Label exitContinueVmClearNativeFlag; Label exitContinueVmClearNativeFlag;
Label updatePcAndContinueInVm; Label updatePcAndContinueInVm; // no reentry
Label return_; Label return_;
Label interrupt; Label interrupt;

View file

@ -31,23 +31,24 @@ namespace A64
// 1. Constant registers (only loaded during codegen entry) // 1. Constant registers (only loaded during codegen entry)
constexpr RegisterA64 rState = x19; // lua_State* L constexpr RegisterA64 rState = x19; // lua_State* L
constexpr RegisterA64 rNativeContext = x20; // NativeContext* context constexpr RegisterA64 rNativeContext = x20; // NativeContext* context
constexpr RegisterA64 rGlobalState = x21; // global_State* L->global
// 2. Frame registers (reloaded when call frame changes; rBase is also reloaded after all calls that may reallocate stack) // 2. Frame registers (reloaded when call frame changes; rBase is also reloaded after all calls that may reallocate stack)
constexpr RegisterA64 rConstants = x21; // TValue* k constexpr RegisterA64 rConstants = x22; // TValue* k
constexpr RegisterA64 rClosure = x22; // Closure* cl constexpr RegisterA64 rClosure = x23; // Closure* cl
constexpr RegisterA64 rCode = x23; // Instruction* code constexpr RegisterA64 rCode = x24; // Instruction* code
constexpr RegisterA64 rBase = x24; // StkId base constexpr RegisterA64 rBase = x25; // StkId base
// Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point // Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point
// See CodeGenA64.cpp for layout // See CodeGenA64.cpp for layout
constexpr unsigned kStashSlots = 8; // stashed non-volatile registers constexpr unsigned kStashSlots = 9; // stashed non-volatile registers
constexpr unsigned kTempSlots = 1; // 8 bytes of temporary space, such luxury!
constexpr unsigned kSpillSlots = 22; // slots for spilling temporary registers constexpr unsigned kSpillSlots = 22; // slots for spilling temporary registers
constexpr unsigned kTempSlots = 2; // 16 bytes of temporary space, such luxury!
constexpr unsigned kStackSize = (kStashSlots + kSpillSlots + kTempSlots) * 8; constexpr unsigned kStackSize = (kStashSlots + kTempSlots + kSpillSlots) * 8;
constexpr AddressA64 sSpillArea = mem(sp, kStashSlots * 8); constexpr AddressA64 sSpillArea = mem(sp, (kStashSlots + kTempSlots) * 8);
constexpr AddressA64 sTemporary = mem(sp, (kStashSlots + kSpillSlots) * 8); constexpr AddressA64 sTemporary = mem(sp, kStashSlots * 8);
inline void emitUpdateBase(AssemblyBuilderA64& build) inline void emitUpdateBase(AssemblyBuilderA64& build)
{ {

View file

@ -12,6 +12,8 @@
#include "lgc.h" #include "lgc.h"
#include "lstate.h" #include "lstate.h"
#include <utility>
namespace Luau namespace Luau
{ {
namespace CodeGen namespace CodeGen
@ -22,10 +24,15 @@ namespace X64
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label) void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label)
{ {
// Refresher on comi/ucomi EFLAGS: // Refresher on comi/ucomi EFLAGS:
// all zero: greater
// CF only: less // CF only: less
// ZF only: equal // ZF only: equal
// PF+CF+ZF: unordered (NaN) // PF+CF+ZF: unordered (NaN)
// To avoid the lack of conditional jumps that check for "greater" conditions in IEEE 754 compliant way, we use "less" forms to emulate these
if (cond == IrCondition::Greater || cond == IrCondition::GreaterEqual || cond == IrCondition::NotGreater || cond == IrCondition::NotGreaterEqual)
std::swap(lhs, rhs);
if (rhs.cat == CategoryX64::reg) if (rhs.cat == CategoryX64::reg)
{ {
build.vucomisd(rhs, lhs); build.vucomisd(rhs, lhs);
@ -41,18 +48,22 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs,
switch (cond) switch (cond)
{ {
case IrCondition::NotLessEqual: case IrCondition::NotLessEqual:
case IrCondition::NotGreaterEqual:
// (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN // (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN
build.jcc(ConditionX64::NotAboveEqual, label); build.jcc(ConditionX64::NotAboveEqual, label);
break; break;
case IrCondition::LessEqual: case IrCondition::LessEqual:
case IrCondition::GreaterEqual:
// (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN // (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN
build.jcc(ConditionX64::AboveEqual, label); build.jcc(ConditionX64::AboveEqual, label);
break; break;
case IrCondition::NotLess: case IrCondition::NotLess:
case IrCondition::NotGreater:
// (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN // (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN
build.jcc(ConditionX64::NotAbove, label); build.jcc(ConditionX64::NotAbove, label);
break; break;
case IrCondition::Less: case IrCondition::Less:
case IrCondition::Greater:
// (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN // (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN
build.jcc(ConditionX64::Above, label); build.jcc(ConditionX64::Above, label);
break; break;
@ -66,6 +77,44 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs,
} }
} }
ConditionX64 getConditionInt(IrCondition cond)
{
switch (cond)
{
case IrCondition::Equal:
return ConditionX64::Equal;
case IrCondition::NotEqual:
return ConditionX64::NotEqual;
case IrCondition::Less:
return ConditionX64::Less;
case IrCondition::NotLess:
return ConditionX64::NotLess;
case IrCondition::LessEqual:
return ConditionX64::LessEqual;
case IrCondition::NotLessEqual:
return ConditionX64::NotLessEqual;
case IrCondition::Greater:
return ConditionX64::Greater;
case IrCondition::NotGreater:
return ConditionX64::NotGreater;
case IrCondition::GreaterEqual:
return ConditionX64::GreaterEqual;
case IrCondition::NotGreaterEqual:
return ConditionX64::NotGreaterEqual;
case IrCondition::UnsignedLess:
return ConditionX64::Below;
case IrCondition::UnsignedLessEqual:
return ConditionX64::BelowEqual;
case IrCondition::UnsignedGreater:
return ConditionX64::Above;
case IrCondition::UnsignedGreaterEqual:
return ConditionX64::AboveEqual;
default:
LUAU_ASSERT(!"Unsupported condition");
return ConditionX64::Zero;
}
}
void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos) void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos)
{ {
LUAU_ASSERT(tmp != node); LUAU_ASSERT(tmp != node);
@ -123,16 +172,6 @@ void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, in
emitUpdateBase(build); emitUpdateBase(build);
} }
void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, int step, int init)
{
IrCallWrapperX64 callWrap(regs, build);
callWrap.addArgument(SizeX64::qword, rState);
callWrap.addArgument(SizeX64::qword, luauRegAddress(limit));
callWrap.addArgument(SizeX64::qword, luauRegAddress(step));
callWrap.addArgument(SizeX64::qword, luauRegAddress(init));
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]);
}
void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra)
{ {
IrCallWrapperX64 callWrap(regs, build); IrCallWrapperX64 callWrap(regs, build);
@ -157,13 +196,14 @@ void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, Operan
emitUpdateBase(build); emitUpdateBase(build);
} }
void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, int ratag, Label& skip) void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, IrOp ra, int ratag, Label& skip)
{ {
// Barrier should've been optimized away if we know that it's not collectable, checking for correctness // Barrier should've been optimized away if we know that it's not collectable, checking for correctness
if (ratag == -1 || !isGCO(ratag)) if (ratag == -1 || !isGCO(ratag))
{ {
// iscollectable(ra) // iscollectable(ra)
build.cmp(luauRegTag(ra), LUA_TSTRING); OperandX64 tag = (ra.kind == IrOpKind::VmReg) ? luauRegTag(vmRegOp(ra)) : luauConstantTag(vmConstOp(ra));
build.cmp(tag, LUA_TSTRING);
build.jcc(ConditionX64::Less, skip); build.jcc(ConditionX64::Less, skip);
} }
@ -172,12 +212,14 @@ void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, Re
build.jcc(ConditionX64::Zero, skip); build.jcc(ConditionX64::Zero, skip);
// iswhite(gcvalue(ra)) // iswhite(gcvalue(ra))
build.mov(tmp, luauRegValue(ra)); OperandX64 value = (ra.kind == IrOpKind::VmReg) ? luauRegValue(vmRegOp(ra)) : luauConstantValue(vmConstOp(ra));
build.mov(tmp, value);
build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT)); build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT));
build.jcc(ConditionX64::Zero, skip); build.jcc(ConditionX64::Zero, skip);
} }
void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, int ratag)
void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, IrOp ra, int ratag)
{ {
Label skip; Label skip;
@ -328,14 +370,12 @@ void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, in
emitUpdateBase(build); emitUpdateBase(build);
} }
void emitUpdatePcAndContinueInVm(AssemblyBuilderX64& build) void emitUpdatePcForExit(AssemblyBuilderX64& build)
{ {
// edx = pcpos * sizeof(Instruction) // edx = pcpos * sizeof(Instruction)
build.add(rdx, sCode); build.add(rdx, sCode);
build.mov(rax, qword[rState + offsetof(lua_State, ci)]); build.mov(rax, qword[rState + offsetof(lua_State, ci)]);
build.mov(qword[rax + offsetof(CallInfo, savedpc)], rdx); build.mov(qword[rax + offsetof(CallInfo, savedpc)], rdx);
emitExit(build, /* continueInVm */ true);
} }
void emitContinueCallInVm(AssemblyBuilderX64& build) void emitContinueCallInVm(AssemblyBuilderX64& build)

View file

@ -42,16 +42,55 @@ constexpr RegisterX64 rBase = r14; // StkId base
constexpr RegisterX64 rNativeContext = r13; // NativeContext* context constexpr RegisterX64 rNativeContext = r13; // NativeContext* context
constexpr RegisterX64 rConstants = r12; // TValue* k constexpr RegisterX64 rConstants = r12; // TValue* k
// Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point constexpr unsigned kExtraLocals = 3; // Number of 8 byte slots available for specialized local variables specified below
// See CodeGenX64.cpp for layout constexpr unsigned kSpillSlots = 5; // Number of 8 byte slots available for register allocator to spill data into
constexpr unsigned kStackSize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments static_assert((kExtraLocals + kSpillSlots) * 8 % 16 == 0, "locals have to preserve 16 byte alignment");
constexpr unsigned kSpillSlots = 4; // locations for register allocator to spill data into
constexpr unsigned kLocalsSize = 24 + 8 * kSpillSlots; // 3 extra slots for our custom locals (also aligns the stack to 16 byte boundary)
constexpr OperandX64 sClosure = qword[rsp + kStackSize + 0]; // Closure* cl constexpr uint8_t kWindowsFirstNonVolXmmReg = 6;
constexpr OperandX64 sCode = qword[rsp + kStackSize + 8]; // Instruction* code
constexpr OperandX64 sTemporarySlot = addr[rsp + kStackSize + 16]; constexpr uint8_t kWindowsUsableXmmRegs = 10; // Some xmm regs are non-volatile, we have to balance how many we want to use/preserve
constexpr OperandX64 sSpillArea = addr[rsp + kStackSize + 24]; constexpr uint8_t kSystemVUsableXmmRegs = 16; // All xmm regs are volatile
inline uint8_t getXmmRegisterCount(ABIX64 abi)
{
return abi == ABIX64::SystemV ? kSystemVUsableXmmRegs : kWindowsUsableXmmRegs;
}
// Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point
// Stack is separated into sections for different data. See CodeGenX64.cpp for layout overview
constexpr unsigned kStackAlign = 8; // Bytes we need to align the stack for non-vol xmm register storage
constexpr unsigned kStackLocalStorage = 8 * kExtraLocals;
constexpr unsigned kStackSpillStorage = 8 * kSpillSlots;
constexpr unsigned kStackExtraArgumentStorage = 2 * 8; // Bytes for 5th and 6th function call arguments used under Windows ABI
constexpr unsigned kStackRegHomeStorage = 4 * 8; // Register 'home' locations that can be used by callees under Windows ABI
inline unsigned getNonVolXmmStorageSize(ABIX64 abi, uint8_t xmmRegCount)
{
if (abi == ABIX64::SystemV)
return 0;
// First 6 are volatile
if (xmmRegCount <= kWindowsFirstNonVolXmmReg)
return 0;
LUAU_ASSERT(xmmRegCount <= 16);
return (xmmRegCount - kWindowsFirstNonVolXmmReg) * 16;
}
// Useful offsets to specific parts
constexpr unsigned kStackOffsetToLocals = kStackExtraArgumentStorage + kStackRegHomeStorage;
constexpr unsigned kStackOffsetToSpillSlots = kStackOffsetToLocals + kStackLocalStorage;
inline unsigned getFullStackSize(ABIX64 abi, uint8_t xmmRegCount)
{
return kStackOffsetToSpillSlots + kStackSpillStorage + getNonVolXmmStorageSize(abi, xmmRegCount) + kStackAlign;
}
constexpr OperandX64 sClosure = qword[rsp + kStackOffsetToLocals + 0]; // Closure* cl
constexpr OperandX64 sCode = qword[rsp + kStackOffsetToLocals + 8]; // Instruction* code
constexpr OperandX64 sTemporarySlot = addr[rsp + kStackOffsetToLocals + 16];
constexpr OperandX64 sSpillArea = addr[rsp + kStackOffsetToSpillSlots];
inline OperandX64 luauReg(int ri) inline OperandX64 luauReg(int ri)
{ {
@ -114,11 +153,6 @@ inline OperandX64 luauNodeKeyTag(RegisterX64 node)
return dword[node + offsetof(LuaNode, key) + kOffsetOfTKeyTagNext]; return dword[node + offsetof(LuaNode, key) + kOffsetOfTKeyTagNext];
} }
inline OperandX64 luauNodeValue(RegisterX64 node)
{
return xmmword[node + offsetof(LuaNode, val)];
}
inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, OperandX64 op) inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, OperandX64 op)
{ {
LUAU_ASSERT(op.cat == CategoryX64::mem); LUAU_ASSERT(op.cat == CategoryX64::mem);
@ -161,16 +195,17 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label
void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label); void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label);
ConditionX64 getConditionInt(IrCondition cond);
void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos); void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos);
void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label);
void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm); void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm);
void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb); void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb);
void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, int step, int init);
void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra);
void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra);
void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, int ratag, Label& skip); void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, IrOp ra, int ratag, Label& skip);
void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, int ratag); void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, IrOp ra, int ratag);
void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp); void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp);
void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build); void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build);
@ -180,7 +215,7 @@ void emitUpdateBase(AssemblyBuilderX64& build);
void emitInterrupt(AssemblyBuilderX64& build); void emitInterrupt(AssemblyBuilderX64& build);
void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, int pcpos); void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, int pcpos);
void emitUpdatePcAndContinueInVm(AssemblyBuilderX64& build); void emitUpdatePcForExit(AssemblyBuilderX64& build);
void emitContinueCallInVm(AssemblyBuilderX64& build); void emitContinueCallInVm(AssemblyBuilderX64& build);
void emitReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers); void emitReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers);

View file

@ -251,7 +251,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i
} }
} }
void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index) void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index, int knownSize)
{ {
// TODO: This should use IrCallWrapperX64 // TODO: This should use IrCallWrapperX64
RegisterX64 rArg1 = (build.abi == ABIX64::Windows) ? rcx : rdi; RegisterX64 rArg1 = (build.abi == ABIX64::Windows) ? rcx : rdi;
@ -285,12 +285,14 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int
build.add(last, index - 1); build.add(last, index - 1);
} }
Label skipResize;
RegisterX64 table = regs.takeReg(rax, kInvalidInstIdx); RegisterX64 table = regs.takeReg(rax, kInvalidInstIdx);
build.mov(table, luauRegValue(ra)); build.mov(table, luauRegValue(ra));
if (count == LUA_MULTRET || knownSize < 0 || knownSize < int(index + count - 1))
{
Label skipResize;
// Resize if h->sizearray < last // Resize if h->sizearray < last
build.cmp(dword[table + offsetof(Table, sizearray)], last); build.cmp(dword[table + offsetof(Table, sizearray)], last);
build.jcc(ConditionX64::NotBelow, skipResize); build.jcc(ConditionX64::NotBelow, skipResize);
@ -301,9 +303,10 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int
build.mov(rArg2, table); build.mov(rArg2, table);
build.mov(rArg1, rState); build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_resizearray)]); build.call(qword[rNativeContext + offsetof(NativeContext, luaH_resizearray)]);
build.mov(table, luauRegValue(ra)); // Reload cloberred register value build.mov(table, luauRegValue(ra)); // Reload clobbered register value
build.setLabel(skipResize); build.setLabel(skipResize);
}
RegisterX64 arrayDst = rdx; RegisterX64 arrayDst = rdx;
RegisterX64 offset = rcx; RegisterX64 offset = rcx;

View file

@ -19,7 +19,7 @@ struct IrRegAllocX64;
void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults);
void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults, bool functionVariadic); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults, bool functionVariadic);
void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index); void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index, int knownSize);
void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat); void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat);
} // namespace X64 } // namespace X64

View file

@ -186,75 +186,12 @@ void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, ui
} }
} }
static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& block, RegisterSet& defRs, std::bitset<256>& capturedRegs) template<typename T>
static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrBlock& block)
{ {
RegisterSet inRs;
auto def = [&](IrOp op, int offset = 0) {
defRs.regs.set(vmRegOp(op) + offset, true);
};
auto use = [&](IrOp op, int offset = 0) {
if (!defRs.regs.test(vmRegOp(op) + offset))
inRs.regs.set(vmRegOp(op) + offset, true);
};
auto maybeDef = [&](IrOp op) {
if (op.kind == IrOpKind::VmReg)
defRs.regs.set(vmRegOp(op), true);
};
auto maybeUse = [&](IrOp op) {
if (op.kind == IrOpKind::VmReg)
{
if (!defRs.regs.test(vmRegOp(op)))
inRs.regs.set(vmRegOp(op), true);
}
};
auto defVarargs = [&](uint8_t varargStart) {
defRs.varargSeq = true;
defRs.varargStart = varargStart;
};
auto useVarargs = [&](uint8_t varargStart) {
requireVariadicSequence(inRs, defRs, varargStart);
// Variadic sequence has been consumed
defRs.varargSeq = false;
defRs.varargStart = 0;
};
auto defRange = [&](int start, int count) {
if (count == -1)
{
defVarargs(start);
}
else
{
for (int i = start; i < start + count; i++)
defRs.regs.set(i, true);
}
};
auto useRange = [&](int start, int count) {
if (count == -1)
{
useVarargs(start);
}
else
{
for (int i = start; i < start + count; i++)
{
if (!defRs.regs.test(i))
inRs.regs.set(i, true);
}
}
};
for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++)
{ {
const IrInst& inst = function.instructions[instIdx]; IrInst& inst = function.instructions[instIdx];
// For correct analysis, all instruction uses must be handled before handling the definitions // For correct analysis, all instruction uses must be handled before handling the definitions
switch (inst.cmd) switch (inst.cmd)
@ -264,7 +201,7 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock&
case IrCmd::LOAD_DOUBLE: case IrCmd::LOAD_DOUBLE:
case IrCmd::LOAD_INT: case IrCmd::LOAD_INT:
case IrCmd::LOAD_TVALUE: case IrCmd::LOAD_TVALUE:
maybeUse(inst.a); // Argument can also be a VmConst visitor.maybeUse(inst.a); // Argument can also be a VmConst
break; break;
case IrCmd::STORE_TAG: case IrCmd::STORE_TAG:
case IrCmd::STORE_POINTER: case IrCmd::STORE_POINTER:
@ -272,63 +209,55 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock&
case IrCmd::STORE_INT: case IrCmd::STORE_INT:
case IrCmd::STORE_VECTOR: case IrCmd::STORE_VECTOR:
case IrCmd::STORE_TVALUE: case IrCmd::STORE_TVALUE:
maybeDef(inst.a); // Argument can also be a pointer value case IrCmd::STORE_SPLIT_TVALUE:
visitor.maybeDef(inst.a); // Argument can also be a pointer value
break; break;
case IrCmd::CMP_ANY: case IrCmd::CMP_ANY:
use(inst.a); visitor.use(inst.a);
use(inst.b); visitor.use(inst.b);
break; break;
case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_TRUTHY:
case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_IF_FALSY:
use(inst.a); visitor.use(inst.a);
break; break;
// A <- B, C // A <- B, C
case IrCmd::DO_ARITH: case IrCmd::DO_ARITH:
case IrCmd::GET_TABLE: case IrCmd::GET_TABLE:
use(inst.b); visitor.use(inst.b);
maybeUse(inst.c); // Argument can also be a VmConst visitor.maybeUse(inst.c); // Argument can also be a VmConst
def(inst.a); visitor.def(inst.a);
break; break;
case IrCmd::SET_TABLE: case IrCmd::SET_TABLE:
use(inst.a); visitor.use(inst.a);
use(inst.b); visitor.use(inst.b);
maybeUse(inst.c); // Argument can also be a VmConst visitor.maybeUse(inst.c); // Argument can also be a VmConst
break; break;
// A <- B // A <- B
case IrCmd::DO_LEN: case IrCmd::DO_LEN:
use(inst.b); visitor.use(inst.b);
def(inst.a); visitor.def(inst.a);
break; break;
case IrCmd::GET_IMPORT: case IrCmd::GET_IMPORT:
def(inst.a); visitor.def(inst.a);
break; break;
case IrCmd::CONCAT: case IrCmd::CONCAT:
useRange(vmRegOp(inst.a), function.uintOp(inst.b)); visitor.useRange(vmRegOp(inst.a), function.uintOp(inst.b));
defRange(vmRegOp(inst.a), function.uintOp(inst.b)); visitor.defRange(vmRegOp(inst.a), function.uintOp(inst.b));
break; break;
case IrCmd::GET_UPVALUE: case IrCmd::GET_UPVALUE:
def(inst.a); visitor.def(inst.a);
break; break;
case IrCmd::SET_UPVALUE: case IrCmd::SET_UPVALUE:
use(inst.b); visitor.use(inst.b);
break;
case IrCmd::PREPARE_FORN:
use(inst.a);
use(inst.b);
use(inst.c);
def(inst.a);
def(inst.b);
def(inst.c);
break; break;
case IrCmd::INTERRUPT: case IrCmd::INTERRUPT:
break; break;
case IrCmd::BARRIER_OBJ: case IrCmd::BARRIER_OBJ:
case IrCmd::BARRIER_TABLE_FORWARD: case IrCmd::BARRIER_TABLE_FORWARD:
use(inst.b); visitor.maybeUse(inst.b);
break; break;
case IrCmd::CLOSE_UPVALS: case IrCmd::CLOSE_UPVALS:
// Closing an upvalue should be counted as a register use (it copies the fresh register value) // Closing an upvalue should be counted as a register use (it copies the fresh register value)
@ -336,23 +265,23 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock&
// Because we don't plan to optimize captured registers atm, we skip full dataflow analysis for them right now // Because we don't plan to optimize captured registers atm, we skip full dataflow analysis for them right now
break; break;
case IrCmd::CAPTURE: case IrCmd::CAPTURE:
maybeUse(inst.a); visitor.maybeUse(inst.a);
if (function.uintOp(inst.b) == 1) if (function.uintOp(inst.b) == 1)
capturedRegs.set(vmRegOp(inst.a), true); visitor.capture(vmRegOp(inst.a));
break; break;
case IrCmd::SETLIST: case IrCmd::SETLIST:
use(inst.b); visitor.use(inst.b);
useRange(vmRegOp(inst.c), function.intOp(inst.d)); visitor.useRange(vmRegOp(inst.c), function.intOp(inst.d));
break; break;
case IrCmd::CALL: case IrCmd::CALL:
use(inst.a); visitor.use(inst.a);
useRange(vmRegOp(inst.a) + 1, function.intOp(inst.b)); visitor.useRange(vmRegOp(inst.a) + 1, function.intOp(inst.b));
defRange(vmRegOp(inst.a), function.intOp(inst.c)); visitor.defRange(vmRegOp(inst.a), function.intOp(inst.c));
break; break;
case IrCmd::RETURN: case IrCmd::RETURN:
useRange(vmRegOp(inst.a), function.intOp(inst.b)); visitor.useRange(vmRegOp(inst.a), function.intOp(inst.b));
break; break;
// TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it // TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it
@ -364,89 +293,89 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock&
{ {
LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1); LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1);
useRange(vmRegOp(inst.c), count); visitor.useRange(vmRegOp(inst.c), count);
} }
else else
{ {
if (count >= 1) if (count >= 1)
use(inst.c); visitor.use(inst.c);
if (count >= 2) if (count >= 2)
maybeUse(inst.d); // Argument can also be a VmConst visitor.maybeUse(inst.d); // Argument can also be a VmConst
} }
} }
else else
{ {
useVarargs(vmRegOp(inst.c)); visitor.useVarargs(vmRegOp(inst.c));
} }
// Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG
if (int count = function.intOp(inst.f); count != -1) if (int count = function.intOp(inst.f); count != -1)
defRange(vmRegOp(inst.b), count); visitor.defRange(vmRegOp(inst.b), count);
break; break;
case IrCmd::FORGLOOP: case IrCmd::FORGLOOP:
// First register is not used by instruction, we check that it's still 'nil' with CHECK_TAG // First register is not used by instruction, we check that it's still 'nil' with CHECK_TAG
use(inst.a, 1); visitor.use(inst.a, 1);
use(inst.a, 2); visitor.use(inst.a, 2);
def(inst.a, 2); visitor.def(inst.a, 2);
defRange(vmRegOp(inst.a) + 3, function.intOp(inst.b)); visitor.defRange(vmRegOp(inst.a) + 3, function.intOp(inst.b));
break; break;
case IrCmd::FORGLOOP_FALLBACK: case IrCmd::FORGLOOP_FALLBACK:
useRange(vmRegOp(inst.a), 3); visitor.useRange(vmRegOp(inst.a), 3);
def(inst.a, 2); visitor.def(inst.a, 2);
defRange(vmRegOp(inst.a) + 3, uint8_t(function.intOp(inst.b))); // ignore most significant bit visitor.defRange(vmRegOp(inst.a) + 3, uint8_t(function.intOp(inst.b))); // ignore most significant bit
break; break;
case IrCmd::FORGPREP_XNEXT_FALLBACK: case IrCmd::FORGPREP_XNEXT_FALLBACK:
use(inst.b); visitor.use(inst.b);
break; break;
case IrCmd::FALLBACK_GETGLOBAL: case IrCmd::FALLBACK_GETGLOBAL:
def(inst.b); visitor.def(inst.b);
break; break;
case IrCmd::FALLBACK_SETGLOBAL: case IrCmd::FALLBACK_SETGLOBAL:
use(inst.b); visitor.use(inst.b);
break; break;
case IrCmd::FALLBACK_GETTABLEKS: case IrCmd::FALLBACK_GETTABLEKS:
use(inst.c); visitor.use(inst.c);
def(inst.b); visitor.def(inst.b);
break; break;
case IrCmd::FALLBACK_SETTABLEKS: case IrCmd::FALLBACK_SETTABLEKS:
use(inst.b); visitor.use(inst.b);
use(inst.c); visitor.use(inst.c);
break; break;
case IrCmd::FALLBACK_NAMECALL: case IrCmd::FALLBACK_NAMECALL:
use(inst.c); visitor.use(inst.c);
defRange(vmRegOp(inst.b), 2); visitor.defRange(vmRegOp(inst.b), 2);
break; break;
case IrCmd::FALLBACK_PREPVARARGS: case IrCmd::FALLBACK_PREPVARARGS:
// No effect on explicitly referenced registers // No effect on explicitly referenced registers
break; break;
case IrCmd::FALLBACK_GETVARARGS: case IrCmd::FALLBACK_GETVARARGS:
defRange(vmRegOp(inst.b), function.intOp(inst.c)); visitor.defRange(vmRegOp(inst.b), function.intOp(inst.c));
break; break;
case IrCmd::FALLBACK_DUPCLOSURE: case IrCmd::FALLBACK_DUPCLOSURE:
def(inst.b); visitor.def(inst.b);
break; break;
case IrCmd::FALLBACK_FORGPREP: case IrCmd::FALLBACK_FORGPREP:
use(inst.b); visitor.use(inst.b);
defRange(vmRegOp(inst.b), 3); visitor.defRange(vmRegOp(inst.b), 3);
break; break;
case IrCmd::ADJUST_STACK_TO_REG: case IrCmd::ADJUST_STACK_TO_REG:
defRange(vmRegOp(inst.a), -1); visitor.defRange(vmRegOp(inst.a), -1);
break; break;
case IrCmd::ADJUST_STACK_TO_TOP: case IrCmd::ADJUST_STACK_TO_TOP:
// While this can be considered to be a vararg consumer, it is already handled in fastcall instructions // While this can be considered to be a vararg consumer, it is already handled in fastcall instructions
break; break;
case IrCmd::GET_TYPEOF: case IrCmd::GET_TYPEOF:
use(inst.a); visitor.use(inst.a);
break; break;
case IrCmd::FINDUPVAL: case IrCmd::FINDUPVAL:
use(inst.a); visitor.use(inst.a);
break; break;
default: default:
@ -460,8 +389,102 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock&
break; break;
} }
} }
}
return inRs; struct BlockVmRegLiveInComputation
{
BlockVmRegLiveInComputation(RegisterSet& defRs, std::bitset<256>& capturedRegs)
: defRs(defRs)
, capturedRegs(capturedRegs)
{
}
RegisterSet& defRs;
std::bitset<256>& capturedRegs;
RegisterSet inRs;
void def(IrOp op, int offset = 0)
{
defRs.regs.set(vmRegOp(op) + offset, true);
}
void use(IrOp op, int offset = 0)
{
if (!defRs.regs.test(vmRegOp(op) + offset))
inRs.regs.set(vmRegOp(op) + offset, true);
}
void maybeDef(IrOp op)
{
if (op.kind == IrOpKind::VmReg)
defRs.regs.set(vmRegOp(op), true);
}
void maybeUse(IrOp op)
{
if (op.kind == IrOpKind::VmReg)
{
if (!defRs.regs.test(vmRegOp(op)))
inRs.regs.set(vmRegOp(op), true);
}
}
void defVarargs(uint8_t varargStart)
{
defRs.varargSeq = true;
defRs.varargStart = varargStart;
}
void useVarargs(uint8_t varargStart)
{
requireVariadicSequence(inRs, defRs, varargStart);
// Variadic sequence has been consumed
defRs.varargSeq = false;
defRs.varargStart = 0;
}
void defRange(int start, int count)
{
if (count == -1)
{
defVarargs(start);
}
else
{
for (int i = start; i < start + count; i++)
defRs.regs.set(i, true);
}
}
void useRange(int start, int count)
{
if (count == -1)
{
useVarargs(start);
}
else
{
for (int i = start; i < start + count; i++)
{
if (!defRs.regs.test(i))
inRs.regs.set(i, true);
}
}
}
void capture(int reg)
{
capturedRegs.set(reg, true);
}
};
static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& block, RegisterSet& defRs, std::bitset<256>& capturedRegs)
{
BlockVmRegLiveInComputation visitor(defRs, capturedRegs);
visitVmRegDefsUses(visitor, function, block);
return visitor.inRs;
} }
// The algorithm used here is commonly known as backwards data-flow analysis. // The algorithm used here is commonly known as backwards data-flow analysis.
@ -866,6 +889,79 @@ void computeCfgDominanceTreeChildren(IrFunction& function)
computeBlockOrdering<domChildren>(function, info.domOrdering, /* preOrder */ nullptr, /* postOrder */ nullptr); computeBlockOrdering<domChildren>(function, info.domOrdering, /* preOrder */ nullptr, /* postOrder */ nullptr);
} }
// This algorithm is based on 'A Linear Time Algorithm for Placing Phi-Nodes' [Vugranam C.Sreedhar]
// It uses the optimized form from LLVM that relies an implicit DJ-graph (join edges are edges of the CFG that are not part of the dominance tree)
void computeIteratedDominanceFrontierForDefs(
IdfContext& ctx, const IrFunction& function, const std::vector<uint32_t>& defBlocks, const std::vector<uint32_t>& liveInBlocks)
{
LUAU_ASSERT(!function.cfg.domOrdering.empty());
LUAU_ASSERT(ctx.queue.empty());
LUAU_ASSERT(ctx.worklist.empty());
ctx.idf.clear();
ctx.visits.clear();
ctx.visits.resize(function.blocks.size());
for (uint32_t defBlock : defBlocks)
{
const BlockOrdering& ordering = function.cfg.domOrdering[defBlock];
ctx.queue.push({defBlock, ordering});
}
while (!ctx.queue.empty())
{
IdfContext::BlockAndOrdering root = ctx.queue.top();
ctx.queue.pop();
LUAU_ASSERT(ctx.worklist.empty());
ctx.worklist.push_back(root.blockIdx);
ctx.visits[root.blockIdx].seenInWorklist = true;
while (!ctx.worklist.empty())
{
uint32_t blockIdx = ctx.worklist.back();
ctx.worklist.pop_back();
// Check if successor node is the node where dominance of the current root ends, making it a part of dominance frontier set
for (uint32_t succIdx : successors(function.cfg, blockIdx))
{
const BlockOrdering& succOrdering = function.cfg.domOrdering[succIdx];
// Nodes in the DF of root always have a level that is less than or equal to the level of root
if (succOrdering.depth > root.ordering.depth)
continue;
if (ctx.visits[succIdx].seenInQueue)
continue;
ctx.visits[succIdx].seenInQueue = true;
// Skip successor block if it doesn't have our variable as a live in there
if (std::find(liveInBlocks.begin(), liveInBlocks.end(), succIdx) == liveInBlocks.end())
continue;
ctx.idf.push_back(succIdx);
// If block doesn't have its own definition of the variable, add it to the queue
if (std::find(defBlocks.begin(), defBlocks.end(), succIdx) == defBlocks.end())
ctx.queue.push({succIdx, succOrdering});
}
// Add dominance tree children that haven't been processed yet to the worklist
for (uint32_t domChildIdx : domChildren(function.cfg, blockIdx))
{
if (ctx.visits[domChildIdx].seenInWorklist)
continue;
ctx.visits[domChildIdx].seenInWorklist = true;
ctx.worklist.push_back(domChildIdx);
}
}
}
}
void computeCfgInfo(IrFunction& function) void computeCfgInfo(IrFunction& function)
{ {
computeCfgBlockEdges(function); computeCfgBlockEdges(function);

View file

@ -1,7 +1,9 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/IrBuilder.h" #include "Luau/IrBuilder.h"
#include "Luau/IrAnalysis.h" #include "Luau/Bytecode.h"
#include "Luau/BytecodeUtils.h"
#include "Luau/IrData.h"
#include "Luau/IrUtils.h" #include "Luau/IrUtils.h"
#include "IrTranslation.h" #include "IrTranslation.h"
@ -22,10 +24,14 @@ IrBuilder::IrBuilder()
{ {
} }
static bool hasTypedParameters(Proto* proto)
{
return proto->typeinfo && proto->numparams != 0;
}
static void buildArgumentTypeChecks(IrBuilder& build, Proto* proto) static void buildArgumentTypeChecks(IrBuilder& build, Proto* proto)
{ {
if (!proto->typeinfo || proto->numparams == 0) LUAU_ASSERT(hasTypedParameters(proto));
return;
for (int i = 0; i < proto->numparams; ++i) for (int i = 0; i < proto->numparams; ++i)
{ {
@ -53,31 +59,31 @@ static void buildArgumentTypeChecks(IrBuilder& build, Proto* proto)
switch (tag) switch (tag)
{ {
case LBC_TYPE_NIL: case LBC_TYPE_NIL:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNIL), build.undef(), build.constInt(1)); build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNIL), build.vmExit(kVmExitEntryGuardPc));
break; break;
case LBC_TYPE_BOOLEAN: case LBC_TYPE_BOOLEAN:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBOOLEAN), build.undef(), build.constInt(1)); build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBOOLEAN), build.vmExit(kVmExitEntryGuardPc));
break; break;
case LBC_TYPE_NUMBER: case LBC_TYPE_NUMBER:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNUMBER), build.undef(), build.constInt(1)); build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNUMBER), build.vmExit(kVmExitEntryGuardPc));
break; break;
case LBC_TYPE_STRING: case LBC_TYPE_STRING:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TSTRING), build.undef(), build.constInt(1)); build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TSTRING), build.vmExit(kVmExitEntryGuardPc));
break; break;
case LBC_TYPE_TABLE: case LBC_TYPE_TABLE:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTABLE), build.undef(), build.constInt(1)); build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTABLE), build.vmExit(kVmExitEntryGuardPc));
break; break;
case LBC_TYPE_FUNCTION: case LBC_TYPE_FUNCTION:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TFUNCTION), build.undef(), build.constInt(1)); build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TFUNCTION), build.vmExit(kVmExitEntryGuardPc));
break; break;
case LBC_TYPE_THREAD: case LBC_TYPE_THREAD:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTHREAD), build.undef(), build.constInt(1)); build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTHREAD), build.vmExit(kVmExitEntryGuardPc));
break; break;
case LBC_TYPE_USERDATA: case LBC_TYPE_USERDATA:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.undef(), build.constInt(1)); build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.vmExit(kVmExitEntryGuardPc));
break; break;
case LBC_TYPE_VECTOR: case LBC_TYPE_VECTOR:
build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TVECTOR), build.undef(), build.constInt(1)); build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TVECTOR), build.vmExit(kVmExitEntryGuardPc));
break; break;
} }
@ -103,11 +109,28 @@ void IrBuilder::buildFunctionIr(Proto* proto)
function.proto = proto; function.proto = proto;
function.variadic = proto->is_vararg != 0; function.variadic = proto->is_vararg != 0;
// Reserve entry block
bool generateTypeChecks = hasTypedParameters(proto);
IrOp entry = generateTypeChecks ? block(IrBlockKind::Internal) : IrOp{};
// Rebuild original control flow blocks // Rebuild original control flow blocks
rebuildBytecodeBasicBlocks(proto); rebuildBytecodeBasicBlocks(proto);
function.bcMapping.resize(proto->sizecode, {~0u, ~0u}); function.bcMapping.resize(proto->sizecode, {~0u, ~0u});
if (generateTypeChecks)
{
beginBlock(entry);
buildArgumentTypeChecks(*this, proto);
inst(IrCmd::JUMP, blockAtInst(0));
}
else
{
entry = blockAtInst(0);
}
function.entryBlock = entry.index;
// Translate all instructions to IR inside blocks // Translate all instructions to IR inside blocks
for (int i = 0; i < proto->sizecode;) for (int i = 0; i < proto->sizecode;)
{ {
@ -123,12 +146,15 @@ void IrBuilder::buildFunctionIr(Proto* proto)
if (instIndexToBlock[i] != kNoAssociatedBlockIndex) if (instIndexToBlock[i] != kNoAssociatedBlockIndex)
beginBlock(blockAtInst(i)); beginBlock(blockAtInst(i));
if (i == 0)
buildArgumentTypeChecks(*this, proto);
// We skip dead bytecode instructions when they appear after block was already terminated // We skip dead bytecode instructions when they appear after block was already terminated
if (!inTerminatedBlock) if (!inTerminatedBlock)
{ {
if (interruptRequested)
{
interruptRequested = false;
inst(IrCmd::INTERRUPT, constUint(i));
}
translateInst(op, pc, i); translateInst(op, pc, i);
if (fastcallSkipTarget != -1) if (fastcallSkipTarget != -1)
@ -313,6 +339,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_DIV: case LOP_DIV:
translateInstBinary(*this, pc, i, TM_DIV); translateInstBinary(*this, pc, i, TM_DIV);
break; break;
case LOP_IDIV:
translateInstBinary(*this, pc, i, TM_IDIV);
break;
case LOP_MOD: case LOP_MOD:
translateInstBinary(*this, pc, i, TM_MOD); translateInstBinary(*this, pc, i, TM_MOD);
break; break;
@ -331,6 +360,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_DIVK: case LOP_DIVK:
translateInstBinaryK(*this, pc, i, TM_DIV); translateInstBinaryK(*this, pc, i, TM_DIV);
break; break;
case LOP_IDIVK:
translateInstBinaryK(*this, pc, i, TM_IDIV);
break;
case LOP_MODK: case LOP_MODK:
translateInstBinaryK(*this, pc, i, TM_MOD); translateInstBinaryK(*this, pc, i, TM_MOD);
break; break;
@ -353,7 +385,8 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
translateInstDupTable(*this, pc, i); translateInstDupTable(*this, pc, i);
break; break;
case LOP_SETLIST: case LOP_SETLIST:
inst(IrCmd::SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); inst(IrCmd::SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1]),
undef());
break; break;
case LOP_GETUPVAL: case LOP_GETUPVAL:
translateInstGetUpval(*this, pc, i); translateInstGetUpval(*this, pc, i);

View file

@ -13,7 +13,7 @@ namespace CodeGen
namespace X64 namespace X64
{ {
static const std::array<OperandX64, 6> kWindowsGprOrder = {rcx, rdx, r8, r9, addr[rsp + 32], addr[rsp + 40]}; static const std::array<OperandX64, 6> kWindowsGprOrder = {rcx, rdx, r8, r9, addr[rsp + kStackRegHomeStorage], addr[rsp + kStackRegHomeStorage + 8]};
static const std::array<OperandX64, 6> kSystemvGprOrder = {rdi, rsi, rdx, rcx, r8, r9}; static const std::array<OperandX64, 6> kSystemvGprOrder = {rdi, rsi, rdx, rcx, r8, r9};
static const std::array<OperandX64, 4> kXmmOrder = {xmm0, xmm1, xmm2, xmm3}; // Common order for first 4 fp arguments on Windows/SystemV static const std::array<OperandX64, 4> kXmmOrder = {xmm0, xmm1, xmm2, xmm3}; // Common order for first 4 fp arguments on Windows/SystemV

View file

@ -89,8 +89,6 @@ const char* getCmdName(IrCmd cmd)
return "LOAD_INT"; return "LOAD_INT";
case IrCmd::LOAD_TVALUE: case IrCmd::LOAD_TVALUE:
return "LOAD_TVALUE"; return "LOAD_TVALUE";
case IrCmd::LOAD_NODE_VALUE_TV:
return "LOAD_NODE_VALUE_TV";
case IrCmd::LOAD_ENV: case IrCmd::LOAD_ENV:
return "LOAD_ENV"; return "LOAD_ENV";
case IrCmd::GET_ARR_ADDR: case IrCmd::GET_ARR_ADDR:
@ -113,8 +111,8 @@ const char* getCmdName(IrCmd cmd)
return "STORE_VECTOR"; return "STORE_VECTOR";
case IrCmd::STORE_TVALUE: case IrCmd::STORE_TVALUE:
return "STORE_TVALUE"; return "STORE_TVALUE";
case IrCmd::STORE_NODE_VALUE_TV: case IrCmd::STORE_SPLIT_TVALUE:
return "STORE_NODE_VALUE_TV"; return "STORE_SPLIT_TVALUE";
case IrCmd::ADD_INT: case IrCmd::ADD_INT:
return "ADD_INT"; return "ADD_INT";
case IrCmd::SUB_INT: case IrCmd::SUB_INT:
@ -127,6 +125,8 @@ const char* getCmdName(IrCmd cmd)
return "MUL_NUM"; return "MUL_NUM";
case IrCmd::DIV_NUM: case IrCmd::DIV_NUM:
return "DIV_NUM"; return "DIV_NUM";
case IrCmd::IDIV_NUM:
return "IDIV_NUM";
case IrCmd::MOD_NUM: case IrCmd::MOD_NUM:
return "MOD_NUM"; return "MOD_NUM";
case IrCmd::MIN_NUM: case IrCmd::MIN_NUM:
@ -157,12 +157,8 @@ const char* getCmdName(IrCmd cmd)
return "JUMP_IF_FALSY"; return "JUMP_IF_FALSY";
case IrCmd::JUMP_EQ_TAG: case IrCmd::JUMP_EQ_TAG:
return "JUMP_EQ_TAG"; return "JUMP_EQ_TAG";
case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_CMP_INT:
return "JUMP_EQ_INT"; return "JUMP_CMP_INT";
case IrCmd::JUMP_LT_INT:
return "JUMP_LT_INT";
case IrCmd::JUMP_GE_UINT:
return "JUMP_GE_UINT";
case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_EQ_POINTER:
return "JUMP_EQ_POINTER"; return "JUMP_EQ_POINTER";
case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_NUM:
@ -171,6 +167,8 @@ const char* getCmdName(IrCmd cmd)
return "JUMP_SLOT_MATCH"; return "JUMP_SLOT_MATCH";
case IrCmd::TABLE_LEN: case IrCmd::TABLE_LEN:
return "TABLE_LEN"; return "TABLE_LEN";
case IrCmd::TABLE_SETNUM:
return "TABLE_SETNUM";
case IrCmd::STRING_LEN: case IrCmd::STRING_LEN:
return "STRING_LEN"; return "STRING_LEN";
case IrCmd::NEW_TABLE: case IrCmd::NEW_TABLE:
@ -215,8 +213,6 @@ const char* getCmdName(IrCmd cmd)
return "GET_UPVALUE"; return "GET_UPVALUE";
case IrCmd::SET_UPVALUE: case IrCmd::SET_UPVALUE:
return "SET_UPVALUE"; return "SET_UPVALUE";
case IrCmd::PREPARE_FORN:
return "PREPARE_FORN";
case IrCmd::CHECK_TAG: case IrCmd::CHECK_TAG:
return "CHECK_TAG"; return "CHECK_TAG";
case IrCmd::CHECK_TRUTHY: case IrCmd::CHECK_TRUTHY:
@ -233,6 +229,8 @@ const char* getCmdName(IrCmd cmd)
return "CHECK_SLOT_MATCH"; return "CHECK_SLOT_MATCH";
case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::CHECK_NODE_NO_NEXT:
return "CHECK_NODE_NO_NEXT"; return "CHECK_NODE_NO_NEXT";
case IrCmd::CHECK_NODE_VALUE:
return "CHECK_NODE_VALUE";
case IrCmd::INTERRUPT: case IrCmd::INTERRUPT:
return "INTERRUPT"; return "INTERRUPT";
case IrCmd::CHECK_GC: case IrCmd::CHECK_GC:
@ -402,7 +400,7 @@ void toString(IrToStringContext& ctx, IrOp op)
append(ctx.result, "U%d", vmUpvalueOp(op)); append(ctx.result, "U%d", vmUpvalueOp(op));
break; break;
case IrOpKind::VmExit: case IrOpKind::VmExit:
append(ctx.result, "exit(%d)", op.index); append(ctx.result, "exit(%d)", vmExitOp(op));
break; break;
} }
} }

View file

@ -1,10 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "IrLoweringA64.h" #include "IrLoweringA64.h"
#include "Luau/CodeGen.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/IrAnalysis.h" #include "Luau/IrData.h"
#include "Luau/IrDump.h"
#include "Luau/IrUtils.h" #include "Luau/IrUtils.h"
#include "EmitCommonA64.h" #include "EmitCommonA64.h"
@ -60,28 +58,56 @@ inline ConditionA64 getConditionFP(IrCondition cond)
} }
} }
static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64 object, RegisterA64 temp, int ra, int ratag, Label& skip) inline ConditionA64 getConditionInt(IrCondition cond)
{ {
RegisterA64 tempw = castReg(KindA64::w, temp); switch (cond)
{
case IrCondition::Equal:
return ConditionA64::Equal;
// Barrier should've been optimized away if we know that it's not collectable, checking for correctness case IrCondition::NotEqual:
if (ratag == -1 || !isGCO(ratag)) return ConditionA64::NotEqual;
{
// iscollectable(ra) case IrCondition::Less:
build.ldr(tempw, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, tt))); return ConditionA64::Minus;
build.cmp(tempw, LUA_TSTRING);
build.b(ConditionA64::Less, skip); case IrCondition::NotLess:
return ConditionA64::Plus;
case IrCondition::LessEqual:
return ConditionA64::LessEqual;
case IrCondition::NotLessEqual:
return ConditionA64::Greater;
case IrCondition::Greater:
return ConditionA64::Greater;
case IrCondition::NotGreater:
return ConditionA64::LessEqual;
case IrCondition::GreaterEqual:
return ConditionA64::GreaterEqual;
case IrCondition::NotGreaterEqual:
return ConditionA64::Less;
case IrCondition::UnsignedLess:
return ConditionA64::CarryClear;
case IrCondition::UnsignedLessEqual:
return ConditionA64::UnsignedLessEqual;
case IrCondition::UnsignedGreater:
return ConditionA64::UnsignedGreater;
case IrCondition::UnsignedGreaterEqual:
return ConditionA64::CarrySet;
default:
LUAU_ASSERT(!"Unexpected condition code");
return ConditionA64::Always;
} }
// isblack(obj2gco(o))
build.ldrb(tempw, mem(object, offsetof(GCheader, marked)));
build.tbz(tempw, BLACKBIT, skip);
// iswhite(gcvalue(ra))
build.ldr(temp, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, value)));
build.ldrb(tempw, mem(temp, offsetof(GCheader, marked)));
build.tst(tempw, bit2mask(WHITE0BIT, WHITE1BIT));
build.b(ConditionA64::Equal, skip); // Equal = Zero after tst
} }
static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA64 src, size_t offset) static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA64 src, size_t offset)
@ -100,6 +126,47 @@ static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA6
} }
} }
static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64 object, RegisterA64 temp, IrOp ra, int ratag, Label& skip)
{
RegisterA64 tempw = castReg(KindA64::w, temp);
AddressA64 addr = temp;
// iscollectable(ra)
if (ratag == -1 || !isGCO(ratag))
{
if (ra.kind == IrOpKind::VmReg)
{
addr = mem(rBase, vmRegOp(ra) * sizeof(TValue) + offsetof(TValue, tt));
}
else if (ra.kind == IrOpKind::VmConst)
{
emitAddOffset(build, temp, rConstants, vmConstOp(ra) * sizeof(TValue) + offsetof(TValue, tt));
}
build.ldr(tempw, addr);
build.cmp(tempw, LUA_TSTRING);
build.b(ConditionA64::Less, skip);
}
// isblack(obj2gco(o))
build.ldrb(tempw, mem(object, offsetof(GCheader, marked)));
build.tbz(tempw, BLACKBIT, skip);
// iswhite(gcvalue(ra))
if (ra.kind == IrOpKind::VmReg)
{
addr = mem(rBase, vmRegOp(ra) * sizeof(TValue) + offsetof(TValue, value));
}
else if (ra.kind == IrOpKind::VmConst)
{
emitAddOffset(build, temp, rConstants, vmConstOp(ra) * sizeof(TValue) + offsetof(TValue, value));
}
build.ldr(temp, addr);
build.ldrb(tempw, mem(temp, offsetof(GCheader, marked)));
build.tst(tempw, bit2mask(WHITE0BIT, WHITE1BIT));
build.b(ConditionA64::Equal, skip); // Equal = Zero after tst
}
static void emitAbort(AssemblyBuilderA64& build, Label& abort) static void emitAbort(AssemblyBuilderA64& build, Label& abort)
{ {
Label skip; Label skip;
@ -124,6 +191,7 @@ static void emitFallback(AssemblyBuilderA64& build, int offset, int pcpos)
static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg) static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg)
{ {
LUAU_ASSERT(kTempSlots >= 1);
build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n)));
build.add(x0, sp, sTemporary.data); // sp-relative offset build.add(x0, sp, sTemporary.data); // sp-relative offset
build.ldr(x1, mem(rNativeContext, uint32_t(func))); build.ldr(x1, mem(rNativeContext, uint32_t(func)));
@ -172,11 +240,12 @@ static bool emitBuiltin(
} }
} }
IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function) IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats)
: build(build) : build(build)
, helpers(helpers) , helpers(helpers)
, function(function) , function(function)
, regs(function, {{x0, x15}, {x16, x17}, {q0, q7}, {q16, q31}}) , stats(stats)
, regs(function, stats, {{x0, x15}, {x16, x17}, {q0, q7}, {q16, q31}})
, valueTracker(function) , valueTracker(function)
, exitHandlerMap(~0u) , exitHandlerMap(~0u)
{ {
@ -189,7 +258,7 @@ IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers,
}); });
} }
void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
valueTracker.beforeInstLowering(inst); valueTracker.beforeInstLowering(inst);
@ -226,16 +295,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::LOAD_TVALUE: case IrCmd::LOAD_TVALUE:
{ {
inst.regA64 = regs.allocReg(KindA64::q, index); inst.regA64 = regs.allocReg(KindA64::q, index);
AddressA64 addr = tempAddr(inst.a, 0);
int addrOffset = inst.b.kind != IrOpKind::None ? intOp(inst.b) : 0;
AddressA64 addr = tempAddr(inst.a, addrOffset);
build.ldr(inst.regA64, addr); build.ldr(inst.regA64, addr);
break; break;
} }
case IrCmd::LOAD_NODE_VALUE_TV:
{
inst.regA64 = regs.allocReg(KindA64::q, index);
build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(LuaNode, val)));
break;
}
case IrCmd::LOAD_ENV: case IrCmd::LOAD_ENV:
inst.regA64 = regs.allocReg(KindA64::x, index); inst.regA64 = regs.allocReg(KindA64::x, index);
build.ldr(inst.regA64, mem(rClosure, offsetof(Closure, env))); build.ldr(inst.regA64, mem(rClosure, offsetof(Closure, env)));
@ -247,7 +312,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
if (inst.b.kind == IrOpKind::Inst) if (inst.b.kind == IrOpKind::Inst)
{ {
build.add(inst.regA64, inst.regA64, zextReg(regOp(inst.b)), kTValueSizeLog2); build.add(inst.regA64, inst.regA64, regOp(inst.b), kTValueSizeLog2);
} }
else if (inst.b.kind == IrOpKind::Constant) else if (inst.b.kind == IrOpKind::Constant)
{ {
@ -276,6 +341,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp1 = regs.allocTemp(KindA64::x);
RegisterA64 temp1w = castReg(KindA64::w, temp1); RegisterA64 temp1w = castReg(KindA64::w, temp1);
RegisterA64 temp2 = regs.allocTemp(KindA64::w); RegisterA64 temp2 = regs.allocTemp(KindA64::w);
RegisterA64 temp2x = castReg(KindA64::x, temp2);
// note: since the stride of the load is the same as the destination register size, we can range check the array index, not the byte offset // note: since the stride of the load is the same as the destination register size, we can range check the array index, not the byte offset
if (uintOp(inst.b) <= AddressA64::kMaxOffset) if (uintOp(inst.b) <= AddressA64::kMaxOffset)
@ -293,7 +359,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
// note: this may clobber inst.a, so it's important that we don't use it after this // note: this may clobber inst.a, so it's important that we don't use it after this
build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node)));
build.add(inst.regA64, inst.regA64, zextReg(temp2), kLuaNodeSizeLog2); build.add(inst.regA64, inst.regA64, temp2x, kLuaNodeSizeLog2); // "zero extend" temp2 to get a larger shift (top 32 bits are zero)
break; break;
} }
case IrCmd::GET_HASH_NODE_ADDR: case IrCmd::GET_HASH_NODE_ADDR:
@ -301,6 +367,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a});
RegisterA64 temp1 = regs.allocTemp(KindA64::w); RegisterA64 temp1 = regs.allocTemp(KindA64::w);
RegisterA64 temp2 = regs.allocTemp(KindA64::w); RegisterA64 temp2 = regs.allocTemp(KindA64::w);
RegisterA64 temp2x = castReg(KindA64::x, temp2);
// hash & ((1 << lsizenode) - 1) == hash & ~(-1 << lsizenode) // hash & ((1 << lsizenode) - 1) == hash & ~(-1 << lsizenode)
build.mov(temp1, -1); build.mov(temp1, -1);
@ -311,7 +378,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
// note: this may clobber inst.a, so it's important that we don't use it after this // note: this may clobber inst.a, so it's important that we don't use it after this
build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node)));
build.add(inst.regA64, inst.regA64, zextReg(temp2), kLuaNodeSizeLog2); build.add(inst.regA64, inst.regA64, temp2x, kLuaNodeSizeLog2); // "zero extend" temp2 to get a larger shift (top 32 bits are zero)
break; break;
} }
case IrCmd::GET_CLOSURE_UPVAL_ADDR: case IrCmd::GET_CLOSURE_UPVAL_ADDR:
@ -324,10 +391,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
} }
case IrCmd::STORE_TAG: case IrCmd::STORE_TAG:
{ {
RegisterA64 temp = regs.allocTemp(KindA64::w);
AddressA64 addr = tempAddr(inst.a, offsetof(TValue, tt)); AddressA64 addr = tempAddr(inst.a, offsetof(TValue, tt));
if (tagOp(inst.b) == 0)
{
build.str(wzr, addr);
}
else
{
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.mov(temp, tagOp(inst.b)); build.mov(temp, tagOp(inst.b));
build.str(temp, addr); build.str(temp, addr);
}
break; break;
} }
case IrCmd::STORE_POINTER: case IrCmd::STORE_POINTER:
@ -345,9 +419,16 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
} }
case IrCmd::STORE_INT: case IrCmd::STORE_INT:
{ {
RegisterA64 temp = tempInt(inst.b);
AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value));
if (inst.b.kind == IrOpKind::Constant && intOp(inst.b) == 0)
{
build.str(wzr, addr);
}
else
{
RegisterA64 temp = tempInt(inst.b);
build.str(temp, addr); build.str(temp, addr);
}
break; break;
} }
case IrCmd::STORE_VECTOR: case IrCmd::STORE_VECTOR:
@ -370,13 +451,48 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
} }
case IrCmd::STORE_TVALUE: case IrCmd::STORE_TVALUE:
{ {
AddressA64 addr = tempAddr(inst.a, 0); int addrOffset = inst.c.kind != IrOpKind::None ? intOp(inst.c) : 0;
AddressA64 addr = tempAddr(inst.a, addrOffset);
build.str(regOp(inst.b), addr); build.str(regOp(inst.b), addr);
break; break;
} }
case IrCmd::STORE_NODE_VALUE_TV: case IrCmd::STORE_SPLIT_TVALUE:
build.str(regOp(inst.b), mem(regOp(inst.a), offsetof(LuaNode, val))); {
int addrOffset = inst.d.kind != IrOpKind::None ? intOp(inst.d) : 0;
RegisterA64 tempt = regs.allocTemp(KindA64::w);
AddressA64 addrt = tempAddr(inst.a, offsetof(TValue, tt) + addrOffset);
build.mov(tempt, tagOp(inst.b));
build.str(tempt, addrt);
AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value) + addrOffset);
if (tagOp(inst.b) == LUA_TBOOLEAN)
{
if (inst.c.kind == IrOpKind::Constant)
{
// note: we reuse tag temp register as value for true booleans, and use built-in zero register for false values
LUAU_ASSERT(LUA_TBOOLEAN == 1);
build.str(intOp(inst.c) ? tempt : wzr, addr);
}
else
build.str(regOp(inst.c), addr);
}
else if (tagOp(inst.b) == LUA_TNUMBER)
{
RegisterA64 temp = tempDouble(inst.c);
build.str(temp, addr);
}
else if (isGCO(tagOp(inst.b)))
{
build.str(regOp(inst.c), addr);
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
}
break; break;
}
case IrCmd::ADD_INT: case IrCmd::ADD_INT:
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});
if (inst.b.kind == IrOpKind::Constant && unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate) if (inst.b.kind == IrOpKind::Constant && unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate)
@ -433,6 +549,15 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.fdiv(inst.regA64, temp1, temp2); build.fdiv(inst.regA64, temp1, temp2);
break; break;
} }
case IrCmd::IDIV_NUM:
{
inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b});
RegisterA64 temp1 = tempDouble(inst.a);
RegisterA64 temp2 = tempDouble(inst.b);
build.fdiv(inst.regA64, temp1, temp2);
build.frintm(inst.regA64, inst.regA64);
break;
}
case IrCmd::MOD_NUM: case IrCmd::MOD_NUM:
{ {
inst.regA64 = regs.allocReg(KindA64::d, index); // can't allocReuse because both A and B are used twice inst.regA64 = regs.allocReg(KindA64::d, index); // can't allocReuse because both A and B are used twice
@ -560,13 +685,11 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
emitUpdateBase(build); emitUpdateBase(build);
// since w0 came from a call, we need to move it so that we don't violate zextReg safety contract inst.regA64 = regs.takeReg(w0, index);
inst.regA64 = regs.allocReg(KindA64::w, index);
build.mov(inst.regA64, w0);
break; break;
} }
case IrCmd::JUMP: case IrCmd::JUMP:
if (inst.a.kind == IrOpKind::VmExit) if (inst.a.kind == IrOpKind::Undef || inst.a.kind == IrOpKind::VmExit)
{ {
Label fresh; Label fresh;
build.b(getTargetLabel(inst.a, fresh)); build.b(getTargetLabel(inst.a, fresh));
@ -644,31 +767,25 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
} }
break; break;
} }
case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_CMP_INT:
if (intOp(inst.b) == 0)
{ {
build.cbz(regOp(inst.a), labelOp(inst.c)); IrCondition cond = conditionOp(inst.c);
if (cond == IrCondition::Equal && intOp(inst.b) == 0)
{
build.cbz(regOp(inst.a), labelOp(inst.d));
}
else if (cond == IrCondition::NotEqual && intOp(inst.b) == 0)
{
build.cbnz(regOp(inst.a), labelOp(inst.d));
} }
else else
{ {
LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate);
build.cmp(regOp(inst.a), uint16_t(intOp(inst.b))); build.cmp(regOp(inst.a), uint16_t(intOp(inst.b)));
build.b(ConditionA64::Equal, labelOp(inst.c)); build.b(getConditionInt(cond), labelOp(inst.d));
} }
jumpOrFallthrough(blockOp(inst.d), next); jumpOrFallthrough(blockOp(inst.e), next);
break;
case IrCmd::JUMP_LT_INT:
LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate);
build.cmp(regOp(inst.a), uint16_t(intOp(inst.b)));
build.b(ConditionA64::Less, labelOp(inst.c));
jumpOrFallthrough(blockOp(inst.d), next);
break;
case IrCmd::JUMP_GE_UINT:
{
LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate);
build.cmp(regOp(inst.a), uint16_t(unsigned(intOp(inst.b))));
build.b(ConditionA64::CarrySet, labelOp(inst.c));
jumpOrFallthrough(blockOp(inst.d), next);
break; break;
} }
case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_EQ_POINTER:
@ -706,16 +823,42 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.mov(x0, reg); build.mov(x0, reg);
build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaH_getn))); build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaH_getn)));
build.blr(x1); build.blr(x1);
inst.regA64 = regs.allocReg(KindA64::d, index);
build.scvtf(inst.regA64, x0); inst.regA64 = regs.takeReg(w0, index);
break; break;
} }
case IrCmd::STRING_LEN: case IrCmd::STRING_LEN:
{ {
RegisterA64 reg = regOp(inst.a);
inst.regA64 = regs.allocReg(KindA64::w, index); inst.regA64 = regs.allocReg(KindA64::w, index);
build.ldr(inst.regA64, mem(reg, offsetof(TString, len))); build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(TString, len)));
break;
}
case IrCmd::TABLE_SETNUM:
{
// note: we need to call regOp before spill so that we don't do redundant reloads
RegisterA64 table = regOp(inst.a);
RegisterA64 key = regOp(inst.b);
RegisterA64 temp = regs.allocTemp(KindA64::w);
regs.spill(build, index, {table, key});
if (w1 != key)
{
build.mov(x1, table);
build.mov(w2, key);
}
else
{
build.mov(temp, w1);
build.mov(x1, table);
build.mov(w2, temp);
}
build.mov(x0, rState);
build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaH_setnum)));
build.blr(x3);
inst.regA64 = regs.takeReg(x0, index);
break; break;
} }
case IrCmd::NEW_TABLE: case IrCmd::NEW_TABLE:
@ -776,8 +919,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
regs.spill(build, index, {temp1}); regs.spill(build, index, {temp1});
build.mov(x0, temp1); build.mov(x0, temp1);
build.mov(w1, intOp(inst.b)); build.mov(w1, intOp(inst.b));
build.ldr(x2, mem(rState, offsetof(lua_State, global))); build.ldr(x2, mem(rGlobalState, offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*)));
build.ldr(x2, mem(x2, offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*)));
build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaT_gettm))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaT_gettm)));
build.blr(x3); build.blr(x3);
@ -812,8 +954,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
inst.regA64 = regs.allocReg(KindA64::w, index); inst.regA64 = regs.allocReg(KindA64::w, index);
RegisterA64 temp = tempDouble(inst.a); RegisterA64 temp = tempDouble(inst.a);
build.fcvtzs(castReg(KindA64::x, inst.regA64), temp); build.fcvtzs(castReg(KindA64::x, inst.regA64), temp);
// truncation needs to clear high bits to preserve zextReg safety contract
build.mov(inst.regA64, inst.regA64);
break; break;
} }
case IrCmd::ADJUST_STACK_TO_REG: case IrCmd::ADJUST_STACK_TO_REG:
@ -828,7 +968,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
else if (inst.b.kind == IrOpKind::Inst) else if (inst.b.kind == IrOpKind::Inst)
{ {
build.add(temp, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); build.add(temp, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue)));
build.add(temp, temp, zextReg(regOp(inst.b)), kTValueSizeLog2); build.add(temp, temp, regOp(inst.b), kTValueSizeLog2);
build.str(temp, mem(rState, offsetof(lua_State, top))); build.str(temp, mem(rState, offsetof(lua_State, top)));
} }
else else
@ -877,9 +1017,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction))); build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction)));
build.blr(x6); build.blr(x6);
// since w0 came from a call, we need to move it so that we don't violate zextReg safety contract inst.regA64 = regs.takeReg(w0, index);
inst.regA64 = regs.allocReg(KindA64::w, index);
build.mov(inst.regA64, w0);
break; break;
} }
case IrCmd::CHECK_FASTCALL_RES: case IrCmd::CHECK_FASTCALL_RES:
@ -974,8 +1112,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::CONCAT: case IrCmd::CONCAT:
regs.spill(build, index); regs.spill(build, index);
build.mov(x0, rState); build.mov(x0, rState);
build.mov(x1, uintOp(inst.b)); build.mov(w1, uintOp(inst.b));
build.mov(x2, vmRegOp(inst.a) + uintOp(inst.b) - 1); build.mov(w2, vmRegOp(inst.a) + uintOp(inst.b) - 1);
build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_concat))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_concat)));
build.blr(x3); build.blr(x3);
@ -1018,8 +1156,10 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.ldr(temp3, mem(rBase, vmRegOp(inst.b) * sizeof(TValue))); build.ldr(temp3, mem(rBase, vmRegOp(inst.b) * sizeof(TValue)));
build.str(temp3, temp2); build.str(temp3, temp2);
if (inst.c.kind == IrOpKind::Undef || isGCO(tagOp(inst.c)))
{
Label skip; Label skip;
checkObjectBarrierConditions(build, temp1, temp2, vmRegOp(inst.b), /* ratag */ -1, skip); checkObjectBarrierConditions(build, temp1, temp2, inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip);
size_t spills = regs.spill(build, index, {temp1}); size_t spills = regs.spill(build, index, {temp1});
@ -1033,23 +1173,13 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
// note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack
build.setLabel(skip); build.setLabel(skip);
}
break; break;
} }
case IrCmd::PREPARE_FORN:
regs.spill(build, index);
build.mov(x0, rState);
build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue)));
build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue)));
build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue)));
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_prepareFORN)));
build.blr(x4);
// note: no emitUpdateBase necessary because prepareFORN does not reallocate stack
break;
case IrCmd::CHECK_TAG: case IrCmd::CHECK_TAG:
{ {
bool continueInVm = (inst.d.kind == IrOpKind::Constant && intOp(inst.d));
Label fresh; // used when guard aborts execution or jumps to a VM exit Label fresh; // used when guard aborts execution or jumps to a VM exit
Label& fail = continueInVm ? helpers.exitContinueVmClearNativeFlag : getTargetLabel(inst.c, fresh); Label& fail = getTargetLabel(inst.c, fresh);
// To support DebugLuauAbortingChecks, CHECK_TAG with VmReg has to be handled // To support DebugLuauAbortingChecks, CHECK_TAG with VmReg has to be handled
RegisterA64 tag = inst.a.kind == IrOpKind::VmReg ? regs.allocTemp(KindA64::w) : regOp(inst.a); RegisterA64 tag = inst.a.kind == IrOpKind::VmReg ? regs.allocTemp(KindA64::w) : regOp(inst.a);
@ -1066,7 +1196,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.cmp(tag, tagOp(inst.b)); build.cmp(tag, tagOp(inst.b));
build.b(ConditionA64::NotEqual, fail); build.b(ConditionA64::NotEqual, fail);
} }
if (!continueInVm)
finalizeTargetLabel(inst.c, fresh); finalizeTargetLabel(inst.c, fresh);
break; break;
} }
@ -1210,14 +1340,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
finalizeTargetLabel(inst.b, fresh); finalizeTargetLabel(inst.b, fresh);
break; break;
} }
case IrCmd::CHECK_NODE_VALUE:
{
Label fresh; // used when guard aborts execution or jumps to a VM exit
RegisterA64 temp = regs.allocTemp(KindA64::w);
build.ldr(temp, mem(regOp(inst.a), offsetof(LuaNode, val.tt)));
LUAU_ASSERT(LUA_TNIL == 0);
build.cbz(temp, getTargetLabel(inst.b, fresh));
finalizeTargetLabel(inst.b, fresh);
break;
}
case IrCmd::INTERRUPT: case IrCmd::INTERRUPT:
{ {
regs.spill(build, index); regs.spill(build, index);
Label self; Label self;
build.ldr(x0, mem(rState, offsetof(lua_State, global))); build.ldr(x0, mem(rGlobalState, offsetof(global_State, cb.interrupt)));
build.ldr(x0, mem(x0, offsetof(global_State, cb.interrupt)));
build.cbnz(x0, self); build.cbnz(x0, self);
Label next = build.setLabel(); Label next = build.setLabel();
@ -1230,11 +1370,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp1 = regs.allocTemp(KindA64::x);
RegisterA64 temp2 = regs.allocTemp(KindA64::x); RegisterA64 temp2 = regs.allocTemp(KindA64::x);
LUAU_ASSERT(offsetof(global_State, totalbytes) == offsetof(global_State, GCthreshold) + 8);
Label skip; Label skip;
build.ldr(temp1, mem(rState, offsetof(lua_State, global))); build.ldp(temp1, temp2, mem(rGlobalState, offsetof(global_State, GCthreshold)));
// TODO: totalbytes and GCthreshold loads can be fused with ldp
build.ldr(temp2, mem(temp1, offsetof(global_State, totalbytes)));
build.ldr(temp1, mem(temp1, offsetof(global_State, GCthreshold)));
build.cmp(temp1, temp2); build.cmp(temp1, temp2);
build.b(ConditionA64::UnsignedGreater, skip); build.b(ConditionA64::UnsignedGreater, skip);
@ -1242,8 +1380,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.mov(x0, rState); build.mov(x0, rState);
build.mov(w1, 1); build.mov(w1, 1);
build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaC_step))); build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaC_step)));
build.blr(x1); build.blr(x2);
emitUpdateBase(build); emitUpdateBase(build);
@ -1257,7 +1395,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
RegisterA64 temp = regs.allocTemp(KindA64::x); RegisterA64 temp = regs.allocTemp(KindA64::x);
Label skip; Label skip;
checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); checkObjectBarrierConditions(build, regOp(inst.a), temp, inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip);
RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads
size_t spills = regs.spill(build, index, {reg}); size_t spills = regs.spill(build, index, {reg});
@ -1301,13 +1439,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
RegisterA64 temp = regs.allocTemp(KindA64::x); RegisterA64 temp = regs.allocTemp(KindA64::x);
Label skip; Label skip;
checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); checkObjectBarrierConditions(build, regOp(inst.a), temp, inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip);
RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads
AddressA64 addr = tempAddr(inst.b, offsetof(TValue, value));
size_t spills = regs.spill(build, index, {reg}); size_t spills = regs.spill(build, index, {reg});
build.mov(x1, reg); build.mov(x1, reg);
build.mov(x0, rState); build.mov(x0, rState);
build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); build.ldr(x2, addr);
build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barriertable))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barriertable)));
build.blr(x3); build.blr(x3);
@ -1453,9 +1592,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
// clear extra variables since we might have more than two // clear extra variables since we might have more than two
if (intOp(inst.b) > 2) if (intOp(inst.b) > 2)
{ {
build.mov(w0, LUA_TNIL); LUAU_ASSERT(LUA_TNIL == 0);
for (int i = 2; i < intOp(inst.b); ++i) for (int i = 2; i < intOp(inst.b); ++i)
build.str(w0, mem(rBase, (vmRegOp(inst.a) + 3 + i) * sizeof(TValue) + offsetof(TValue, tt))); build.str(wzr, mem(rBase, (vmRegOp(inst.a) + 3 + i) * sizeof(TValue) + offsetof(TValue, tt)));
} }
// we use full iter fallback for now; in the future it could be worthwhile to accelerate array iteration here // we use full iter fallback for now; in the future it could be worthwhile to accelerate array iteration here
build.mov(x0, rState); build.mov(x0, rState);
@ -1564,7 +1703,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
{ {
emitAddOffset(build, x1, rCode, uintOp(inst.a) * sizeof(Instruction)); emitAddOffset(build, x1, rCode, uintOp(inst.a) * sizeof(Instruction));
build.mov(x2, rBase); build.mov(x2, rBase);
build.mov(x3, vmRegOp(inst.b)); build.mov(w3, vmRegOp(inst.b));
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, executeGETVARARGSMultRet))); build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, executeGETVARARGSMultRet)));
build.blr(x4); build.blr(x4);
@ -1573,10 +1712,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
else else
{ {
build.mov(x1, rBase); build.mov(x1, rBase);
build.mov(x2, vmRegOp(inst.b)); build.mov(w2, vmRegOp(inst.b));
build.mov(x3, intOp(inst.c)); build.mov(w3, intOp(inst.c));
build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, executeGETVARARGSConst))); build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, executeGETVARARGSConst)));
build.blr(x4); build.blr(x4);
// note: no emitUpdateBase necessary because executeGETVARARGSConst does not reallocate stack
} }
break; break;
case IrCmd::NEWCLOSURE: case IrCmd::NEWCLOSURE:
@ -1793,13 +1934,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
{ {
inst.regA64 = regs.allocReg(KindA64::x, index); inst.regA64 = regs.allocReg(KindA64::x, index);
build.ldr(inst.regA64, mem(rState, offsetof(lua_State, global)));
LUAU_ASSERT(sizeof(TString*) == 8); LUAU_ASSERT(sizeof(TString*) == 8);
if (inst.a.kind == IrOpKind::Inst) if (inst.a.kind == IrOpKind::Inst)
build.add(inst.regA64, inst.regA64, zextReg(regOp(inst.a)), 3); build.add(inst.regA64, rGlobalState, regOp(inst.a), 3);
else if (inst.a.kind == IrOpKind::Constant) else if (inst.a.kind == IrOpKind::Constant)
build.add(inst.regA64, inst.regA64, uint16_t(tagOp(inst.a)) * 8); build.add(inst.regA64, rGlobalState, uint16_t(tagOp(inst.a)) * 8);
else else
LUAU_ASSERT(!"Unsupported instruction form"); LUAU_ASSERT(!"Unsupported instruction form");
@ -1839,9 +1979,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
regs.freeTempRegs(); regs.freeTempRegs();
} }
void IrLoweringA64::finishBlock() void IrLoweringA64::finishBlock(const IrBlock& curr, const IrBlock& next)
{ {
regs.assertNoSpills(); if (!regs.spills.empty())
{
// If we have spills remaining, we have to immediately lower the successor block
for (uint32_t predIdx : predecessors(function.cfg, function.getBlockIndex(next)))
LUAU_ASSERT(predIdx == function.getBlockIndex(curr));
// And the next block cannot be a join block in cfg
LUAU_ASSERT(next.useCount == 1);
}
} }
void IrLoweringA64::finishFunction() void IrLoweringA64::finishFunction()
@ -1862,10 +2010,22 @@ void IrLoweringA64::finishFunction()
for (ExitHandler& handler : exitHandlers) for (ExitHandler& handler : exitHandlers)
{ {
LUAU_ASSERT(handler.pcpos != kVmExitEntryGuardPc);
build.setLabel(handler.self); build.setLabel(handler.self);
build.mov(x0, handler.pcpos * sizeof(Instruction)); build.mov(x0, handler.pcpos * sizeof(Instruction));
build.b(helpers.updatePcAndContinueInVm); build.b(helpers.updatePcAndContinueInVm);
} }
if (stats)
{
if (error)
stats->loweringErrors++;
if (regs.error)
stats->regAllocErrors++;
}
} }
bool IrLoweringA64::hasError() const bool IrLoweringA64::hasError() const
@ -1873,12 +2033,12 @@ bool IrLoweringA64::hasError() const
return error || regs.error; return error || regs.error;
} }
bool IrLoweringA64::isFallthroughBlock(IrBlock target, IrBlock next) bool IrLoweringA64::isFallthroughBlock(const IrBlock& target, const IrBlock& next)
{ {
return target.start == next.start; return target.start == next.start;
} }
void IrLoweringA64::jumpOrFallthrough(IrBlock& target, IrBlock& next) void IrLoweringA64::jumpOrFallthrough(IrBlock& target, const IrBlock& next)
{ {
if (!isFallthroughBlock(target, next)) if (!isFallthroughBlock(target, next))
build.b(target.label); build.b(target.label);
@ -1891,7 +2051,11 @@ Label& IrLoweringA64::getTargetLabel(IrOp op, Label& fresh)
if (op.kind == IrOpKind::VmExit) if (op.kind == IrOpKind::VmExit)
{ {
if (uint32_t* index = exitHandlerMap.find(op.index)) // Special exit case that doesn't have to update pcpos
if (vmExitOp(op) == kVmExitEntryGuardPc)
return helpers.exitContinueVmClearNativeFlag;
if (uint32_t* index = exitHandlerMap.find(vmExitOp(op)))
return exitHandlers[*index].self; return exitHandlers[*index].self;
return fresh; return fresh;
@ -1906,10 +2070,10 @@ void IrLoweringA64::finalizeTargetLabel(IrOp op, Label& fresh)
{ {
emitAbort(build, fresh); emitAbort(build, fresh);
} }
else if (op.kind == IrOpKind::VmExit && fresh.id != 0) else if (op.kind == IrOpKind::VmExit && fresh.id != 0 && fresh.id != helpers.exitContinueVmClearNativeFlag.id)
{ {
exitHandlerMap[op.index] = uint32_t(exitHandlers.size()); exitHandlerMap[vmExitOp(op)] = uint32_t(exitHandlers.size());
exitHandlers.push_back({fresh, op.index}); exitHandlers.push_back({fresh, vmExitOp(op)});
} }
} }

View file

@ -17,22 +17,23 @@ namespace CodeGen
struct ModuleHelpers; struct ModuleHelpers;
struct AssemblyOptions; struct AssemblyOptions;
struct LoweringStats;
namespace A64 namespace A64
{ {
struct IrLoweringA64 struct IrLoweringA64
{ {
IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function); IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats);
void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); void lowerInst(IrInst& inst, uint32_t index, const IrBlock& next);
void finishBlock(); void finishBlock(const IrBlock& curr, const IrBlock& next);
void finishFunction(); void finishFunction();
bool hasError() const; bool hasError() const;
bool isFallthroughBlock(IrBlock target, IrBlock next); bool isFallthroughBlock(const IrBlock& target, const IrBlock& next);
void jumpOrFallthrough(IrBlock& target, IrBlock& next); void jumpOrFallthrough(IrBlock& target, const IrBlock& next);
Label& getTargetLabel(IrOp op, Label& fresh); Label& getTargetLabel(IrOp op, Label& fresh);
void finalizeTargetLabel(IrOp op, Label& fresh); void finalizeTargetLabel(IrOp op, Label& fresh);
@ -74,6 +75,7 @@ struct IrLoweringA64
ModuleHelpers& helpers; ModuleHelpers& helpers;
IrFunction& function; IrFunction& function;
LoweringStats* stats = nullptr;
IrRegAllocA64 regs; IrRegAllocA64 regs;

View file

@ -1,19 +1,19 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "IrLoweringX64.h" #include "IrLoweringX64.h"
#include "Luau/CodeGen.h"
#include "Luau/DenseHash.h" #include "Luau/DenseHash.h"
#include "Luau/IrAnalysis.h" #include "Luau/IrData.h"
#include "Luau/IrCallWrapperX64.h"
#include "Luau/IrDump.h"
#include "Luau/IrUtils.h" #include "Luau/IrUtils.h"
#include "Luau/IrCallWrapperX64.h"
#include "EmitBuiltinsX64.h" #include "EmitBuiltinsX64.h"
#include "EmitCommonX64.h" #include "EmitCommonX64.h"
#include "EmitInstructionX64.h" #include "EmitInstructionX64.h"
#include "NativeState.h" #include "NativeState.h"
#include "lstate.h" #include "lstate.h"
#include "lgc.h"
namespace Luau namespace Luau
{ {
@ -22,11 +22,12 @@ namespace CodeGen
namespace X64 namespace X64
{ {
IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function) IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats)
: build(build) : build(build)
, helpers(helpers) , helpers(helpers)
, function(function) , function(function)
, regs(build, function) , stats(stats)
, regs(build, function, stats)
, valueTracker(function) , valueTracker(function)
, exitHandlerMap(~0u) , exitHandlerMap(~0u)
{ {
@ -59,7 +60,7 @@ void IrLoweringX64::storeDoubleAsFloat(OperandX64 dst, IrOp src)
build.vmovss(dst, tmp.reg); build.vmovss(dst, tmp.reg);
} }
void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{ {
regs.currInstIdx = index; regs.currInstIdx = index;
@ -111,22 +112,21 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.mov(inst.regX64, luauRegValueInt(vmRegOp(inst.a))); build.mov(inst.regX64, luauRegValueInt(vmRegOp(inst.a)));
break; break;
case IrCmd::LOAD_TVALUE: case IrCmd::LOAD_TVALUE:
{
inst.regX64 = regs.allocReg(SizeX64::xmmword, index); inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
int addrOffset = inst.b.kind != IrOpKind::None ? intOp(inst.b) : 0;
if (inst.a.kind == IrOpKind::VmReg) if (inst.a.kind == IrOpKind::VmReg)
build.vmovups(inst.regX64, luauReg(vmRegOp(inst.a))); build.vmovups(inst.regX64, luauReg(vmRegOp(inst.a)));
else if (inst.a.kind == IrOpKind::VmConst) else if (inst.a.kind == IrOpKind::VmConst)
build.vmovups(inst.regX64, luauConstant(vmConstOp(inst.a))); build.vmovups(inst.regX64, luauConstant(vmConstOp(inst.a)));
else if (inst.a.kind == IrOpKind::Inst) else if (inst.a.kind == IrOpKind::Inst)
build.vmovups(inst.regX64, xmmword[regOp(inst.a)]); build.vmovups(inst.regX64, xmmword[regOp(inst.a) + addrOffset]);
else else
LUAU_ASSERT(!"Unsupported instruction form"); LUAU_ASSERT(!"Unsupported instruction form");
break; break;
case IrCmd::LOAD_NODE_VALUE_TV: }
inst.regX64 = regs.allocReg(SizeX64::xmmword, index);
build.vmovups(inst.regX64, luauNodeValue(regOp(inst.a)));
break;
case IrCmd::LOAD_ENV: case IrCmd::LOAD_ENV:
inst.regX64 = regs.allocReg(SizeX64::qword, index); inst.regX64 = regs.allocReg(SizeX64::qword, index);
@ -252,16 +252,59 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 2), inst.d); storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 2), inst.d);
break; break;
case IrCmd::STORE_TVALUE: case IrCmd::STORE_TVALUE:
{
int addrOffset = inst.c.kind != IrOpKind::None ? intOp(inst.c) : 0;
if (inst.a.kind == IrOpKind::VmReg) if (inst.a.kind == IrOpKind::VmReg)
build.vmovups(luauReg(vmRegOp(inst.a)), regOp(inst.b)); build.vmovups(luauReg(vmRegOp(inst.a)), regOp(inst.b));
else if (inst.a.kind == IrOpKind::Inst) else if (inst.a.kind == IrOpKind::Inst)
build.vmovups(xmmword[regOp(inst.a)], regOp(inst.b)); build.vmovups(xmmword[regOp(inst.a) + addrOffset], regOp(inst.b));
else else
LUAU_ASSERT(!"Unsupported instruction form"); LUAU_ASSERT(!"Unsupported instruction form");
break; break;
case IrCmd::STORE_NODE_VALUE_TV: }
build.vmovups(luauNodeValue(regOp(inst.a)), regOp(inst.b)); case IrCmd::STORE_SPLIT_TVALUE:
{
int addrOffset = inst.d.kind != IrOpKind::None ? intOp(inst.d) : 0;
OperandX64 tagLhs = inst.a.kind == IrOpKind::Inst ? dword[regOp(inst.a) + offsetof(TValue, tt) + addrOffset] : luauRegTag(vmRegOp(inst.a));
build.mov(tagLhs, tagOp(inst.b));
if (tagOp(inst.b) == LUA_TBOOLEAN)
{
OperandX64 valueLhs =
inst.a.kind == IrOpKind::Inst ? dword[regOp(inst.a) + offsetof(TValue, value) + addrOffset] : luauRegValueInt(vmRegOp(inst.a));
build.mov(valueLhs, inst.c.kind == IrOpKind::Constant ? OperandX64(intOp(inst.c)) : regOp(inst.c));
}
else if (tagOp(inst.b) == LUA_TNUMBER)
{
OperandX64 valueLhs =
inst.a.kind == IrOpKind::Inst ? qword[regOp(inst.a) + offsetof(TValue, value) + addrOffset] : luauRegValue(vmRegOp(inst.a));
if (inst.c.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c)));
build.vmovsd(valueLhs, tmp.reg);
}
else
{
build.vmovsd(valueLhs, regOp(inst.c));
}
}
else if (isGCO(tagOp(inst.b)))
{
OperandX64 valueLhs =
inst.a.kind == IrOpKind::Inst ? qword[regOp(inst.a) + offsetof(TValue, value) + addrOffset] : luauRegValue(vmRegOp(inst.a));
build.mov(valueLhs, regOp(inst.c));
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
}
break; break;
}
case IrCmd::ADD_INT: case IrCmd::ADD_INT:
{ {
inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a});
@ -365,6 +408,22 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.vdivsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b)); build.vdivsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
} }
break; break;
case IrCmd::IDIV_NUM:
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
if (inst.a.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, memRegDoubleOp(inst.a));
build.vdivsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b));
}
else
{
build.vdivsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
}
build.vroundsd(inst.regX64, inst.regX64, inst.regX64, RoundingModeX64::RoundToNegativeInfinity);
break;
case IrCmd::MOD_NUM: case IrCmd::MOD_NUM:
{ {
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});
@ -565,24 +624,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
break; break;
} }
case IrCmd::JUMP: case IrCmd::JUMP:
if (inst.a.kind == IrOpKind::VmExit) jumpOrAbortOnUndef(inst.a, next);
{
if (uint32_t* index = exitHandlerMap.find(inst.a.index))
{
build.jmp(exitHandlers[*index].self);
}
else
{
Label self;
build.jmp(self);
exitHandlerMap[inst.a.index] = uint32_t(exitHandlers.size());
exitHandlers.push_back({self, inst.a.index});
}
}
else
{
jumpOrFallthrough(blockOp(inst.a), next);
}
break; break;
case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_TRUTHY:
jumpIfTruthy(build, vmRegOp(inst.a), labelOp(inst.b), labelOp(inst.c)); jumpIfTruthy(build, vmRegOp(inst.a), labelOp(inst.b), labelOp(inst.c));
@ -597,6 +639,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(inst.b.kind == IrOpKind::Inst || inst.b.kind == IrOpKind::Constant); LUAU_ASSERT(inst.b.kind == IrOpKind::Inst || inst.b.kind == IrOpKind::Constant);
OperandX64 opb = inst.b.kind == IrOpKind::Inst ? regOp(inst.b) : OperandX64(tagOp(inst.b)); OperandX64 opb = inst.b.kind == IrOpKind::Inst ? regOp(inst.b) : OperandX64(tagOp(inst.b));
if (inst.a.kind == IrOpKind::Constant)
build.cmp(opb, tagOp(inst.a));
else
build.cmp(memRegTagOp(inst.a), opb); build.cmp(memRegTagOp(inst.a), opb);
if (isFallthroughBlock(blockOp(inst.d), next)) if (isFallthroughBlock(blockOp(inst.d), next))
@ -611,42 +656,36 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
} }
break; break;
} }
case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_CMP_INT:
if (intOp(inst.b) == 0)
{ {
IrCondition cond = conditionOp(inst.c);
if ((cond == IrCondition::Equal || cond == IrCondition::NotEqual) && intOp(inst.b) == 0)
{
bool invert = cond == IrCondition::NotEqual;
build.test(regOp(inst.a), regOp(inst.a)); build.test(regOp(inst.a), regOp(inst.a));
if (isFallthroughBlock(blockOp(inst.c), next)) if (isFallthroughBlock(blockOp(inst.d), next))
{ {
build.jcc(ConditionX64::NotZero, labelOp(inst.d)); build.jcc(invert ? ConditionX64::Zero : ConditionX64::NotZero, labelOp(inst.e));
jumpOrFallthrough(blockOp(inst.c), next); jumpOrFallthrough(blockOp(inst.d), next);
} }
else else
{ {
build.jcc(ConditionX64::Zero, labelOp(inst.c)); build.jcc(invert ? ConditionX64::NotZero : ConditionX64::Zero, labelOp(inst.d));
jumpOrFallthrough(blockOp(inst.d), next); jumpOrFallthrough(blockOp(inst.e), next);
} }
} }
else else
{ {
build.cmp(regOp(inst.a), intOp(inst.b)); build.cmp(regOp(inst.a), intOp(inst.b));
build.jcc(ConditionX64::Equal, labelOp(inst.c)); build.jcc(getConditionInt(cond), labelOp(inst.d));
jumpOrFallthrough(blockOp(inst.d), next); jumpOrFallthrough(blockOp(inst.e), next);
} }
break; break;
case IrCmd::JUMP_LT_INT: }
build.cmp(regOp(inst.a), intOp(inst.b));
build.jcc(ConditionX64::Less, labelOp(inst.c));
jumpOrFallthrough(blockOp(inst.d), next);
break;
case IrCmd::JUMP_GE_UINT:
build.cmp(regOp(inst.a), unsigned(intOp(inst.b)));
build.jcc(ConditionX64::AboveEqual, labelOp(inst.c));
jumpOrFallthrough(blockOp(inst.d), next);
break;
case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_EQ_POINTER:
build.cmp(regOp(inst.a), regOp(inst.b)); build.cmp(regOp(inst.a), regOp(inst.b));
@ -659,7 +698,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
ScopedRegX64 tmp{regs, SizeX64::xmmword}; ScopedRegX64 tmp{regs, SizeX64::xmmword};
// TODO: jumpOnNumberCmp should work on IrCondition directly
jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), cond, labelOp(inst.d)); jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), cond, labelOp(inst.d));
jumpOrFallthrough(blockOp(inst.e), next); jumpOrFallthrough(blockOp(inst.e), next);
break; break;
@ -669,9 +707,17 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
IrCallWrapperX64 callWrap(regs, build, index); IrCallWrapperX64 callWrap(regs, build, index);
callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a);
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]);
inst.regX64 = regs.takeReg(eax, index);
inst.regX64 = regs.allocReg(SizeX64::xmmword, index); break;
build.vcvtsi2sd(inst.regX64, inst.regX64, eax); }
case IrCmd::TABLE_SETNUM:
{
IrCallWrapperX64 callWrap(regs, build, index);
callWrap.addArgument(SizeX64::qword, rState);
callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a);
callWrap.addArgument(SizeX64::dword, regOp(inst.b), inst.b);
callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_setnum)]);
inst.regX64 = regs.takeReg(rax, index);
break; break;
} }
case IrCmd::STRING_LEN: case IrCmd::STRING_LEN:
@ -968,19 +1014,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
tmp1.free(); tmp1.free();
callBarrierObject(regs, build, tmp2.release(), {}, vmRegOp(inst.b), /* ratag */ -1); if (inst.c.kind == IrOpKind::Undef || isGCO(tagOp(inst.c)))
callBarrierObject(regs, build, tmp2.release(), {}, inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c));
break; break;
} }
case IrCmd::PREPARE_FORN:
callPrepareForN(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c));
break;
case IrCmd::CHECK_TAG: case IrCmd::CHECK_TAG:
{
bool continueInVm = (inst.d.kind == IrOpKind::Constant && intOp(inst.d));
build.cmp(memRegTagOp(inst.a), tagOp(inst.b)); build.cmp(memRegTagOp(inst.a), tagOp(inst.b));
jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.c, continueInVm); jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.c, next);
break; break;
}
case IrCmd::CHECK_TRUTHY: case IrCmd::CHECK_TRUTHY:
{ {
// Constant tags which don't require boolean value check should've been removed in constant folding // Constant tags which don't require boolean value check should've been removed in constant folding
@ -992,7 +1033,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
{ {
// Fail to fallback on 'nil' (falsy) // Fail to fallback on 'nil' (falsy)
build.cmp(memRegTagOp(inst.a), LUA_TNIL); build.cmp(memRegTagOp(inst.a), LUA_TNIL);
jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.c); jumpOrAbortOnUndef(ConditionX64::Equal, inst.c, next);
// Skip value test if it's not a boolean (truthy) // Skip value test if it's not a boolean (truthy)
build.cmp(memRegTagOp(inst.a), LUA_TBOOLEAN); build.cmp(memRegTagOp(inst.a), LUA_TBOOLEAN);
@ -1001,7 +1042,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
// fail to fallback on 'false' boolean value (falsy) // fail to fallback on 'false' boolean value (falsy)
build.cmp(memRegUintOp(inst.b), 0); build.cmp(memRegUintOp(inst.b), 0);
jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.c); jumpOrAbortOnUndef(ConditionX64::Equal, inst.c, next);
if (inst.a.kind != IrOpKind::Constant) if (inst.a.kind != IrOpKind::Constant)
build.setLabel(skip); build.setLabel(skip);
@ -1009,11 +1050,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
} }
case IrCmd::CHECK_READONLY: case IrCmd::CHECK_READONLY:
build.cmp(byte[regOp(inst.a) + offsetof(Table, readonly)], 0); build.cmp(byte[regOp(inst.a) + offsetof(Table, readonly)], 0);
jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.b); jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.b, next);
break; break;
case IrCmd::CHECK_NO_METATABLE: case IrCmd::CHECK_NO_METATABLE:
build.cmp(qword[regOp(inst.a) + offsetof(Table, metatable)], 0); build.cmp(qword[regOp(inst.a) + offsetof(Table, metatable)], 0);
jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.b); jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.b, next);
break; break;
case IrCmd::CHECK_SAFE_ENV: case IrCmd::CHECK_SAFE_ENV:
{ {
@ -1023,7 +1064,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]); build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]);
build.cmp(byte[tmp.reg + offsetof(Table, safeenv)], 0); build.cmp(byte[tmp.reg + offsetof(Table, safeenv)], 0);
jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.a); jumpOrAbortOnUndef(ConditionX64::Equal, inst.a, next);
break; break;
} }
case IrCmd::CHECK_ARRAY_SIZE: case IrCmd::CHECK_ARRAY_SIZE:
@ -1034,7 +1075,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
else else
LUAU_ASSERT(!"Unsupported instruction form"); LUAU_ASSERT(!"Unsupported instruction form");
jumpOrAbortOnUndef(ConditionX64::BelowEqual, ConditionX64::NotBelowEqual, inst.c); jumpOrAbortOnUndef(ConditionX64::BelowEqual, inst.c, next);
break; break;
case IrCmd::JUMP_SLOT_MATCH: case IrCmd::JUMP_SLOT_MATCH:
case IrCmd::CHECK_SLOT_MATCH: case IrCmd::CHECK_SLOT_MATCH:
@ -1080,7 +1121,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.mov(tmp.reg, dword[regOp(inst.a) + offsetof(LuaNode, key) + kOffsetOfTKeyTagNext]); build.mov(tmp.reg, dword[regOp(inst.a) + offsetof(LuaNode, key) + kOffsetOfTKeyTagNext]);
build.shr(tmp.reg, kTKeyTagBits); build.shr(tmp.reg, kTKeyTagBits);
jumpOrAbortOnUndef(ConditionX64::NotZero, ConditionX64::Zero, inst.b); jumpOrAbortOnUndef(ConditionX64::NotZero, inst.b, next);
break;
}
case IrCmd::CHECK_NODE_VALUE:
{
build.cmp(dword[regOp(inst.a) + offsetof(LuaNode, val) + offsetof(TValue, tt)], LUA_TNIL);
jumpOrAbortOnUndef(ConditionX64::Equal, inst.b, next);
break; break;
} }
case IrCmd::INTERRUPT: case IrCmd::INTERRUPT:
@ -1109,7 +1156,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
callStepGc(regs, build); callStepGc(regs, build);
break; break;
case IrCmd::BARRIER_OBJ: case IrCmd::BARRIER_OBJ:
callBarrierObject(regs, build, regOp(inst.a), inst.a, vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c)); callBarrierObject(regs, build, regOp(inst.a), inst.a, inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c));
break; break;
case IrCmd::BARRIER_TABLE_BACK: case IrCmd::BARRIER_TABLE_BACK:
callBarrierTableFast(regs, build, regOp(inst.a), inst.a); callBarrierTableFast(regs, build, regOp(inst.a), inst.a);
@ -1119,7 +1166,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
Label skip; Label skip;
ScopedRegX64 tmp{regs, SizeX64::qword}; ScopedRegX64 tmp{regs, SizeX64::qword};
checkObjectBarrierConditions(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip);
checkObjectBarrierConditions(build, tmp.reg, regOp(inst.a), inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip);
{ {
ScopedSpills spillGuard(regs); ScopedSpills spillGuard(regs);
@ -1182,7 +1230,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
// Fallbacks to non-IR instruction implementations // Fallbacks to non-IR instruction implementations
case IrCmd::SETLIST: case IrCmd::SETLIST:
regs.assertAllFree(); regs.assertAllFree();
emitInstSetList(regs, build, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e)); emitInstSetList(
regs, build, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e), inst.f.kind == IrOpKind::Undef ? -1 : int(uintOp(inst.f)));
break; break;
case IrCmd::CALL: case IrCmd::CALL:
regs.assertAllFree(); regs.assertAllFree();
@ -1560,9 +1609,17 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
regs.freeLastUseRegs(inst, index); regs.freeLastUseRegs(inst, index);
} }
void IrLoweringX64::finishBlock() void IrLoweringX64::finishBlock(const IrBlock& curr, const IrBlock& next)
{ {
regs.assertNoSpills(); if (!regs.spills.empty())
{
// If we have spills remaining, we have to immediately lower the successor block
for (uint32_t predIdx : predecessors(function.cfg, function.getBlockIndex(next)))
LUAU_ASSERT(predIdx == function.getBlockIndex(curr));
// And the next block cannot be a join block in cfg
LUAU_ASSERT(next.useCount == 1);
}
} }
void IrLoweringX64::finishFunction() void IrLoweringX64::finishFunction()
@ -1583,10 +1640,22 @@ void IrLoweringX64::finishFunction()
for (ExitHandler& handler : exitHandlers) for (ExitHandler& handler : exitHandlers)
{ {
LUAU_ASSERT(handler.pcpos != kVmExitEntryGuardPc);
build.setLabel(handler.self); build.setLabel(handler.self);
build.mov(edx, handler.pcpos * sizeof(Instruction)); build.mov(edx, handler.pcpos * sizeof(Instruction));
build.jmp(helpers.updatePcAndContinueInVm); build.jmp(helpers.updatePcAndContinueInVm);
} }
if (stats)
{
if (regs.maxUsedSlot > kSpillSlots)
stats->regAllocErrors++;
if (regs.maxUsedSlot > stats->maxSpillSlotsUsed)
stats->maxSpillSlotsUsed = regs.maxUsedSlot;
}
} }
bool IrLoweringX64::hasError() const bool IrLoweringX64::hasError() const
@ -1598,50 +1667,81 @@ bool IrLoweringX64::hasError() const
return false; return false;
} }
bool IrLoweringX64::isFallthroughBlock(IrBlock target, IrBlock next) bool IrLoweringX64::isFallthroughBlock(const IrBlock& target, const IrBlock& next)
{ {
return target.start == next.start; return target.start == next.start;
} }
void IrLoweringX64::jumpOrFallthrough(IrBlock& target, IrBlock& next) Label& IrLoweringX64::getTargetLabel(IrOp op, Label& fresh)
{
if (op.kind == IrOpKind::Undef)
return fresh;
if (op.kind == IrOpKind::VmExit)
{
// Special exit case that doesn't have to update pcpos
if (vmExitOp(op) == kVmExitEntryGuardPc)
return helpers.exitContinueVmClearNativeFlag;
if (uint32_t* index = exitHandlerMap.find(vmExitOp(op)))
return exitHandlers[*index].self;
return fresh;
}
return labelOp(op);
}
void IrLoweringX64::finalizeTargetLabel(IrOp op, Label& fresh)
{
if (op.kind == IrOpKind::VmExit && fresh.id != 0 && fresh.id != helpers.exitContinueVmClearNativeFlag.id)
{
exitHandlerMap[vmExitOp(op)] = uint32_t(exitHandlers.size());
exitHandlers.push_back({fresh, vmExitOp(op)});
}
}
void IrLoweringX64::jumpOrFallthrough(IrBlock& target, const IrBlock& next)
{ {
if (!isFallthroughBlock(target, next)) if (!isFallthroughBlock(target, next))
build.jmp(target.label); build.jmp(target.label);
} }
void IrLoweringX64::jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef, bool continueInVm) void IrLoweringX64::jumpOrAbortOnUndef(ConditionX64 cond, IrOp target, const IrBlock& next)
{ {
if (targetOrUndef.kind == IrOpKind::Undef) Label fresh;
Label& label = getTargetLabel(target, fresh);
if (target.kind == IrOpKind::Undef)
{ {
if (continueInVm) if (cond == ConditionX64::Count)
{ {
build.jcc(cond, helpers.exitContinueVmClearNativeFlag); build.ud2(); // Unconditional jump to abort is just an abort
return; }
else
{
build.jcc(getReverseCondition(cond), label);
build.ud2();
build.setLabel(label);
}
}
else if (cond == ConditionX64::Count)
{
// Unconditional jump can be skipped if it's a fallthrough
if (target.kind == IrOpKind::VmExit || !isFallthroughBlock(blockOp(target), next))
build.jmp(label);
}
else
{
build.jcc(cond, label);
} }
Label skip; finalizeTargetLabel(target, fresh);
build.jcc(condInverse, skip);
build.ud2();
build.setLabel(skip);
} }
else if (targetOrUndef.kind == IrOpKind::VmExit)
void IrLoweringX64::jumpOrAbortOnUndef(IrOp target, const IrBlock& next)
{ {
if (uint32_t* index = exitHandlerMap.find(targetOrUndef.index)) jumpOrAbortOnUndef(ConditionX64::Count, target, next);
{
build.jcc(cond, exitHandlers[*index].self);
}
else
{
Label self;
build.jcc(cond, self);
exitHandlerMap[targetOrUndef.index] = uint32_t(exitHandlers.size());
exitHandlers.push_back({self, targetOrUndef.index});
}
}
else
{
build.jcc(cond, labelOp(targetOrUndef));
}
} }
OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op)

View file

@ -19,23 +19,29 @@ namespace CodeGen
struct ModuleHelpers; struct ModuleHelpers;
struct AssemblyOptions; struct AssemblyOptions;
struct LoweringStats;
namespace X64 namespace X64
{ {
struct IrLoweringX64 struct IrLoweringX64
{ {
IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function); IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats);
void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); void lowerInst(IrInst& inst, uint32_t index, const IrBlock& next);
void finishBlock(); void finishBlock(const IrBlock& curr, const IrBlock& next);
void finishFunction(); void finishFunction();
bool hasError() const; bool hasError() const;
bool isFallthroughBlock(IrBlock target, IrBlock next); bool isFallthroughBlock(const IrBlock& target, const IrBlock& next);
void jumpOrFallthrough(IrBlock& target, IrBlock& next); void jumpOrFallthrough(IrBlock& target, const IrBlock& next);
void jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef, bool continueInVm = false);
Label& getTargetLabel(IrOp op, Label& fresh);
void finalizeTargetLabel(IrOp op, Label& fresh);
void jumpOrAbortOnUndef(ConditionX64 cond, IrOp target, const IrBlock& next);
void jumpOrAbortOnUndef(IrOp target, const IrBlock& next);
void storeDoubleAsFloat(OperandX64 dst, IrOp src); void storeDoubleAsFloat(OperandX64 dst, IrOp src);
@ -71,6 +77,7 @@ struct IrLoweringX64
ModuleHelpers& helpers; ModuleHelpers& helpers;
IrFunction& function; IrFunction& function;
LoweringStats* stats = nullptr;
IrRegAllocX64 regs; IrRegAllocX64 regs;

View file

@ -2,6 +2,7 @@
#include "IrRegAllocA64.h" #include "IrRegAllocA64.h"
#include "Luau/AssemblyBuilderA64.h" #include "Luau/AssemblyBuilderA64.h"
#include "Luau/CodeGen.h"
#include "Luau/IrUtils.h" #include "Luau/IrUtils.h"
#include "BitUtils.h" #include "BitUtils.h"
@ -70,9 +71,9 @@ static int getReloadOffset(IrCmd cmd)
LUAU_UNREACHABLE(); LUAU_UNREACHABLE();
} }
static AddressA64 getReloadAddress(const IrFunction& function, const IrInst& inst) static AddressA64 getReloadAddress(const IrFunction& function, const IrInst& inst, bool limitToCurrentBlock)
{ {
IrOp location = function.findRestoreOp(inst); IrOp location = function.findRestoreOp(inst, limitToCurrentBlock);
if (location.kind == IrOpKind::VmReg) if (location.kind == IrOpKind::VmReg)
return mem(rBase, vmRegOp(location) * sizeof(TValue) + getReloadOffset(inst.cmd)); return mem(rBase, vmRegOp(location) * sizeof(TValue) + getReloadOffset(inst.cmd));
@ -99,7 +100,7 @@ static void restoreInst(AssemblyBuilderA64& build, uint32_t& freeSpillSlots, IrF
else else
{ {
LUAU_ASSERT(!inst.spilled && inst.needsReload); LUAU_ASSERT(!inst.spilled && inst.needsReload);
AddressA64 addr = getReloadAddress(function, function.instructions[s.inst]); AddressA64 addr = getReloadAddress(function, function.instructions[s.inst], /*limitToCurrentBlock*/ false);
LUAU_ASSERT(addr.base != xzr); LUAU_ASSERT(addr.base != xzr);
build.ldr(reg, addr); build.ldr(reg, addr);
} }
@ -109,8 +110,9 @@ static void restoreInst(AssemblyBuilderA64& build, uint32_t& freeSpillSlots, IrF
inst.regA64 = reg; inst.regA64 = reg;
} }
IrRegAllocA64::IrRegAllocA64(IrFunction& function, std::initializer_list<std::pair<RegisterA64, RegisterA64>> regs) IrRegAllocA64::IrRegAllocA64(IrFunction& function, LoweringStats* stats, std::initializer_list<std::pair<RegisterA64, RegisterA64>> regs)
: function(function) : function(function)
, stats(stats)
{ {
for (auto& p : regs) for (auto& p : regs)
{ {
@ -321,7 +323,7 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init
{ {
// instead of spilling the register to never reload it, we assume the register is not needed anymore // instead of spilling the register to never reload it, we assume the register is not needed anymore
} }
else if (getReloadAddress(function, def).base != xzr) else if (getReloadAddress(function, def, /*limitToCurrentBlock*/ true).base != xzr)
{ {
// instead of spilling the register to stack, we can reload it from VM stack/constants // instead of spilling the register to stack, we can reload it from VM stack/constants
// we still need to record the spill for restore(start) to work // we still need to record the spill for restore(start) to work
@ -329,6 +331,9 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init
spills.push_back(s); spills.push_back(s);
def.needsReload = true; def.needsReload = true;
if (stats)
stats->spillsToRestore++;
} }
else else
{ {
@ -345,6 +350,14 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init
spills.push_back(s); spills.push_back(s);
def.spilled = true; def.spilled = true;
if (stats)
{
stats->spillsToSlot++;
if (slot != kInvalidSpill && unsigned(slot + 1) > stats->maxSpillSlotsUsed)
stats->maxSpillSlotsUsed = slot + 1;
}
} }
def.regA64 = noreg; def.regA64 = noreg;
@ -411,11 +424,6 @@ void IrRegAllocA64::restoreReg(AssemblyBuilderA64& build, IrInst& inst)
LUAU_ASSERT(!"Expected to find a spill record"); LUAU_ASSERT(!"Expected to find a spill record");
} }
void IrRegAllocA64::assertNoSpills() const
{
LUAU_ASSERT(spills.empty());
}
IrRegAllocA64::Set& IrRegAllocA64::getSet(KindA64 kind) IrRegAllocA64::Set& IrRegAllocA64::getSet(KindA64 kind)
{ {
switch (kind) switch (kind)

View file

@ -12,6 +12,9 @@ namespace Luau
{ {
namespace CodeGen namespace CodeGen
{ {
struct LoweringStats;
namespace A64 namespace A64
{ {
@ -19,7 +22,7 @@ class AssemblyBuilderA64;
struct IrRegAllocA64 struct IrRegAllocA64
{ {
IrRegAllocA64(IrFunction& function, std::initializer_list<std::pair<RegisterA64, RegisterA64>> regs); IrRegAllocA64(IrFunction& function, LoweringStats* stats, std::initializer_list<std::pair<RegisterA64, RegisterA64>> regs);
RegisterA64 allocReg(KindA64 kind, uint32_t index); RegisterA64 allocReg(KindA64 kind, uint32_t index);
RegisterA64 allocTemp(KindA64 kind); RegisterA64 allocTemp(KindA64 kind);
@ -43,8 +46,6 @@ struct IrRegAllocA64
// Restores register for a single instruction; may not assign the previously used register! // Restores register for a single instruction; may not assign the previously used register!
void restoreReg(AssemblyBuilderA64& build, IrInst& inst); void restoreReg(AssemblyBuilderA64& build, IrInst& inst);
void assertNoSpills() const;
struct Set struct Set
{ {
// which registers are in the set that the allocator manages (initialized at construction) // which registers are in the set that the allocator manages (initialized at construction)
@ -71,6 +72,7 @@ struct IrRegAllocA64
Set& getSet(KindA64 kind); Set& getSet(KindA64 kind);
IrFunction& function; IrFunction& function;
LoweringStats* stats = nullptr;
Set gpr, simd; Set gpr, simd;
std::vector<Spill> spills; std::vector<Spill> spills;

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/IrRegAllocX64.h" #include "Luau/IrRegAllocX64.h"
#include "Luau/CodeGen.h"
#include "Luau/IrUtils.h" #include "Luau/IrUtils.h"
#include "EmitCommonX64.h" #include "EmitCommonX64.h"
@ -14,9 +15,11 @@ namespace X64
static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11}; static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11};
IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function) IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function, LoweringStats* stats)
: build(build) : build(build)
, function(function) , function(function)
, stats(stats)
, usableXmmRegCount(getXmmRegisterCount(build.abi))
{ {
freeGprMap.fill(true); freeGprMap.fill(true);
gprInstUsers.fill(kInvalidInstIdx); gprInstUsers.fill(kInvalidInstIdx);
@ -28,7 +31,7 @@ RegisterX64 IrRegAllocX64::allocReg(SizeX64 size, uint32_t instIdx)
{ {
if (size == SizeX64::xmmword) if (size == SizeX64::xmmword)
{ {
for (size_t i = 0; i < freeXmmMap.size(); ++i) for (size_t i = 0; i < usableXmmRegCount; ++i)
{ {
if (freeXmmMap[i]) if (freeXmmMap[i])
{ {
@ -54,7 +57,12 @@ RegisterX64 IrRegAllocX64::allocReg(SizeX64 size, uint32_t instIdx)
// Out of registers, spill the value with the furthest next use // Out of registers, spill the value with the furthest next use
const std::array<uint32_t, 16>& regInstUsers = size == SizeX64::xmmword ? xmmInstUsers : gprInstUsers; const std::array<uint32_t, 16>& regInstUsers = size == SizeX64::xmmword ? xmmInstUsers : gprInstUsers;
if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(regInstUsers); furthestUseTarget != kInvalidInstIdx) if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(regInstUsers); furthestUseTarget != kInvalidInstIdx)
return takeReg(function.instructions[furthestUseTarget].regX64, instIdx); {
RegisterX64 reg = function.instructions[furthestUseTarget].regX64;
reg.size = size; // Adjust size to the requested
return takeReg(reg, instIdx);
}
LUAU_ASSERT(!"Out of registers to allocate"); LUAU_ASSERT(!"Out of registers to allocate");
return noreg; return noreg;
@ -219,10 +227,16 @@ void IrRegAllocX64::preserve(IrInst& inst)
spill.stackSlot = uint8_t(i); spill.stackSlot = uint8_t(i);
inst.spilled = true; inst.spilled = true;
if (stats)
stats->spillsToSlot++;
} }
else else
{ {
inst.needsReload = true; inst.needsReload = true;
if (stats)
stats->spillsToRestore++;
} }
spills.push_back(spill); spills.push_back(spill);
@ -332,7 +346,9 @@ unsigned IrRegAllocX64::findSpillStackSlot(IrValueKind valueKind)
IrOp IrRegAllocX64::getRestoreOp(const IrInst& inst) const IrOp IrRegAllocX64::getRestoreOp(const IrInst& inst) const
{ {
if (IrOp location = function.findRestoreOp(inst); location.kind == IrOpKind::VmReg || location.kind == IrOpKind::VmConst) // When restoring the value, we allow cross-block restore because we have commited to the target location at spill time
if (IrOp location = function.findRestoreOp(inst, /*limitToCurrentBlock*/ false);
location.kind == IrOpKind::VmReg || location.kind == IrOpKind::VmConst)
return location; return location;
return IrOp(); return IrOp();
@ -340,11 +356,16 @@ IrOp IrRegAllocX64::getRestoreOp(const IrInst& inst) const
bool IrRegAllocX64::hasRestoreOp(const IrInst& inst) const bool IrRegAllocX64::hasRestoreOp(const IrInst& inst) const
{ {
return getRestoreOp(inst).kind != IrOpKind::None; // When checking if value has a restore operation to spill it, we only allow it in the same block
IrOp location = function.findRestoreOp(inst, /*limitToCurrentBlock*/ true);
return location.kind == IrOpKind::VmReg || location.kind == IrOpKind::VmConst;
} }
OperandX64 IrRegAllocX64::getRestoreAddress(const IrInst& inst, IrOp restoreOp) OperandX64 IrRegAllocX64::getRestoreAddress(const IrInst& inst, IrOp restoreOp)
{ {
LUAU_ASSERT(restoreOp.kind != IrOpKind::None);
switch (getCmdValueKind(inst.cmd)) switch (getCmdValueKind(inst.cmd))
{ {
case IrValueKind::Unknown: case IrValueKind::Unknown:

View file

@ -411,7 +411,7 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp(
IrOp falsey = build.block(IrBlockKind::Internal); IrOp falsey = build.block(IrBlockKind::Internal);
IrOp truthy = build.block(IrBlockKind::Internal); IrOp truthy = build.block(IrBlockKind::Internal);
IrOp exit = build.block(IrBlockKind::Internal); IrOp exit = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_EQ_INT, res, build.constInt(0), falsey, truthy); build.inst(IrCmd::JUMP_CMP_INT, res, build.constInt(0), build.cond(IrCondition::Equal), falsey, truthy);
build.beginBlock(falsey); build.beginBlock(falsey);
build.inst(IrCmd::STORE_INT, build.vmReg(ra), build.constInt(0)); build.inst(IrCmd::STORE_INT, build.vmReg(ra), build.constInt(0));
@ -484,7 +484,7 @@ static BuiltinImplResult translateBuiltinBit32Shift(
if (!knownGoodShift) if (!knownGoodShift)
{ {
IrOp block = build.block(IrBlockKind::Internal); IrOp block = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_GE_UINT, vbi, build.constInt(32), fallback, block); build.inst(IrCmd::JUMP_CMP_INT, vbi, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block);
build.beginBlock(block); build.beginBlock(block);
} }
@ -549,36 +549,56 @@ static BuiltinImplResult translateBuiltinBit32Extract(
IrOp vb = builtinLoadDouble(build, args); IrOp vb = builtinLoadDouble(build, args);
IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va);
IrOp f = build.inst(IrCmd::NUM_TO_INT, vb);
IrOp value; IrOp value;
if (nparams == 2) if (nparams == 2)
{ {
IrOp block = build.block(IrBlockKind::Internal); if (vb.kind == IrOpKind::Constant)
build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); {
build.beginBlock(block); int f = int(build.function.doubleOp(vb));
// TODO: this can be optimized using a bit-select instruction (bt on x86) if (unsigned(f) >= 32)
IrOp shift = build.inst(IrCmd::BITRSHIFT_UINT, n, f); build.inst(IrCmd::JUMP, fallback);
value = build.inst(IrCmd::BITAND_UINT, shift, build.constInt(1));
// TODO: this pair can be optimized using a bit-select instruction (bt on x86)
if (f)
value = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constInt(f));
if ((f + 1) < 32)
value = build.inst(IrCmd::BITAND_UINT, value, build.constInt(1));
} }
else else
{ {
IrOp f = build.inst(IrCmd::NUM_TO_INT, vb);
IrOp block = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block);
build.beginBlock(block);
// TODO: this pair can be optimized using a bit-select instruction (bt on x86)
IrOp shift = build.inst(IrCmd::BITRSHIFT_UINT, n, f);
value = build.inst(IrCmd::BITAND_UINT, shift, build.constInt(1));
}
}
else
{
IrOp f = build.inst(IrCmd::NUM_TO_INT, vb);
builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos); builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos);
IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1));
IrOp w = build.inst(IrCmd::NUM_TO_INT, vc); IrOp w = build.inst(IrCmd::NUM_TO_INT, vc);
IrOp block1 = build.block(IrBlockKind::Internal); IrOp block1 = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_LT_INT, f, build.constInt(0), fallback, block1); build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(0), build.cond(IrCondition::Less), fallback, block1);
build.beginBlock(block1); build.beginBlock(block1);
IrOp block2 = build.block(IrBlockKind::Internal); IrOp block2 = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_LT_INT, w, build.constInt(1), fallback, block2); build.inst(IrCmd::JUMP_CMP_INT, w, build.constInt(1), build.cond(IrCondition::Less), fallback, block2);
build.beginBlock(block2); build.beginBlock(block2);
IrOp block3 = build.block(IrBlockKind::Internal); IrOp block3 = build.block(IrBlockKind::Internal);
IrOp fw = build.inst(IrCmd::ADD_INT, f, w); IrOp fw = build.inst(IrCmd::ADD_INT, f, w);
build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); build.inst(IrCmd::JUMP_CMP_INT, fw, build.constInt(33), build.cond(IrCondition::Less), block3, fallback);
build.beginBlock(block3); build.beginBlock(block3);
IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1)));
@ -615,10 +635,15 @@ static BuiltinImplResult translateBuiltinBit32ExtractK(
uint32_t m = ~(0xfffffffeu << w1); uint32_t m = ~(0xfffffffeu << w1);
IrOp nf = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constInt(f)); IrOp result = n;
IrOp and_ = build.inst(IrCmd::BITAND_UINT, nf, build.constInt(m));
IrOp value = build.inst(IrCmd::UINT_TO_NUM, and_); if (f)
result = build.inst(IrCmd::BITRSHIFT_UINT, result, build.constInt(f));
if ((f + w1 + 1) < 32)
result = build.inst(IrCmd::BITAND_UINT, result, build.constInt(m));
IrOp value = build.inst(IrCmd::UINT_TO_NUM, result);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value);
if (ra != arg) if (ra != arg)
@ -673,7 +698,7 @@ static BuiltinImplResult translateBuiltinBit32Replace(
if (nparams == 3) if (nparams == 3)
{ {
IrOp block = build.block(IrBlockKind::Internal); IrOp block = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block);
build.beginBlock(block); build.beginBlock(block);
// TODO: this can be optimized using a bit-select instruction (btr on x86) // TODO: this can be optimized using a bit-select instruction (btr on x86)
@ -694,16 +719,16 @@ static BuiltinImplResult translateBuiltinBit32Replace(
IrOp w = build.inst(IrCmd::NUM_TO_INT, vd); IrOp w = build.inst(IrCmd::NUM_TO_INT, vd);
IrOp block1 = build.block(IrBlockKind::Internal); IrOp block1 = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_LT_INT, f, build.constInt(0), fallback, block1); build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(0), build.cond(IrCondition::Less), fallback, block1);
build.beginBlock(block1); build.beginBlock(block1);
IrOp block2 = build.block(IrBlockKind::Internal); IrOp block2 = build.block(IrBlockKind::Internal);
build.inst(IrCmd::JUMP_LT_INT, w, build.constInt(1), fallback, block2); build.inst(IrCmd::JUMP_CMP_INT, w, build.constInt(1), build.cond(IrCondition::Less), fallback, block2);
build.beginBlock(block2); build.beginBlock(block2);
IrOp block3 = build.block(IrBlockKind::Internal); IrOp block3 = build.block(IrBlockKind::Internal);
IrOp fw = build.inst(IrCmd::ADD_INT, f, w); IrOp fw = build.inst(IrCmd::ADD_INT, f, w);
build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); build.inst(IrCmd::JUMP_CMP_INT, fw, build.constInt(33), build.cond(IrCondition::Less), block3, fallback);
build.beginBlock(block3); build.beginBlock(block3);
IrOp shift1 = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); IrOp shift1 = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1)));
@ -748,6 +773,28 @@ static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, i
return {BuiltinImplType::Full, 1}; return {BuiltinImplType::Full, 1};
} }
static BuiltinImplResult translateBuiltinTableInsert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
{
if (nparams != 2 || nresults > 0)
return {BuiltinImplType::None, -1};
build.loadAndCheckTag(build.vmReg(arg), LUA_TTABLE, build.vmExit(pcpos));
IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(arg));
build.inst(IrCmd::CHECK_READONLY, table, build.vmExit(pcpos));
IrOp pos = build.inst(IrCmd::ADD_INT, build.inst(IrCmd::TABLE_LEN, table), build.constInt(1));
IrOp setnum = build.inst(IrCmd::TABLE_SETNUM, table, pos);
IrOp va = build.inst(IrCmd::LOAD_TVALUE, args);
build.inst(IrCmd::STORE_TVALUE, setnum, va);
build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, args, build.undef());
return {BuiltinImplType::Full, 0};
}
static BuiltinImplResult translateBuiltinStringLen(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) static BuiltinImplResult translateBuiltinStringLen(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos)
{ {
if (nparams < 1 || nresults > 1) if (nparams < 1 || nresults > 1)
@ -849,6 +896,8 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg,
return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults); return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults);
case LBF_VECTOR: case LBF_VECTOR:
return translateBuiltinVector(build, nparams, ra, arg, args, nresults, pcpos); return translateBuiltinVector(build, nparams, ra, arg, args, nresults, pcpos);
case LBF_TABLE_INSERT:
return translateBuiltinTableInsert(build, nparams, ra, arg, args, nresults, pcpos);
case LBF_STRING_LEN: case LBF_STRING_LEN:
return translateBuiltinStringLen(build, nparams, ra, arg, args, nresults, pcpos); return translateBuiltinStringLen(build, nparams, ra, arg, args, nresults, pcpos);
default: default:

Some files were not shown because too many files have changed in this diff Show more