Sync to upstream/release/565 (#845)

We've made a few small changes to reduce the amount of stack we use when
typechecking nested method calls (eg `foo:bar():baz():quux()`).

We've also fixed a small bytecode compiler issue that caused us to emit
redundant jump instructions in code that conditionally uses `break` or
`continue`.

On the new solver, we've switched to a new, better way to handle
augmentations to unsealed tables. We've also made some substantial
improvements to type inference and error reporting on function calls.
These things should both be on par with the old solver now.

The main improvements to the native code generator have been elimination
of some redundant type tag checks. Also, we are starting to inline
particular fastcalls directly to IR.

---------

Co-authored-by: Arseny Kapoulkine <arseny.kapoulkine@gmail.com>
Co-authored-by: Vyacheslav Egorov <vegorov@roblox.com>
This commit is contained in:
Andy Friesen 2023-02-24 13:49:38 -08:00 committed by GitHub
parent fbd93cbc03
commit d2ab5df62b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
63 changed files with 3501 additions and 2430 deletions

View file

@ -159,6 +159,20 @@ struct SetPropConstraint
TypeId propType;
};
// result ~ setIndexer subjectType indexType propType
//
// If the subject is a table or table-like thing that already has an indexer,
// unify its indexType and propType with those from this constraint.
//
// If the table is a free or unsealed table, we augment it with a new indexer.
struct SetIndexerConstraint
{
TypeId resultType;
TypeId subjectType;
TypeId indexType;
TypeId propType;
};
// if negation:
// result ~ if isSingleton D then ~D else unknown where D = discriminantType
// if not negation:
@ -170,9 +184,19 @@ struct SingletonOrTopTypeConstraint
bool negated;
};
// resultType ~ unpack sourceTypePack
//
// Similar to PackSubtypeConstraint, but with one important difference: If the
// sourcePack is blocked, this constraint blocks.
struct UnpackConstraint
{
TypePackId resultPack;
TypePackId sourcePack;
};
using ConstraintV = Variant<SubtypeConstraint, PackSubtypeConstraint, GeneralizationConstraint, InstantiationConstraint, UnaryConstraint,
BinaryConstraint, IterableConstraint, NameConstraint, TypeAliasExpansionConstraint, FunctionCallConstraint, PrimitiveTypeConstraint,
HasPropConstraint, SetPropConstraint, SingletonOrTopTypeConstraint>;
HasPropConstraint, SetPropConstraint, SetIndexerConstraint, SingletonOrTopTypeConstraint, UnpackConstraint>;
struct Constraint
{

View file

@ -191,7 +191,7 @@ struct ConstraintGraphBuilder
Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional<TypeId> expectedType);
std::tuple<TypeId, TypeId, RefinementId> checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional<TypeId> expectedType);
TypePackId checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);
std::vector<TypeId> checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs);
TypeId checkLValue(const ScopePtr& scope, AstExpr* expr);
@ -244,10 +244,31 @@ struct ConstraintGraphBuilder
**/
TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments);
/**
* Creates generic types given a list of AST definitions, resolving default
* types as required.
* @param scope the scope that the generics should belong to.
* @param generics the AST generics to create types for.
* @param useCache whether to use the generic type cache for the given
* scope.
* @param addTypes whether to add the types to the scope's
* privateTypeBindings map.
**/
std::vector<std::pair<Name, GenericTypeDefinition>> createGenerics(
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache = false);
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache = false, bool addTypes = true);
/**
* Creates generic type packs given a list of AST definitions, resolving
* default type packs as required.
* @param scope the scope that the generic packs should belong to.
* @param generics the AST generics to create type packs for.
* @param useCache whether to use the generic type pack cache for the given
* scope.
* @param addTypes whether to add the types to the scope's
* privateTypePackBindings map.
**/
std::vector<std::pair<Name, GenericTypePackDefinition>> createGenericPacks(
const ScopePtr& scope, AstArray<AstGenericTypePack> packs, bool useCache = false);
const ScopePtr& scope, AstArray<AstGenericTypePack> packs, bool useCache = false, bool addTypes = true);
Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack);

View file

@ -8,6 +8,7 @@
#include "Luau/Normalize.h"
#include "Luau/ToString.h"
#include "Luau/Type.h"
#include "Luau/TypeReduction.h"
#include "Luau/Variant.h"
#include <vector>
@ -19,7 +20,12 @@ struct DcrLogger;
// TypeId, TypePackId, or Constraint*. It is impossible to know which, but we
// never dereference this pointer.
using BlockedConstraintId = const void*;
using BlockedConstraintId = Variant<TypeId, TypePackId, const Constraint*>;
struct HashBlockedConstraintId
{
size_t operator()(const BlockedConstraintId& bci) const;
};
struct ModuleResolver;
@ -47,6 +53,7 @@ struct ConstraintSolver
NotNull<BuiltinTypes> builtinTypes;
InternalErrorReporter iceReporter;
NotNull<Normalizer> normalizer;
NotNull<TypeReduction> reducer;
// The entire set of constraints that the solver is trying to resolve.
std::vector<NotNull<Constraint>> constraints;
NotNull<Scope> rootScope;
@ -65,7 +72,7 @@ struct ConstraintSolver
// anything.
std::unordered_map<NotNull<const Constraint>, size_t> blockedConstraints;
// A mapping of type/pack pointers to the constraints they block.
std::unordered_map<BlockedConstraintId, std::vector<NotNull<const Constraint>>> blocked;
std::unordered_map<BlockedConstraintId, std::vector<NotNull<const Constraint>>, HashBlockedConstraintId> blocked;
// Memoized instantiations of type aliases.
DenseHashMap<InstantiationSignature, TypeId, HashInstantiationSignature> instantiatedAliases{{}};
@ -78,7 +85,8 @@ struct ConstraintSolver
DcrLogger* logger;
explicit ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger);
ModuleName moduleName, NotNull<TypeReduction> reducer, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles,
DcrLogger* logger);
// Randomize the order in which to dispatch constraints
void randomize(unsigned seed);
@ -112,7 +120,9 @@ struct ConstraintSolver
bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const HasPropConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const SetPropConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const SetIndexerConstraint& c, NotNull<const Constraint> constraint, bool force);
bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint);
bool tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint);
// for a, ... in some_table do
// also handles __iter metamethod
@ -123,6 +133,7 @@ struct ConstraintSolver
TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force);
std::optional<TypeId> lookupTableProp(TypeId subjectType, const std::string& propName);
std::optional<TypeId> lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set<TypeId>& seen);
void block(NotNull<const Constraint> target, NotNull<const Constraint> constraint);
/**

View file

@ -4,6 +4,7 @@
#include "Luau/Constraint.h"
#include "Luau/NotNull.h"
#include "Luau/Scope.h"
#include "Luau/Module.h"
#include "Luau/ToString.h"
#include "Luau/Error.h"
#include "Luau/Variant.h"
@ -34,11 +35,26 @@ struct TypeBindingSnapshot
std::string typeString;
};
struct ExprTypesAtLocation
{
Location location;
TypeId ty;
std::optional<TypeId> expectedTy;
};
struct AnnotationTypesAtLocation
{
Location location;
TypeId resolvedTy;
};
struct ConstraintGenerationLog
{
std::string source;
std::unordered_map<std::string, Location> constraintLocations;
std::vector<ErrorSnapshot> errors;
std::vector<ExprTypesAtLocation> exprTypeLocations;
std::vector<AnnotationTypesAtLocation> annotationTypeLocations;
};
struct ScopeSnapshot
@ -49,16 +65,11 @@ struct ScopeSnapshot
std::vector<ScopeSnapshot> children;
};
enum class ConstraintBlockKind
{
TypeId,
TypePackId,
ConstraintId,
};
using ConstraintBlockTarget = Variant<TypeId, TypePackId, NotNull<const Constraint>>;
struct ConstraintBlock
{
ConstraintBlockKind kind;
ConstraintBlockTarget target;
std::string stringification;
};
@ -71,16 +82,18 @@ struct ConstraintSnapshot
struct BoundarySnapshot
{
std::unordered_map<std::string, ConstraintSnapshot> constraints;
DenseHashMap<const Constraint*, ConstraintSnapshot> unsolvedConstraints{nullptr};
ScopeSnapshot rootScope;
DenseHashMap<const void*, std::string> typeStrings{nullptr};
};
struct StepSnapshot
{
std::string currentConstraint;
const Constraint* currentConstraint;
bool forced;
std::unordered_map<std::string, ConstraintSnapshot> unsolvedConstraints;
DenseHashMap<const Constraint*, ConstraintSnapshot> unsolvedConstraints{nullptr};
ScopeSnapshot rootScope;
DenseHashMap<const void*, std::string> typeStrings{nullptr};
};
struct TypeSolveLog
@ -95,8 +108,6 @@ struct TypeCheckLog
std::vector<ErrorSnapshot> errors;
};
using ConstraintBlockTarget = Variant<TypeId, TypePackId, NotNull<const Constraint>>;
struct DcrLogger
{
std::string compileOutput();
@ -104,6 +115,7 @@ struct DcrLogger
void captureSource(std::string source);
void captureGenerationError(const TypeError& error);
void captureConstraintLocation(NotNull<const Constraint> constraint, Location location);
void captureGenerationModule(const ModulePtr& module);
void pushBlock(NotNull<const Constraint> constraint, TypeId block);
void pushBlock(NotNull<const Constraint> constraint, TypePackId block);
@ -126,9 +138,10 @@ private:
TypeSolveLog solveLog;
TypeCheckLog checkLog;
ToStringOptions opts;
ToStringOptions opts{true};
std::vector<ConstraintBlock> snapshotBlocks(NotNull<const Constraint> constraint);
void captureBoundaryState(BoundarySnapshot& target, const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints);
};
} // namespace Luau

View file

@ -52,7 +52,7 @@ struct Scope
std::optional<TypeId> lookup(Symbol sym) const;
std::optional<TypeId> lookup(DefId def) const;
std::optional<std::pair<TypeId, Scope*>> lookupEx(Symbol sym);
std::optional<std::pair<Binding*, Scope*>> lookupEx(Symbol sym);
std::optional<TypeFun> lookupType(const Name& name);
std::optional<TypeFun> lookupImportedType(const Name& moduleAlias, const Name& name);

View file

@ -37,6 +37,11 @@ struct Symbol
AstLocal* local;
AstName global;
explicit operator bool() const
{
return local != nullptr || global.value != nullptr;
}
bool operator==(const Symbol& rhs) const
{
if (local)

View file

@ -246,6 +246,18 @@ struct WithPredicate
{
T type;
PredicateVec predicates;
WithPredicate() = default;
explicit WithPredicate(T type)
: type(type)
{
}
WithPredicate(T type, PredicateVec predicates)
: type(type)
, predicates(std::move(predicates))
{
}
};
using MagicFunction = std::function<std::optional<WithPredicate<TypePackId>>(
@ -853,4 +865,15 @@ bool hasTag(TypeId ty, const std::string& tagName);
bool hasTag(const Property& prop, const std::string& tagName);
bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work.
/*
* Use this to change the kind of a particular type.
*
* LUAU_NOINLINE so that the calling frame doesn't have to pay the stack storage for the new variant.
*/
template<typename T, typename... Args>
LUAU_NOINLINE T* emplaceType(Type* ty, Args&&... args)
{
return &ty->ty.emplace<T>(std::forward<Args>(args)...);
}
} // namespace Luau

View file

@ -146,10 +146,12 @@ struct TypeChecker
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr);
WithPredicate<TypePackId> checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr);
WithPredicate<TypePackId> checkExprPackHelper2(
const ScopePtr& scope, const AstExprCall& expr, TypeId selfType, TypeId actualFunctionType, TypeId functionType, TypePackId retPack);
std::vector<std::optional<TypeId>> getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall);
std::optional<WithPredicate<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
std::unique_ptr<WithPredicate<TypePackId>> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors);
bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations,

View file

@ -12,11 +12,36 @@ namespace Luau
namespace detail
{
template<typename T>
struct ReductionContext
struct ReductionEdge
{
T type = nullptr;
bool irreducible = false;
};
struct TypeReductionMemoization
{
TypeReductionMemoization() = default;
TypeReductionMemoization(const TypeReductionMemoization&) = delete;
TypeReductionMemoization& operator=(const TypeReductionMemoization&) = delete;
TypeReductionMemoization(TypeReductionMemoization&&) = default;
TypeReductionMemoization& operator=(TypeReductionMemoization&&) = default;
DenseHashMap<TypeId, ReductionEdge<TypeId>> types{nullptr};
DenseHashMap<TypePackId, ReductionEdge<TypePackId>> typePacks{nullptr};
bool isIrreducible(TypeId ty);
bool isIrreducible(TypePackId tp);
TypeId memoize(TypeId ty, TypeId reducedTy);
TypePackId memoize(TypePackId tp, TypePackId reducedTp);
// Reducing A into B may have a non-irreducible edge A to B for which B is not irreducible, which means B could be reduced into C.
// Because reduction should always be transitive, A should point to C if A points to B and B points to C.
std::optional<ReductionEdge<TypeId>> memoizedof(TypeId ty) const;
std::optional<ReductionEdge<TypePackId>> memoizedof(TypePackId tp) const;
};
} // namespace detail
struct TypeReductionOptions
@ -42,29 +67,19 @@ struct TypeReduction
std::optional<TypePackId> reduce(TypePackId tp);
std::optional<TypeFun> reduce(const TypeFun& fun);
/// Creating a child TypeReduction will allow the parent TypeReduction to share its memoization with the child TypeReductions.
/// This is safe as long as the parent's TypeArena continues to outlive both TypeReduction memoization.
TypeReduction fork(NotNull<TypeArena> arena, const TypeReductionOptions& opts = {}) const;
private:
const TypeReduction* parent = nullptr;
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
NotNull<struct InternalErrorReporter> handle;
TypeReductionOptions options;
DenseHashMap<TypeId, detail::ReductionContext<TypeId>> memoizedTypes{nullptr};
DenseHashMap<TypePackId, detail::ReductionContext<TypePackId>> memoizedTypePacks{nullptr};
TypeReductionOptions options;
detail::TypeReductionMemoization memoization;
// Computes an *estimated length* of the cartesian product of the given type.
size_t cartesianProductSize(TypeId ty) const;
bool hasExceededCartesianProductLimit(TypeId ty) const;
bool hasExceededCartesianProductLimit(TypePackId tp) const;
std::optional<TypeId> memoizedof(TypeId ty) const;
std::optional<TypePackId> memoizedof(TypePackId tp) const;
};
} // namespace Luau

View file

@ -67,6 +67,12 @@ struct Unifier
UnifierSharedState& sharedState;
// When the Unifier is forced to unify two blocked types (or packs), they
// get added to these vectors. The ConstraintSolver can use this to know
// when it is safe to reattempt dispatching a constraint.
std::vector<TypeId> blockedTypes;
std::vector<TypePackId> blockedTypePacks;
Unifier(
NotNull<Normalizer> normalizer, Mode mode, NotNull<Scope> scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr);

View file

@ -320,6 +320,9 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block)
prepopulateGlobalScope(scope, block);
visitBlockWithoutChildScope(scope, block);
if (FFlag::DebugLuauLogSolverToJson)
logger->captureGenerationModule(module);
}
void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block)
@ -357,13 +360,11 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope,
for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true))
{
initialFun.typeParams.push_back(gen);
defnScope->privateTypeBindings[name] = TypeFun{gen.ty};
}
for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true))
{
initialFun.typePackParams.push_back(genPack);
defnScope->privateTypePackBindings[name] = genPack.tp;
}
if (alias->exported)
@ -503,13 +504,13 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local)
if (j - i < packTypes.head.size())
varTypes[j] = packTypes.head[j - i];
else
varTypes[j] = freshType(scope);
varTypes[j] = arena->addType(BlockedType{});
}
}
std::vector<TypeId> tailValues{varTypes.begin() + i, varTypes.end()};
TypePackId tailPack = arena->addTypePack(std::move(tailValues));
addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack});
addConstraint(scope, local->location, UnpackConstraint{tailPack, exprPack});
}
}
}
@ -686,6 +687,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct
Checkpoint start = checkpoint(this);
FunctionSignature sig = checkFunctionSignature(scope, function->func);
std::unordered_set<Constraint*> excludeList;
if (AstExprLocal* localName = function->name->as<AstExprLocal>())
{
std::optional<TypeId> existingFunctionTy = scope->lookup(localName->local);
@ -716,9 +719,20 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct
}
else if (AstExprIndexName* indexName = function->name->as<AstExprIndexName>())
{
Checkpoint check1 = checkpoint(this);
TypeId lvalueType = checkLValue(scope, indexName);
Checkpoint check2 = checkpoint(this);
forEachConstraint(check1, check2, this, [&excludeList](const ConstraintPtr& c) {
excludeList.insert(c.get());
});
// TODO figure out how to populate the location field of the table Property.
addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType});
if (get<FreeType>(lvalueType))
asMutable(lvalueType)->ty.emplace<BoundType>(generalizedType);
else
addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType});
}
else if (AstExprError* err = function->name->as<AstExprError>())
{
@ -735,8 +749,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct
std::unique_ptr<Constraint> c =
std::make_unique<Constraint>(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature});
forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) {
c->dependencies.push_back(NotNull{constraint.get()});
forEachConstraint(start, end, this, [&c, &excludeList](const ConstraintPtr& constraint) {
if (!excludeList.count(constraint.get()))
c->dependencies.push_back(NotNull{constraint.get()});
});
addConstraint(scope, std::move(c));
@ -763,16 +778,31 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block)
visitBlockWithoutChildScope(innerScope, block);
}
static void bindFreeType(TypeId a, TypeId b)
{
FreeType* af = getMutable<FreeType>(a);
FreeType* bf = getMutable<FreeType>(b);
LUAU_ASSERT(af || bf);
if (!bf)
asMutable(a)->ty.emplace<BoundType>(b);
else if (!af)
asMutable(b)->ty.emplace<BoundType>(a);
else if (subsumes(bf->scope, af->scope))
asMutable(a)->ty.emplace<BoundType>(b);
else if (subsumes(af->scope, bf->scope))
asMutable(b)->ty.emplace<BoundType>(a);
}
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign)
{
TypePackId varPackId = checkLValues(scope, assign->vars);
TypePack expectedPack = extendTypePack(*arena, builtinTypes, varPackId, assign->values.size);
std::vector<TypeId> varTypes = checkLValues(scope, assign->vars);
std::vector<std::optional<TypeId>> expectedTypes;
expectedTypes.reserve(expectedPack.head.size());
expectedTypes.reserve(varTypes.size());
for (TypeId ty : expectedPack.head)
for (TypeId ty : varTypes)
{
ty = follow(ty);
if (get<FreeType>(ty))
@ -781,9 +811,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign)
expectedTypes.push_back(ty);
}
TypePackId valuePack = checkPack(scope, assign->values, expectedTypes).tp;
TypePackId exprPack = checkPack(scope, assign->values, expectedTypes).tp;
TypePackId varPack = arena->addTypePack({varTypes});
addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId});
addConstraint(scope, assign->location, PackSubtypeConstraint{exprPack, varPack});
}
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign)
@ -865,11 +896,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia
asMutable(aliasTy)->ty.emplace<BoundType>(ty);
std::vector<TypeId> typeParams;
for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true))
for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true, /* addTypes */ false))
typeParams.push_back(tyParam.second.ty);
std::vector<TypePackId> typePackParams;
for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true))
for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false))
typePackParams.push_back(tpParam.second.tp);
addConstraint(scope, alias->type->location,
@ -1010,7 +1041,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction
for (auto& [name, generic] : generics)
{
genericTys.push_back(generic.ty);
scope->privateTypeBindings[name] = TypeFun{generic.ty};
}
std::vector<TypePackId> genericTps;
@ -1018,7 +1048,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction
for (auto& [name, generic] : genericPacks)
{
genericTps.push_back(generic.tp);
scope->privateTypePackBindings[name] = generic.tp;
}
ScopePtr funScope = scope;
@ -1161,7 +1190,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
TypePackId expectedArgPack = arena->freshTypePack(scope.get());
TypePackId expectedRetPack = arena->freshTypePack(scope.get());
TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack});
TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack, std::nullopt, call->self});
TypeId instantiatedFnType = arena->addType(BlockedType{});
addConstraint(scope, call->location, InstantiationConstraint{instantiatedFnType, fnType});
@ -1264,7 +1293,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa
// TODO: How do expectedTypes play into this? Do they?
TypePackId rets = arena->addTypePack(BlockedTypePack{});
TypePackId argPack = arena->addTypePack(TypePack{args, argTail});
FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets);
FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self);
NotNull<Constraint> fcc = addConstraint(scope, call->func->location,
FunctionCallConstraint{
@ -1457,7 +1486,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* gl
Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName)
{
TypeId obj = check(scope, indexName->expr).ty;
TypeId result = freshType(scope);
TypeId result = arena->addType(BlockedType{});
std::optional<DefId> def = dfg->getDef(indexName);
if (def)
@ -1468,13 +1497,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName*
scope->dcrRefinements[*def] = result;
}
TableType::Props props{{indexName->index.value, Property{result}}};
const std::optional<TableIndexer> indexer;
TableType ttv{std::move(props), indexer, TypeLevel{}, scope.get(), TableState::Free};
TypeId expectedTableType = arena->addType(std::move(ttv));
addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType});
addConstraint(scope, indexName->expr->location, HasPropConstraint{result, obj, indexName->index.value});
if (def)
return Inference{result, refinementArena.proposition(*def, builtinTypes->truthyType)};
@ -1589,6 +1612,8 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
else if (typeguard->type == "number")
discriminantTy = builtinTypes->numberType;
else if (typeguard->type == "boolean")
discriminantTy = builtinTypes->booleanType;
else if (typeguard->type == "thread")
discriminantTy = builtinTypes->threadType;
else if (typeguard->type == "table")
discriminantTy = builtinTypes->tableType;
@ -1596,8 +1621,8 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
discriminantTy = builtinTypes->functionType;
else if (typeguard->type == "userdata")
{
// For now, we don't really care about being accurate with userdata if the typeguard was using typeof
discriminantTy = builtinTypes->neverType; // TODO: replace with top class type
// For now, we don't really care about being accurate with userdata if the typeguard was using typeof.
discriminantTy = builtinTypes->classType;
}
else if (!typeguard->isTypeof && typeguard->type == "vector")
discriminantTy = builtinTypes->neverType; // TODO: figure out a way to deal with this quirky type
@ -1649,18 +1674,15 @@ std::tuple<TypeId, TypeId, RefinementId> ConstraintGraphBuilder::checkBinary(
}
}
TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs)
std::vector<TypeId> ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray<AstExpr*> exprs)
{
std::vector<TypeId> types;
types.reserve(exprs.size);
for (size_t i = 0; i < exprs.size; ++i)
{
AstExpr* const expr = exprs.data[i];
for (AstExpr* expr : exprs)
types.push_back(checkLValue(scope, expr));
}
return arena->addTypePack(std::move(types));
return types;
}
/**
@ -1679,6 +1701,28 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr)
indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'};
return checkLValue(scope, &synthetic);
}
// An indexer is only interesting in an lvalue-ey way if it is at the
// tail of an expression.
//
// If the indexer is not at the tail, then we are not interested in
// augmenting the lhs data structure with a new indexer. Constraint
// generation can treat it as an ordinary lvalue.
//
// eg
//
// a.b.c[1] = 44 -- lvalue
// a.b[4].c = 2 -- rvalue
TypeId resultType = arena->addType(BlockedType{});
TypeId subjectType = check(scope, indexExpr->expr).ty;
TypeId indexType = check(scope, indexExpr->index).ty;
TypeId propType = arena->addType(BlockedType{});
addConstraint(scope, expr->location, SetIndexerConstraint{resultType, subjectType, indexType, propType});
module->astTypes[expr] = propType;
return propType;
}
else if (!expr->is<AstExprIndexName>())
return check(scope, expr).ty;
@ -1718,7 +1762,8 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr)
auto lookupResult = scope->lookupEx(sym);
if (!lookupResult)
return check(scope, expr).ty;
const auto [subjectType, symbolScope] = std::move(*lookupResult);
const auto [subjectBinding, symbolScope] = std::move(*lookupResult);
TypeId subjectType = subjectBinding->typeId;
TypeId propTy = freshType(scope);
@ -1739,14 +1784,17 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr)
module->astTypes[expr] = prevSegmentTy;
module->astTypes[e] = updatedType;
symbolScope->bindings[sym].typeId = updatedType;
std::optional<DefId> def = dfg->getDef(sym);
if (def)
if (!subjectType->persistent)
{
// This can fail if the user is erroneously trying to augment a builtin
// table like os or string.
symbolScope->dcrRefinements[*def] = updatedType;
symbolScope->bindings[sym].typeId = updatedType;
std::optional<DefId> def = dfg->getDef(sym);
if (def)
{
// This can fail if the user is erroneously trying to augment a builtin
// table like os or string.
symbolScope->dcrRefinements[*def] = updatedType;
}
}
return propTy;
@ -1904,13 +1952,11 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
for (const auto& [name, g] : genericDefinitions)
{
genericTypes.push_back(g.ty);
signatureScope->privateTypeBindings[name] = TypeFun{g.ty};
}
for (const auto& [name, g] : genericPackDefinitions)
{
genericTypePacks.push_back(g.tp);
signatureScope->privateTypePackBindings[name] = g.tp;
}
// Local variable works around an odd gcc 11.3 warning: <anonymous> may be used uninitialized
@ -2023,15 +2069,14 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
actualFunction.generics = std::move(genericTypes);
actualFunction.genericPacks = std::move(genericTypePacks);
actualFunction.argNames = std::move(argNames);
actualFunction.hasSelf = fn->self != nullptr;
TypeId actualFunctionType = arena->addType(std::move(actualFunction));
LUAU_ASSERT(actualFunctionType);
module->astTypes[fn] = actualFunctionType;
if (expectedType && get<FreeType>(*expectedType))
{
asMutable(*expectedType)->ty.emplace<BoundType>(actualFunctionType);
}
bindFreeType(*expectedType, actualFunctionType);
return {
/* signature */ actualFunctionType,
@ -2179,13 +2224,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b
for (const auto& [name, g] : genericDefinitions)
{
genericTypes.push_back(g.ty);
signatureScope->privateTypeBindings[name] = TypeFun{g.ty};
}
for (const auto& [name, g] : genericPackDefinitions)
{
genericTypePacks.push_back(g.tp);
signatureScope->privateTypePackBindings[name] = g.tp;
}
}
else
@ -2330,7 +2373,7 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const
}
std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGraphBuilder::createGenerics(
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache)
const ScopePtr& scope, AstArray<AstGenericType> generics, bool useCache, bool addTypes)
{
std::vector<std::pair<Name, GenericTypeDefinition>> result;
for (const auto& generic : generics)
@ -2350,6 +2393,9 @@ std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGraphBuilder::crea
if (generic.defaultValue)
defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false);
if (addTypes)
scope->privateTypeBindings[generic.name.value] = TypeFun{genericTy};
result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}});
}
@ -2357,7 +2403,7 @@ std::vector<std::pair<Name, GenericTypeDefinition>> ConstraintGraphBuilder::crea
}
std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGraphBuilder::createGenericPacks(
const ScopePtr& scope, AstArray<AstGenericTypePack> generics, bool useCache)
const ScopePtr& scope, AstArray<AstGenericTypePack> generics, bool useCache, bool addTypes)
{
std::vector<std::pair<Name, GenericTypePackDefinition>> result;
for (const auto& generic : generics)
@ -2378,6 +2424,9 @@ std::vector<std::pair<Name, GenericTypePackDefinition>> ConstraintGraphBuilder::
if (generic.defaultValue)
defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false);
if (addTypes)
scope->privateTypePackBindings[generic.name.value] = genericTy;
result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}});
}
@ -2394,11 +2443,9 @@ Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location lo
if (auto f = first(tp))
return Inference{*f, refinement};
TypeId typeResult = freshType(scope);
TypePack onePack{{typeResult}, freshTypePack(scope)};
TypePackId oneTypePack = arena->addTypePack(std::move(onePack));
addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack});
TypeId typeResult = arena->addType(BlockedType{});
TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get()));
addConstraint(scope, location, UnpackConstraint{resultPack, tp});
return Inference{typeResult, refinement};
}

View file

@ -22,6 +22,22 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false);
namespace Luau
{
size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const
{
size_t result = 0;
if (const TypeId* ty = get_if<TypeId>(&bci))
result = std::hash<TypeId>()(*ty);
else if (const TypePackId* tp = get_if<TypePackId>(&bci))
result = std::hash<TypePackId>()(*tp);
else if (Constraint const* const* c = get_if<const Constraint*>(&bci))
result = std::hash<const Constraint*>()(*c);
else
LUAU_ASSERT(!"Should be unreachable");
return result;
}
[[maybe_unused]] static void dumpBindings(NotNull<Scope> scope, ToStringOptions& opts)
{
for (const auto& [k, v] : scope->bindings)
@ -221,10 +237,12 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts)
}
ConstraintSolver::ConstraintSolver(NotNull<Normalizer> normalizer, NotNull<Scope> rootScope, std::vector<NotNull<Constraint>> constraints,
ModuleName moduleName, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles, DcrLogger* logger)
ModuleName moduleName, NotNull<TypeReduction> reducer, NotNull<ModuleResolver> moduleResolver, std::vector<RequireCycle> requireCycles,
DcrLogger* logger)
: arena(normalizer->arena)
, builtinTypes(normalizer->builtinTypes)
, normalizer(normalizer)
, reducer(reducer)
, constraints(std::move(constraints))
, rootScope(rootScope)
, currentModuleName(std::move(moduleName))
@ -326,6 +344,27 @@ void ConstraintSolver::run()
if (force)
printf("Force ");
printf("Dispatched\n\t%s\n", saveMe.c_str());
if (force)
{
printf("Blocked on:\n");
for (const auto& [bci, cv] : blocked)
{
if (end(cv) == std::find(begin(cv), end(cv), c))
continue;
if (auto bty = get_if<TypeId>(&bci))
printf("\tType %s\n", toString(*bty, opts).c_str());
else if (auto btp = get_if<TypePackId>(&bci))
printf("\tPack %s\n", toString(*btp, opts).c_str());
else if (auto cc = get_if<const Constraint*>(&bci))
printf("\tCons %s\n", toString(**cc, opts).c_str());
else
LUAU_ASSERT(!"Unreachable??");
}
}
dump(this, opts);
}
}
@ -411,8 +450,12 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
success = tryDispatch(*hpc, constraint);
else if (auto spc = get<SetPropConstraint>(*constraint))
success = tryDispatch(*spc, constraint, force);
else if (auto spc = get<SetIndexerConstraint>(*constraint))
success = tryDispatch(*spc, constraint, force);
else if (auto sottc = get<SingletonOrTopTypeConstraint>(*constraint))
success = tryDispatch(*sottc, constraint);
else if (auto uc = get<UnpackConstraint>(*constraint))
success = tryDispatch(*uc, constraint);
else
LUAU_ASSERT(false);
@ -424,26 +467,46 @@ bool ConstraintSolver::tryDispatch(NotNull<const Constraint> constraint, bool fo
bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull<const Constraint> constraint, bool force)
{
if (!recursiveBlock(c.subType, constraint))
return false;
if (!recursiveBlock(c.superType, constraint))
return false;
if (isBlocked(c.subType))
return block(c.subType, constraint);
else if (isBlocked(c.superType))
return block(c.superType, constraint);
unify(c.subType, c.superType, constraint->scope);
Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant};
u.useScopes = true;
u.tryUnify(c.subType, c.superType);
if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty())
{
for (TypeId bt : u.blockedTypes)
block(bt, constraint);
for (TypePackId btp : u.blockedTypePacks)
block(btp, constraint);
return false;
}
if (!u.errors.empty())
{
TypeId errorType = errorRecoveryType();
u.tryUnify(c.subType, errorType);
u.tryUnify(c.superType, errorType);
}
const auto [changedTypes, changedPacks] = u.log.getChanges();
u.log.commit();
unblock(changedTypes);
unblock(changedPacks);
// unify(c.subType, c.superType, constraint->scope);
return true;
}
bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force)
{
if (!recursiveBlock(c.subPack, constraint) || !recursiveBlock(c.superPack, constraint))
return false;
if (isBlocked(c.subPack))
return block(c.subPack, constraint);
else if (isBlocked(c.superPack))
@ -1183,8 +1246,26 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
TypeId instantiatedTy = arena->addType(BlockedType{});
TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result});
auto ic = pushConstraint(constraint->scope, constraint->location, InstantiationConstraint{instantiatedTy, fn});
auto sc = pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{instantiatedTy, inferredTy});
auto pushConstraintGreedy = [this, constraint](ConstraintV cv) -> Constraint* {
std::unique_ptr<Constraint> c = std::make_unique<Constraint>(constraint->scope, constraint->location, std::move(cv));
NotNull<Constraint> borrow{c.get()};
bool ok = tryDispatch(borrow, false);
if (ok)
return nullptr;
solverConstraints.push_back(std::move(c));
unsolvedConstraints.push_back(borrow);
return borrow;
};
// HACK: We don't want other constraints to act on the free type pack
// created above until after these two constraints are solved, so we try to
// dispatch them directly.
auto ic = pushConstraintGreedy(InstantiationConstraint{instantiatedTy, fn});
auto sc = pushConstraintGreedy(SubtypeConstraint{instantiatedTy, inferredTy});
// Anything that is blocked on this constraint must also be blocked on our
// synthesized constraints.
@ -1193,8 +1274,10 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
{
for (const auto& blockedConstraint : blockedIt->second)
{
block(ic, blockedConstraint);
block(sc, blockedConstraint);
if (ic)
block(NotNull{ic}, blockedConstraint);
if (sc)
block(NotNull{sc}, blockedConstraint);
}
}
@ -1230,6 +1313,8 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull<const Con
return true;
}
subjectType = reducer->reduce(subjectType).value_or(subjectType);
std::optional<TypeId> resultType = lookupTableProp(subjectType, c.prop);
if (!resultType)
{
@ -1360,11 +1445,18 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull<const Con
if (existingPropType)
{
unify(c.propType, *existingPropType, constraint->scope);
if (!isBlocked(c.propType))
unify(c.propType, *existingPropType, constraint->scope);
bind(c.resultType, c.subjectType);
return true;
}
if (get<AnyType>(subjectType) || get<ErrorType>(subjectType) || get<NeverType>(subjectType))
{
bind(c.resultType, subjectType);
return true;
}
if (get<FreeType>(subjectType))
{
TypeId ty = arena->freshType(constraint->scope);
@ -1381,21 +1473,27 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull<const Con
LUAU_ASSERT(ty);
bind(subjectType, ty);
bind(c.resultType, ty);
if (follow(c.resultType) != follow(ty))
bind(c.resultType, ty);
return true;
}
else if (auto ttv = getMutable<TableType>(subjectType))
{
if (ttv->state == TableState::Free)
{
LUAU_ASSERT(!subjectType->persistent);
ttv->props[c.path[0]] = Property{c.propType};
bind(c.resultType, c.subjectType);
return true;
}
else if (ttv->state == TableState::Unsealed)
{
LUAU_ASSERT(!subjectType->persistent);
std::optional<TypeId> augmented = updateTheTableType(NotNull{arena}, subjectType, c.path, c.propType);
bind(c.resultType, augmented.value_or(subjectType));
bind(subjectType, c.resultType);
return true;
}
else
@ -1411,16 +1509,62 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull<const Con
bind(c.resultType, subjectType);
return true;
}
else if (get<AnyType>(subjectType) || get<ErrorType>(subjectType) || get<NeverType>(subjectType))
{
bind(c.resultType, subjectType);
return true;
}
LUAU_ASSERT(0);
return true;
}
bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull<const Constraint> constraint, bool force)
{
TypeId subjectType = follow(c.subjectType);
if (isBlocked(subjectType))
return block(subjectType, constraint);
if (auto ft = get<FreeType>(subjectType))
{
Scope* scope = ft->scope;
TableType* tt = &asMutable(subjectType)->ty.emplace<TableType>(TableState::Free, TypeLevel{}, scope);
tt->indexer = TableIndexer{c.indexType, c.propType};
asMutable(c.resultType)->ty.emplace<BoundType>(subjectType);
asMutable(c.propType)->ty.emplace<FreeType>(scope);
unblock(c.propType);
unblock(c.resultType);
return true;
}
else if (auto tt = get<TableType>(subjectType))
{
if (tt->indexer)
{
// TODO This probably has to be invariant.
unify(c.indexType, tt->indexer->indexType, constraint->scope);
asMutable(c.propType)->ty.emplace<BoundType>(tt->indexer->indexResultType);
asMutable(c.resultType)->ty.emplace<BoundType>(subjectType);
unblock(c.propType);
unblock(c.resultType);
return true;
}
else if (tt->state == TableState::Free || tt->state == TableState::Unsealed)
{
auto mtt = getMutable<TableType>(subjectType);
mtt->indexer = TableIndexer{c.indexType, c.propType};
asMutable(c.propType)->ty.emplace<FreeType>(tt->scope);
asMutable(c.resultType)->ty.emplace<BoundType>(subjectType);
unblock(c.propType);
unblock(c.resultType);
return true;
}
// Do not augment sealed or generic tables that lack indexers
}
asMutable(c.propType)->ty.emplace<BoundType>(builtinTypes->errorRecoveryType());
asMutable(c.resultType)->ty.emplace<BoundType>(builtinTypes->errorRecoveryType());
unblock(c.propType);
unblock(c.resultType);
return true;
}
bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull<const Constraint> constraint)
{
if (isBlocked(c.discriminantType))
@ -1439,6 +1583,69 @@ bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNul
return true;
}
bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull<const Constraint> constraint)
{
TypePackId sourcePack = follow(c.sourcePack);
TypePackId resultPack = follow(c.resultPack);
if (isBlocked(sourcePack))
return block(sourcePack, constraint);
if (isBlocked(resultPack))
{
asMutable(resultPack)->ty.emplace<BoundTypePack>(sourcePack);
unblock(resultPack);
return true;
}
TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack));
auto destIter = begin(resultPack);
auto destEnd = end(resultPack);
size_t i = 0;
while (destIter != destEnd)
{
if (i >= srcPack.head.size())
break;
TypeId srcTy = follow(srcPack.head[i]);
if (isBlocked(*destIter))
{
if (follow(srcTy) == *destIter)
{
// Cyclic type dependency. (????)
asMutable(*destIter)->ty.emplace<FreeType>(constraint->scope);
}
else
asMutable(*destIter)->ty.emplace<BoundType>(srcTy);
unblock(*destIter);
}
else
unify(*destIter, srcTy, constraint->scope);
++destIter;
++i;
}
// We know that resultPack does not have a tail, but we don't know if
// sourcePack is long enough to fill every value. Replace every remaining
// result TypeId with the error recovery type.
while (destIter != destEnd)
{
if (isBlocked(*destIter))
{
asMutable(*destIter)->ty.emplace<BoundType>(builtinTypes->errorRecoveryType());
unblock(*destIter);
}
++destIter;
}
return true;
}
bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull<const Constraint> constraint, bool force)
{
auto block_ = [&](auto&& t) {
@ -1628,10 +1835,20 @@ bool ConstraintSolver::tryDispatchIterableFunction(
std::optional<TypeId> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName)
{
std::unordered_set<TypeId> seen;
return lookupTableProp(subjectType, propName, seen);
}
std::optional<TypeId> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set<TypeId>& seen)
{
if (!seen.insert(subjectType).second)
return std::nullopt;
auto collectParts = [&](auto&& unionOrIntersection) -> std::pair<std::optional<TypeId>, std::vector<TypeId>> {
std::optional<TypeId> blocked;
std::vector<TypeId> parts;
std::vector<TypeId> freeParts;
for (TypeId expectedPart : unionOrIntersection)
{
expectedPart = follow(expectedPart);
@ -1644,6 +1861,29 @@ std::optional<TypeId> ConstraintSolver::lookupTableProp(TypeId subjectType, cons
else if (ttv->indexer && maybeString(ttv->indexer->indexType))
parts.push_back(ttv->indexer->indexResultType);
}
else if (get<FreeType>(expectedPart))
{
freeParts.push_back(expectedPart);
}
}
// If the only thing resembling a match is a single fresh type, we can
// confidently tablify it. If other types match or if there are more
// than one free type, we can't do anything.
if (parts.empty() && 1 == freeParts.size())
{
TypeId freePart = freeParts.front();
const FreeType* ft = get<FreeType>(freePart);
LUAU_ASSERT(ft);
Scope* scope = ft->scope;
TableType* tt = &asMutable(freePart)->ty.emplace<TableType>();
tt->state = TableState::Free;
tt->scope = scope;
TypeId propType = arena->freshType(scope);
tt->props[propName] = Property{propType};
parts.push_back(propType);
}
return {blocked, parts};
@ -1651,12 +1891,75 @@ std::optional<TypeId> ConstraintSolver::lookupTableProp(TypeId subjectType, cons
std::optional<TypeId> resultType;
if (auto ttv = get<TableType>(subjectType))
if (get<AnyType>(subjectType) || get<NeverType>(subjectType))
{
return subjectType;
}
else if (auto ttv = getMutable<TableType>(subjectType))
{
if (auto prop = ttv->props.find(propName); prop != ttv->props.end())
resultType = prop->second.type;
else if (ttv->indexer && maybeString(ttv->indexer->indexType))
resultType = ttv->indexer->indexResultType;
else if (ttv->state == TableState::Free)
{
resultType = arena->addType(FreeType{ttv->scope});
ttv->props[propName] = Property{*resultType};
}
}
else if (auto mt = get<MetatableType>(subjectType))
{
if (auto p = lookupTableProp(mt->table, propName, seen))
return p;
TypeId mtt = follow(mt->metatable);
if (get<BlockedType>(mtt))
return mtt;
else if (auto metatable = get<TableType>(mtt))
{
auto indexProp = metatable->props.find("__index");
if (indexProp == metatable->props.end())
return std::nullopt;
// TODO: __index can be an overloaded function.
TypeId indexType = follow(indexProp->second.type);
if (auto ft = get<FunctionType>(indexType))
{
std::optional<TypeId> ret = first(ft->retTypes);
if (ret)
return *ret;
else
return std::nullopt;
}
return lookupTableProp(indexType, propName, seen);
}
}
else if (auto ct = get<ClassType>(subjectType))
{
while (ct)
{
if (auto prop = ct->props.find(propName); prop != ct->props.end())
return prop->second.type;
else if (ct->parent)
ct = get<ClassType>(follow(*ct->parent));
else
break;
}
}
else if (auto pt = get<PrimitiveType>(subjectType); pt && pt->metatable)
{
const TableType* metatable = get<TableType>(follow(*pt->metatable));
LUAU_ASSERT(metatable);
auto indexProp = metatable->props.find("__index");
if (indexProp == metatable->props.end())
return std::nullopt;
return lookupTableProp(indexProp->second.type, propName, seen);
}
else if (auto utv = get<UnionType>(subjectType))
{
@ -1704,7 +2007,7 @@ void ConstraintSolver::block(NotNull<const Constraint> target, NotNull<const Con
if (FFlag::DebugLuauLogSolver)
printf("block Constraint %s on\t%s\n", toString(*target, opts).c_str(), toString(*constraint, opts).c_str());
block_(target, constraint);
block_(target.get(), constraint);
}
bool ConstraintSolver::block(TypeId target, NotNull<const Constraint> constraint)
@ -1715,7 +2018,7 @@ bool ConstraintSolver::block(TypeId target, NotNull<const Constraint> constraint
if (FFlag::DebugLuauLogSolver)
printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str());
block_(target, constraint);
block_(follow(target), constraint);
return false;
}
@ -1802,7 +2105,7 @@ void ConstraintSolver::unblock(NotNull<const Constraint> progressed)
if (FFlag::DebugLuauLogSolverToJson)
logger->popBlock(progressed);
return unblock_(progressed);
return unblock_(progressed.get());
}
void ConstraintSolver::unblock(TypeId progressed)
@ -1810,7 +2113,10 @@ void ConstraintSolver::unblock(TypeId progressed)
if (FFlag::DebugLuauLogSolverToJson)
logger->popBlock(progressed);
return unblock_(progressed);
unblock_(progressed);
if (auto bt = get<BoundType>(progressed))
unblock(bt->boundTo);
}
void ConstraintSolver::unblock(TypePackId progressed)

View file

@ -9,17 +9,39 @@
namespace Luau
{
template<typename T>
static std::string toPointerId(const T* ptr)
{
return std::to_string(reinterpret_cast<size_t>(ptr));
}
static std::string toPointerId(NotNull<const Constraint> ptr)
{
return std::to_string(reinterpret_cast<size_t>(ptr.get()));
}
namespace Json
{
template<typename T>
void write(JsonEmitter& emitter, const T* ptr)
{
write(emitter, toPointerId(ptr));
}
void write(JsonEmitter& emitter, NotNull<const Constraint> ptr)
{
write(emitter, toPointerId(ptr));
}
void write(JsonEmitter& emitter, const Location& location)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("beginLine", location.begin.line);
o.writePair("beginColumn", location.begin.column);
o.writePair("endLine", location.end.line);
o.writePair("endColumn", location.end.column);
o.finish();
ArrayEmitter a = emitter.writeArray();
a.writeValue(location.begin.line);
a.writeValue(location.begin.column);
a.writeValue(location.end.line);
a.writeValue(location.end.column);
a.finish();
}
void write(JsonEmitter& emitter, const ErrorSnapshot& snapshot)
@ -47,24 +69,43 @@ void write(JsonEmitter& emitter, const TypeBindingSnapshot& snapshot)
o.finish();
}
template<typename K, typename V>
void write(JsonEmitter& emitter, const DenseHashMap<const K*, V>& map)
{
ObjectEmitter o = emitter.writeObject();
for (const auto& [k, v] : map)
o.writePair(toPointerId(k), v);
o.finish();
}
void write(JsonEmitter& emitter, const ExprTypesAtLocation& tys)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("location", tys.location);
o.writePair("ty", toPointerId(tys.ty));
if (tys.expectedTy)
o.writePair("expectedTy", toPointerId(*tys.expectedTy));
o.finish();
}
void write(JsonEmitter& emitter, const AnnotationTypesAtLocation& tys)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("location", tys.location);
o.writePair("resolvedTy", toPointerId(tys.resolvedTy));
o.finish();
}
void write(JsonEmitter& emitter, const ConstraintGenerationLog& log)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("source", log.source);
emitter.writeComma();
write(emitter, "constraintLocations");
emitter.writeRaw(":");
ObjectEmitter locationEmitter = emitter.writeObject();
for (const auto& [id, location] : log.constraintLocations)
{
locationEmitter.writePair(id, location);
}
locationEmitter.finish();
o.writePair("errors", log.errors);
o.writePair("exprTypeLocations", log.exprTypeLocations);
o.writePair("annotationTypeLocations", log.annotationTypeLocations);
o.finish();
}
@ -78,26 +119,34 @@ void write(JsonEmitter& emitter, const ScopeSnapshot& snapshot)
o.finish();
}
void write(JsonEmitter& emitter, const ConstraintBlockKind& kind)
{
switch (kind)
{
case ConstraintBlockKind::TypeId:
return write(emitter, "type");
case ConstraintBlockKind::TypePackId:
return write(emitter, "typePack");
case ConstraintBlockKind::ConstraintId:
return write(emitter, "constraint");
default:
LUAU_ASSERT(0);
}
}
void write(JsonEmitter& emitter, const ConstraintBlock& block)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("kind", block.kind);
o.writePair("stringification", block.stringification);
auto go = [&o](auto&& t) {
using T = std::decay_t<decltype(t)>;
o.writePair("id", toPointerId(t));
if constexpr (std::is_same_v<T, TypeId>)
{
o.writePair("kind", "type");
}
else if constexpr (std::is_same_v<T, TypePackId>)
{
o.writePair("kind", "typePack");
}
else if constexpr (std::is_same_v<T, NotNull<const Constraint>>)
{
o.writePair("kind", "constraint");
}
else
static_assert(always_false_v<T>, "non-exhaustive possibility switch");
};
visit(go, block.target);
o.finish();
}
@ -114,7 +163,8 @@ void write(JsonEmitter& emitter, const BoundarySnapshot& snapshot)
{
ObjectEmitter o = emitter.writeObject();
o.writePair("rootScope", snapshot.rootScope);
o.writePair("constraints", snapshot.constraints);
o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints);
o.writePair("typeStrings", snapshot.typeStrings);
o.finish();
}
@ -125,6 +175,7 @@ void write(JsonEmitter& emitter, const StepSnapshot& snapshot)
o.writePair("forced", snapshot.forced);
o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints);
o.writePair("rootScope", snapshot.rootScope);
o.writePair("typeStrings", snapshot.typeStrings);
o.finish();
}
@ -146,11 +197,6 @@ void write(JsonEmitter& emitter, const TypeCheckLog& log)
} // namespace Json
static std::string toPointerId(NotNull<const Constraint> ptr)
{
return std::to_string(reinterpret_cast<size_t>(ptr.get()));
}
static ScopeSnapshot snapshotScope(const Scope* scope, ToStringOptions& opts)
{
std::unordered_map<Name, BindingSnapshot> bindings;
@ -230,6 +276,32 @@ void DcrLogger::captureSource(std::string source)
generationLog.source = std::move(source);
}
void DcrLogger::captureGenerationModule(const ModulePtr& module)
{
generationLog.exprTypeLocations.reserve(module->astTypes.size());
for (const auto& [expr, ty] : module->astTypes)
{
ExprTypesAtLocation tys;
tys.location = expr->location;
tys.ty = ty;
if (auto expectedTy = module->astExpectedTypes.find(expr))
tys.expectedTy = *expectedTy;
generationLog.exprTypeLocations.push_back(tys);
}
generationLog.annotationTypeLocations.reserve(module->astResolvedTypes.size());
for (const auto& [annot, ty] : module->astResolvedTypes)
{
AnnotationTypesAtLocation tys;
tys.location = annot->location;
tys.resolvedTy = ty;
generationLog.annotationTypeLocations.push_back(tys);
}
}
void DcrLogger::captureGenerationError(const TypeError& error)
{
std::string stringifiedError = toString(error);
@ -239,12 +311,6 @@ void DcrLogger::captureGenerationError(const TypeError& error)
});
}
void DcrLogger::captureConstraintLocation(NotNull<const Constraint> constraint, Location location)
{
std::string id = toPointerId(constraint);
generationLog.constraintLocations[id] = location;
}
void DcrLogger::pushBlock(NotNull<const Constraint> constraint, TypeId block)
{
constraintBlocks[constraint].push_back(block);
@ -284,44 +350,70 @@ void DcrLogger::popBlock(NotNull<const Constraint> block)
}
}
void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
static void snapshotTypeStrings(const std::vector<ExprTypesAtLocation>& interestedExprs,
const std::vector<AnnotationTypesAtLocation>& interestedAnnots, DenseHashMap<const void*, std::string>& map, ToStringOptions& opts)
{
solveLog.initialState.rootScope = snapshotScope(rootScope, opts);
solveLog.initialState.constraints.clear();
for (const ExprTypesAtLocation& tys : interestedExprs)
{
map[tys.ty] = toString(tys.ty, opts);
if (tys.expectedTy)
map[*tys.expectedTy] = toString(*tys.expectedTy, opts);
}
for (const AnnotationTypesAtLocation& tys : interestedAnnots)
{
map[tys.resolvedTy] = toString(tys.resolvedTy, opts);
}
}
void DcrLogger::captureBoundaryState(
BoundarySnapshot& target, const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
target.rootScope = snapshotScope(rootScope, opts);
target.unsolvedConstraints.clear();
for (NotNull<const Constraint> c : unsolvedConstraints)
{
std::string id = toPointerId(c);
solveLog.initialState.constraints[id] = {
target.unsolvedConstraints[c.get()] = {
toString(*c.get(), opts),
c->location,
snapshotBlocks(c),
};
}
snapshotTypeStrings(generationLog.exprTypeLocations, generationLog.annotationTypeLocations, target.typeStrings, opts);
}
void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
captureBoundaryState(solveLog.initialState, rootScope, unsolvedConstraints);
}
StepSnapshot DcrLogger::prepareStepSnapshot(
const Scope* rootScope, NotNull<const Constraint> current, bool force, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts);
std::string currentId = toPointerId(current);
std::unordered_map<std::string, ConstraintSnapshot> constraints;
DenseHashMap<const Constraint*, ConstraintSnapshot> constraints{nullptr};
for (NotNull<const Constraint> c : unsolvedConstraints)
{
std::string id = toPointerId(c);
constraints[id] = {
constraints[c.get()] = {
toString(*c.get(), opts),
c->location,
snapshotBlocks(c),
};
}
DenseHashMap<const void*, std::string> typeStrings{nullptr};
snapshotTypeStrings(generationLog.exprTypeLocations, generationLog.annotationTypeLocations, typeStrings, opts);
return StepSnapshot{
currentId,
current,
force,
constraints,
std::move(constraints),
scopeSnapshot,
std::move(typeStrings),
};
}
@ -332,18 +424,7 @@ void DcrLogger::commitStepSnapshot(StepSnapshot snapshot)
void DcrLogger::captureFinalSolverState(const Scope* rootScope, const std::vector<NotNull<const Constraint>>& unsolvedConstraints)
{
solveLog.finalState.rootScope = snapshotScope(rootScope, opts);
solveLog.finalState.constraints.clear();
for (NotNull<const Constraint> c : unsolvedConstraints)
{
std::string id = toPointerId(c);
solveLog.finalState.constraints[id] = {
toString(*c.get(), opts),
c->location,
snapshotBlocks(c),
};
}
captureBoundaryState(solveLog.finalState, rootScope, unsolvedConstraints);
}
void DcrLogger::captureTypeCheckError(const TypeError& error)
@ -370,21 +451,21 @@ std::vector<ConstraintBlock> DcrLogger::snapshotBlocks(NotNull<const Constraint>
if (const TypeId* ty = get_if<TypeId>(&target))
{
snapshot.push_back({
ConstraintBlockKind::TypeId,
*ty,
toString(*ty, opts),
});
}
else if (const TypePackId* tp = get_if<TypePackId>(&target))
{
snapshot.push_back({
ConstraintBlockKind::TypePackId,
*tp,
toString(*tp, opts),
});
}
else if (const NotNull<const Constraint>* c = get_if<NotNull<const Constraint>>(&target))
{
snapshot.push_back({
ConstraintBlockKind::ConstraintId,
*c,
toString(*(c->get()), opts),
});
}

View file

@ -899,8 +899,8 @@ ModulePtr check(
cgb.visit(sourceModule.root);
result->errors = std::move(cgb.errors);
ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, moduleResolver,
requireCycles, logger.get()};
ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name,
NotNull{result->reduction.get()}, moduleResolver, requireCycles, logger.get()};
if (options.randomizeConstraintResolutionSeed)
cs.randomize(*options.randomizeConstraintResolutionSeed);

View file

@ -1441,6 +1441,8 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor
if (!unionNormals(here, *tn))
return false;
}
else if (get<BlockedType>(there))
LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType");
else
LUAU_ASSERT(!"Unreachable");

View file

@ -183,7 +183,7 @@ struct PureQuantifier : Substitution
else if (ttv->state == TableState::Generic)
seenGenericType = true;
return ttv->state == TableState::Unsealed || (ttv->state == TableState::Free && subsumes(scope, ttv->scope));
return (ttv->state == TableState::Unsealed || ttv->state == TableState::Free) && subsumes(scope, ttv->scope);
}
return false;

View file

@ -31,12 +31,12 @@ std::optional<TypeId> Scope::lookup(Symbol sym) const
{
auto r = const_cast<Scope*>(this)->lookupEx(sym);
if (r)
return r->first;
return r->first->typeId;
else
return std::nullopt;
}
std::optional<std::pair<TypeId, Scope*>> Scope::lookupEx(Symbol sym)
std::optional<std::pair<Binding*, Scope*>> Scope::lookupEx(Symbol sym)
{
Scope* s = this;
@ -44,7 +44,7 @@ std::optional<std::pair<TypeId, Scope*>> Scope::lookupEx(Symbol sym)
{
auto it = s->bindings.find(sym);
if (it != s->bindings.end())
return std::pair{it->second.typeId, s};
return std::pair{&it->second, s};
if (s->parent)
s = s->parent.get();

View file

@ -1533,6 +1533,10 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]";
return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType);
}
else if constexpr (std::is_same_v<T, SetIndexerConstraint>)
{
return tos(c.resultType) + " ~ setIndexer " + tos(c.subjectType) + " [ " + tos(c.indexType) + " ] " + tos(c.propType);
}
else if constexpr (std::is_same_v<T, SingletonOrTopTypeConstraint>)
{
std::string result = tos(c.resultType);
@ -1543,6 +1547,8 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts)
else
return result + " ~ if isSingleton D then D else unknown where D = " + discriminant;
}
else if constexpr (std::is_same_v<T, UnpackConstraint>)
return tos(c.resultPack) + " ~ unpack " + tos(c.sourcePack);
else
static_assert(always_false_v<T>, "Non-exhaustive constraint switch");
};

View file

@ -4,6 +4,7 @@
#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
#include "Luau/Clone.h"
#include "Luau/Common.h"
#include "Luau/DcrLogger.h"
#include "Luau/Error.h"
#include "Luau/Instantiation.h"
@ -329,11 +330,12 @@ struct TypeChecker2
for (size_t i = 0; i < count; ++i)
{
AstExpr* value = i < local->values.size ? local->values.data[i] : nullptr;
const bool isPack = value && (value->is<AstExprCall>() || value->is<AstExprVarargs>());
if (value)
visit(value, RValue);
if (i != local->values.size - 1 || value)
if (i != local->values.size - 1 || !isPack)
{
AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr;
@ -351,16 +353,19 @@ struct TypeChecker2
visit(var->annotation);
}
}
else
else if (value)
{
LUAU_ASSERT(value);
TypePackId valuePack = lookupPack(value);
TypePack valueTypes;
if (i < local->vars.size)
valueTypes = extendTypePack(module->internalTypes, builtinTypes, valuePack, local->vars.size - i);
TypePackId valueTypes = lookupPack(value);
auto it = begin(valueTypes);
Location errorLocation;
for (size_t j = i; j < local->vars.size; ++j)
{
if (it == end(valueTypes))
if (j - i >= valueTypes.head.size())
{
errorLocation = local->vars.data[j]->location;
break;
}
@ -368,14 +373,28 @@ struct TypeChecker2
if (var->annotation)
{
TypeId varType = lookupAnnotation(var->annotation);
ErrorVec errors = tryUnify(stack.back(), value->location, *it, varType);
ErrorVec errors = tryUnify(stack.back(), value->location, valueTypes.head[j - i], varType);
if (!errors.empty())
reportErrors(std::move(errors));
visit(var->annotation);
}
}
++it;
if (valueTypes.head.size() < local->vars.size - i)
{
reportError(
CountMismatch{
// We subtract 1 here because the final AST
// expression is not worth one value. It is worth 0
// or more depending on valueTypes.head
local->values.size - 1 + valueTypes.head.size(),
std::nullopt,
local->vars.size,
local->values.data[local->values.size - 1]->is<AstExprCall>() ? CountMismatch::FunctionResult
: CountMismatch::ExprListResult,
},
errorLocation);
}
}
}
@ -810,6 +829,95 @@ struct TypeChecker2
// TODO!
}
ErrorVec visitOverload(AstExprCall* call, NotNull<const FunctionType> overloadFunctionType, const std::vector<Location>& argLocs,
TypePackId expectedArgTypes, TypePackId expectedRetType)
{
ErrorVec overloadErrors =
tryUnify(stack.back(), call->location, overloadFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult);
size_t argIndex = 0;
auto inferredArgIt = begin(overloadFunctionType->argTypes);
auto expectedArgIt = begin(expectedArgTypes);
while (inferredArgIt != end(overloadFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes))
{
Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex];
ErrorVec argErrors = tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt);
for (TypeError e : argErrors)
overloadErrors.emplace_back(e);
++argIndex;
++inferredArgIt;
++expectedArgIt;
}
// piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad
ErrorVec argumentErrors = tryUnify(stack.back(), call->location, expectedArgTypes, overloadFunctionType->argTypes);
for (TypeError e : argumentErrors)
if (get<CountMismatch>(e) != nullptr)
overloadErrors.emplace_back(std::move(e));
return overloadErrors;
}
void reportOverloadResolutionErrors(AstExprCall* call, std::vector<TypeId> overloads, TypePackId expectedArgTypes,
const std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<std::pair<ErrorVec, const FunctionType*>> overloadsErrors)
{
if (overloads.size() == 1)
{
reportErrors(std::get<0>(overloadsErrors.front()));
return;
}
std::vector<TypeId> overloadTypes = overloadsThatMatchArgCount;
if (overloadsThatMatchArgCount.size() == 0)
{
reportError(GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location);
// If no overloads match argument count, just list all overloads.
overloadTypes = overloads;
}
else
{
// Report errors of the first argument-count-matching, but failing overload
TypeId overload = overloadsThatMatchArgCount[0];
// Remove the overload we are reporting errors about from the list of alternatives
overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end());
const FunctionType* ftv = get<FunctionType>(overload);
LUAU_ASSERT(ftv); // overload must be a function type here
auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [ftv](const std::pair<ErrorVec, const FunctionType*>& e) {
return ftv == std::get<1>(e);
});
LUAU_ASSERT(error != overloadsErrors.end());
reportErrors(std::get<0>(*error));
// If only one overload matched, we don't need this error because we provided the previous errors.
if (overloadsThatMatchArgCount.size() == 1)
return;
}
std::string s;
for (size_t i = 0; i < overloadTypes.size(); ++i)
{
TypeId overload = follow(overloadTypes[i]);
if (i > 0)
s += "; ";
if (i > 0 && i == overloadTypes.size() - 1)
s += "and ";
s += toString(overload);
}
if (overloadsThatMatchArgCount.size() == 0)
reportError(ExtraInformation{"Available overloads: " + s}, call->func->location);
else
reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location);
}
void visit(AstExprCall* call)
{
visit(call->func, RValue);
@ -865,6 +973,10 @@ struct TypeChecker2
return;
}
}
else if (auto itv = get<IntersectionType>(functionType))
{
// We do nothing here because we'll flatten the intersection later, but we don't want to report it as a non-function.
}
else if (auto utv = get<UnionType>(functionType))
{
// Sometimes it's okay to call a union of functions, but only if all of the functions are the same.
@ -930,48 +1042,105 @@ struct TypeChecker2
TypePackId expectedArgTypes = arena->addTypePack(args);
const FunctionType* inferredFunctionType = get<FunctionType>(testFunctionType);
LUAU_ASSERT(inferredFunctionType); // testFunctionType should always be a FunctionType here
std::vector<TypeId> overloads = flattenIntersection(testFunctionType);
std::vector<std::pair<ErrorVec, const FunctionType*>> overloadsErrors;
overloadsErrors.reserve(overloads.size());
size_t argIndex = 0;
auto inferredArgIt = begin(inferredFunctionType->argTypes);
auto expectedArgIt = begin(expectedArgTypes);
while (inferredArgIt != end(inferredFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes))
std::vector<TypeId> overloadsThatMatchArgCount;
for (TypeId overload : overloads)
{
Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex];
reportErrors(tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt));
overload = follow(overload);
++argIndex;
++inferredArgIt;
++expectedArgIt;
const FunctionType* overloadFn = get<FunctionType>(overload);
if (!overloadFn)
{
reportError(CannotCallNonFunction{overload}, call->func->location);
return;
}
else
{
// We may have to instantiate the overload in order for it to typecheck.
if (std::optional<TypeId> instantiatedFunctionType = instantiation.substitute(overload))
{
overloadFn = get<FunctionType>(*instantiatedFunctionType);
}
else
{
overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overloadFn);
return;
}
}
ErrorVec overloadErrors = visitOverload(call, NotNull{overloadFn}, argLocs, expectedArgTypes, expectedRetType);
if (overloadErrors.empty())
return;
bool argMismatch = false;
for (auto error : overloadErrors)
{
CountMismatch* cm = get<CountMismatch>(error);
if (!cm)
continue;
if (cm->context == CountMismatch::Arg)
{
argMismatch = true;
break;
}
}
if (!argMismatch)
overloadsThatMatchArgCount.push_back(overload);
overloadsErrors.emplace_back(std::move(overloadErrors), overloadFn);
}
// piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad
ErrorVec errors = tryUnify(stack.back(), call->location, expectedArgTypes, inferredFunctionType->argTypes);
for (TypeError e : errors)
if (get<CountMismatch>(e) != nullptr)
reportError(std::move(e));
reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors);
}
reportErrors(tryUnify(stack.back(), call->location, inferredFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult));
void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context)
{
visit(expr, RValue);
TypeId leftType = lookupType(expr);
const NormalizedType* norm = normalizer.normalize(leftType);
if (!norm)
reportError(NormalizationTooComplex{}, location);
checkIndexTypeFromType(leftType, *norm, propName, location, context);
}
void visit(AstExprIndexName* indexName, ValueContext context)
{
visit(indexName->expr, RValue);
TypeId leftType = lookupType(indexName->expr);
const NormalizedType* norm = normalizer.normalize(leftType);
if (!norm)
reportError(NormalizationTooComplex{}, indexName->indexLocation);
checkIndexTypeFromType(leftType, *norm, indexName->index.value, indexName->location, context);
visitExprName(indexName->expr, indexName->location, indexName->index.value, context);
}
void visit(AstExprIndexExpr* indexExpr, ValueContext context)
{
if (auto str = indexExpr->index->as<AstExprConstantString>())
{
const std::string stringValue(str->value.data, str->value.size);
visitExprName(indexExpr->expr, indexExpr->location, stringValue, context);
return;
}
// TODO!
visit(indexExpr->expr, LValue);
visit(indexExpr->index, RValue);
NotNull<Scope> scope = stack.back();
TypeId exprType = lookupType(indexExpr->expr);
TypeId indexType = lookupType(indexExpr->index);
if (auto tt = get<TableType>(exprType))
{
if (tt->indexer)
reportErrors(tryUnify(scope, indexExpr->index->location, indexType, tt->indexer->indexType));
else
reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location);
}
}
void visit(AstExprFunction* fn)
@ -1879,8 +2048,17 @@ struct TypeChecker2
ty = *mtIndex;
}
if (getTableType(ty))
return bool(findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location));
if (auto tt = getTableType(ty))
{
if (findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location))
return true;
else if (tt->indexer && isPrim(tt->indexer->indexResultType, PrimitiveType::String))
return tt->indexer->indexResultType;
else
return false;
}
else if (const ClassType* cls = get<ClassType>(ty))
return bool(lookupClassProp(cls, prop));
else if (const UnionType* utv = get<UnionType>(ty))

View file

@ -1759,7 +1759,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit)
{
reportErrorCodeTooComplex(expr.location);
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
WithPredicate<TypeId> result;
@ -1767,23 +1767,23 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (auto a = expr.as<AstExprGroup>())
result = checkExpr(scope, *a->expr, expectedType);
else if (expr.is<AstExprConstantNil>())
result = {nilType};
result = WithPredicate{nilType};
else if (const AstExprConstantBool* bexpr = expr.as<AstExprConstantBool>())
{
if (forceSingleton || (expectedType && maybeSingleton(*expectedType)))
result = {singletonType(bexpr->value)};
result = WithPredicate{singletonType(bexpr->value)};
else
result = {booleanType};
result = WithPredicate{booleanType};
}
else if (const AstExprConstantString* sexpr = expr.as<AstExprConstantString>())
{
if (forceSingleton || (expectedType && maybeSingleton(*expectedType)))
result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))};
result = WithPredicate{singletonType(std::string(sexpr->value.data, sexpr->value.size))};
else
result = {stringType};
result = WithPredicate{stringType};
}
else if (expr.is<AstExprConstantNumber>())
result = {numberType};
result = WithPredicate{numberType};
else if (auto a = expr.as<AstExprLocal>())
result = checkExpr(scope, *a);
else if (auto a = expr.as<AstExprGlobal>())
@ -1837,7 +1837,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
// TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint
// ice("AstExprLocal exists but no binding definition for it?", expr.location);
reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}});
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr)
@ -1849,7 +1849,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}};
reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}});
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr)
@ -1859,26 +1859,26 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (get<TypePack>(varargPack))
{
if (std::optional<TypeId> ty = first(varargPack))
return {*ty};
return WithPredicate{*ty};
return {nilType};
return WithPredicate{nilType};
}
else if (get<FreeTypePack>(varargPack))
{
TypeId head = freshType(scope);
TypePackId tail = freshTypePack(scope);
*asMutable(varargPack) = TypePack{{head}, tail};
return {head};
return WithPredicate{head};
}
if (get<ErrorType>(varargPack))
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
else if (auto vtp = get<VariadicTypePack>(varargPack))
return {vtp->ty};
return WithPredicate{vtp->ty};
else if (get<Unifiable::Generic>(varargPack))
{
// TODO: Better error?
reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"});
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
else
ice("Unknown TypePack type in checkExpr(AstExprVarargs)!");
@ -1929,9 +1929,9 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
lhsType = stripFromNilAndReport(lhsType, expr.expr->location);
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, lhsType, name, expr.location, /* addErrors= */ true))
return {*ty};
return WithPredicate{*ty};
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
std::optional<TypeId> TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors)
@ -2138,7 +2138,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (std::optional<TypeId> refiTy = resolveLValue(scope, *lvalue))
return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}};
return {ty};
return WithPredicate{ty};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional<TypeId> expectedType)
@ -2147,7 +2147,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
checkFunctionBody(funScope, funTy, expr);
return {quantify(funScope, funTy, expr.location)};
return WithPredicate{quantify(funScope, funTy, expr.location)};
}
TypeId TypeChecker::checkExprTable(
@ -2252,7 +2252,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit)
{
reportErrorCodeTooComplex(expr.location);
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
std::vector<std::pair<TypeId, TypeId>> fieldTypes(expr.items.size);
@ -2339,7 +2339,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
expectedIndexResultType = fieldTypes[i].second;
}
return {checkExprTable(scope, expr, fieldTypes, expectedType)};
return WithPredicate{checkExprTable(scope, expr, fieldTypes, expectedType)};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr)
@ -2356,7 +2356,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
const bool operandIsAny = get<AnyType>(operandType) || get<ErrorType>(operandType) || get<NeverType>(operandType);
if (operandIsAny)
return {operandType};
return WithPredicate{operandType};
if (typeCouldHaveMetatable(operandType))
{
@ -2377,16 +2377,16 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (!state.errors.empty())
retType = errorRecoveryType(retType);
return {retType};
return WithPredicate{retType};
}
reportError(expr.location,
GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())});
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
reportErrors(tryUnify(operandType, numberType, scope, expr.location));
return {numberType};
return WithPredicate{numberType};
}
case AstExprUnary::Len:
{
@ -2396,7 +2396,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
// # operator is guaranteed to return number
if (get<AnyType>(operandType) || get<ErrorType>(operandType) || get<NeverType>(operandType))
return {numberType};
return WithPredicate{numberType};
DenseHashSet<TypeId> seen{nullptr};
@ -2420,7 +2420,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
if (!hasLength(operandType, seen, &recursionCount))
reportError(TypeError{expr.location, NotATable{operandType}});
return {numberType};
return WithPredicate{numberType};
}
default:
ice("Unknown AstExprUnary " + std::to_string(int(expr.op)));
@ -3014,7 +3014,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
WithPredicate<TypeId> rhs = checkExpr(scope, *expr.right);
// Intentionally discarding predicates with other operators.
return {checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)};
return WithPredicate{checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)};
}
}
@ -3045,7 +3045,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
// any type errors that may arise from it are going to be useless.
currentModule->errors.resize(oldSize);
return {errorRecoveryType(scope)};
return WithPredicate{errorRecoveryType(scope)};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional<TypeId> expectedType)
@ -3061,12 +3061,12 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
WithPredicate<TypeId> falseType = checkExpr(falseScope, *expr.falseExpr, expectedType);
if (falseType.type == trueType.type)
return {trueType.type};
return WithPredicate{trueType.type};
std::vector<TypeId> types = reduceUnion({trueType.type, falseType.type});
if (types.empty())
return {neverType};
return {types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})};
return WithPredicate{neverType};
return WithPredicate{types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})};
}
WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprInterpString& expr)
@ -3074,7 +3074,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
for (AstExpr* expr : expr.expressions)
checkExpr(scope, *expr);
return {stringType};
return WithPredicate{stringType};
}
TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx)
@ -3704,7 +3704,7 @@ WithPredicate<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, cons
{
WithPredicate<TypePackId> result = checkExprPackHelper(scope, expr);
if (containsNever(result.type))
return {uninhabitableTypePack};
return WithPredicate{uninhabitableTypePack};
return result;
}
@ -3715,14 +3715,14 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
else if (expr.is<AstExprVarargs>())
{
if (!scope->varargPack)
return {errorRecoveryTypePack(scope)};
return WithPredicate{errorRecoveryTypePack(scope)};
return {*scope->varargPack};
return WithPredicate{*scope->varargPack};
}
else
{
TypeId type = checkExpr(scope, expr).type;
return {addTypePack({type})};
return WithPredicate{addTypePack({type})};
}
}
@ -3994,71 +3994,77 @@ WithPredicate<TypePackId> TypeChecker::checkExprPackHelper(const ScopePtr& scope
{
retPack = freshTypePack(free->level);
TypePackId freshArgPack = freshTypePack(free->level);
asMutable(actualFunctionType)->ty.emplace<FunctionType>(free->level, freshArgPack, retPack);
emplaceType<FunctionType>(asMutable(actualFunctionType), free->level, freshArgPack, retPack);
}
else
retPack = freshTypePack(scope->level);
// checkExpr will log the pre-instantiated type of the function.
// That's not nearly as interesting as the instantiated type, which will include details about how
// generic functions are being instantiated for this particular callsite.
currentModule->astOriginalCallTypes[expr.func] = follow(functionType);
currentModule->astTypes[expr.func] = actualFunctionType;
// We break this function up into a lambda here to limit our stack footprint.
// The vectors used by this function aren't allocated until the lambda is actually called.
auto the_rest = [&]() -> WithPredicate<TypePackId> {
// checkExpr will log the pre-instantiated type of the function.
// That's not nearly as interesting as the instantiated type, which will include details about how
// generic functions are being instantiated for this particular callsite.
currentModule->astOriginalCallTypes[expr.func] = follow(functionType);
currentModule->astTypes[expr.func] = actualFunctionType;
std::vector<TypeId> overloads = flattenIntersection(actualFunctionType);
std::vector<TypeId> overloads = flattenIntersection(actualFunctionType);
std::vector<std::optional<TypeId>> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self);
std::vector<std::optional<TypeId>> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self);
WithPredicate<TypePackId> argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes);
TypePackId argPack = argListResult.type;
WithPredicate<TypePackId> argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes);
TypePackId argPack = argListResult.type;
if (get<Unifiable::Error>(argPack))
return {errorRecoveryTypePack(scope)};
if (get<Unifiable::Error>(argPack))
return WithPredicate{errorRecoveryTypePack(scope)};
TypePack* args = nullptr;
if (expr.self)
{
argPack = addTypePack(TypePack{{selfType}, argPack});
argListResult.type = argPack;
}
args = getMutable<TypePack>(argPack);
LUAU_ASSERT(args);
TypePack* args = nullptr;
if (expr.self)
{
argPack = addTypePack(TypePack{{selfType}, argPack});
argListResult.type = argPack;
}
args = getMutable<TypePack>(argPack);
LUAU_ASSERT(args);
std::vector<Location> argLocations;
argLocations.reserve(expr.args.size + 1);
if (expr.self)
argLocations.push_back(expr.func->as<AstExprIndexName>()->expr->location);
for (AstExpr* arg : expr.args)
argLocations.push_back(arg->location);
std::vector<Location> argLocations;
argLocations.reserve(expr.args.size + 1);
if (expr.self)
argLocations.push_back(expr.func->as<AstExprIndexName>()->expr->location);
for (AstExpr* arg : expr.args)
argLocations.push_back(arg->location);
std::vector<OverloadErrorEntry> errors; // errors encountered for each overload
std::vector<OverloadErrorEntry> errors; // errors encountered for each overload
std::vector<TypeId> overloadsThatMatchArgCount;
std::vector<TypeId> overloadsThatDont;
std::vector<TypeId> overloadsThatMatchArgCount;
std::vector<TypeId> overloadsThatDont;
for (TypeId fn : overloads)
{
fn = follow(fn);
for (TypeId fn : overloads)
{
fn = follow(fn);
if (auto ret = checkCallOverload(
scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors))
return *ret;
}
if (auto ret = checkCallOverload(
scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors))
return *ret;
}
if (handleSelfCallMismatch(scope, expr, args, argLocations, errors))
return {retPack};
if (handleSelfCallMismatch(scope, expr, args, argLocations, errors))
return WithPredicate{retPack};
reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors);
reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors);
const FunctionType* overload = nullptr;
if (!overloadsThatMatchArgCount.empty())
overload = get<FunctionType>(overloadsThatMatchArgCount[0]);
if (!overload && !overloadsThatDont.empty())
overload = get<FunctionType>(overloadsThatDont[0]);
if (overload)
return {errorRecoveryTypePack(overload->retTypes)};
const FunctionType* overload = nullptr;
if (!overloadsThatMatchArgCount.empty())
overload = get<FunctionType>(overloadsThatMatchArgCount[0]);
if (!overload && !overloadsThatDont.empty())
overload = get<FunctionType>(overloadsThatDont[0]);
if (overload)
return WithPredicate{errorRecoveryTypePack(overload->retTypes)};
return {errorRecoveryTypePack(retPack)};
return WithPredicate{errorRecoveryTypePack(retPack)};
};
return the_rest();
}
std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const std::vector<TypeId>& overloads, size_t argumentCount, bool selfCall)
@ -4119,8 +4125,13 @@ std::vector<std::optional<TypeId>> TypeChecker::getExpectedTypesForCall(const st
return expectedTypes;
}
std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack,
TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult,
/*
* Note: We return a std::unique_ptr here rather than an optional to manage our stack consumption.
* If this was an optional, callers would have to pay the stack cost for the result. This is problematic
* for functions that need to support recursion up to 600 levels deep.
*/
std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn,
TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector<Location>* argLocations, const WithPredicate<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<TypeId>& overloadsThatDont, std::vector<OverloadErrorEntry>& errors)
{
LUAU_ASSERT(argLocations);
@ -4130,16 +4141,16 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
if (get<AnyType>(fn))
{
unify(anyTypePack, argPack, scope, expr.location);
return {{anyTypePack}};
return std::make_unique<WithPredicate<TypePackId>>(anyTypePack);
}
if (get<ErrorType>(fn))
{
return {{errorRecoveryTypePack(scope)}};
return std::make_unique<WithPredicate<TypePackId>>(errorRecoveryTypePack(scope));
}
if (get<NeverType>(fn))
return {{uninhabitableTypePack}};
return std::make_unique<WithPredicate<TypePackId>>(uninhabitableTypePack);
if (auto ftv = get<FreeType>(fn))
{
@ -4152,7 +4163,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
options.isFunctionCall = true;
unify(r, fn, scope, expr.location, options);
return {{retPack}};
return std::make_unique<WithPredicate<TypePackId>>(retPack);
}
std::vector<Location> metaArgLocations;
@ -4191,7 +4202,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
{
reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}});
unify(errorRecoveryTypePack(scope), retPack, scope, expr.func->location);
return {{errorRecoveryTypePack(retPack)}};
return std::make_unique<WithPredicate<TypePackId>>(errorRecoveryTypePack(retPack));
}
// When this function type has magic functions and did return something, we select that overload instead.
@ -4200,7 +4211,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
{
// TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458
if (std::optional<WithPredicate<TypePackId>> ret = ftv->magicFunction(*this, scope, expr, argListResult))
return *ret;
return std::make_unique<WithPredicate<TypePackId>>(std::move(*ret));
}
Unifier state = mkUnifier(scope, expr.location);
@ -4209,7 +4220,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, /*argLocations*/ {});
if (!state.errors.empty())
{
return {};
return nullptr;
}
checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, *argLocations);
@ -4244,10 +4255,10 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
currentModule->astOverloadResolvedTypes[&expr] = fn;
// We select this overload
return {{retPack}};
return std::make_unique<WithPredicate<TypePackId>>(retPack);
}
return {};
return nullptr;
}
bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector<Location>& argLocations,
@ -4404,7 +4415,7 @@ WithPredicate<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, cons
};
if (exprs.size == 0)
return {pack};
return WithPredicate{pack};
TypePack* tp = getMutable<TypePack>(pack);
@ -4484,7 +4495,7 @@ WithPredicate<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, cons
log.commit();
if (uninhabitable)
return {uninhabitableTypePack};
return WithPredicate{uninhabitableTypePack};
return {pack, predicates};
}

View file

@ -16,11 +16,167 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false)
namespace Luau
{
namespace detail
{
bool TypeReductionMemoization::isIrreducible(TypeId ty)
{
ty = follow(ty);
// Only does shallow check, the TypeReducer itself already does deep traversal.
if (auto edge = types.find(ty); edge && edge->irreducible)
return true;
else if (get<FreeType>(ty) || get<BlockedType>(ty) || get<PendingExpansionType>(ty))
return false;
else if (auto tt = get<TableType>(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed))
return false;
else
return true;
}
bool TypeReductionMemoization::isIrreducible(TypePackId tp)
{
tp = follow(tp);
// Only does shallow check, the TypeReducer itself already does deep traversal.
if (auto edge = typePacks.find(tp); edge && edge->irreducible)
return true;
else if (get<FreeTypePack>(tp) || get<BlockedTypePack>(tp))
return false;
else if (auto vtp = get<VariadicTypePack>(tp))
return isIrreducible(vtp->ty);
else
return true;
}
TypeId TypeReductionMemoization::memoize(TypeId ty, TypeId reducedTy)
{
ty = follow(ty);
reducedTy = follow(reducedTy);
// The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible.
// We don't need to recurse much further than that, because we already record the irreducibility from
// the bottom up.
bool irreducible = isIrreducible(reducedTy);
if (auto it = get<IntersectionType>(reducedTy))
{
for (TypeId part : it)
irreducible &= isIrreducible(part);
}
else if (auto ut = get<UnionType>(reducedTy))
{
for (TypeId option : ut)
irreducible &= isIrreducible(option);
}
else if (auto tt = get<TableType>(reducedTy))
{
for (auto& [k, p] : tt->props)
irreducible &= isIrreducible(p.type);
if (tt->indexer)
{
irreducible &= isIrreducible(tt->indexer->indexType);
irreducible &= isIrreducible(tt->indexer->indexResultType);
}
for (auto ta : tt->instantiatedTypeParams)
irreducible &= isIrreducible(ta);
for (auto tpa : tt->instantiatedTypePackParams)
irreducible &= isIrreducible(tpa);
}
else if (auto mt = get<MetatableType>(reducedTy))
{
irreducible &= isIrreducible(mt->table);
irreducible &= isIrreducible(mt->metatable);
}
else if (auto ft = get<FunctionType>(reducedTy))
{
irreducible &= isIrreducible(ft->argTypes);
irreducible &= isIrreducible(ft->retTypes);
}
else if (auto nt = get<NegationType>(reducedTy))
irreducible &= isIrreducible(nt->ty);
types[ty] = {reducedTy, irreducible};
types[reducedTy] = {reducedTy, irreducible};
return reducedTy;
}
TypePackId TypeReductionMemoization::memoize(TypePackId tp, TypePackId reducedTp)
{
tp = follow(tp);
reducedTp = follow(reducedTp);
bool irreducible = isIrreducible(reducedTp);
TypePackIterator it = begin(tp);
while (it != end(tp))
{
irreducible &= isIrreducible(*it);
++it;
}
if (it.tail())
irreducible &= isIrreducible(*it.tail());
typePacks[tp] = {reducedTp, irreducible};
typePacks[reducedTp] = {reducedTp, irreducible};
return reducedTp;
}
std::optional<ReductionEdge<TypeId>> TypeReductionMemoization::memoizedof(TypeId ty) const
{
auto fetchContext = [this](TypeId ty) -> std::optional<ReductionEdge<TypeId>> {
if (auto edge = types.find(ty))
return *edge;
else
return std::nullopt;
};
TypeId currentTy = ty;
std::optional<ReductionEdge<TypeId>> lastEdge;
while (auto edge = fetchContext(currentTy))
{
lastEdge = edge;
if (edge->irreducible)
return edge;
else if (edge->type == currentTy)
return edge;
else
currentTy = edge->type;
}
return lastEdge;
}
std::optional<ReductionEdge<TypePackId>> TypeReductionMemoization::memoizedof(TypePackId tp) const
{
auto fetchContext = [this](TypePackId tp) -> std::optional<ReductionEdge<TypePackId>> {
if (auto edge = typePacks.find(tp))
return *edge;
else
return std::nullopt;
};
TypePackId currentTp = tp;
std::optional<ReductionEdge<TypePackId>> lastEdge;
while (auto edge = fetchContext(currentTp))
{
lastEdge = edge;
if (edge->irreducible)
return edge;
else if (edge->type == currentTp)
return edge;
else
currentTp = edge->type;
}
return lastEdge;
}
} // namespace detail
namespace
{
using detail::ReductionContext;
template<typename A, typename B, typename Thing>
std::pair<const A*, const B*> get2(const Thing& one, const Thing& two)
{
@ -34,9 +190,7 @@ struct TypeReducer
NotNull<TypeArena> arena;
NotNull<BuiltinTypes> builtinTypes;
NotNull<InternalErrorReporter> handle;
DenseHashMap<TypeId, ReductionContext<TypeId>>* memoizedTypes;
DenseHashMap<TypePackId, ReductionContext<TypePackId>>* memoizedTypePacks;
NotNull<detail::TypeReductionMemoization> memoization;
DenseHashSet<const void*>* cyclics;
int depth = 0;
@ -50,12 +204,6 @@ struct TypeReducer
TypeId functionType(TypeId ty);
TypeId negationType(TypeId ty);
bool isIrreducible(TypeId ty);
bool isIrreducible(TypePackId tp);
TypeId memoize(TypeId ty, TypeId reducedTy);
TypePackId memoize(TypePackId tp, TypePackId reducedTp);
using BinaryFold = std::optional<TypeId> (TypeReducer::*)(TypeId, TypeId);
using UnaryFold = TypeId (TypeReducer::*)(TypeId);
@ -64,12 +212,15 @@ struct TypeReducer
{
ty = follow(ty);
if (auto ctx = memoizedTypes->find(ty))
return {ctx->type, getMutable<T>(ctx->type)};
if (auto edge = memoization->memoizedof(ty))
return {edge->type, getMutable<T>(edge->type)};
// We specifically do not want to use [`detail::TypeReductionMemoization::memoize`] because that will
// potentially consider these copiedTy to be reducible, but we need this to resolve cyclic references
// without attempting to recursively reduce it, causing copies of copies of copies of...
TypeId copiedTy = arena->addType(*t);
(*memoizedTypes)[ty] = {copiedTy, true};
(*memoizedTypes)[copiedTy] = {copiedTy, true};
memoization->types[ty] = {copiedTy, true};
memoization->types[copiedTy] = {copiedTy, true};
return {copiedTy, getMutable<T>(copiedTy)};
}
@ -175,8 +326,13 @@ TypeId TypeReducer::reduce(TypeId ty)
{
ty = follow(ty);
if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible)
return ctx->type;
if (auto edge = memoization->memoizedof(ty))
{
if (edge->irreducible)
return edge->type;
else
ty = edge->type;
}
else if (cyclics->contains(ty))
return ty;
@ -196,15 +352,20 @@ TypeId TypeReducer::reduce(TypeId ty)
else
result = ty;
return memoize(ty, result);
return memoization->memoize(ty, result);
}
TypePackId TypeReducer::reduce(TypePackId tp)
{
tp = follow(tp);
if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible)
return ctx->type;
if (auto edge = memoization->memoizedof(tp))
{
if (edge->irreducible)
return edge->type;
else
tp = edge->type;
}
else if (cyclics->contains(tp))
return tp;
@ -237,11 +398,11 @@ TypePackId TypeReducer::reduce(TypePackId tp)
}
if (!didReduce)
return memoize(tp, tp);
return memoization->memoize(tp, tp);
else if (head.empty() && tail)
return memoize(tp, *tail);
return memoization->memoize(tp, *tail);
else
return memoize(tp, arena->addTypePack(TypePack{std::move(head), tail}));
return memoization->memoize(tp, arena->addTypePack(TypePack{std::move(head), tail}));
}
std::optional<TypeId> TypeReducer::intersectionType(TypeId left, TypeId right)
@ -832,111 +993,6 @@ TypeId TypeReducer::negationType(TypeId ty)
return ty; // for all T except the ones handled above, ~T ~ ~T
}
bool TypeReducer::isIrreducible(TypeId ty)
{
ty = follow(ty);
// Only does shallow check, the TypeReducer itself already does deep traversal.
if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible)
return true;
else if (get<FreeType>(ty) || get<BlockedType>(ty) || get<PendingExpansionType>(ty))
return false;
else if (auto tt = get<TableType>(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed))
return false;
else
return true;
}
bool TypeReducer::isIrreducible(TypePackId tp)
{
tp = follow(tp);
// Only does shallow check, the TypeReducer itself already does deep traversal.
if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible)
return true;
else if (get<FreeTypePack>(tp) || get<BlockedTypePack>(tp))
return false;
else if (auto vtp = get<VariadicTypePack>(tp))
return isIrreducible(vtp->ty);
else
return true;
}
TypeId TypeReducer::memoize(TypeId ty, TypeId reducedTy)
{
ty = follow(ty);
reducedTy = follow(reducedTy);
// The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible.
// We don't need to recurse much further than that, because we already record the irreducibility from
// the bottom up.
bool irreducible = isIrreducible(reducedTy);
if (auto it = get<IntersectionType>(reducedTy))
{
for (TypeId part : it)
irreducible &= isIrreducible(part);
}
else if (auto ut = get<UnionType>(reducedTy))
{
for (TypeId option : ut)
irreducible &= isIrreducible(option);
}
else if (auto tt = get<TableType>(reducedTy))
{
for (auto& [k, p] : tt->props)
irreducible &= isIrreducible(p.type);
if (tt->indexer)
{
irreducible &= isIrreducible(tt->indexer->indexType);
irreducible &= isIrreducible(tt->indexer->indexResultType);
}
for (auto ta : tt->instantiatedTypeParams)
irreducible &= isIrreducible(ta);
for (auto tpa : tt->instantiatedTypePackParams)
irreducible &= isIrreducible(tpa);
}
else if (auto mt = get<MetatableType>(reducedTy))
{
irreducible &= isIrreducible(mt->table);
irreducible &= isIrreducible(mt->metatable);
}
else if (auto ft = get<FunctionType>(reducedTy))
{
irreducible &= isIrreducible(ft->argTypes);
irreducible &= isIrreducible(ft->retTypes);
}
else if (auto nt = get<NegationType>(reducedTy))
irreducible &= isIrreducible(nt->ty);
(*memoizedTypes)[ty] = {reducedTy, irreducible};
(*memoizedTypes)[reducedTy] = {reducedTy, irreducible};
return reducedTy;
}
TypePackId TypeReducer::memoize(TypePackId tp, TypePackId reducedTp)
{
tp = follow(tp);
reducedTp = follow(reducedTp);
bool irreducible = isIrreducible(reducedTp);
TypePackIterator it = begin(tp);
while (it != end(tp))
{
irreducible &= isIrreducible(*it);
++it;
}
if (it.tail())
irreducible &= isIrreducible(*it.tail());
(*memoizedTypePacks)[tp] = {reducedTp, irreducible};
(*memoizedTypePacks)[reducedTp] = {reducedTp, irreducible};
return reducedTp;
}
struct MarkCycles : TypeVisitor
{
DenseHashSet<const void*> cyclics{nullptr};
@ -961,7 +1017,6 @@ struct MarkCycles : TypeVisitor
return !cyclics.find(follow(tp));
}
};
} // namespace
TypeReduction::TypeReduction(
@ -981,8 +1036,13 @@ std::optional<TypeId> TypeReduction::reduce(TypeId ty)
return ty;
else if (!options.allowTypeReductionsFromOtherArenas && ty->owningArena != arena)
return ty;
else if (auto memoized = memoizedof(ty))
return *memoized;
else if (auto edge = memoization.memoizedof(ty))
{
if (edge->irreducible)
return edge->type;
else
ty = edge->type;
}
else if (hasExceededCartesianProductLimit(ty))
return std::nullopt;
@ -991,7 +1051,7 @@ std::optional<TypeId> TypeReduction::reduce(TypeId ty)
MarkCycles finder;
finder.traverse(ty);
TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclics};
TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics};
return reducer.reduce(ty);
}
catch (const RecursionLimitException&)
@ -1008,8 +1068,13 @@ std::optional<TypePackId> TypeReduction::reduce(TypePackId tp)
return tp;
else if (!options.allowTypeReductionsFromOtherArenas && tp->owningArena != arena)
return tp;
else if (auto memoized = memoizedof(tp))
return *memoized;
else if (auto edge = memoization.memoizedof(tp))
{
if (edge->irreducible)
return edge->type;
else
tp = edge->type;
}
else if (hasExceededCartesianProductLimit(tp))
return std::nullopt;
@ -1018,7 +1083,7 @@ std::optional<TypePackId> TypeReduction::reduce(TypePackId tp)
MarkCycles finder;
finder.traverse(tp);
TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclics};
TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics};
return reducer.reduce(tp);
}
catch (const RecursionLimitException&)
@ -1039,13 +1104,6 @@ std::optional<TypeFun> TypeReduction::reduce(const TypeFun& fun)
return std::nullopt;
}
TypeReduction TypeReduction::fork(NotNull<TypeArena> arena, const TypeReductionOptions& opts) const
{
TypeReduction child{arena, builtinTypes, handle, opts};
child.parent = this;
return child;
}
size_t TypeReduction::cartesianProductSize(TypeId ty) const
{
ty = follow(ty);
@ -1093,24 +1151,4 @@ bool TypeReduction::hasExceededCartesianProductLimit(TypePackId tp) const
return false;
}
std::optional<TypeId> TypeReduction::memoizedof(TypeId ty) const
{
if (auto ctx = memoizedTypes.find(ty); ctx && ctx->irreducible)
return ctx->type;
else if (parent)
return parent->memoizedof(ty);
else
return std::nullopt;
}
std::optional<TypePackId> TypeReduction::memoizedof(TypePackId tp) const
{
if (auto ctx = memoizedTypePacks.find(tp); ctx && ctx->irreducible)
return ctx->type;
else if (parent)
return parent->memoizedof(tp);
else
return std::nullopt;
}
} // namespace Luau

View file

@ -520,7 +520,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
size_t errorCount = errors.size();
if (const UnionType* subUnion = log.getMutable<UnionType>(subTy))
if (log.getMutable<BlockedType>(subTy) && log.getMutable<BlockedType>(superTy))
{
blockedTypes.push_back(subTy);
blockedTypes.push_back(superTy);
}
else if (const UnionType* subUnion = log.getMutable<UnionType>(subTy))
{
tryUnifyUnionWithType(subTy, subUnion, superTy);
}

View file

@ -487,7 +487,7 @@ int main(int argc, char** argv)
if (args.size() < 4)
help(args);
for (int i = 1; i < args.size(); ++i)
for (size_t i = 1; i < args.size(); ++i)
{
if (args[i] == "--help")
help(args);

View file

@ -42,6 +42,7 @@ struct IrBuilder
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e);
IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f);
IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence
IrOp blockAtInst(uint32_t index);
@ -57,6 +58,8 @@ struct IrBuilder
IrFunction function;
uint32_t activeBlockIdx = ~0u;
std::vector<uint32_t> instIndexToBlock; // Block index at the bytecode instruction
};

View file

@ -5,6 +5,7 @@
#include "Luau/RegisterX64.h"
#include "Luau/RegisterA64.h"
#include <optional>
#include <vector>
#include <stdint.h>
@ -186,6 +187,16 @@ enum class IrCmd : uint8_t
// A: int
INT_TO_NUM,
// Adjust stack top (L->top) to point at 'B' TValues *after* the specified register
// This is used to return muliple values
// A: Rn
// B: int (offset)
ADJUST_STACK_TO_REG,
// Restore stack top (L->top) to point to the function stack top (L->ci->top)
// This is used to recover after calling a variadic function
ADJUST_STACK_TO_TOP,
// Fallback functions
// Perform an arithmetic operation on TValues of any type
@ -329,7 +340,7 @@ enum class IrCmd : uint8_t
// Call specified function
// A: unsigned int (bytecode instruction index)
// B: Rn (function, followed by arguments)
// C: int (argument count or -1 to preserve all arguments up to stack top)
// C: int (argument count or -1 to use all arguments up to stack top)
// D: int (result count or -1 to preserve all results and adjust stack top)
// Note: return values are placed starting from Rn specified in 'B'
LOP_CALL,
@ -337,13 +348,13 @@ enum class IrCmd : uint8_t
// Return specified values from the function
// A: unsigned int (bytecode instruction index)
// B: Rn (value start)
// B: int (result count or -1 to return all values up to stack top)
// C: int (result count or -1 to return all values up to stack top)
LOP_RETURN,
// Perform a fast call of a built-in function
// A: unsigned int (bytecode instruction index)
// B: Rn (argument start)
// C: int (argument count or -1 preserve all arguments up to stack top)
// C: int (argument count or -1 use all arguments up to stack top)
// D: block (fallback)
// Note: return values are placed starting from Rn specified in 'B'
LOP_FASTCALL,
@ -560,6 +571,7 @@ struct IrInst
IrOp c;
IrOp d;
IrOp e;
IrOp f;
uint32_t lastUse = 0;
uint16_t useCount = 0;
@ -584,9 +596,10 @@ struct IrBlock
uint16_t useCount = 0;
// Start points to an instruction index in a stream
// End is implicit
// 'start' and 'finish' define an inclusive range of instructions which belong to this block inside the function
// When block has been constructed, 'finish' always points to the first and only terminating instruction
uint32_t start = ~0u;
uint32_t finish = ~0u;
Label label;
};
@ -633,6 +646,19 @@ struct IrFunction
return value.valueTag;
}
std::optional<uint8_t> asTagOp(IrOp op)
{
if (op.kind != IrOpKind::Constant)
return std::nullopt;
IrConst& value = constOp(op);
if (value.kind != IrConstKind::Tag)
return std::nullopt;
return value.valueTag;
}
bool boolOp(IrOp op)
{
IrConst& value = constOp(op);
@ -641,6 +667,19 @@ struct IrFunction
return value.valueBool;
}
std::optional<bool> asBoolOp(IrOp op)
{
if (op.kind != IrOpKind::Constant)
return std::nullopt;
IrConst& value = constOp(op);
if (value.kind != IrConstKind::Bool)
return std::nullopt;
return value.valueBool;
}
int intOp(IrOp op)
{
IrConst& value = constOp(op);
@ -649,6 +688,19 @@ struct IrFunction
return value.valueInt;
}
std::optional<int> asIntOp(IrOp op)
{
if (op.kind != IrOpKind::Constant)
return std::nullopt;
IrConst& value = constOp(op);
if (value.kind != IrConstKind::Int)
return std::nullopt;
return value.valueInt;
}
unsigned uintOp(IrOp op)
{
IrConst& value = constOp(op);
@ -657,6 +709,19 @@ struct IrFunction
return value.valueUint;
}
std::optional<unsigned> asUintOp(IrOp op)
{
if (op.kind != IrOpKind::Constant)
return std::nullopt;
IrConst& value = constOp(op);
if (value.kind != IrConstKind::Uint)
return std::nullopt;
return value.valueUint;
}
double doubleOp(IrOp op)
{
IrConst& value = constOp(op);
@ -665,11 +730,31 @@ struct IrFunction
return value.valueDouble;
}
std::optional<double> asDoubleOp(IrOp op)
{
if (op.kind != IrOpKind::Constant)
return std::nullopt;
IrConst& value = constOp(op);
if (value.kind != IrConstKind::Double)
return std::nullopt;
return value.valueDouble;
}
IrCondition conditionOp(IrOp op)
{
LUAU_ASSERT(op.kind == IrOpKind::Condition);
return IrCondition(op.index);
}
uint32_t getBlockIndex(const IrBlock& block)
{
// Can only be called with blocks from our vector
LUAU_ASSERT(&block >= blocks.data() && &block <= blocks.data() + blocks.size());
return uint32_t(&block - blocks.data());
}
};
} // namespace CodeGen

View file

@ -162,6 +162,8 @@ inline bool isPseudo(IrCmd cmd)
return cmd == IrCmd::NOP || cmd == IrCmd::SUBSTITUTE;
}
bool isGCO(uint8_t tag);
// Remove a single instruction
void kill(IrFunction& function, IrInst& inst);
@ -179,7 +181,7 @@ void replace(IrFunction& function, IrOp& original, IrOp replacement);
// Replace a single instruction
// Target instruction index instead of reference is used to handle introduction of a new block terminator
void replace(IrFunction& function, uint32_t instIdx, IrInst replacement);
void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst replacement);
// Replace instruction with a different value (using IrCmd::SUBSTITUTE)
void substitute(IrFunction& function, IrInst& inst, IrOp replacement);
@ -188,10 +190,13 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement);
void applySubstitutions(IrFunction& function, IrOp& op);
void applySubstitutions(IrFunction& function, IrInst& inst);
// Compare numbers using IR condition value
bool compare(double a, double b, IrCondition cond);
// Perform constant folding on instruction at index
// For most instructions, successful folding results in a IrCmd::SUBSTITUTE
// But it can also be successful on conditional control-flow, replacing it with an unconditional IrCmd::JUMP
void foldConstants(IrBuilder& build, IrFunction& function, uint32_t instIdx);
void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t instIdx);
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,16 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/IrData.h"
namespace Luau
{
namespace CodeGen
{
struct IrBuilder;
void constPropInBlockChains(IrBuilder& build);
} // namespace CodeGen
} // namespace Luau

View file

@ -7,6 +7,7 @@
#include "Luau/CodeBlockUnwind.h"
#include "Luau/IrAnalysis.h"
#include "Luau/IrBuilder.h"
#include "Luau/OptimizeConstProp.h"
#include "Luau/OptimizeFinalX64.h"
#include "Luau/UnwindBuilder.h"
#include "Luau/UnwindBuilderDwarf2.h"
@ -31,7 +32,7 @@
#endif
#endif
LUAU_FASTFLAGVARIABLE(DebugUseOldCodegen, false)
LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false)
namespace Luau
{
@ -40,12 +41,6 @@ namespace CodeGen
constexpr uint32_t kFunctionAlignment = 32;
struct InstructionOutline
{
int pcpos;
int length;
};
static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers)
{
if (build.logText)
@ -64,346 +59,6 @@ static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers)
emitContinueCallInVm(build);
}
static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i,
Label* labelarr, Label& next, Label& fallback)
{
int skip = 0;
switch (op)
{
case LOP_NOP:
break;
case LOP_LOADNIL:
emitInstLoadNil(build, pc);
break;
case LOP_LOADB:
emitInstLoadB(build, pc, i, labelarr);
break;
case LOP_LOADN:
emitInstLoadN(build, pc);
break;
case LOP_LOADK:
emitInstLoadK(build, pc);
break;
case LOP_LOADKX:
emitInstLoadKX(build, pc);
break;
case LOP_MOVE:
emitInstMove(build, pc);
break;
case LOP_GETGLOBAL:
emitInstGetGlobal(build, pc, i, fallback);
break;
case LOP_SETGLOBAL:
emitInstSetGlobal(build, pc, i, next, fallback);
break;
case LOP_NAMECALL:
emitInstNameCall(build, pc, i, proto->k, next, fallback);
break;
case LOP_CALL:
emitInstCall(build, helpers, pc, i);
break;
case LOP_RETURN:
emitInstReturn(build, helpers, pc, i);
break;
case LOP_GETTABLE:
emitInstGetTable(build, pc, fallback);
break;
case LOP_SETTABLE:
emitInstSetTable(build, pc, next, fallback);
break;
case LOP_GETTABLEKS:
emitInstGetTableKS(build, pc, i, fallback);
break;
case LOP_SETTABLEKS:
emitInstSetTableKS(build, pc, i, next, fallback);
break;
case LOP_GETTABLEN:
emitInstGetTableN(build, pc, fallback);
break;
case LOP_SETTABLEN:
emitInstSetTableN(build, pc, next, fallback);
break;
case LOP_JUMP:
emitInstJump(build, pc, i, labelarr);
break;
case LOP_JUMPBACK:
emitInstJumpBack(build, pc, i, labelarr);
break;
case LOP_JUMPIF:
emitInstJumpIf(build, pc, i, labelarr, /* not_ */ false);
break;
case LOP_JUMPIFNOT:
emitInstJumpIf(build, pc, i, labelarr, /* not_ */ true);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::LessEqual, fallback);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::Less, fallback);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLessEqual, fallback);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLess, fallback);
break;
case LOP_JUMPX:
emitInstJumpX(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKNIL:
emitInstJumpxEqNil(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKB:
emitInstJumpxEqB(build, pc, i, labelarr);
break;
case LOP_JUMPXEQKN:
emitInstJumpxEqN(build, pc, proto->k, i, labelarr);
break;
case LOP_JUMPXEQKS:
emitInstJumpxEqS(build, pc, i, labelarr);
break;
case LOP_ADD:
emitInstBinary(build, pc, TM_ADD, fallback);
break;
case LOP_SUB:
emitInstBinary(build, pc, TM_SUB, fallback);
break;
case LOP_MUL:
emitInstBinary(build, pc, TM_MUL, fallback);
break;
case LOP_DIV:
emitInstBinary(build, pc, TM_DIV, fallback);
break;
case LOP_MOD:
emitInstBinary(build, pc, TM_MOD, fallback);
break;
case LOP_POW:
emitInstBinary(build, pc, TM_POW, fallback);
break;
case LOP_ADDK:
emitInstBinaryK(build, pc, TM_ADD, fallback);
break;
case LOP_SUBK:
emitInstBinaryK(build, pc, TM_SUB, fallback);
break;
case LOP_MULK:
emitInstBinaryK(build, pc, TM_MUL, fallback);
break;
case LOP_DIVK:
emitInstBinaryK(build, pc, TM_DIV, fallback);
break;
case LOP_MODK:
emitInstBinaryK(build, pc, TM_MOD, fallback);
break;
case LOP_POWK:
emitInstPowK(build, pc, proto->k, fallback);
break;
case LOP_NOT:
emitInstNot(build, pc);
break;
case LOP_MINUS:
emitInstMinus(build, pc, fallback);
break;
case LOP_LENGTH:
emitInstLength(build, pc, fallback);
break;
case LOP_NEWTABLE:
emitInstNewTable(build, pc, i, next);
break;
case LOP_DUPTABLE:
emitInstDupTable(build, pc, i, next);
break;
case LOP_SETLIST:
emitInstSetList(build, pc, next);
break;
case LOP_GETUPVAL:
emitInstGetUpval(build, pc);
break;
case LOP_SETUPVAL:
emitInstSetUpval(build, pc, next);
break;
case LOP_CLOSEUPVALS:
emitInstCloseUpvals(build, pc, next);
break;
case LOP_FASTCALL:
// We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1
skip = emitInstFastCall(build, pc, i, next) + 1;
break;
case LOP_FASTCALL1:
// We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1
skip = emitInstFastCall1(build, pc, i, next) + 1;
break;
case LOP_FASTCALL2:
skip = emitInstFastCall2(build, pc, i, next);
break;
case LOP_FASTCALL2K:
skip = emitInstFastCall2K(build, pc, i, next);
break;
case LOP_FORNPREP:
emitInstForNPrep(build, pc, i, next, labelarr[i + 1 + LUAU_INSN_D(*pc)]);
break;
case LOP_FORNLOOP:
emitInstForNLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next);
break;
case LOP_FORGLOOP:
emitinstForGLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next, fallback);
break;
case LOP_FORGPREP_NEXT:
emitInstForGPrepNext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback);
break;
case LOP_FORGPREP_INEXT:
emitInstForGPrepInext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback);
break;
case LOP_AND:
emitInstAnd(build, pc);
break;
case LOP_ANDK:
emitInstAndK(build, pc);
break;
case LOP_OR:
emitInstOr(build, pc);
break;
case LOP_ORK:
emitInstOrK(build, pc);
break;
case LOP_GETIMPORT:
emitInstGetImport(build, pc, fallback);
break;
case LOP_CONCAT:
emitInstConcat(build, pc, i, next);
break;
case LOP_COVERAGE:
emitInstCoverage(build, i);
break;
default:
emitFallback(build, data, op, i);
break;
}
return skip;
}
static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauOpcode op, const Instruction* pc, int i, Label* labelarr)
{
switch (op)
{
case LOP_GETIMPORT:
emitSetSavedPc(build, i + 1);
emitInstGetImportFallback(build, LUAU_INSN_A(*pc), pc[1]);
break;
case LOP_GETTABLE:
emitInstGetTableFallback(build, pc, i);
break;
case LOP_SETTABLE:
emitInstSetTableFallback(build, pc, i);
break;
case LOP_GETTABLEN:
emitInstGetTableNFallback(build, pc, i);
break;
case LOP_SETTABLEN:
emitInstSetTableNFallback(build, pc, i);
break;
case LOP_NAMECALL:
// TODO: fast-paths that we've handled can be removed from the fallback
emitFallback(build, data, op, i);
break;
case LOP_JUMPIFEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false);
break;
case LOP_JUMPIFLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::LessEqual);
break;
case LOP_JUMPIFLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::Less);
break;
case LOP_JUMPIFNOTEQ:
emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true);
break;
case LOP_JUMPIFNOTLE:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLessEqual);
break;
case LOP_JUMPIFNOTLT:
emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLess);
break;
case LOP_ADD:
emitInstBinaryFallback(build, pc, i, TM_ADD);
break;
case LOP_SUB:
emitInstBinaryFallback(build, pc, i, TM_SUB);
break;
case LOP_MUL:
emitInstBinaryFallback(build, pc, i, TM_MUL);
break;
case LOP_DIV:
emitInstBinaryFallback(build, pc, i, TM_DIV);
break;
case LOP_MOD:
emitInstBinaryFallback(build, pc, i, TM_MOD);
break;
case LOP_POW:
emitInstBinaryFallback(build, pc, i, TM_POW);
break;
case LOP_ADDK:
emitInstBinaryKFallback(build, pc, i, TM_ADD);
break;
case LOP_SUBK:
emitInstBinaryKFallback(build, pc, i, TM_SUB);
break;
case LOP_MULK:
emitInstBinaryKFallback(build, pc, i, TM_MUL);
break;
case LOP_DIVK:
emitInstBinaryKFallback(build, pc, i, TM_DIV);
break;
case LOP_MODK:
emitInstBinaryKFallback(build, pc, i, TM_MOD);
break;
case LOP_POWK:
emitInstBinaryKFallback(build, pc, i, TM_POW);
break;
case LOP_MINUS:
emitInstMinusFallback(build, pc, i);
break;
case LOP_LENGTH:
emitInstLengthFallback(build, pc, i);
break;
case LOP_FORGLOOP:
emitinstForGLoopFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]);
break;
case LOP_FORGPREP_NEXT:
case LOP_FORGPREP_INEXT:
emitInstForGPrepXnextFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]);
break;
case LOP_GETGLOBAL:
// TODO: luaV_gettable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETGLOBAL:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
case LOP_GETTABLEKS:
// Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access
// It is also required to perform cached slot update
// TODO: extra fast-paths could be lowered before the full fallback
emitFallback(build, data, op, i);
break;
case LOP_SETTABLEKS:
// TODO: luaV_settable + cachedslot update instead of full fallback
emitFallback(build, data, op, i);
break;
default:
LUAU_ASSERT(!"Expected fallback for instruction");
}
}
static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options)
{
NativeProto* result = new NativeProto();
@ -423,153 +78,32 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat
build.logAppend("\n");
}
if (!FFlag::DebugUseOldCodegen)
{
build.align(kFunctionAlignment, AlignmentDataX64::Ud2);
Label start = build.setLabel();
IrBuilder builder;
builder.buildFunctionIr(proto);
optimizeMemoryOperandsX64(builder.function);
IrLoweringX64 lowering(build, helpers, data, proto, builder.function);
lowering.lower(options);
result->instTargets = new uintptr_t[proto->sizecode];
for (int i = 0; i < proto->sizecode; i++)
{
auto [irLocation, asmLocation] = builder.function.bcMapping[i];
result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation - start.location;
}
result->location = start.location;
if (build.logText)
build.logAppend("\n");
return result;
}
std::vector<Label> instLabels;
instLabels.resize(proto->sizecode);
std::vector<Label> instFallbacks;
instFallbacks.resize(proto->sizecode);
std::vector<InstructionOutline> instOutlines;
instOutlines.reserve(64);
build.align(kFunctionAlignment, AlignmentDataX64::Ud2);
Label start = build.setLabel();
for (int i = 0; i < proto->sizecode;)
IrBuilder builder;
builder.buildFunctionIr(proto);
if (!FFlag::DebugCodegenNoOpt)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
build.setLabel(instLabels[i]);
if (options.annotator)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i);
Label& next = nexti < proto->sizecode ? instLabels[nexti] : start; // Last instruction can't use 'next' label
int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), next, instFallbacks[i]);
if (skip != 0)
instOutlines.push_back({nexti, skip});
i = nexti + skip;
LUAU_ASSERT(i <= proto->sizecode);
constPropInBlockChains(builder);
}
size_t textSize = build.text.size();
uint32_t codeSize = build.getCodeSize();
optimizeMemoryOperandsX64(builder.function);
if (options.annotator && options.includeOutlinedCode)
build.logAppend("; outlined instructions\n");
IrLoweringX64 lowering(build, helpers, data, proto, builder.function);
for (auto [pcpos, length] : instOutlines)
{
int i = pcpos;
while (i < pcpos + length)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
build.setLabel(instLabels[i]);
if (options.annotator && options.includeOutlinedCode)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, i);
Label& next = nexti < proto->sizecode ? instLabels[nexti] : start; // Last instruction can't use 'next' label
int skip = emitInst(build, data, helpers, proto, op, pc, i, instLabels.data(), next, instFallbacks[i]);
LUAU_ASSERT(skip == 0);
i = nexti;
}
if (i < proto->sizecode)
build.jmp(instLabels[i]);
}
if (options.annotator && options.includeOutlinedCode)
build.logAppend("; outlined code\n");
for (int i = 0, instid = 0; i < proto->sizecode; ++instid)
{
const Instruction* pc = &proto->code[i];
LuauOpcode op = LuauOpcode(LUAU_INSN_OP(*pc));
int nexti = i + getOpLength(op);
LUAU_ASSERT(nexti <= proto->sizecode);
if (instFallbacks[i].id == 0)
{
i = nexti;
continue;
}
if (options.annotator && options.includeOutlinedCode)
options.annotator(options.annotatorContext, build.text, proto->bytecodeid, instid);
build.setLabel(instFallbacks[i]);
emitInstFallback(build, data, op, pc, i, instLabels.data());
// Jump back to the next instruction handler
if (nexti < proto->sizecode)
build.jmp(instLabels[nexti]);
i = nexti;
}
// Truncate assembly output if we don't care for outlined code part
if (!options.includeOutlinedCode)
{
build.text.resize(textSize);
build.logAppend("; skipping %u bytes of outlined code\n", build.getCodeSize() - codeSize);
}
lowering.lower(options);
result->instTargets = new uintptr_t[proto->sizecode];
for (int i = 0; i < proto->sizecode; i++)
result->instTargets[i] = instLabels[i].location - start.location;
{
auto [irLocation, asmLocation] = builder.function.bcMapping[i];
result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation - start.location;
}
result->location = start.location;

View file

@ -5,33 +5,18 @@
#include "Luau/Bytecode.h"
#include "EmitCommonX64.h"
#include "IrTranslateBuiltins.h" // Used temporarily for shared definition of BuiltinImplResult
#include "NativeState.h"
#include "lstate.h"
// TODO: LBF_MATH_FREXP and LBF_MATH_MODF can work for 1 result case if second store is removed
namespace Luau
{
namespace CodeGen
{
BuiltinImplResult emitBuiltinAssert(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback)
{
if (nparams < 1 || nresults != 0)
return {BuiltinImplType::None, -1};
if (build.logText)
build.logAppend("; inlined LBF_ASSERT\n");
Label skip;
jumpIfFalsy(build, arg, fallback, skip);
// TODO: use of 'skip' causes a jump to a jump instruction that skips the fallback - can be optimized
build.setLabel(skip);
return {BuiltinImplType::UsesFallback, 0};
}
BuiltinImplResult emitBuiltinMathFloor(AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback)
{
if (nparams < 1 || nresults > 1)
@ -620,7 +605,8 @@ BuiltinImplResult emitBuiltin(AssemblyBuilderX64& build, int bfid, int nparams,
switch (bfid)
{
case LBF_ASSERT:
return emitBuiltinAssert(build, nparams, ra, arg, args, nresults, fallback);
// This builtin fast-path was already translated to IR
return {BuiltinImplType::None, -1};
case LBF_MATH_FLOOR:
return emitBuiltinMathFloor(build, nparams, ra, arg, args, nresults, fallback);
case LBF_MATH_CEIL:

View file

@ -9,18 +9,7 @@ namespace CodeGen
class AssemblyBuilderX64;
struct Label;
struct OperandX64;
enum class BuiltinImplType
{
None,
UsesFallback, // Uses fallback for unsupported cases
};
struct BuiltinImplResult
{
BuiltinImplType type;
int actualResultCount;
};
struct BuiltinImplResult;
BuiltinImplResult emitBuiltin(AssemblyBuilderX64& build, int bfid, int nparams, int ra, int arg, OperandX64 args, int nresults, Label& fallback);

View file

@ -151,14 +151,6 @@ inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, Opera
build.vmovups(luauReg(ri), tmp);
}
inline void setNodeValue(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 op, int ri)
{
LUAU_ASSERT(op.cat == CategoryX64::mem);
build.vmovups(tmp, luauReg(ri));
build.vmovups(op, tmp);
}
inline void jumpIfTagIs(AssemblyBuilderX64& build, int ri, lua_Type tag, Label& label)
{
build.cmp(luauRegTag(ri), tag);

View file

@ -7,6 +7,7 @@
#include "EmitBuiltinsX64.h"
#include "EmitCommonX64.h"
#include "NativeState.h"
#include "IrTranslateBuiltins.h" // Used temporarily until emitInstFastCallN is removed
#include "lobject.h"
#include "ltm.h"
@ -16,59 +17,6 @@ namespace Luau
namespace CodeGen
{
void emitInstLoadNil(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
build.mov(luauRegTag(ra), LUA_TNIL);
}
void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
build.mov(luauRegValue(ra), LUAU_INSN_B(*pc));
build.mov(luauRegTag(ra), LUA_TBOOLEAN);
if (int target = LUAU_INSN_C(*pc))
build.jmp(labelarr[pcpos + 1 + target]);
}
void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
build.vmovsd(xmm0, build.f64(double(LUAU_INSN_D(*pc))));
build.vmovsd(luauRegValue(ra), xmm0);
build.mov(luauRegTag(ra), LUA_TNUMBER);
}
void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
build.vmovups(xmm0, luauConstant(LUAU_INSN_D(*pc)));
build.vmovups(luauReg(ra), xmm0);
}
void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
build.vmovups(xmm0, luauConstant(aux));
build.vmovups(luauReg(ra), xmm0);
}
void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
build.vmovups(xmm0, luauReg(rb));
build.vmovups(luauReg(ra), xmm0);
}
void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
@ -429,363 +377,6 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Ins
build.jmp(qword[rdx + rax * 2]);
}
void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
build.jmp(labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]);
}
void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
emitInterrupt(build, pcpos);
build.jmp(labelarr[pcpos + 1 + LUAU_INSN_D(*pc)]);
}
void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_)
{
int ra = LUAU_INSN_A(*pc);
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 1];
if (not_)
jumpIfFalsy(build, ra, target, exit);
else
jumpIfTruthy(build, ra, target, exit);
}
void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = pc[1];
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2];
build.mov(eax, luauRegTag(ra));
build.cmp(eax, luauRegTag(rb));
build.jcc(ConditionX64::NotEqual, not_ ? target : exit);
// fast-path: number
build.cmp(eax, LUA_TNUMBER);
build.jcc(ConditionX64::NotEqual, fallback);
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), ConditionX64::NotEqual, not_ ? target : exit);
if (!not_)
build.jmp(target);
}
void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_)
{
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
emitSetSavedPc(build, pcpos + 1);
jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target);
}
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = pc[1];
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
// fast-path: number
jumpIfTagIsNot(build, ra, LUA_TNUMBER, fallback);
jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), luauRegValue(rb), cond, target);
}
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond)
{
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
emitSetSavedPc(build, pcpos + 1);
jumpOnAnyCmpFallback(build, LUAU_INSN_A(*pc), pc[1], cond, target);
}
void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
emitInterrupt(build, pcpos);
build.jmp(labelarr[pcpos + 1 + LUAU_INSN_E(*pc)]);
}
void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
bool not_ = (pc[1] & 0x80000000) != 0;
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
build.cmp(luauRegTag(ra), LUA_TNIL);
build.jcc(not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target);
}
void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
bool not_ = (aux & 0x80000000) != 0;
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2];
jumpIfTagIsNot(build, ra, LUA_TBOOLEAN, not_ ? target : exit);
build.test(luauRegValueInt(ra), 1);
build.jcc((aux & 0x1) ^ not_ ? ConditionX64::NotZero : ConditionX64::Zero, target);
}
void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
bool not_ = (aux & 0x80000000) != 0;
TValue kv = k[aux & 0xffffff];
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2];
jumpIfTagIsNot(build, ra, LUA_TNUMBER, not_ ? target : exit);
if (not_)
{
jumpOnNumberCmp(build, xmm0, luauRegValue(ra), build.f64(kv.value.n), ConditionX64::NotEqual, target);
}
else
{
// Compact equality check requires two labels, so it's not supported in generic 'jumpOnNumberCmp'
build.vmovsd(xmm0, luauRegValue(ra));
build.vucomisd(xmm0, build.f64(kv.value.n));
build.jcc(ConditionX64::Parity, exit); // We first have to check PF=1 for NaN operands, because it also sets ZF=1
build.jcc(ConditionX64::Zero, target); // Now that NaN is out of the way, we can check ZF=1 for equality
}
}
void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
bool not_ = (aux & 0x80000000) != 0;
Label& target = labelarr[pcpos + 1 + LUAU_INSN_D(*pc)];
Label& exit = labelarr[pcpos + 2];
jumpIfTagIsNot(build, ra, LUA_TSTRING, not_ ? target : exit);
build.mov(rax, luauRegValue(ra));
build.cmp(rax, luauConstantValue(aux & 0xffffff));
build.jcc(not_ ? ConditionX64::NotEqual : ConditionX64::Equal, target);
}
static void emitInstBinaryNumeric(AssemblyBuilderX64& build, int ra, int rb, int rc, OperandX64 opc, TMS tm, Label& fallback)
{
jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
if (rc != -1 && rc != rb)
jumpIfTagIsNot(build, rc, LUA_TNUMBER, fallback);
// fast-path: number
build.vmovsd(xmm0, luauRegValue(rb));
switch (tm)
{
case TM_ADD:
build.vaddsd(xmm0, xmm0, opc);
break;
case TM_SUB:
build.vsubsd(xmm0, xmm0, opc);
break;
case TM_MUL:
build.vmulsd(xmm0, xmm0, opc);
break;
case TM_DIV:
build.vdivsd(xmm0, xmm0, opc);
break;
case TM_MOD:
// This follows the implementation of 'luai_nummod' which is less precise than 'fmod' for better performance
build.vmovsd(xmm1, opc);
build.vdivsd(xmm2, xmm0, xmm1);
build.vroundsd(xmm2, xmm2, xmm2, RoundingModeX64::RoundToNegativeInfinity);
build.vmulsd(xmm1, xmm2, xmm1);
build.vsubsd(xmm0, xmm0, xmm1);
break;
case TM_POW:
build.vmovsd(xmm1, luauRegValue(rc));
build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]);
break;
default:
LUAU_ASSERT(!"unsupported binary op");
}
build.vmovsd(luauRegValue(ra), xmm0);
if (ra != rb && ra != rc)
build.mov(luauRegTag(ra), LUA_TNUMBER);
}
void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback)
{
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), LUAU_INSN_C(*pc), luauRegValue(LUAU_INSN_C(*pc)), tm, fallback);
}
void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm)
{
emitSetSavedPc(build, pcpos + 1);
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), tm);
}
void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback)
{
emitInstBinaryNumeric(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), -1, luauConstantValue(LUAU_INSN_C(*pc)), tm, fallback);
}
void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm)
{
emitSetSavedPc(build, pcpos + 1);
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstantAddress(LUAU_INSN_C(*pc)), tm);
}
void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
double kv = nvalue(&k[LUAU_INSN_C(*pc)]);
jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
// fast-path: number
build.vmovsd(xmm0, luauRegValue(rb));
// Specialize for a few constants, similar to how it's done in the VM
if (kv == 2.0)
{
build.vmulsd(xmm0, xmm0, xmm0);
}
else if (kv == 0.5)
{
build.vsqrtsd(xmm0, xmm0, xmm0);
}
else if (kv == 3.0)
{
build.vmulsd(xmm1, xmm0, xmm0);
build.vmulsd(xmm0, xmm0, xmm1);
}
else
{
build.vmovsd(xmm1, build.f64(kv));
build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]);
}
build.vmovsd(luauRegValue(ra), xmm0);
if (ra != rb)
build.mov(luauRegTag(ra), LUA_TNUMBER);
}
void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
Label saveone, savezero, exit;
jumpIfFalsy(build, rb, saveone, savezero);
build.setLabel(savezero);
build.mov(luauRegValueInt(ra), 0);
build.jmp(exit);
build.setLabel(saveone);
build.mov(luauRegValueInt(ra), 1);
build.setLabel(exit);
build.mov(luauRegTag(ra), LUA_TBOOLEAN);
}
void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
jumpIfTagIsNot(build, rb, LUA_TNUMBER, fallback);
// fast-path: number
build.vxorpd(xmm0, xmm0, xmm0);
build.vsubsd(xmm0, xmm0, luauRegValue(rb));
build.vmovsd(luauRegValue(ra), xmm0);
if (ra != rb)
build.mov(luauRegTag(ra), LUA_TNUMBER);
}
void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
callArithHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_B(*pc)), TM_UNM);
}
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
// fast-path: table without __len
build.mov(rArg1, luauRegValue(rb));
jumpIfMetatablePresent(build, rArg1, fallback);
// First argument (Table*) is already in rArg1
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]);
build.vcvtsi2sd(xmm0, xmm0, eax);
build.vmovsd(luauRegValue(ra), xmm0);
build.mov(luauRegTag(ra), LUA_TNUMBER);
}
void emitInstLengthFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
callLengthHelper(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc));
}
void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next)
{
int ra = LUAU_INSN_A(*pc);
int b = LUAU_INSN_B(*pc);
uint32_t aux = pc[1];
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(dwordReg(rArg2), aux);
build.mov(dwordReg(rArg3), b == 0 ? 0 : 1 << (b - 1));
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]);
build.mov(luauRegValue(ra), rax);
build.mov(luauRegTag(ra), LUA_TTABLE);
callCheckGc(build, pcpos, /* savepc = */ false, next);
}
void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next)
{
int ra = LUAU_INSN_A(*pc);
emitSetSavedPc(build, pcpos + 1);
build.mov(rArg1, rState);
build.mov(rArg2, luauConstantValue(LUAU_INSN_D(*pc)));
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_clone)]);
build.mov(luauRegValue(ra), rax);
build.mov(luauRegTag(ra), LUA_TTABLE);
callCheckGc(build, pcpos, /* savepc= */ false, next);
}
void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next)
{
int ra = LUAU_INSN_A(*pc);
@ -890,68 +481,7 @@ void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& ne
callBarrierTableFast(build, table, next);
}
void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
int up = LUAU_INSN_B(*pc);
build.mov(rax, sClosure);
build.add(rax, offsetof(Closure, l.uprefs) + sizeof(TValue) * up);
// uprefs[] is either an actual value, or it points to UpVal object which has a pointer to value
Label skip;
// TODO: jumpIfTagIsNot can be generalized to take OperandX64 and then we can use it here; let's wait until we see this more though
build.cmp(dword[rax + offsetof(TValue, tt)], LUA_TUPVAL);
build.jcc(ConditionX64::NotEqual, skip);
// UpVal.v points to the value (either on stack, or on heap inside each UpVal, but we can deref it unconditionally)
build.mov(rax, qword[rax + offsetof(TValue, value.gc)]);
build.mov(rax, qword[rax + offsetof(UpVal, v)]);
build.setLabel(skip);
build.vmovups(xmm0, xmmword[rax]);
build.vmovups(luauReg(ra), xmm0);
}
void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, Label& next)
{
int ra = LUAU_INSN_A(*pc);
int up = LUAU_INSN_B(*pc);
RegisterX64 upval = rax;
RegisterX64 tmp = rcx;
build.mov(tmp, sClosure);
build.mov(upval, qword[tmp + offsetof(Closure, l.uprefs) + sizeof(TValue) * up + offsetof(TValue, value.gc)]);
build.mov(tmp, qword[upval + offsetof(UpVal, v)]);
build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[tmp], xmm0);
callBarrierObject(build, tmp, upval, ra, next);
}
void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, Label& next)
{
int ra = LUAU_INSN_A(*pc);
// L->openupval != 0
build.mov(rax, qword[rState + offsetof(lua_State, openupval)]);
build.test(rax, rax);
build.jcc(ConditionX64::Zero, next);
// ra <= L->openuval->v
build.lea(rcx, addr[rBase + ra * sizeof(TValue)]);
build.cmp(rcx, qword[rax + offsetof(UpVal, v)]);
build.jcc(ConditionX64::Above, next);
build.mov(rArg2, rcx);
build.mov(rArg1, rState);
build.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]);
}
static int emitInstFastCallN(
static void emitInstFastCallN(
AssemblyBuilderX64& build, const Instruction* pc, bool customParams, int customParamCount, OperandX64 customArgs, int pcpos, Label& fallback)
{
int bfid = LUAU_INSN_A(*pc);
@ -966,8 +496,6 @@ static int emitInstFastCallN(
int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1;
OperandX64 args = customParams ? customArgs : luauRegAddress(ra + 2);
jumpIfUnsafeEnv(build, rax, fallback);
BuiltinImplResult br = emitBuiltin(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback);
if (br.type == BuiltinImplType::UsesFallback)
@ -986,7 +514,7 @@ static int emitInstFastCallN(
build.mov(qword[rState + offsetof(lua_State, top)], rax);
}
return skip; // Return fallback instruction sequence length
return;
}
// TODO: we can skip saving pc for some well-behaved builtins which we didn't inline
@ -1052,108 +580,29 @@ static int emitInstFastCallN(
build.mov(rax, qword[rax + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], rax);
}
return skip; // Return fallback instruction sequence length
}
int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 1, /* customArgs */ 0, pcpos, fallback);
}
int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
return emitInstFastCallN(build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauRegAddress(pc[1]), pcpos, fallback);
}
int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
return emitInstFastCallN(
build, pc, /* customParams */ true, /* customParamCount */ 2, /* customArgs */ luauConstantAddress(pc[1]), pcpos, fallback);
}
int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
return emitInstFastCallN(build, pc, /* customParams */ false, /* customParamCount */ 0, /* customArgs */ 0, pcpos, fallback);
}
void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopStart, Label& loopExit)
{
int ra = LUAU_INSN_A(*pc);
Label tryConvert;
jumpIfTagIsNot(build, ra + 0, LUA_TNUMBER, tryConvert);
jumpIfTagIsNot(build, ra + 1, LUA_TNUMBER, tryConvert);
jumpIfTagIsNot(build, ra + 2, LUA_TNUMBER, tryConvert);
// After successful conversion of arguments to number, we return here
Label retry = build.setLabel();
RegisterX64 limit = xmm0;
RegisterX64 step = xmm1;
RegisterX64 idx = xmm2;
RegisterX64 zero = xmm3;
build.vxorpd(zero, xmm0, xmm0);
build.vmovsd(limit, luauRegValue(ra + 0));
build.vmovsd(step, luauRegValue(ra + 1));
build.vmovsd(idx, luauRegValue(ra + 2));
Label reverse;
// step <= 0
jumpOnNumberCmp(build, noreg, step, zero, ConditionX64::LessEqual, reverse);
// TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation
// false: idx <= limit
jumpOnNumberCmp(build, noreg, idx, limit, ConditionX64::LessEqual, loopStart);
build.jmp(loopExit);
// true: limit <= idx
build.setLabel(reverse);
jumpOnNumberCmp(build, noreg, limit, idx, ConditionX64::LessEqual, loopStart);
build.jmp(loopExit);
// TOOD: place at the end of the function
build.setLabel(tryConvert);
emitSetSavedPc(build, pcpos + 1);
callPrepareForN(build, ra + 0, ra + 1, ra + 2);
build.jmp(retry);
}
void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit)
{
emitInterrupt(build, pcpos);
int ra = LUAU_INSN_A(*pc);
RegisterX64 limit = xmm0;
RegisterX64 step = xmm1;
RegisterX64 idx = xmm2;
RegisterX64 zero = xmm3;
build.vxorpd(zero, xmm0, xmm0);
build.vmovsd(limit, luauRegValue(ra + 0));
build.vmovsd(step, luauRegValue(ra + 1));
build.vmovsd(idx, luauRegValue(ra + 2));
build.vaddsd(idx, idx, step);
build.vmovsd(luauRegValue(ra + 2), idx);
Label reverse;
// step <= 0
jumpOnNumberCmp(build, noreg, step, zero, ConditionX64::LessEqual, reverse);
// false: idx <= limit
jumpOnNumberCmp(build, noreg, idx, limit, ConditionX64::LessEqual, loopRepeat);
build.jmp(loopExit);
// true: limit <= idx
build.setLabel(reverse);
jumpOnNumberCmp(build, noreg, limit, idx, ConditionX64::LessEqual, loopRepeat);
}
void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
@ -1248,46 +697,6 @@ void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc,
build.jcc(ConditionX64::NotZero, loopRepeat);
}
void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
// fast-path: pairs/next
jumpIfUnsafeEnv(build, rax, fallback);
jumpIfTagIsNot(build, ra + 1, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, ra + 2, LUA_TNIL, fallback);
build.mov(luauRegTag(ra), LUA_TNIL);
// setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
build.mov(luauRegValue(ra + 2), 0);
build.mov(luauRegTag(ra + 2), LUA_TLIGHTUSERDATA);
build.jmp(target);
}
void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
// fast-path: ipairs/inext
jumpIfUnsafeEnv(build, rax, fallback);
jumpIfTagIsNot(build, ra + 1, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, ra + 2, LUA_TNUMBER, fallback);
build.vxorpd(xmm0, xmm0, xmm0);
build.vmovsd(xmm1, luauRegValue(ra + 2));
jumpOnNumberCmp(build, noreg, xmm0, xmm1, ConditionX64::NotEqual, fallback);
build.mov(luauRegTag(ra), LUA_TNIL);
// setpvalue(ra + 2, reinterpret_cast<void*>(uintptr_t(0)));
build.mov(luauRegValue(ra + 2), 0);
build.mov(luauRegTag(ra + 2), LUA_TLIGHTUSERDATA);
build.jmp(target);
}
void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& target)
{
int ra = LUAU_INSN_A(*pc);
@ -1375,169 +784,6 @@ void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc)
emitInstOrX(build, LUAU_INSN_A(*pc), LUAU_INSN_B(*pc), luauConstant(LUAU_INSN_C(*pc)));
}
void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int c = LUAU_INSN_C(*pc);
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
// unsigned(c) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], c);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
build.mov(rax, qword[table + offsetof(Table, array)]);
setLuauReg(build, xmm0, ra, xmmword[rax + c * sizeof(TValue)]);
}
void emitInstGetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
TValue n;
setnvalue(&n, LUAU_INSN_C(*pc) + 1);
callGetTable(build, LUAU_INSN_B(*pc), build.bytes(&n, sizeof(n)), LUAU_INSN_A(*pc));
}
void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int c = LUAU_INSN_C(*pc);
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
// unsigned(c) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], c);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
jumpIfTableIsReadOnly(build, table, fallback);
// setobj2t(L, &h->array[c], ra);
build.mov(rax, qword[table + offsetof(Table, array)]);
build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[rax + c * sizeof(TValue)], xmm0);
callBarrierTable(build, rax, table, ra, next);
}
void emitInstSetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
TValue n;
setnvalue(&n, LUAU_INSN_C(*pc) + 1);
callSetTable(build, LUAU_INSN_B(*pc), build.bytes(&n, sizeof(n)), LUAU_INSN_A(*pc));
}
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, rc, LUA_TNUMBER, fallback);
// fast-path: table with a number index
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
RegisterX64 intIndex = eax;
RegisterX64 fpIndex = xmm0;
build.vmovsd(fpIndex, luauRegValue(rc));
convertNumberToIndexOrJump(build, xmm1, fpIndex, intIndex, fallback);
// index - 1
build.dec(intIndex);
// unsigned(index - 1) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], intIndex);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
// setobj2s(L, ra, &h->array[unsigned(index - 1)]);
build.mov(rdx, qword[table + offsetof(Table, array)]);
build.shl(intIndex, kTValueSizeLog2);
setLuauReg(build, xmm0, ra, xmmword[rdx + rax]);
}
void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
callGetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc));
}
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
jumpIfTagIsNot(build, rc, LUA_TNUMBER, fallback);
// fast-path: table with a number index
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
RegisterX64 intIndex = eax;
RegisterX64 fpIndex = xmm0;
build.vmovsd(fpIndex, luauRegValue(rc));
convertNumberToIndexOrJump(build, xmm1, fpIndex, intIndex, fallback);
// index - 1
build.dec(intIndex);
// unsigned(index - 1) < unsigned(h->sizearray)
build.cmp(dword[table + offsetof(Table, sizearray)], intIndex);
build.jcc(ConditionX64::BelowEqual, fallback);
jumpIfMetatablePresent(build, table, fallback);
jumpIfTableIsReadOnly(build, table, fallback);
// setobj2t(L, &h->array[unsigned(index - 1)], ra);
build.mov(rdx, qword[table + offsetof(Table, array)]);
build.shl(intIndex, kTValueSizeLog2);
build.vmovups(xmm0, luauReg(ra));
build.vmovups(xmmword[rdx + rax], xmm0);
callBarrierTable(build, rdx, table, ra, next);
}
void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos)
{
emitSetSavedPc(build, pcpos + 1);
callSetTable(build, LUAU_INSN_B(*pc), luauRegAddress(LUAU_INSN_C(*pc)), LUAU_INSN_A(*pc));
}
void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int k = LUAU_INSN_D(*pc);
jumpIfUnsafeEnv(build, rax, fallback);
// note: if import failed, k[] is nil; we could check this during codegen, but we instead use runtime fallback
// this allows us to handle ahead-of-time codegen smoothly when an import fails to resolve at runtime
build.cmp(luauConstantTag(k), LUA_TNIL);
build.jcc(ConditionX64::Equal, fallback);
build.vmovups(xmm0, luauConstant(k));
build.vmovups(luauReg(ra), xmm0);
}
void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux)
{
build.mov(rax, sClosure);
@ -1567,105 +813,6 @@ void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux)
build.mov(qword[rState + offsetof(lua_State, top)], rax);
}
void emitInstGetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
uint32_t aux = pc[1];
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
RegisterX64 node = rdx;
getTableNodeAtCachedSlot(build, rax, node, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
setLuauReg(build, xmm0, ra, luauNodeValue(node));
}
void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
uint32_t aux = pc[1];
jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback);
RegisterX64 table = rcx;
build.mov(table, luauRegValue(rb));
// fast-path: set value at the expected slot
RegisterX64 node = rdx;
getTableNodeAtCachedSlot(build, rax, node, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
jumpIfTableIsReadOnly(build, table, fallback);
setNodeValue(build, xmm0, luauNodeValue(node), ra);
callBarrierTable(build, rax, table, ra, next);
}
void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
RegisterX64 table = rcx;
build.mov(rax, sClosure);
build.mov(table, qword[rax + offsetof(Closure, env)]);
RegisterX64 node = rdx;
getTableNodeAtCachedSlot(build, rax, node, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
setLuauReg(build, xmm0, ra, luauNodeValue(node));
}
void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback)
{
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
RegisterX64 table = rcx;
build.mov(rax, sClosure);
build.mov(table, qword[rax + offsetof(Closure, env)]);
RegisterX64 node = rdx;
getTableNodeAtCachedSlot(build, rax, node, table, pcpos);
jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback);
jumpIfTableIsReadOnly(build, table, fallback);
setNodeValue(build, xmm0, luauNodeValue(node), ra);
callBarrierTable(build, rax, table, ra, next);
}
void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next)
{
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
int rc = LUAU_INSN_C(*pc);
emitSetSavedPc(build, pcpos + 1);
// luaV_concat(L, c - b + 1, c)
build.mov(rArg1, rState);
build.mov(dwordReg(rArg2), rc - rb + 1);
build.mov(dwordReg(rArg3), rc);
build.call(qword[rNativeContext + offsetof(NativeContext, luaV_concat)]);
emitUpdateBase(build);
// setobj2s(L, ra, base + b)
build.vmovups(xmm0, luauReg(rb));
build.vmovups(luauReg(ra), xmm0);
callCheckGc(build, pcpos, /* savepc= */ false, next);
}
void emitInstCoverage(AssemblyBuilderX64& build, int pcpos)
{
build.mov(rcx, sCode);

View file

@ -14,78 +14,25 @@ namespace CodeGen
{
class AssemblyBuilderX64;
enum class ConditionX64 : uint8_t;
struct Label;
struct ModuleHelpers;
struct NativeState;
void emitInstLoadNil(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstLoadN(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstLoadKX(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMove(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback);
void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos);
void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos);
void emitInstJump(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpBack(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpIf(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfEq(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_, Label& fallback);
void emitInstJumpIfEqFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, bool not_);
void emitInstJumpIfCond(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond, Label& fallback);
void emitInstJumpIfCondFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr, ConditionX64 cond);
void emitInstJumpX(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqNil(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqB(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstJumpxEqN(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, int pcpos, Label* labelarr);
void emitInstJumpxEqS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label* labelarr);
void emitInstBinary(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback);
void emitInstBinaryFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm);
void emitInstBinaryK(AssemblyBuilderX64& build, const Instruction* pc, TMS tm, Label& fallback);
void emitInstBinaryKFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, TMS tm);
void emitInstPowK(AssemblyBuilderX64& build, const Instruction* pc, const TValue* k, Label& fallback);
void emitInstNot(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstMinus(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstMinusFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstLength(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstLengthFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstNewTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next);
void emitInstDupTable(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next);
void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next);
void emitInstGetUpval(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstSetUpval(AssemblyBuilderX64& build, const Instruction* pc, Label& next);
void emitInstCloseUpvals(AssemblyBuilderX64& build, const Instruction* pc, Label& next);
int emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
int emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
int emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
int emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstForNPrep(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopStart, Label& loopExit);
void emitInstForNLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit);
void emitInstFastCall1(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstFastCall2(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstFastCall2K(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstFastCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitinstForGLoop(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat, Label& loopExit, Label& fallback);
void emitinstForGLoopFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& loopRepeat);
void emitInstForGPrepNext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback);
void emitInstForGPrepInext(AssemblyBuilderX64& build, const Instruction* pc, Label& target, Label& fallback);
void emitInstForGPrepXnextFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& target);
void emitInstAnd(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstAndK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstOr(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstOrK(AssemblyBuilderX64& build, const Instruction* pc);
void emitInstGetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstGetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetTableN(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback);
void emitInstSetTableNFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstGetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstSetTable(AssemblyBuilderX64& build, const Instruction* pc, Label& next, Label& fallback);
void emitInstSetTableFallback(AssemblyBuilderX64& build, const Instruction* pc, int pcpos);
void emitInstGetImport(AssemblyBuilderX64& build, const Instruction* pc, Label& fallback);
void emitInstGetImportFallback(AssemblyBuilderX64& build, int ra, uint32_t aux);
void emitInstGetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetTableKS(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback);
void emitInstGetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& fallback);
void emitInstSetGlobal(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next, Label& fallback);
void emitInstConcat(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, Label& next);
void emitInstCoverage(AssemblyBuilderX64& build, int pcpos);
} // namespace CodeGen

View file

@ -44,6 +44,7 @@ void updateUseCounts(IrFunction& function)
checkOp(inst.c);
checkOp(inst.d);
checkOp(inst.e);
checkOp(inst.f);
}
}
@ -68,6 +69,7 @@ void updateLastUseLocations(IrFunction& function)
checkOp(inst.c);
checkOp(inst.d);
checkOp(inst.e);
checkOp(inst.f);
}
}

View file

@ -269,17 +269,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_FASTCALL:
{
int skip = LUAU_INSN_C(*pc);
IrOp fallback = block(IrBlockKind::Fallback);
IrOp next = blockAtInst(i + skip + 2);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
inst(IrCmd::LOP_FASTCALL, constUint(i), vmReg(LUAU_INSN_A(call)), constInt(LUAU_INSN_B(call) - 1), fallback);
inst(IrCmd::JUMP, next);
beginBlock(fallback);
translateFastCallN(*this, pc, i, false, 0, {}, next, IrCmd::LOP_FASTCALL);
activeFastcallFallback = true;
fastcallFallbackReturn = next;
@ -288,17 +280,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_FASTCALL1:
{
int skip = LUAU_INSN_C(*pc);
IrOp fallback = block(IrBlockKind::Fallback);
IrOp next = blockAtInst(i + skip + 2);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
inst(IrCmd::LOP_FASTCALL1, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), fallback);
inst(IrCmd::JUMP, next);
beginBlock(fallback);
translateFastCallN(*this, pc, i, true, 1, constBool(false), next, IrCmd::LOP_FASTCALL1);
activeFastcallFallback = true;
fastcallFallbackReturn = next;
@ -307,17 +291,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_FASTCALL2:
{
int skip = LUAU_INSN_C(*pc);
IrOp fallback = block(IrBlockKind::Fallback);
IrOp next = blockAtInst(i + skip + 2);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
inst(IrCmd::LOP_FASTCALL2, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), vmReg(pc[1]), fallback);
inst(IrCmd::JUMP, next);
beginBlock(fallback);
translateFastCallN(*this, pc, i, true, 2, vmReg(pc[1]), next, IrCmd::LOP_FASTCALL2);
activeFastcallFallback = true;
fastcallFallbackReturn = next;
@ -326,17 +302,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i)
case LOP_FASTCALL2K:
{
int skip = LUAU_INSN_C(*pc);
IrOp fallback = block(IrBlockKind::Fallback);
IrOp next = blockAtInst(i + skip + 2);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
inst(IrCmd::LOP_FASTCALL2K, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), vmConst(pc[1]), fallback);
inst(IrCmd::JUMP, next);
beginBlock(fallback);
translateFastCallN(*this, pc, i, true, 2, vmConst(pc[1]), next, IrCmd::LOP_FASTCALL2K);
activeFastcallFallback = true;
fastcallFallbackReturn = next;
@ -449,6 +417,7 @@ bool IrBuilder::isInternalBlock(IrOp block)
void IrBuilder::beginBlock(IrOp block)
{
IrBlock& target = function.blocks[block.index];
activeBlockIdx = block.index;
LUAU_ASSERT(target.start == ~0u || target.start == uint32_t(function.instructions.size()));
@ -511,36 +480,46 @@ IrOp IrBuilder::cond(IrCondition cond)
IrOp IrBuilder::inst(IrCmd cmd)
{
return inst(cmd, {}, {}, {}, {}, {});
return inst(cmd, {}, {}, {}, {}, {}, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a)
{
return inst(cmd, a, {}, {}, {}, {});
return inst(cmd, a, {}, {}, {}, {}, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b)
{
return inst(cmd, a, b, {}, {}, {});
return inst(cmd, a, b, {}, {}, {}, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c)
{
return inst(cmd, a, b, c, {}, {});
return inst(cmd, a, b, c, {}, {}, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d)
{
return inst(cmd, a, b, c, d, {});
return inst(cmd, a, b, c, d, {}, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e)
{
return inst(cmd, a, b, c, d, e, {});
}
IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f)
{
uint32_t index = uint32_t(function.instructions.size());
function.instructions.push_back({cmd, a, b, c, d, e});
function.instructions.push_back({cmd, a, b, c, d, e, f});
LUAU_ASSERT(!inTerminatedBlock);
if (isBlockTerminator(cmd))
{
function.blocks[activeBlockIdx].finish = index;
inTerminatedBlock = true;
}
return {IrOpKind::Inst, index};
}

View file

@ -148,6 +148,10 @@ const char* getCmdName(IrCmd cmd)
return "NUM_TO_INDEX";
case IrCmd::INT_TO_NUM:
return "INT_TO_NUM";
case IrCmd::ADJUST_STACK_TO_REG:
return "ADJUST_STACK_TO_REG";
case IrCmd::ADJUST_STACK_TO_TOP:
return "ADJUST_STACK_TO_TOP";
case IrCmd::DO_ARITH:
return "DO_ARITH";
case IrCmd::DO_LEN:
@ -280,35 +284,20 @@ void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index)
ctx.result.append(getCmdName(inst.cmd));
if (inst.a.kind != IrOpKind::None)
{
append(ctx.result, " ");
toString(ctx, inst.a);
}
auto checkOp = [&ctx](IrOp op, const char* sep) {
if (op.kind != IrOpKind::None)
{
ctx.result.append(sep);
toString(ctx, op);
}
};
if (inst.b.kind != IrOpKind::None)
{
append(ctx.result, ", ");
toString(ctx, inst.b);
}
if (inst.c.kind != IrOpKind::None)
{
append(ctx.result, ", ");
toString(ctx, inst.c);
}
if (inst.d.kind != IrOpKind::None)
{
append(ctx.result, ", ");
toString(ctx, inst.d);
}
if (inst.e.kind != IrOpKind::None)
{
append(ctx.result, ", ");
toString(ctx, inst.e);
}
checkOp(inst.a, " ");
checkOp(inst.b, ", ");
checkOp(inst.c, ", ");
checkOp(inst.d, ", ");
checkOp(inst.e, ", ");
checkOp(inst.f, ", ");
}
void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index)
@ -421,7 +410,7 @@ std::string toString(IrFunction& function, bool includeDetails)
}
// To allow dumping blocks that are still being constructed, we can't rely on terminator and need a bounds check
for (uint32_t index = block.start; index < uint32_t(function.instructions.size()); index++)
for (uint32_t index = block.start; index <= block.finish && index < uint32_t(function.instructions.size()); index++)
{
IrInst& inst = function.instructions[index];
@ -440,13 +429,9 @@ std::string toString(IrFunction& function, bool includeDetails)
toString(ctx, inst, index);
ctx.result.append("\n");
}
if (isBlockTerminator(inst.cmd))
{
append(ctx.result, "\n");
break;
}
}
append(ctx.result, "\n");
}
return result;

View file

@ -7,6 +7,7 @@
#include "Luau/IrDump.h"
#include "Luau/IrUtils.h"
#include "EmitBuiltinsX64.h"
#include "EmitCommonX64.h"
#include "EmitInstructionX64.h"
#include "NativeState.h"
@ -20,18 +21,14 @@ namespace Luau
namespace CodeGen
{
static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11};
IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, NativeState& data, Proto* proto, IrFunction& function)
: build(build)
, helpers(helpers)
, data(data)
, proto(proto)
, function(function)
, regs(function)
{
freeGprMap.fill(true);
freeXmmMap.fill(true);
// In order to allocate registers during lowering, we need to know where instruction results are last used
updateLastUseLocations(function);
}
@ -95,11 +92,13 @@ void IrLoweringX64::lower(AssemblyOptions options)
uint32_t blockIndex = sortedBlocks[i];
IrBlock& block = function.blocks[blockIndex];
LUAU_ASSERT(block.start != ~0u);
if (block.kind == IrBlockKind::Dead)
continue;
LUAU_ASSERT(block.start != ~0u);
LUAU_ASSERT(block.finish != ~0u);
// If we want to skip fallback code IR/asm, we'll record when those blocks start once we see them
if (block.kind == IrBlockKind::Fallback && !seenFallback)
{
@ -116,7 +115,7 @@ void IrLoweringX64::lower(AssemblyOptions options)
build.setLabel(block.label);
for (uint32_t index = block.start; true; index++)
for (uint32_t index = block.start; index <= block.finish; index++)
{
LUAU_ASSERT(index < function.instructions.size());
@ -151,15 +150,11 @@ void IrLoweringX64::lower(AssemblyOptions options)
lowerInst(inst, index, next);
freeLastUseRegs(inst, index);
if (isBlockTerminator(inst.cmd))
{
if (options.includeIr)
build.logAppend("#\n");
break;
}
regs.freeLastUseRegs(inst, index);
}
if (options.includeIr)
build.logAppend("#\n");
}
if (outputEnabled && !options.includeOutlinedCode && seenFallback)
@ -183,7 +178,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
switch (inst.cmd)
{
case IrCmd::LOAD_TAG:
inst.regX64 = allocGprReg(SizeX64::dword);
inst.regX64 = regs.allocGprReg(SizeX64::dword);
if (inst.a.kind == IrOpKind::VmReg)
build.mov(inst.regX64, luauRegTag(inst.a.index));
@ -197,7 +192,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(!"Unsupported instruction form");
break;
case IrCmd::LOAD_POINTER:
inst.regX64 = allocGprReg(SizeX64::qword);
inst.regX64 = regs.allocGprReg(SizeX64::qword);
if (inst.a.kind == IrOpKind::VmReg)
build.mov(inst.regX64, luauRegValue(inst.a.index));
@ -207,7 +202,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(!"Unsupported instruction form");
break;
case IrCmd::LOAD_DOUBLE:
inst.regX64 = allocXmmReg();
inst.regX64 = regs.allocXmmReg();
if (inst.a.kind == IrOpKind::VmReg)
build.vmovsd(inst.regX64, luauRegValue(inst.a.index));
@ -219,12 +214,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::LOAD_INT:
LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg);
inst.regX64 = allocGprReg(SizeX64::dword);
inst.regX64 = regs.allocGprReg(SizeX64::dword);
build.mov(inst.regX64, luauRegValueInt(inst.a.index));
break;
case IrCmd::LOAD_TVALUE:
inst.regX64 = allocXmmReg();
inst.regX64 = regs.allocXmmReg();
if (inst.a.kind == IrOpKind::VmReg)
build.vmovups(inst.regX64, luauReg(inst.a.index));
@ -236,12 +231,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(!"Unsupported instruction form");
break;
case IrCmd::LOAD_NODE_VALUE_TV:
inst.regX64 = allocXmmReg();
inst.regX64 = regs.allocXmmReg();
build.vmovups(inst.regX64, luauNodeValue(regOp(inst.a)));
break;
case IrCmd::LOAD_ENV:
inst.regX64 = allocGprReg(SizeX64::qword);
inst.regX64 = regs.allocGprReg(SizeX64::qword);
build.mov(inst.regX64, sClosure);
build.mov(inst.regX64, qword[inst.regX64 + offsetof(Closure, env)]);
@ -249,7 +244,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::GET_ARR_ADDR:
if (inst.b.kind == IrOpKind::Inst)
{
inst.regX64 = allocGprRegOrReuse(SizeX64::qword, index, {inst.b});
inst.regX64 = regs.allocGprRegOrReuse(SizeX64::qword, index, {inst.b});
if (dwordReg(inst.regX64) != regOp(inst.b))
build.mov(dwordReg(inst.regX64), regOp(inst.b));
@ -259,7 +254,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
else if (inst.b.kind == IrOpKind::Constant)
{
inst.regX64 = allocGprRegOrReuse(SizeX64::qword, index, {inst.a});
inst.regX64 = regs.allocGprRegOrReuse(SizeX64::qword, index, {inst.a});
build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(Table, array)]);
@ -273,9 +268,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
break;
case IrCmd::GET_SLOT_NODE_ADDR:
{
inst.regX64 = allocGprReg(SizeX64::qword);
inst.regX64 = regs.allocGprReg(SizeX64::qword);
ScopedReg tmp{*this, SizeX64::qword};
ScopedRegX64 tmp{regs, SizeX64::qword};
getTableNodeAtCachedSlot(build, tmp.reg, inst.regX64, regOp(inst.a), uintOp(inst.b));
break;
@ -298,7 +293,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
if (inst.b.kind == IrOpKind::Constant)
{
ScopedReg tmp{*this, SizeX64::xmmword};
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, build.f64(doubleOp(inst.b)));
build.vmovsd(luauRegValue(inst.a.index), tmp.reg);
@ -336,7 +331,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.vmovups(luauNodeValue(regOp(inst.a)), regOp(inst.b));
break;
case IrCmd::ADD_INT:
inst.regX64 = allocGprRegOrReuse(SizeX64::dword, index, {inst.a});
inst.regX64 = regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.a});
if (inst.regX64 == regOp(inst.a) && intOp(inst.b) == 1)
build.inc(inst.regX64);
@ -346,7 +341,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.lea(inst.regX64, addr[regOp(inst.a) + intOp(inst.b)]);
break;
case IrCmd::SUB_INT:
inst.regX64 = allocGprRegOrReuse(SizeX64::dword, index, {inst.a});
inst.regX64 = regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.a});
if (inst.regX64 == regOp(inst.a) && intOp(inst.b) == 1)
build.dec(inst.regX64);
@ -356,34 +351,74 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.lea(inst.regX64, addr[regOp(inst.a) - intOp(inst.b)]);
break;
case IrCmd::ADD_NUM:
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
build.vaddsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
if (inst.a.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, memRegDoubleOp(inst.a));
build.vaddsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b));
}
else
{
build.vaddsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
}
break;
case IrCmd::SUB_NUM:
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
build.vsubsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
if (inst.a.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, memRegDoubleOp(inst.a));
build.vsubsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b));
}
else
{
build.vsubsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
}
break;
case IrCmd::MUL_NUM:
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
build.vmulsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
if (inst.a.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, memRegDoubleOp(inst.a));
build.vmulsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b));
}
else
{
build.vmulsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
}
break;
case IrCmd::DIV_NUM:
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
build.vdivsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
if (inst.a.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmovsd(tmp.reg, memRegDoubleOp(inst.a));
build.vdivsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b));
}
else
{
build.vdivsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b));
}
break;
case IrCmd::MOD_NUM:
{
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
RegisterX64 lhs = regOp(inst.a);
if (inst.b.kind == IrOpKind::Inst)
{
ScopedReg tmp{*this, SizeX64::xmmword};
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vdivsd(tmp.reg, lhs, memRegDoubleOp(inst.b));
build.vroundsd(tmp.reg, tmp.reg, tmp.reg, RoundingModeX64::RoundToNegativeInfinity);
@ -392,8 +427,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
else
{
ScopedReg tmp1{*this, SizeX64::xmmword};
ScopedReg tmp2{*this, SizeX64::xmmword};
ScopedRegX64 tmp1{regs, SizeX64::xmmword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
build.vmovsd(tmp1.reg, memRegDoubleOp(inst.b));
build.vdivsd(tmp2.reg, lhs, tmp1.reg);
@ -405,9 +440,21 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
case IrCmd::POW_NUM:
{
inst.regX64 = allocXmmRegOrReuse(index, {inst.a, inst.b});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b});
RegisterX64 lhs = regOp(inst.a);
ScopedRegX64 tmp{regs, SizeX64::xmmword};
RegisterX64 lhs;
if (inst.a.kind == IrOpKind::Constant)
{
build.vmovsd(tmp.reg, memRegDoubleOp(inst.a));
lhs = tmp.reg;
}
else
{
lhs = regOp(inst.a);
}
if (inst.b.kind == IrOpKind::Inst)
{
@ -439,7 +486,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
else if (rhs == 3.0)
{
ScopedReg tmp{*this, SizeX64::xmmword};
ScopedRegX64 tmp{regs, SizeX64::xmmword};
build.vmulsd(tmp.reg, lhs, lhs);
build.vmulsd(inst.regX64, lhs, tmp.reg);
@ -472,7 +519,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
}
case IrCmd::UNM_NUM:
{
inst.regX64 = allocXmmRegOrReuse(index, {inst.a});
inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a});
RegisterX64 src = regOp(inst.a);
@ -491,15 +538,23 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::NOT_ANY:
{
// TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target
inst.regX64 = allocGprRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
inst.regX64 = regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.a, inst.b});
Label saveone, savezero, exit;
build.cmp(regOp(inst.a), LUA_TNIL);
build.jcc(ConditionX64::Equal, saveone);
if (inst.a.kind == IrOpKind::Constant)
{
// Other cases should've been constant folded
LUAU_ASSERT(tagOp(inst.a) == LUA_TBOOLEAN);
}
else
{
build.cmp(regOp(inst.a), LUA_TNIL);
build.jcc(ConditionX64::Equal, saveone);
build.cmp(regOp(inst.a), LUA_TBOOLEAN);
build.jcc(ConditionX64::NotEqual, savezero);
build.cmp(regOp(inst.a), LUA_TBOOLEAN);
build.jcc(ConditionX64::NotEqual, savezero);
}
build.cmp(regOp(inst.b), 0);
build.jcc(ConditionX64::Equal, saveone);
@ -566,7 +621,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
IrCondition cond = IrCondition(inst.c.index);
ScopedReg tmp{*this, SizeX64::xmmword};
ScopedRegX64 tmp{regs, SizeX64::xmmword};
// TODO: jumpOnNumberCmp should work on IrCondition directly
jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), getX64Condition(cond), labelOp(inst.d));
@ -586,14 +641,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
break;
}
case IrCmd::TABLE_LEN:
inst.regX64 = allocXmmReg();
inst.regX64 = regs.allocXmmReg();
build.mov(rArg1, regOp(inst.a));
build.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]);
build.vcvtsi2sd(inst.regX64, inst.regX64, eax);
break;
case IrCmd::NEW_TABLE:
inst.regX64 = allocGprReg(SizeX64::qword);
inst.regX64 = regs.allocGprReg(SizeX64::qword);
build.mov(rArg1, rState);
build.mov(dwordReg(rArg2), uintOp(inst.a));
@ -604,7 +659,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
build.mov(inst.regX64, rax);
break;
case IrCmd::DUP_TABLE:
inst.regX64 = allocGprReg(SizeX64::qword);
inst.regX64 = regs.allocGprReg(SizeX64::qword);
// Re-ordered to avoid register conflict
build.mov(rArg2, regOp(inst.a));
@ -616,18 +671,51 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
break;
case IrCmd::NUM_TO_INDEX:
{
inst.regX64 = allocGprReg(SizeX64::dword);
inst.regX64 = regs.allocGprReg(SizeX64::dword);
ScopedReg tmp{*this, SizeX64::xmmword};
ScopedRegX64 tmp{regs, SizeX64::xmmword};
convertNumberToIndexOrJump(build, tmp.reg, regOp(inst.a), inst.regX64, labelOp(inst.b));
break;
}
case IrCmd::INT_TO_NUM:
inst.regX64 = allocXmmReg();
inst.regX64 = regs.allocXmmReg();
build.vcvtsi2sd(inst.regX64, inst.regX64, regOp(inst.a));
break;
case IrCmd::ADJUST_STACK_TO_REG:
{
LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg);
if (inst.b.kind == IrOpKind::Constant)
{
ScopedRegX64 tmp{regs, SizeX64::qword};
build.lea(tmp.reg, addr[rBase + (inst.a.index + intOp(inst.b)) * sizeof(TValue)]);
build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg);
}
else if (inst.b.kind == IrOpKind::Inst)
{
ScopedRegX64 tmp(regs, regs.allocGprRegOrReuse(SizeX64::dword, index, {inst.b}));
build.shl(qwordReg(tmp.reg), kTValueSizeLog2);
build.lea(qwordReg(tmp.reg), addr[rBase + qwordReg(tmp.reg) + inst.a.index * sizeof(TValue)]);
build.mov(qword[rState + offsetof(lua_State, top)], qwordReg(tmp.reg));
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
}
break;
}
case IrCmd::ADJUST_STACK_TO_TOP:
{
ScopedRegX64 tmp{regs, SizeX64::qword};
build.mov(tmp.reg, qword[rState + offsetof(lua_State, ci)]);
build.mov(tmp.reg, qword[tmp.reg + offsetof(CallInfo, top)]);
build.mov(qword[rState + offsetof(lua_State, top)], tmp.reg);
break;
}
case IrCmd::DO_ARITH:
LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg);
LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg);
@ -702,8 +790,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg);
LUAU_ASSERT(inst.b.kind == IrOpKind::VmUpvalue);
ScopedReg tmp1{*this, SizeX64::qword};
ScopedReg tmp2{*this, SizeX64::xmmword};
ScopedRegX64 tmp1{regs, SizeX64::qword};
ScopedRegX64 tmp2{regs, SizeX64::xmmword};
build.mov(tmp1.reg, sClosure);
build.add(tmp1.reg, offsetof(Closure, l.uprefs) + sizeof(TValue) * inst.b.index);
@ -730,9 +818,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg);
Label next;
ScopedReg tmp1{*this, SizeX64::qword};
ScopedReg tmp2{*this, SizeX64::qword};
ScopedReg tmp3{*this, SizeX64::xmmword};
ScopedRegX64 tmp1{regs, SizeX64::qword};
ScopedRegX64 tmp2{regs, SizeX64::qword};
ScopedRegX64 tmp3{regs, SizeX64::xmmword};
build.mov(tmp1.reg, sClosure);
build.mov(tmp2.reg, qword[tmp1.reg + offsetof(Closure, l.uprefs) + sizeof(TValue) * inst.a.index + offsetof(TValue, value.gc)]);
@ -758,6 +846,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
{
jumpIfTagIsNot(build, inst.a.index, lua_Type(tagOp(inst.b)), labelOp(inst.c));
}
else if (inst.a.kind == IrOpKind::VmConst)
{
build.cmp(luauConstantTag(inst.a.index), tagOp(inst.b));
build.jcc(ConditionX64::NotEqual, labelOp(inst.c));
}
else
{
LUAU_ASSERT(!"Unsupported instruction form");
@ -771,7 +864,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
break;
case IrCmd::CHECK_SAFE_ENV:
{
ScopedReg tmp{*this, SizeX64::qword};
ScopedRegX64 tmp{regs, SizeX64::qword};
jumpIfUnsafeEnv(build, tmp.reg, labelOp(inst.a));
break;
@ -790,7 +883,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
{
LUAU_ASSERT(inst.b.kind == IrOpKind::VmConst);
ScopedReg tmp{*this, SizeX64::qword};
ScopedRegX64 tmp{regs, SizeX64::qword};
jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(inst.b.index), labelOp(inst.c));
break;
@ -810,7 +903,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg);
Label skip;
ScopedReg tmp{*this, SizeX64::qword};
ScopedRegX64 tmp{regs, SizeX64::qword};
callBarrierObject(build, tmp.reg, regOp(inst.a), inst.b.index, skip);
build.setLabel(skip);
@ -829,7 +922,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg);
Label skip;
ScopedReg tmp{*this, SizeX64::qword};
ScopedRegX64 tmp{regs, SizeX64::qword};
callBarrierTable(build, tmp.reg, regOp(inst.a), inst.b.index, skip);
build.setLabel(skip);
@ -838,8 +931,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
case IrCmd::SET_SAVEDPC:
{
// This is like emitSetSavedPc, but using register allocation instead of relying on rax/rdx
ScopedReg tmp1{*this, SizeX64::qword};
ScopedReg tmp2{*this, SizeX64::qword};
ScopedRegX64 tmp1{regs, SizeX64::qword};
ScopedRegX64 tmp2{regs, SizeX64::qword};
build.mov(tmp2.reg, sCode);
build.add(tmp2.reg, uintOp(inst.a) * sizeof(Instruction));
@ -852,8 +945,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next)
LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg);
Label next;
ScopedReg tmp1{*this, SizeX64::qword};
ScopedReg tmp2{*this, SizeX64::qword};
ScopedRegX64 tmp1{regs, SizeX64::qword};
ScopedRegX64 tmp2{regs, SizeX64::qword};
// L->openupval != 0
build.mov(tmp1.reg, qword[rState + offsetof(lua_State, openupval)]);
@ -1131,124 +1224,6 @@ Label& IrLoweringX64::labelOp(IrOp op) const
return blockOp(op).label;
}
RegisterX64 IrLoweringX64::allocGprReg(SizeX64 preferredSize)
{
LUAU_ASSERT(
preferredSize == SizeX64::byte || preferredSize == SizeX64::word || preferredSize == SizeX64::dword || preferredSize == SizeX64::qword);
for (RegisterX64 reg : kGprAllocOrder)
{
if (freeGprMap[reg.index])
{
freeGprMap[reg.index] = false;
return RegisterX64{preferredSize, reg.index};
}
}
LUAU_ASSERT(!"Out of GPR registers to allocate");
return noreg;
}
RegisterX64 IrLoweringX64::allocXmmReg()
{
for (size_t i = 0; i < freeXmmMap.size(); ++i)
{
if (freeXmmMap[i])
{
freeXmmMap[i] = false;
return RegisterX64{SizeX64::xmmword, uint8_t(i)};
}
}
LUAU_ASSERT(!"Out of XMM registers to allocate");
return noreg;
}
RegisterX64 IrLoweringX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list<IrOp> oprefs)
{
for (IrOp op : oprefs)
{
if (op.kind != IrOpKind::Inst)
continue;
IrInst& source = function.instructions[op.index];
if (source.lastUse == index && !source.reusedReg)
{
LUAU_ASSERT(source.regX64.size != SizeX64::xmmword);
LUAU_ASSERT(source.regX64 != noreg);
source.reusedReg = true;
return RegisterX64{preferredSize, source.regX64.index};
}
}
return allocGprReg(preferredSize);
}
RegisterX64 IrLoweringX64::allocXmmRegOrReuse(uint32_t index, std::initializer_list<IrOp> oprefs)
{
for (IrOp op : oprefs)
{
if (op.kind != IrOpKind::Inst)
continue;
IrInst& source = function.instructions[op.index];
if (source.lastUse == index && !source.reusedReg)
{
LUAU_ASSERT(source.regX64.size == SizeX64::xmmword);
LUAU_ASSERT(source.regX64 != noreg);
source.reusedReg = true;
return source.regX64;
}
}
return allocXmmReg();
}
void IrLoweringX64::freeReg(RegisterX64 reg)
{
if (reg.size == SizeX64::xmmword)
{
LUAU_ASSERT(!freeXmmMap[reg.index]);
freeXmmMap[reg.index] = true;
}
else
{
LUAU_ASSERT(!freeGprMap[reg.index]);
freeGprMap[reg.index] = true;
}
}
void IrLoweringX64::freeLastUseReg(IrInst& target, uint32_t index)
{
if (target.lastUse == index && !target.reusedReg)
{
// Register might have already been freed if it had multiple uses inside a single instruction
if (target.regX64 == noreg)
return;
freeReg(target.regX64);
target.regX64 = noreg;
}
}
void IrLoweringX64::freeLastUseRegs(const IrInst& inst, uint32_t index)
{
auto checkOp = [this, index](IrOp op) {
if (op.kind == IrOpKind::Inst)
freeLastUseReg(function.instructions[op.index], index);
};
checkOp(inst.a);
checkOp(inst.b);
checkOp(inst.c);
checkOp(inst.d);
checkOp(inst.e);
}
ConditionX64 IrLoweringX64::getX64Condition(IrCondition cond) const
{
// TODO: this function will not be required when jumpOnNumberCmp starts accepting an IrCondition
@ -1282,27 +1257,5 @@ ConditionX64 IrLoweringX64::getX64Condition(IrCondition cond) const
return ConditionX64::Count;
}
IrLoweringX64::ScopedReg::ScopedReg(IrLoweringX64& owner, SizeX64 size)
: owner(owner)
{
if (size == SizeX64::xmmword)
reg = owner.allocXmmReg();
else
reg = owner.allocGprReg(size);
}
IrLoweringX64::ScopedReg::~ScopedReg()
{
if (reg != noreg)
owner.freeReg(reg);
}
void IrLoweringX64::ScopedReg::free()
{
LUAU_ASSERT(reg != noreg);
owner.freeReg(reg);
reg = noreg;
}
} // namespace CodeGen
} // namespace Luau

View file

@ -4,8 +4,8 @@
#include "Luau/AssemblyBuilderX64.h"
#include "Luau/IrData.h"
#include <array>
#include <initializer_list>
#include "IrRegAllocX64.h"
#include <vector>
struct Proto;
@ -46,33 +46,8 @@ struct IrLoweringX64
IrBlock& blockOp(IrOp op) const;
Label& labelOp(IrOp op) const;
// Unscoped register allocation
RegisterX64 allocGprReg(SizeX64 preferredSize);
RegisterX64 allocXmmReg();
RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list<IrOp> oprefs);
RegisterX64 allocXmmRegOrReuse(uint32_t index, std::initializer_list<IrOp> oprefs);
void freeReg(RegisterX64 reg);
void freeLastUseReg(IrInst& target, uint32_t index);
void freeLastUseRegs(const IrInst& inst, uint32_t index);
ConditionX64 getX64Condition(IrCondition cond) const;
struct ScopedReg
{
ScopedReg(IrLoweringX64& owner, SizeX64 size);
~ScopedReg();
ScopedReg(const ScopedReg&) = delete;
ScopedReg& operator=(const ScopedReg&) = delete;
void free();
IrLoweringX64& owner;
RegisterX64 reg;
};
AssemblyBuilderX64& build;
ModuleHelpers& helpers;
NativeState& data;
@ -80,8 +55,7 @@ struct IrLoweringX64
IrFunction& function;
std::array<bool, 16> freeGprMap;
std::array<bool, 16> freeXmmMap;
IrRegAllocX64 regs;
};
} // namespace CodeGen

View file

@ -0,0 +1,181 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "IrRegAllocX64.h"
#include "Luau/CodeGen.h"
#include "Luau/DenseHash.h"
#include "Luau/IrAnalysis.h"
#include "Luau/IrDump.h"
#include "Luau/IrUtils.h"
#include "EmitCommonX64.h"
#include "EmitInstructionX64.h"
#include "NativeState.h"
#include "lstate.h"
#include <algorithm>
namespace Luau
{
namespace CodeGen
{
static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11};
IrRegAllocX64::IrRegAllocX64(IrFunction& function)
: function(function)
{
freeGprMap.fill(true);
freeXmmMap.fill(true);
}
RegisterX64 IrRegAllocX64::allocGprReg(SizeX64 preferredSize)
{
LUAU_ASSERT(
preferredSize == SizeX64::byte || preferredSize == SizeX64::word || preferredSize == SizeX64::dword || preferredSize == SizeX64::qword);
for (RegisterX64 reg : kGprAllocOrder)
{
if (freeGprMap[reg.index])
{
freeGprMap[reg.index] = false;
return RegisterX64{preferredSize, reg.index};
}
}
LUAU_ASSERT(!"Out of GPR registers to allocate");
return noreg;
}
RegisterX64 IrRegAllocX64::allocXmmReg()
{
for (size_t i = 0; i < freeXmmMap.size(); ++i)
{
if (freeXmmMap[i])
{
freeXmmMap[i] = false;
return RegisterX64{SizeX64::xmmword, uint8_t(i)};
}
}
LUAU_ASSERT(!"Out of XMM registers to allocate");
return noreg;
}
RegisterX64 IrRegAllocX64::allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list<IrOp> oprefs)
{
for (IrOp op : oprefs)
{
if (op.kind != IrOpKind::Inst)
continue;
IrInst& source = function.instructions[op.index];
if (source.lastUse == index && !source.reusedReg)
{
LUAU_ASSERT(source.regX64.size != SizeX64::xmmword);
LUAU_ASSERT(source.regX64 != noreg);
source.reusedReg = true;
return RegisterX64{preferredSize, source.regX64.index};
}
}
return allocGprReg(preferredSize);
}
RegisterX64 IrRegAllocX64::allocXmmRegOrReuse(uint32_t index, std::initializer_list<IrOp> oprefs)
{
for (IrOp op : oprefs)
{
if (op.kind != IrOpKind::Inst)
continue;
IrInst& source = function.instructions[op.index];
if (source.lastUse == index && !source.reusedReg)
{
LUAU_ASSERT(source.regX64.size == SizeX64::xmmword);
LUAU_ASSERT(source.regX64 != noreg);
source.reusedReg = true;
return source.regX64;
}
}
return allocXmmReg();
}
void IrRegAllocX64::freeReg(RegisterX64 reg)
{
if (reg.size == SizeX64::xmmword)
{
LUAU_ASSERT(!freeXmmMap[reg.index]);
freeXmmMap[reg.index] = true;
}
else
{
LUAU_ASSERT(!freeGprMap[reg.index]);
freeGprMap[reg.index] = true;
}
}
void IrRegAllocX64::freeLastUseReg(IrInst& target, uint32_t index)
{
if (target.lastUse == index && !target.reusedReg)
{
// Register might have already been freed if it had multiple uses inside a single instruction
if (target.regX64 == noreg)
return;
freeReg(target.regX64);
target.regX64 = noreg;
}
}
void IrRegAllocX64::freeLastUseRegs(const IrInst& inst, uint32_t index)
{
auto checkOp = [this, index](IrOp op) {
if (op.kind == IrOpKind::Inst)
freeLastUseReg(function.instructions[op.index], index);
};
checkOp(inst.a);
checkOp(inst.b);
checkOp(inst.c);
checkOp(inst.d);
checkOp(inst.e);
checkOp(inst.f);
}
ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, SizeX64 size)
: owner(owner)
{
if (size == SizeX64::xmmword)
reg = owner.allocXmmReg();
else
reg = owner.allocGprReg(size);
}
ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg)
: owner(owner)
, reg(reg)
{
}
ScopedRegX64::~ScopedRegX64()
{
if (reg != noreg)
owner.freeReg(reg);
}
void ScopedRegX64::free()
{
LUAU_ASSERT(reg != noreg);
owner.freeReg(reg);
reg = noreg;
}
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,51 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/IrData.h"
#include "Luau/RegisterX64.h"
#include <array>
#include <initializer_list>
namespace Luau
{
namespace CodeGen
{
struct IrRegAllocX64
{
IrRegAllocX64(IrFunction& function);
RegisterX64 allocGprReg(SizeX64 preferredSize);
RegisterX64 allocXmmReg();
RegisterX64 allocGprRegOrReuse(SizeX64 preferredSize, uint32_t index, std::initializer_list<IrOp> oprefs);
RegisterX64 allocXmmRegOrReuse(uint32_t index, std::initializer_list<IrOp> oprefs);
void freeReg(RegisterX64 reg);
void freeLastUseReg(IrInst& target, uint32_t index);
void freeLastUseRegs(const IrInst& inst, uint32_t index);
IrFunction& function;
std::array<bool, 16> freeGprMap;
std::array<bool, 16> freeXmmMap;
};
struct ScopedRegX64
{
ScopedRegX64(IrRegAllocX64& owner, SizeX64 size);
ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg);
~ScopedRegX64();
ScopedRegX64(const ScopedRegX64&) = delete;
ScopedRegX64& operator=(const ScopedRegX64&) = delete;
void free();
IrRegAllocX64& owner;
RegisterX64 reg;
};
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,40 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "IrTranslateBuiltins.h"
#include "Luau/Bytecode.h"
#include "Luau/IrBuilder.h"
#include "lstate.h"
namespace Luau
{
namespace CodeGen
{
BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback)
{
if (nparams < 1 || nresults != 0)
return {BuiltinImplType::None, -1};
IrOp cont = build.block(IrBlockKind::Internal);
// TODO: maybe adding a guard like CHECK_TRUTHY can be useful
build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(arg), fallback, cont);
build.beginBlock(cont);
return {BuiltinImplType::UsesFallback, 0};
}
BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback)
{
switch (bfid)
{
case LBF_ASSERT:
return translateBuiltinAssert(build, nparams, ra, arg, args, nresults, fallback);
default:
return {BuiltinImplType::None, -1};
}
}
} // namespace CodeGen
} // namespace Luau

View file

@ -0,0 +1,27 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
namespace Luau
{
namespace CodeGen
{
struct IrBuilder;
struct IrOp;
enum class BuiltinImplType
{
None,
UsesFallback, // Uses fallback for unsupported cases
};
struct BuiltinImplResult
{
BuiltinImplType type;
int actualResultCount;
};
BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback);
} // namespace CodeGen
} // namespace Luau

View file

@ -6,6 +6,7 @@
#include "Luau/IrUtils.h"
#include "CustomExecUtils.h"
#include "IrTranslateBuiltins.h"
#include "lobject.h"
#include "ltm.h"
@ -68,7 +69,6 @@ void translateInstLoadK(IrBuilder& build, const Instruction* pc)
{
int ra = LUAU_INSN_A(*pc);
// TODO: per-component loads and stores might be preferable
IrOp load = build.inst(IrCmd::LOAD_TVALUE, build.vmConst(LUAU_INSN_D(*pc)));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load);
}
@ -78,7 +78,6 @@ void translateInstLoadKX(IrBuilder& build, const Instruction* pc)
int ra = LUAU_INSN_A(*pc);
uint32_t aux = pc[1];
// TODO: per-component loads and stores might be preferable
IrOp load = build.inst(IrCmd::LOAD_TVALUE, build.vmConst(aux));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load);
}
@ -88,7 +87,6 @@ void translateInstMove(IrBuilder& build, const Instruction* pc)
int ra = LUAU_INSN_A(*pc);
int rb = LUAU_INSN_B(*pc);
// TODO: per-component loads and stores might be preferable
IrOp load = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), load);
}
@ -146,10 +144,11 @@ void translateInstJumpIfEq(IrBuilder& build, const Instruction* pc, int pcpos, b
build.inst(IrCmd::JUMP_CMP_NUM, va, vb, build.cond(IrCondition::NotEqual), not_ ? target : next, not_ ? next : target);
FallbackStreamScope scope(build, fallback, next);
build.beginBlock(fallback);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::JUMP_CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(not_ ? IrCondition::NotEqual : IrCondition::Equal), target, next);
build.beginBlock(next);
}
void translateInstJumpIfCond(IrBuilder& build, const Instruction* pc, int pcpos, IrCondition cond)
@ -173,10 +172,11 @@ void translateInstJumpIfCond(IrBuilder& build, const Instruction* pc, int pcpos,
build.inst(IrCmd::JUMP_CMP_NUM, va, vb, build.cond(cond), target, next);
FallbackStreamScope scope(build, fallback, next);
build.beginBlock(fallback);
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::JUMP_CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(cond), target, next);
build.beginBlock(next);
}
void translateInstJumpX(IrBuilder& build, const Instruction* pc, int pcpos)
@ -479,6 +479,61 @@ void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc)
build.inst(IrCmd::CLOSE_UPVALS, build.vmReg(ra));
}
void translateFastCallN(
IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next, IrCmd fallbackCmd)
{
int bfid = LUAU_INSN_A(*pc);
int skip = LUAU_INSN_C(*pc);
IrOp fallback = build.block(IrBlockKind::Fallback);
Instruction call = pc[skip + 1];
LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL);
int ra = LUAU_INSN_A(call);
int nparams = customParams ? customParamCount : LUAU_INSN_B(call) - 1;
int nresults = LUAU_INSN_C(call) - 1;
int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1;
IrOp args = customParams ? customArgs : build.vmReg(ra + 2);
build.inst(IrCmd::CHECK_SAFE_ENV, fallback);
BuiltinImplResult br = translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, args, nparams, nresults, fallback);
if (br.type == BuiltinImplType::UsesFallback)
{
if (nresults == LUA_MULTRET)
build.inst(IrCmd::ADJUST_STACK_TO_REG, build.vmReg(ra), build.constInt(br.actualResultCount));
else if (nparams == LUA_MULTRET)
build.inst(IrCmd::ADJUST_STACK_TO_TOP);
}
else
{
switch (fallbackCmd)
{
case IrCmd::LOP_FASTCALL:
build.inst(IrCmd::LOP_FASTCALL, build.constUint(pcpos), build.vmReg(ra), build.constInt(nparams), fallback);
break;
case IrCmd::LOP_FASTCALL1:
build.inst(IrCmd::LOP_FASTCALL1, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), fallback);
break;
case IrCmd::LOP_FASTCALL2:
build.inst(IrCmd::LOP_FASTCALL2, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), build.vmReg(pc[1]), fallback);
break;
case IrCmd::LOP_FASTCALL2K:
build.inst(IrCmd::LOP_FASTCALL2K, build.constUint(pcpos), build.vmReg(ra), build.vmReg(arg), build.vmConst(pc[1]), fallback);
break;
default:
LUAU_ASSERT(!"unexpected command");
}
}
build.inst(IrCmd::JUMP, next);
// this will be filled with IR corresponding to instructions after FASTCALL until skip+1
build.beginBlock(fallback);
}
void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos)
{
int ra = LUAU_INSN_A(*pc);
@ -589,7 +644,6 @@ void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpo
build.inst(IrCmd::JUMP, target);
// FallbackStreamScope not used here because this instruction doesn't fallthrough to next instruction
build.beginBlock(fallback);
build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target);
}
@ -622,7 +676,6 @@ void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcp
build.inst(IrCmd::JUMP, target);
// FallbackStreamScope not used here because this instruction doesn't fallthrough to next instruction
build.beginBlock(fallback);
build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target);
}
@ -700,7 +753,6 @@ void translateInstGetTableN(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(c));
// TODO: per-component loads and stores might be preferable
IrOp arrElTval = build.inst(IrCmd::LOAD_TVALUE, arrEl);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), arrElTval);
@ -731,7 +783,6 @@ void translateInstSetTableN(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, build.constInt(c));
// TODO: per-component loads and stores might be preferable
IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra));
build.inst(IrCmd::STORE_TVALUE, arrEl, tva);
@ -771,7 +822,6 @@ void translateInstGetTable(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, index);
// TODO: per-component loads and stores might be preferable
IrOp arrElTval = build.inst(IrCmd::LOAD_TVALUE, arrEl);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), arrElTval);
@ -810,7 +860,6 @@ void translateInstSetTable(IrBuilder& build, const Instruction* pc, int pcpos)
IrOp arrEl = build.inst(IrCmd::GET_ARR_ADDR, vb, index);
// TODO: per-component loads and stores might be preferable
IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra));
build.inst(IrCmd::STORE_TVALUE, arrEl, tva);
@ -842,7 +891,6 @@ void translateInstGetImport(IrBuilder& build, const Instruction* pc, int pcpos)
build.beginBlock(fastPath);
// TODO: per-component loads and stores might be preferable
IrOp tvk = build.inst(IrCmd::LOAD_TVALUE, build.vmConst(k));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvk);
@ -871,7 +919,6 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback);
// TODO: per-component loads and stores might be preferable
IrOp tvn = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrSlotEl);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvn);
@ -900,7 +947,6 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback);
build.inst(IrCmd::CHECK_READONLY, vb, fallback);
// TODO: per-component loads and stores might be preferable
IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra));
build.inst(IrCmd::STORE_NODE_VALUE_TV, addrSlotEl, tva);
@ -925,7 +971,6 @@ void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback);
// TODO: per-component loads and stores might be preferable
IrOp tvn = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrSlotEl);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvn);
@ -949,7 +994,6 @@ void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback);
build.inst(IrCmd::CHECK_READONLY, env, fallback);
// TODO: per-component loads and stores might be preferable
IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra));
build.inst(IrCmd::STORE_NODE_VALUE_TV, addrSlotEl, tva);
@ -971,7 +1015,6 @@ void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos)
build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1));
build.inst(IrCmd::CONCAT, build.vmReg(rb), build.constUint(rc - rb + 1));
// TODO: per-component loads and stores might be preferable
IrOp tvb = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(rb));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvb);

View file

@ -15,6 +15,7 @@ namespace CodeGen
enum class IrCondition : uint8_t;
struct IrOp;
struct IrBuilder;
enum class IrCmd : uint8_t;
void translateInstLoadNil(IrBuilder& build, const Instruction* pc);
void translateInstLoadB(IrBuilder& build, const Instruction* pc, int pcpos);
@ -42,6 +43,8 @@ void translateInstDupTable(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstGetUpval(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc);
void translateFastCallN(
IrBuilder& build, const Instruction* pc, int pcpos, bool customParams, int customParamCount, IrOp customArgs, IrOp next, IrCmd fallbackCmd);
void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos);
void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpos);

View file

@ -14,20 +14,6 @@ namespace Luau
namespace CodeGen
{
static uint32_t getBlockEnd(IrFunction& function, uint32_t start)
{
LUAU_ASSERT(start < function.instructions.size());
uint32_t end = start;
// Find previous block terminator
while (!isBlockTerminator(function.instructions[end].cmd))
end++;
LUAU_ASSERT(end < function.instructions.size());
return end;
}
void addUse(IrFunction& function, IrOp op)
{
if (op.kind == IrOpKind::Inst)
@ -44,6 +30,12 @@ void removeUse(IrFunction& function, IrOp op)
removeUse(function, function.blocks[op.index]);
}
bool isGCO(uint8_t tag)
{
// mirrors iscollectable(o) from VM/lobject.h
return tag >= LUA_TSTRING;
}
void kill(IrFunction& function, IrInst& inst)
{
LUAU_ASSERT(inst.useCount == 0);
@ -55,12 +47,14 @@ void kill(IrFunction& function, IrInst& inst)
removeUse(function, inst.c);
removeUse(function, inst.d);
removeUse(function, inst.e);
removeUse(function, inst.f);
inst.a = {};
inst.b = {};
inst.c = {};
inst.d = {};
inst.e = {};
inst.f = {};
}
void kill(IrFunction& function, uint32_t start, uint32_t end)
@ -84,10 +78,9 @@ void kill(IrFunction& function, IrBlock& block)
block.kind = IrBlockKind::Dead;
uint32_t start = block.start;
uint32_t end = getBlockEnd(function, start);
kill(function, start, end);
kill(function, block.start, block.finish);
block.start = ~0u;
block.finish = ~0u;
}
void removeUse(IrFunction& function, IrInst& inst)
@ -117,7 +110,7 @@ void replace(IrFunction& function, IrOp& original, IrOp replacement)
original = replacement;
}
void replace(IrFunction& function, uint32_t instIdx, IrInst replacement)
void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst replacement)
{
IrInst& inst = function.instructions[instIdx];
@ -127,19 +120,18 @@ void replace(IrFunction& function, uint32_t instIdx, IrInst replacement)
addUse(function, replacement.c);
addUse(function, replacement.d);
addUse(function, replacement.e);
addUse(function, replacement.f);
// If we introduced an earlier terminating instruction, all following instructions become dead
if (!isBlockTerminator(inst.cmd) && isBlockTerminator(replacement.cmd))
{
uint32_t start = instIdx + 1;
// Block has has to be fully constructed before replacement is performed
LUAU_ASSERT(block.finish != ~0u);
LUAU_ASSERT(instIdx + 1 <= block.finish);
// If we are in the process of constructing a block, replacement might happen at the last instruction
if (start < function.instructions.size())
{
uint32_t end = getBlockEnd(function, start);
kill(function, instIdx + 1, block.finish);
kill(function, start, end);
}
block.finish = instIdx;
}
removeUse(function, inst.a);
@ -147,6 +139,7 @@ void replace(IrFunction& function, uint32_t instIdx, IrInst replacement)
removeUse(function, inst.c);
removeUse(function, inst.d);
removeUse(function, inst.e);
removeUse(function, inst.f);
inst = replacement;
}
@ -162,12 +155,14 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement)
removeUse(function, inst.c);
removeUse(function, inst.d);
removeUse(function, inst.e);
removeUse(function, inst.f);
inst.a = replacement;
inst.b = {};
inst.c = {};
inst.d = {};
inst.e = {};
inst.f = {};
}
void applySubstitutions(IrFunction& function, IrOp& op)
@ -203,9 +198,10 @@ void applySubstitutions(IrFunction& function, IrInst& inst)
applySubstitutions(function, inst.c);
applySubstitutions(function, inst.d);
applySubstitutions(function, inst.e);
applySubstitutions(function, inst.f);
}
static bool compare(double a, double b, IrCondition cond)
bool compare(double a, double b, IrCondition cond)
{
switch (cond)
{
@ -236,7 +232,7 @@ static bool compare(double a, double b, IrCondition cond)
return false;
}
void foldConstants(IrBuilder& build, IrFunction& function, uint32_t index)
void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t index)
{
IrInst& inst = function.instructions[index];
@ -311,27 +307,27 @@ void foldConstants(IrBuilder& build, IrFunction& function, uint32_t index)
if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant)
{
if (function.tagOp(inst.a) == function.tagOp(inst.b))
replace(function, index, {IrCmd::JUMP, inst.c});
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, index, {IrCmd::JUMP, inst.d});
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
break;
case IrCmd::JUMP_EQ_INT:
if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant)
{
if (function.intOp(inst.a) == function.intOp(inst.b))
replace(function, index, {IrCmd::JUMP, inst.c});
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, index, {IrCmd::JUMP, inst.d});
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
break;
case IrCmd::JUMP_CMP_NUM:
if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant)
{
if (compare(function.doubleOp(inst.a), function.doubleOp(inst.b), function.conditionOp(inst.c)))
replace(function, index, {IrCmd::JUMP, inst.d});
replace(function, block, index, {IrCmd::JUMP, inst.d});
else
replace(function, index, {IrCmd::JUMP, inst.e});
replace(function, block, index, {IrCmd::JUMP, inst.e});
}
break;
case IrCmd::NUM_TO_INDEX:
@ -347,11 +343,11 @@ void foldConstants(IrBuilder& build, IrFunction& function, uint32_t index)
if (double(arrIndex) == value)
substitute(function, inst, build.constInt(arrIndex));
else
replace(function, index, {IrCmd::JUMP, inst.b});
replace(function, block, index, {IrCmd::JUMP, inst.b});
}
else
{
replace(function, index, {IrCmd::JUMP, inst.b});
replace(function, block, index, {IrCmd::JUMP, inst.b});
}
}
break;
@ -365,7 +361,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, uint32_t index)
if (function.tagOp(inst.a) == function.tagOp(inst.b))
kill(function, inst);
else
replace(function, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path
replace(function, block, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path
}
break;
default:

View file

@ -0,0 +1,565 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/OptimizeConstProp.h"
#include "Luau/DenseHash.h"
#include "Luau/IrBuilder.h"
#include "Luau/IrUtils.h"
#include "lua.h"
namespace Luau
{
namespace CodeGen
{
// Data we know about the register value
struct RegisterInfo
{
uint8_t tag = 0xff;
IrOp value;
// Used to quickly invalidate links between SSA values and register memory
// It's a bit imprecise where value and tag both always invalidate together
uint32_t version = 0;
bool knownNotReadonly = false;
bool knownNoMetatable = false;
};
// Load instructions are linked to target register to carry knowledge about the target
// We track a register version at the point of the load so it's easy to break the link when register is updated
struct RegisterLink
{
uint8_t reg = 0;
uint32_t version = 0;
};
// Data we know about the current VM state
struct ConstPropState
{
uint8_t tryGetTag(IrOp op)
{
if (RegisterInfo* info = tryGetRegisterInfo(op))
return info->tag;
return 0xff;
}
void saveTag(IrOp op, uint8_t tag)
{
if (RegisterInfo* info = tryGetRegisterInfo(op))
info->tag = tag;
}
IrOp tryGetValue(IrOp op)
{
if (RegisterInfo* info = tryGetRegisterInfo(op))
return info->value;
return IrOp{IrOpKind::None, 0u};
}
void saveValue(IrOp op, IrOp value)
{
LUAU_ASSERT(value.kind == IrOpKind::Constant);
if (RegisterInfo* info = tryGetRegisterInfo(op))
info->value = value;
}
void invalidate(RegisterInfo& reg, bool invalidateTag, bool invalidateValue)
{
if (invalidateTag)
{
reg.tag = 0xff;
}
if (invalidateValue)
{
reg.value = {};
reg.knownNotReadonly = false;
reg.knownNoMetatable = false;
}
reg.version++;
}
void invalidateTag(IrOp regOp)
{
LUAU_ASSERT(regOp.kind == IrOpKind::VmReg);
invalidate(regs[regOp.index], /* invalidateTag */ true, /* invalidateValue */ false);
}
void invalidateValue(IrOp regOp)
{
LUAU_ASSERT(regOp.kind == IrOpKind::VmReg);
invalidate(regs[regOp.index], /* invalidateTag */ false, /* invalidateValue */ true);
}
void invalidate(IrOp regOp)
{
LUAU_ASSERT(regOp.kind == IrOpKind::VmReg);
invalidate(regs[regOp.index], /* invalidateTag */ true, /* invalidateValue */ true);
}
void invalidateRegistersFrom(uint32_t firstReg)
{
for (int i = int(firstReg); i <= maxReg; ++i)
invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true);
maxReg = int(firstReg) - 1;
}
void invalidateHeap()
{
for (int i = 0; i <= maxReg; ++i)
invalidateHeap(regs[i]);
}
void invalidateHeap(RegisterInfo& reg)
{
reg.knownNotReadonly = false;
reg.knownNoMetatable = false;
}
void invalidateAll()
{
// Invalidating registers also invalidates what we know about the heap (stored in RegisterInfo)
invalidateRegistersFrom(0u);
inSafeEnv = false;
}
void createRegLink(uint32_t instIdx, IrOp regOp)
{
LUAU_ASSERT(regOp.kind == IrOpKind::VmReg);
LUAU_ASSERT(!instLink.contains(instIdx));
instLink[instIdx] = RegisterLink{uint8_t(regOp.index), regs[regOp.index].version};
}
RegisterInfo* tryGetRegisterInfo(IrOp op)
{
if (op.kind == IrOpKind::VmReg)
{
maxReg = int(op.index) > maxReg ? int(op.index) : maxReg;
return &regs[op.index];
}
if (RegisterLink* link = tryGetRegLink(op))
{
maxReg = int(link->reg) > maxReg ? int(link->reg) : maxReg;
return &regs[link->reg];
}
return nullptr;
}
RegisterLink* tryGetRegLink(IrOp instOp)
{
if (instOp.kind != IrOpKind::Inst)
return nullptr;
if (RegisterLink* link = instLink.find(instOp.index))
{
// Check that the target register hasn't changed the value
if (link->version > regs[link->reg].version)
return nullptr;
return link;
}
return nullptr;
}
RegisterInfo regs[256];
// For range/full invalidations, we only want to visit a limited number of data that we have recorded
int maxReg = 0;
bool inSafeEnv = false;
bool checkedGc = false;
DenseHashMap<uint32_t, RegisterLink> instLink{~0u};
};
static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& function, IrBlock& block, IrInst& inst, uint32_t index)
{
switch (inst.cmd)
{
case IrCmd::LOAD_TAG:
if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff)
substitute(function, inst, build.constTag(tag));
else if (inst.a.kind == IrOpKind::VmReg)
state.createRegLink(index, inst.a);
break;
case IrCmd::LOAD_POINTER:
if (inst.a.kind == IrOpKind::VmReg)
state.createRegLink(index, inst.a);
break;
case IrCmd::LOAD_DOUBLE:
if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant)
substitute(function, inst, value);
else if (inst.a.kind == IrOpKind::VmReg)
state.createRegLink(index, inst.a);
break;
case IrCmd::LOAD_INT:
if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant)
substitute(function, inst, value);
else if (inst.a.kind == IrOpKind::VmReg)
state.createRegLink(index, inst.a);
break;
case IrCmd::LOAD_TVALUE:
if (inst.a.kind == IrOpKind::VmReg)
state.createRegLink(index, inst.a);
break;
case IrCmd::STORE_TAG:
if (inst.a.kind == IrOpKind::VmReg)
{
if (inst.b.kind == IrOpKind::Constant)
{
uint8_t value = function.tagOp(inst.b);
if (state.tryGetTag(inst.a) == value)
kill(function, inst);
else
state.saveTag(inst.a, value);
}
else
{
state.invalidateTag(inst.a);
}
}
break;
case IrCmd::STORE_POINTER:
if (inst.a.kind == IrOpKind::VmReg)
state.invalidateValue(inst.a);
break;
case IrCmd::STORE_DOUBLE:
if (inst.a.kind == IrOpKind::VmReg)
{
if (inst.b.kind == IrOpKind::Constant)
{
std::optional<double> oldValue = function.asDoubleOp(state.tryGetValue(inst.a));
double newValue = function.doubleOp(inst.b);
if (oldValue && *oldValue == newValue)
kill(function, inst);
else
state.saveValue(inst.a, inst.b);
}
else
{
state.invalidateValue(inst.a);
}
}
break;
case IrCmd::STORE_INT:
if (inst.a.kind == IrOpKind::VmReg)
{
if (inst.b.kind == IrOpKind::Constant)
{
std::optional<int> oldValue = function.asIntOp(state.tryGetValue(inst.a));
int newValue = function.intOp(inst.b);
if (oldValue && *oldValue == newValue)
kill(function, inst);
else
state.saveValue(inst.a, inst.b);
}
else
{
state.invalidateValue(inst.a);
}
}
break;
case IrCmd::STORE_TVALUE:
if (inst.a.kind == IrOpKind::VmReg)
{
state.invalidate(inst.a);
if (uint8_t tag = state.tryGetTag(inst.b); tag != 0xff)
state.saveTag(inst.a, tag);
if (IrOp value = state.tryGetValue(inst.b); value.kind != IrOpKind::None)
state.saveValue(inst.a, value);
}
break;
case IrCmd::JUMP_IF_TRUTHY:
if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff)
{
if (tag == LUA_TNIL)
replace(function, block, index, {IrCmd::JUMP, inst.c});
else if (tag != LUA_TBOOLEAN)
replace(function, block, index, {IrCmd::JUMP, inst.b});
}
break;
case IrCmd::JUMP_IF_FALSY:
if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff)
{
if (tag == LUA_TNIL)
replace(function, block, index, {IrCmd::JUMP, inst.b});
else if (tag != LUA_TBOOLEAN)
replace(function, block, index, {IrCmd::JUMP, inst.c});
}
break;
case IrCmd::JUMP_EQ_TAG:
{
uint8_t tagA = inst.a.kind == IrOpKind::Constant ? function.tagOp(inst.a) : state.tryGetTag(inst.a);
uint8_t tagB = inst.b.kind == IrOpKind::Constant ? function.tagOp(inst.b) : state.tryGetTag(inst.b);
if (tagA != 0xff && tagB != 0xff)
{
if (tagA == tagB)
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
break;
}
case IrCmd::JUMP_EQ_INT:
{
std::optional<int> valueA = function.asIntOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a));
std::optional<int> valueB = function.asIntOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b));
if (valueA && valueB)
{
if (*valueA == *valueB)
replace(function, block, index, {IrCmd::JUMP, inst.c});
else
replace(function, block, index, {IrCmd::JUMP, inst.d});
}
break;
}
case IrCmd::JUMP_CMP_NUM:
{
std::optional<double> valueA = function.asDoubleOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a));
std::optional<double> valueB = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b));
if (valueA && valueB)
{
if (compare(*valueA, *valueB, function.conditionOp(inst.c)))
replace(function, block, index, {IrCmd::JUMP, inst.d});
else
replace(function, block, index, {IrCmd::JUMP, inst.e});
}
break;
}
case IrCmd::GET_UPVALUE:
state.invalidate(inst.a);
break;
case IrCmd::CHECK_TAG:
{
uint8_t b = function.tagOp(inst.b);
if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff)
{
if (tag == b)
kill(function, inst);
else
replace(function, block, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path
}
else
{
state.saveTag(inst.a, b); // We can assume the tag value going forward
}
break;
}
case IrCmd::CHECK_READONLY:
if (RegisterInfo* info = state.tryGetRegisterInfo(inst.a))
{
if (info->knownNotReadonly)
kill(function, inst);
else
info->knownNotReadonly = true;
}
break;
case IrCmd::CHECK_NO_METATABLE:
if (RegisterInfo* info = state.tryGetRegisterInfo(inst.a))
{
if (info->knownNoMetatable)
kill(function, inst);
else
info->knownNoMetatable = true;
}
break;
case IrCmd::CHECK_SAFE_ENV:
if (state.inSafeEnv)
kill(function, inst);
else
state.inSafeEnv = true;
break;
case IrCmd::CHECK_GC:
// It is enough to perform a GC check once in a block
if (state.checkedGc)
kill(function, inst);
else
state.checkedGc = true;
break;
case IrCmd::BARRIER_OBJ:
case IrCmd::BARRIER_TABLE_FORWARD:
if (inst.b.kind == IrOpKind::VmReg)
{
if (uint8_t tag = state.tryGetTag(inst.b); tag != 0xff)
{
// If the written object is not collectable, barrier is not required
if (!isGCO(tag))
kill(function, inst);
}
}
break;
case IrCmd::LOP_FASTCALL:
case IrCmd::LOP_FASTCALL1:
case IrCmd::LOP_FASTCALL2:
case IrCmd::LOP_FASTCALL2K:
// TODO: classify fast call behaviors to avoid heap invalidation
state.invalidateHeap(); // Even a builtin method can change table properties
state.invalidateRegistersFrom(inst.b.index);
break;
case IrCmd::LOP_AND:
case IrCmd::LOP_ANDK:
case IrCmd::LOP_OR:
case IrCmd::LOP_ORK:
state.invalidate(inst.b);
break;
// These instructions don't have an effect on register/memory state we are tracking
case IrCmd::NOP:
case IrCmd::LOAD_NODE_VALUE_TV:
case IrCmd::LOAD_ENV:
case IrCmd::GET_ARR_ADDR:
case IrCmd::GET_SLOT_NODE_ADDR:
case IrCmd::STORE_NODE_VALUE_TV:
case IrCmd::ADD_INT:
case IrCmd::SUB_INT:
case IrCmd::ADD_NUM:
case IrCmd::SUB_NUM:
case IrCmd::MUL_NUM:
case IrCmd::DIV_NUM:
case IrCmd::MOD_NUM:
case IrCmd::POW_NUM:
case IrCmd::UNM_NUM:
case IrCmd::NOT_ANY:
case IrCmd::JUMP:
case IrCmd::JUMP_EQ_POINTER:
case IrCmd::TABLE_LEN:
case IrCmd::NEW_TABLE:
case IrCmd::DUP_TABLE:
case IrCmd::NUM_TO_INDEX:
case IrCmd::INT_TO_NUM:
case IrCmd::CHECK_ARRAY_SIZE:
case IrCmd::CHECK_SLOT_MATCH:
case IrCmd::BARRIER_TABLE_BACK:
case IrCmd::LOP_RETURN:
case IrCmd::LOP_COVERAGE:
case IrCmd::SET_UPVALUE:
case IrCmd::LOP_SETLIST: // We don't track table state that this can invalidate
case IrCmd::SET_SAVEDPC: // TODO: we may be able to remove some updates to PC
case IrCmd::CLOSE_UPVALS: // Doesn't change memory that we track
case IrCmd::CAPTURE:
case IrCmd::SUBSTITUTE:
case IrCmd::ADJUST_STACK_TO_REG: // Changes stack top, but not the values
case IrCmd::ADJUST_STACK_TO_TOP: // Changes stack top, but not the values
break;
// We don't model the following instructions, so we just clear all the knowledge we have built up
// Many of these call user functions that can change memory and captured registers
// Some of these might yield with similar effects
case IrCmd::JUMP_CMP_ANY:
case IrCmd::DO_ARITH:
case IrCmd::DO_LEN:
case IrCmd::GET_TABLE:
case IrCmd::SET_TABLE:
case IrCmd::GET_IMPORT:
case IrCmd::CONCAT:
case IrCmd::PREPARE_FORN:
case IrCmd::INTERRUPT: // TODO: it will be important to keep tag/value state, but we have to track register capture
case IrCmd::LOP_NAMECALL:
case IrCmd::LOP_CALL:
case IrCmd::LOP_FORGLOOP:
case IrCmd::LOP_FORGLOOP_FALLBACK:
case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK:
case IrCmd::FALLBACK_GETGLOBAL:
case IrCmd::FALLBACK_SETGLOBAL:
case IrCmd::FALLBACK_GETTABLEKS:
case IrCmd::FALLBACK_SETTABLEKS:
case IrCmd::FALLBACK_NAMECALL:
case IrCmd::FALLBACK_PREPVARARGS:
case IrCmd::FALLBACK_GETVARARGS:
case IrCmd::FALLBACK_NEWCLOSURE:
case IrCmd::FALLBACK_DUPCLOSURE:
case IrCmd::FALLBACK_FORGPREP:
// TODO: this is very conservative, some of there instructions can be tracked better
// TODO: non-captured register tags and values should not be cleared here
state.invalidateAll();
break;
}
}
static void constPropInBlock(IrBuilder& build, IrBlock& block, ConstPropState& state)
{
IrFunction& function = build.function;
for (uint32_t index = block.start; index <= block.finish; index++)
{
LUAU_ASSERT(index < function.instructions.size());
IrInst& inst = function.instructions[index];
applySubstitutions(function, inst);
foldConstants(build, function, block, index);
constPropInInst(state, build, function, block, inst, index);
}
}
static void constPropInBlockChain(IrBuilder& build, std::vector<uint8_t>& visited, IrBlock* block)
{
IrFunction& function = build.function;
ConstPropState state;
while (block)
{
uint32_t blockIdx = function.getBlockIndex(*block);
LUAU_ASSERT(!visited[blockIdx]);
visited[blockIdx] = true;
constPropInBlock(build, *block, state);
IrInst& termInst = function.instructions[block->finish];
IrBlock* nextBlock = nullptr;
// Unconditional jump into a block with a single user (current block) allows us to continue optimization
// with the information we have gathered so far (unless we have already visited that block earlier)
if (termInst.cmd == IrCmd::JUMP)
{
IrBlock& target = function.blockOp(termInst.a);
if (target.useCount == 1 && !visited[function.getBlockIndex(target)] && target.kind != IrBlockKind::Fallback)
nextBlock = &target;
}
block = nextBlock;
}
}
void constPropInBlockChains(IrBuilder& build)
{
IrFunction& function = build.function;
std::vector<uint8_t> visited(function.blocks.size(), false);
for (IrBlock& block : function.blocks)
{
if (block.kind == IrBlockKind::Fallback || block.kind == IrBlockKind::Dead)
continue;
if (visited[function.getBlockIndex(block)])
continue;
constPropInBlockChain(build, visited, &block);
}
}
} // namespace CodeGen
} // namespace Luau

View file

@ -17,7 +17,7 @@ static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block)
{
LUAU_ASSERT(block.kind != IrBlockKind::Dead);
for (uint32_t index = block.start; true; index++)
for (uint32_t index = block.start; index <= block.finish; index++)
{
LUAU_ASSERT(index < function.instructions.size());
IrInst& inst = function.instructions[index];
@ -90,9 +90,6 @@ static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block)
default:
break;
}
if (isBlockTerminator(inst.cmd))
break;
}
}

View file

@ -42,7 +42,7 @@
// Note: due to limitations of the versioning scheme, some bytecode blobs that carry version 2 are using features from version 3. Starting from version 3, version should be sufficient to indicate bytecode compatibility.
//
// Version 1: Baseline version for the open-source release. Supported until 0.521.
// Version 2: Adds Proto::linedefined. Currently supported.
// Version 2: Adds Proto::linedefined. Supported until 0.544.
// Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported.
// Bytecode opcode, part of the instruction header

View file

@ -25,6 +25,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25)
LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300)
LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5)
LUAU_FASTFLAGVARIABLE(LuauCompileTerminateBC, false)
namespace Luau
{
@ -132,14 +134,18 @@ struct Compiler
return uint8_t(upvals.size() - 1);
}
bool allPathsEndWithReturn(AstStat* node)
// true iff all execution paths through node subtree result in return/break/continue
// note: because this function doesn't visit loop nodes, it (correctly) only detects break/continue that refer to the outer control flow
bool alwaysTerminates(AstStat* node)
{
if (AstStatBlock* stat = node->as<AstStatBlock>())
return stat->body.size > 0 && allPathsEndWithReturn(stat->body.data[stat->body.size - 1]);
return stat->body.size > 0 && alwaysTerminates(stat->body.data[stat->body.size - 1]);
else if (node->is<AstStatReturn>())
return true;
else if (FFlag::LuauCompileTerminateBC && (node->is<AstStatBreak>() || node->is<AstStatContinue>()))
return true;
else if (AstStatIf* stat = node->as<AstStatIf>())
return stat->elsebody && allPathsEndWithReturn(stat->thenbody) && allPathsEndWithReturn(stat->elsebody);
return stat->elsebody && alwaysTerminates(stat->thenbody) && alwaysTerminates(stat->elsebody);
else
return false;
}
@ -213,7 +219,7 @@ struct Compiler
// valid function bytecode must always end with RETURN
// we elide this if we're guaranteed to hit a RETURN statement regardless of the control flow
if (!allPathsEndWithReturn(stat))
if (!alwaysTerminates(stat))
{
setDebugLineEnd(stat);
closeLocals(0);
@ -257,7 +263,7 @@ struct Compiler
f.costModel = modelCost(func->body, func->args.data, func->args.size, builtins);
// track functions that only ever return a single value so that we can convert multret calls to fixedret calls
if (allPathsEndWithReturn(func->body))
if (alwaysTerminates(func->body))
{
ReturnVisitor returnVisitor(this);
stat->visit(&returnVisitor);
@ -640,7 +646,7 @@ struct Compiler
}
// for the fallthrough path we need to ensure we clear out target registers
if (!usedFallthrough && !allPathsEndWithReturn(func->body))
if (!usedFallthrough && !alwaysTerminates(func->body))
{
for (size_t i = 0; i < targetCount; ++i)
bytecode.emitABC(LOP_LOADNIL, uint8_t(target + i), 0, 0);
@ -2435,9 +2441,9 @@ struct Compiler
if (stat->elsebody && elseJump.size() > 0)
{
// we don't need to skip past "else" body if "then" ends with return
// we don't need to skip past "else" body if "then" ends with return/break/continue
// this is important because, if "else" also ends with return, we may *not* have any statement to skip to!
if (allPathsEndWithReturn(stat->thenbody))
if (alwaysTerminates(stat->thenbody))
{
size_t elseLabel = bytecode.emitLabel();

View file

@ -70,6 +70,7 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/include/Luau/IrUtils.h
CodeGen/include/Luau/Label.h
CodeGen/include/Luau/OperandX64.h
CodeGen/include/Luau/OptimizeConstProp.h
CodeGen/include/Luau/OptimizeFinalX64.h
CodeGen/include/Luau/RegisterA64.h
CodeGen/include/Luau/RegisterX64.h
@ -92,9 +93,12 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/src/IrBuilder.cpp
CodeGen/src/IrDump.cpp
CodeGen/src/IrLoweringX64.cpp
CodeGen/src/IrRegAllocX64.cpp
CodeGen/src/IrTranslateBuiltins.cpp
CodeGen/src/IrTranslation.cpp
CodeGen/src/IrUtils.cpp
CodeGen/src/NativeState.cpp
CodeGen/src/OptimizeConstProp.cpp
CodeGen/src/OptimizeFinalX64.cpp
CodeGen/src/UnwindBuilderDwarf2.cpp
CodeGen/src/UnwindBuilderWin.cpp
@ -109,6 +113,8 @@ target_sources(Luau.CodeGen PRIVATE
CodeGen/src/Fallbacks.h
CodeGen/src/FallbacksProlog.h
CodeGen/src/IrLoweringX64.h
CodeGen/src/IrRegAllocX64.h
CodeGen/src/IrTranslateBuiltins.h
CodeGen/src/IrTranslation.h
CodeGen/src/NativeState.h
)

View file

@ -1677,6 +1677,8 @@ RETURN R0 0
TEST_CASE("LoopBreak")
{
ScopedFastFlag sff("LuauCompileTerminateBC", true);
// default codegen: compile breaks as unconditional jumps
CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break else end end"), R"(
L0: GETIMPORT R0 2 [math.random]
@ -1684,7 +1686,6 @@ CALL R0 0 1
LOADK R1 K3 [0.5]
JUMPIFNOTLT R0 R1 L1
RETURN R0 0
JUMP L1
L1: JUMPBACK L0
RETURN R0 0
)");
@ -1702,6 +1703,8 @@ L1: RETURN R0 0
TEST_CASE("LoopContinue")
{
ScopedFastFlag sff("LuauCompileTerminateBC", true);
// default codegen: compile continue as unconditional jumps
CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue else end break until false error()"), R"(
L0: GETIMPORT R0 2 [math.random]
@ -1710,7 +1713,6 @@ LOADK R1 K3 [0.5]
JUMPIFNOTLT R0 R1 L2
JUMP L1
JUMP L2
JUMP L2
L1: JUMPBACK L0
L2: GETIMPORT R0 5 [error]
CALL R0 0 0
@ -6808,4 +6810,59 @@ RETURN R0 0
)");
}
TEST_CASE("ElideJumpAfterIf")
{
ScopedFastFlag sff("LuauCompileTerminateBC", true);
// break refers to outer loop => we can elide unconditional branches
CHECK_EQ("\n" + compileFunction0(R"(
local foo, bar = ...
repeat
if foo then break
elseif bar then break
end
print(1234)
until foo == bar
)"),
R"(
GETVARARGS R0 2
L0: JUMPIFNOT R0 L1
RETURN R0 0
L1: JUMPIF R1 L2
GETIMPORT R2 1 [print]
LOADN R3 1234
CALL R2 1 0
JUMPIFEQ R0 R1 L2
JUMPBACK L0
L2: RETURN R0 0
)");
// break refers to inner loop => branches remain
CHECK_EQ("\n" + compileFunction0(R"(
local foo, bar = ...
repeat
if foo then while true do break end
elseif bar then while true do break end
end
print(1234)
until foo == bar
)"),
R"(
GETVARARGS R0 2
L0: JUMPIFNOT R0 L1
JUMP L2
JUMPBACK L2
JUMP L2
L1: JUMPIFNOT R1 L2
JUMP L2
JUMPBACK L2
L2: GETIMPORT R2 1 [print]
LOADN R3 1234
CALL R2 1 0
JUMPIFEQ R0 R1 L3
JUMPBACK L0
L3: RETURN R0 0
)");
}
TEST_SUITE_END();

View file

@ -31,7 +31,8 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code)
void ConstraintGraphBuilderFixture::solve(const std::string& code)
{
generateConstraints(code);
ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger};
ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull{mainModule->reduction.get()},
NotNull(&moduleResolver), {}, &logger};
cs.run();
}

View file

@ -3,6 +3,7 @@
#include "Luau/IrAnalysis.h"
#include "Luau/IrDump.h"
#include "Luau/IrUtils.h"
#include "Luau/OptimizeConstProp.h"
#include "Luau/OptimizeFinalX64.h"
#include "doctest.h"
@ -16,12 +17,18 @@ class IrBuilderFixture
public:
void constantFold()
{
for (size_t i = 0; i < build.function.instructions.size(); i++)
for (IrBlock& block : build.function.blocks)
{
IrInst& inst = build.function.instructions[i];
if (block.kind == IrBlockKind::Dead)
continue;
applySubstitutions(build.function, inst);
foldConstants(build, build.function, uint32_t(i));
for (size_t i = block.start; i <= block.finish; i++)
{
IrInst& inst = build.function.instructions[i];
applySubstitutions(build.function, inst);
foldConstants(build, build.function, block, uint32_t(i));
}
}
}
@ -96,7 +103,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptCheckTag")
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
optimizeMemoryOperandsX64(build.function);
@ -109,7 +116,7 @@ bb_0:
LOP_RETURN 0u
bb_fallback_1:
LOP_RETURN 0u
LOP_RETURN 1u
)");
}
@ -147,7 +154,6 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag1")
IrOp opA = build.inst(IrCmd::LOAD_TAG, build.vmReg(1));
IrOp opB = build.inst(IrCmd::LOAD_TAG, build.vmReg(2));
build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
@ -184,7 +190,6 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag2")
IrOp opB = build.inst(IrCmd::LOAD_TAG, build.vmReg(2));
build.inst(IrCmd::STORE_TAG, build.vmReg(6), opA);
build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
@ -223,7 +228,6 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3")
IrOp arrElem = build.inst(IrCmd::GET_ARR_ADDR, table, build.constInt(0));
IrOp opA = build.inst(IrCmd::LOAD_TAG, arrElem);
build.inst(IrCmd::JUMP_EQ_TAG, opA, build.constTag(0), trueBlock, falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
@ -261,7 +265,6 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptJumpCmpNum")
IrOp opA = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1));
IrOp opB = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2));
build.inst(IrCmd::JUMP_CMP_NUM, opA, opB, trueBlock, falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
@ -466,14 +469,14 @@ bb_3:
TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowCmpNum")
{
IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0));
auto compareFold = [this](IrOp lhs, IrOp rhs, IrCondition cond, bool result) {
IrOp instOp;
IrInst instExpected;
withTwoBlocks([&](IrOp a, IrOp b) {
instOp = build.inst(IrCmd::JUMP_CMP_NUM, lhs, rhs, build.cond(cond), a, b);
IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0));
instOp = build.inst(
IrCmd::JUMP_CMP_NUM, lhs.kind == IrOpKind::None ? nan : lhs, rhs.kind == IrOpKind::None ? nan : rhs, build.cond(cond), a, b);
instExpected = IrInst{IrCmd::JUMP, result ? a : b};
});
@ -482,6 +485,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowCmpNum")
checkEq(instOp, instExpected);
};
IrOp nan; // Empty operand is used to signal a placement of a 'nan'
compareFold(build.constDouble(1), build.constDouble(1), IrCondition::Equal, true);
compareFold(build.constDouble(1), build.constDouble(2), IrCondition::Equal, false);
compareFold(nan, nan, IrCondition::Equal, false);
@ -532,3 +537,661 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowCmpNum")
}
TEST_SUITE_END();
TEST_SUITE_BEGIN("ConstantPropagation");
TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(10));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5));
// We know constants from those loads
build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(0)));
build.inst(IrCmd::STORE_INT, build.vmReg(4), build.inst(IrCmd::LOAD_INT, build.vmReg(1)));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)));
// We know that these overrides have no effect
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(10));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5));
// But we can invalidate them with unknown values
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.inst(IrCmd::LOAD_TAG, build.vmReg(6)));
build.inst(IrCmd::STORE_INT, build.vmReg(1), build.inst(IrCmd::LOAD_INT, build.vmReg(7)));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(8)));
// So now the constant stores have to be made
build.inst(IrCmd::STORE_TAG, build.vmReg(9), build.inst(IrCmd::LOAD_TAG, build.vmReg(0)));
build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::LOAD_INT, build.vmReg(1)));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(11), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
STORE_INT R1, 10i
STORE_DOUBLE R2, 0.5
STORE_TAG R3, tnumber
STORE_INT R4, 10i
STORE_DOUBLE R5, 0.5
%12 = LOAD_TAG R6
STORE_TAG R0, %12
%14 = LOAD_INT R7
STORE_INT R1, %14
%16 = LOAD_DOUBLE R8
STORE_DOUBLE R2, %16
%18 = LOAD_TAG R0
STORE_TAG R9, %18
%20 = LOAD_INT R1
STORE_INT R10, %20
%22 = LOAD_DOUBLE R2
STORE_DOUBLE R11, %22
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "PropagateThroughTvalue")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5));
IrOp tv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0));
build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), tv);
// We know constants from those loads
build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(1)));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
STORE_DOUBLE R0, 0.5
%2 = LOAD_TVALUE R0
STORE_TVALUE R1, %2
STORE_TAG R3, tnumber
STORE_DOUBLE R3, 0.5
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "SkipCheckTag")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "SkipOncePerBlockChecks")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::CHECK_SAFE_ENV);
build.inst(IrCmd::CHECK_SAFE_ENV);
build.inst(IrCmd::CHECK_GC);
build.inst(IrCmd::CHECK_GC);
build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); // Can make env unsafe
build.inst(IrCmd::CHECK_SAFE_ENV);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
CHECK_SAFE_ENV
CHECK_GC
DO_LEN R1, R2
CHECK_SAFE_ENV
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTableState")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0));
build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback);
build.inst(IrCmd::CHECK_READONLY, table, fallback);
build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback);
build.inst(IrCmd::CHECK_READONLY, table, fallback);
build.inst(IrCmd::DO_LEN, build.vmReg(1), build.vmReg(2)); // Can access all heap memory
build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback);
build.inst(IrCmd::CHECK_READONLY, table, fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_POINTER R0
CHECK_NO_METATABLE %0, bb_fallback_1
CHECK_READONLY %0, bb_fallback_1
DO_LEN R1, R2
CHECK_NO_METATABLE %0, bb_fallback_1
CHECK_READONLY %0, bb_fallback_1
LOP_RETURN 0u
bb_fallback_1:
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1));
build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, build.vmReg(0));
IrOp something = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2));
build.inst(IrCmd::BARRIER_OBJ, something, build.vmReg(0));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(10));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(0.5));
build.inst(IrCmd::CONCAT, build.vmReg(0), build.vmReg(3)); // Concat invalidates more than the target register
build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.inst(IrCmd::LOAD_TAG, build.vmReg(0)));
build.inst(IrCmd::STORE_INT, build.vmReg(4), build.inst(IrCmd::LOAD_INT, build.vmReg(1)));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
STORE_INT R1, 10i
STORE_DOUBLE R2, 0.5
CONCAT R0, R3
%4 = LOAD_TAG R0
STORE_TAG R3, %4
%6 = LOAD_INT R1
STORE_INT R4, %6
%8 = LOAD_DOUBLE R2
STORE_DOUBLE R5, %8
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5));
IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0));
build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback);
build.inst(IrCmd::CHECK_READONLY, table, fallback);
build.inst(IrCmd::LOP_FASTCALL1, build.constUint(0), build.vmReg(1), build.vmReg(2), fallback);
build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback);
build.inst(IrCmd::CHECK_READONLY, table, fallback);
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0))); // At least R0 wasn't touched
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_DOUBLE R0, 0.5
%1 = LOAD_POINTER R0
CHECK_NO_METATABLE %1, bb_fallback_1
CHECK_READONLY %1, bb_fallback_1
LOP_FASTCALL1 0u, R1, R2, bb_fallback_1
CHECK_NO_METATABLE %1, bb_fallback_1
CHECK_READONLY %1, bb_fallback_1
STORE_DOUBLE R1, 0.5
LOP_RETURN 0u
bb_fallback_1:
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5));
build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_INT R0, 10i
STORE_DOUBLE R0, 0.5
STORE_INT R0, 10i
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType")
{
IrOp block = build.block(IrBlockKind::Internal);
build.beginBlock(block);
build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5));
build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10));
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_INT R0, 10i
STORE_DOUBLE R0, 0.5
STORE_INT R0, 10i
LOP_RETURN 0u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagation")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
IrOp unknown = build.inst(IrCmd::LOAD_TAG, build.vmReg(0));
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback);
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_TAG R0
CHECK_TAG %0, tnumber, bb_fallback_1
LOP_RETURN 0u
bb_fallback_1:
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagationConflicting")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
IrOp unknown = build.inst(IrCmd::LOAD_TAG, build.vmReg(0));
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback);
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnil), fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(0));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_TAG R0
CHECK_TAG %0, tnumber, bb_fallback_1
JUMP bb_fallback_1
bb_fallback_1:
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "TruthyTestRemoval")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp trueBlock = build.block(IrBlockKind::Internal);
IrOp falseBlock = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
IrOp unknown = build.inst(IrCmd::LOAD_TAG, build.vmReg(1));
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback);
build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(1), trueBlock, falseBlock);
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(2));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(3));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_TAG R1
CHECK_TAG %0, tnumber, bb_fallback_3
JUMP bb_1
bb_1:
LOP_RETURN 1u
bb_fallback_3:
LOP_RETURN 3u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "FalsyTestRemoval")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp trueBlock = build.block(IrBlockKind::Internal);
IrOp falseBlock = build.block(IrBlockKind::Internal);
IrOp fallback = build.block(IrBlockKind::Fallback);
build.beginBlock(block);
IrOp unknown = build.inst(IrCmd::LOAD_TAG, build.vmReg(1));
build.inst(IrCmd::CHECK_TAG, unknown, build.constTag(tnumber), fallback);
build.inst(IrCmd::JUMP_IF_FALSY, build.vmReg(1), trueBlock, falseBlock);
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(2));
build.beginBlock(fallback);
build.inst(IrCmd::LOP_RETURN, build.constUint(3));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_TAG R1
CHECK_TAG %0, tnumber, bb_fallback_3
JUMP bb_2
bb_2:
LOP_RETURN 2u
bb_fallback_3:
LOP_RETURN 3u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "TagEqRemoval")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp trueBlock = build.block(IrBlockKind::Internal);
IrOp falseBlock = build.block(IrBlockKind::Internal);
build.beginBlock(block);
IrOp tag = build.inst(IrCmd::LOAD_TAG, build.vmReg(1));
build.inst(IrCmd::CHECK_TAG, tag, build.constTag(tboolean));
build.inst(IrCmd::JUMP_EQ_TAG, tag, build.constTag(tnumber), trueBlock, falseBlock);
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(2));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
%0 = LOAD_TAG R1
CHECK_TAG %0, tboolean
JUMP bb_2
bb_2:
LOP_RETURN 2u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp trueBlock = build.block(IrBlockKind::Internal);
IrOp falseBlock = build.block(IrBlockKind::Internal);
build.beginBlock(block);
IrOp value = build.inst(IrCmd::LOAD_INT, build.vmReg(1));
build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(5));
build.inst(IrCmd::JUMP_EQ_INT, value, build.constInt(5), trueBlock, falseBlock);
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(2));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_INT R1, 5i
JUMP bb_1
bb_1:
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval")
{
IrOp block = build.block(IrBlockKind::Internal);
IrOp trueBlock = build.block(IrBlockKind::Internal);
IrOp falseBlock = build.block(IrBlockKind::Internal);
build.beginBlock(block);
IrOp value = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1));
build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(4.0));
build.inst(IrCmd::JUMP_CMP_NUM, value, build.constDouble(8.0), build.cond(IrCondition::Greater), trueBlock, falseBlock);
build.beginBlock(trueBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(falseBlock);
build.inst(IrCmd::LOP_RETURN, build.constUint(2));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_DOUBLE R1, 4
JUMP bb_2
bb_2:
LOP_RETURN 2u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "DataFlowsThroughDirectJumpToUniqueSuccessor")
{
IrOp block1 = build.block(IrBlockKind::Internal);
IrOp block2 = build.block(IrBlockKind::Internal);
build.beginBlock(block1);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::JUMP, block2);
build.beginBlock(block2);
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.inst(IrCmd::LOAD_TAG, build.vmReg(0)));
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
JUMP bb_1
bb_1:
STORE_TAG R1, tnumber
LOP_RETURN 1u
)");
}
TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUniqueSuccessor")
{
IrOp block1 = build.block(IrBlockKind::Internal);
IrOp block2 = build.block(IrBlockKind::Internal);
IrOp block3 = build.block(IrBlockKind::Internal);
build.beginBlock(block1);
build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber));
build.inst(IrCmd::JUMP, block2);
build.beginBlock(block2);
build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.inst(IrCmd::LOAD_TAG, build.vmReg(0)));
build.inst(IrCmd::LOP_RETURN, build.constUint(1));
build.beginBlock(block3);
build.inst(IrCmd::JUMP, block2);
updateUseCounts(build.function);
constPropInBlockChains(build);
CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"(
bb_0:
STORE_TAG R0, tnumber
JUMP bb_1
bb_1:
%2 = LOAD_TAG R0
STORE_TAG R1, %2
LOP_RETURN 1u
bb_2:
JUMP bb_1
)");
}
TEST_SUITE_END();

View file

@ -71,6 +71,14 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive")
TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table")
{
// Under DCR, we don't seal the outer occurrance of the table `Cyclic` which
// breaks this test. I'm not sure if that behaviour change is important or
// not, but it's tangental to the core purpose of this test.
ScopedFastFlag sff[] = {
{"DebugLuauDeferredConstraintResolution", false},
};
CheckResult result = check(R"(
local Cyclic = {}
function Cyclic.get()
@ -85,13 +93,13 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table")
* Assert that the return type of get() is the same as the outer table.
*/
TypeId counterType = requireType("Cyclic");
TypeId ty = requireType("Cyclic");
TypeArena dest;
CloneState cloneState;
TypeId counterCopy = clone(counterType, dest, cloneState);
TypeId cloneTy = clone(ty, dest, cloneState);
TableType* ttv = getMutable<TableType>(counterCopy);
TableType* ttv = getMutable<TableType>(cloneTy);
REQUIRE(ttv != nullptr);
CHECK_EQ(std::optional<std::string>{"Cyclic"}, ttv->syntheticName);
@ -105,11 +113,42 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table")
std::optional<TypeId> methodReturnType = first(ftv->retTypes);
REQUIRE(methodReturnType);
CHECK_EQ(methodReturnType, counterCopy);
CHECK_MESSAGE(methodReturnType == cloneTy, toString(methodType, {true}) << " should be pointer identical to " << toString(cloneTy, {true}));
CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type
CHECK_EQ(2, dest.types.size()); // One table and one function
}
TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table_2")
{
TypeArena src;
TypeId tableTy = src.addType(TableType{});
TableType* tt = getMutable<TableType>(tableTy);
REQUIRE(tt);
TypeId methodTy = src.addType(FunctionType{src.addTypePack({}), src.addTypePack({tableTy})});
tt->props["get"].type = methodTy;
TypeArena dest;
CloneState cloneState;
TypeId cloneTy = clone(tableTy, dest, cloneState);
TableType* ctt = getMutable<TableType>(cloneTy);
REQUIRE(ctt);
TypeId clonedMethodType = ctt->props["get"].type;
REQUIRE(clonedMethodType);
const FunctionType* cmf = get<FunctionType>(clonedMethodType);
REQUIRE(cmf);
std::optional<TypeId> cloneMethodReturnType = first(cmf->retTypes);
REQUIRE(bool(cloneMethodReturnType));
CHECK(*cloneMethodReturnType == cloneTy);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena")
{
CheckResult result = check(R"(

View file

@ -194,7 +194,7 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any")
CHECK_EQ(*typeChecker.anyType, *ttv->props["one"].type);
CHECK_EQ(*typeChecker.anyType, *ttv->props["two"].type);
CHECK_MESSAGE(get<FunctionType>(ttv->props["three"].type), "Should be a function: " << *ttv->props["three"].type);
CHECK_MESSAGE(get<FunctionType>(follow(ttv->props["three"].type)), "Should be a function: " << *ttv->props["three"].type);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_iterator_variables_are_any")

View file

@ -786,7 +786,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param")
TypeId parentTy = requireType("foo");
auto ttv = get<TableType>(follow(parentTy));
auto ftv = get<FunctionType>(ttv->props.at("method").type);
auto ftv = get<FunctionType>(follow(ttv->props.at("method").type));
CHECK_EQ("foo:method<a>(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv));
}
@ -803,12 +803,16 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param")
end
)");
TypeId parentTy = requireType("foo");
auto ttv = get<TableType>(follow(parentTy));
auto ftv = get<FunctionType>(ttv->props.at("method").type);
ToStringOptions opts;
opts.hideFunctionSelfArgument = true;
TypeId parentTy = requireType("foo");
auto ttv = get<TableType>(follow(parentTy));
REQUIRE_MESSAGE(ttv, "Expected a table but got " << toString(parentTy, opts));
TypeId methodTy = follow(ttv->props.at("method").type);
auto ftv = get<FunctionType>(methodTy);
REQUIRE_MESSAGE(ftv, "Expected a function but got " << toString(methodTy, opts));
CHECK_EQ("foo:method<a>(arg: string): ()", toStringNamedFunction("foo:method", *ftv, opts));
}

View file

@ -855,16 +855,8 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni
type FutureIntersection = A & B
)");
if (FFlag::DebugLuauDeferredConstraintResolution)
{
// To be quite honest, I don't know exactly why DCR fixes this.
LUAU_REQUIRE_NO_ERRORS(result);
}
else
{
// TODO: shared self causes this test to break in bizarre ways.
LUAU_REQUIRE_ERRORS(result);
}
// TODO: shared self causes this test to break in bizarre ways.
LUAU_REQUIRE_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok")

View file

@ -660,7 +660,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4")
)");
LUAU_REQUIRE_NO_ERRORS(result);
dumpErrors(result);
/*
* mergesort takes two arguments: an array of some type T and a function that takes two Ts.
@ -1424,9 +1423,11 @@ end
TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite")
{
CheckResult result = check(R"(
function string.len(): number
return 1
end
function string.len(): number
return 1
end
local s = string
)");
LUAU_REQUIRE_NO_ERRORS(result);
@ -1434,11 +1435,11 @@ end
// if 'string' library property was replaced with an internal module type, it will be freed and the next check will crash
frontend.clear();
result = check(R"(
print(string.len('hello'))
CheckResult result2 = check(R"(
print(string.len('hello'))
)");
LUAU_REQUIRE_NO_ERRORS(result);
LUAU_REQUIRE_NO_ERRORS(result2);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite_2")

View file

@ -1404,6 +1404,40 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns")
}
}
TEST_CASE_FIXTURE(BuiltinsFixture, "refine_boolean")
{
CheckResult result = check(R"(
local function f(x: number | boolean)
if typeof(x) == "boolean" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("number", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "refine_thread")
{
CheckResult result = check(R"(
local function f(x: number | thread)
if typeof(x) == "thread" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("thread", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("number", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil")
{
CheckResult result = check(R"(

View file

@ -347,8 +347,8 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_3")
const TableType* arg0Table = get<TableType>(follow(arg0));
REQUIRE(arg0Table != nullptr);
REQUIRE(arg0Table->props.find("bar") != arg0Table->props.end());
REQUIRE(arg0Table->props.find("baz") != arg0Table->props.end());
CHECK(arg0Table->props.count("bar"));
CHECK(arg0Table->props.count("baz"));
}
TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1")
@ -2482,12 +2482,18 @@ TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer")
TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer")
{
CheckResult result = check("local a = {} a[0] = 7 a[0] = 't'");
CheckResult result = check(R"(
local a = {}
a[0] = 7
a[0] = 't'
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{
typeChecker.numberType,
typeChecker.stringType,
}}));
CHECK((Location{Position{3, 15}, Position{3, 18}}) == result.errors[0].location);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK(tm->wantedType == typeChecker.numberType);
CHECK(tm->givenType == typeChecker.stringType);
}
TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer")
@ -2673,7 +2679,10 @@ TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick")
)");
ModulePtr module = getMainModule();
CHECK_GE(100, module->internalTypes.types.size());
if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_GE(500, module->internalTypes.types.size());
else
CHECK_GE(100, module->internalTypes.types.size());
}
TEST_CASE_FIXTURE(Fixture, "MixedPropertiesAndIndexers")

View file

@ -1,5 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Common.h"
#include "Luau/Scope.h"
#include "Luau/Symbol.h"
#include "Luau/TypeInfer.h"
#include "Luau/Type.h"
@ -9,6 +11,8 @@
using namespace Luau;
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution)
struct TryUnifyFixture : Fixture
{
TypeArena arena;
@ -254,7 +258,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_41095_concat_log_in_sealed_table_unifica
LUAU_REQUIRE_ERROR_COUNT(2, result);
CHECK_EQ(toString(result.errors[0]), "No overload for function accepts 0 arguments.");
CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()");
if (FFlag::DebugLuauDeferredConstraintResolution)
CHECK_EQ(toString(result.errors[1]), "Available overloads: <V>({V}, V) -> (); and <V>({V}, number, V) -> ()");
else
CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()");
}
TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly")

View file

@ -230,7 +230,7 @@ TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never")
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never")
TEST_CASE_FIXTURE(Fixture, "for_loop_over_never")
{
CheckResult result = check(R"(
for i, v in (5 :: never) do

View file

@ -5,23 +5,16 @@ AstQuery.last_argument_function_call_type
AstQuery::getDocumentationSymbolAtPosition.overloaded_class_method
AstQuery::getDocumentationSymbolAtPosition.overloaded_fn
AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop
AutocompleteTest.autocomplete_first_function_arg_expected_type
AutocompleteTest.autocomplete_oop_implicit_self
AutocompleteTest.autocomplete_string_singleton_equality
AutocompleteTest.do_compatible_self_calls
AutocompleteTest.do_wrong_compatible_self_calls
AutocompleteTest.type_correct_expected_return_type_suggestion
AutocompleteTest.type_correct_suggestion_for_overloads
BuiltinTests.aliased_string_format
BuiltinTests.assert_removes_falsy_types
BuiltinTests.assert_removes_falsy_types2
BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type
BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy
BuiltinTests.bad_select_should_not_crash
BuiltinTests.coroutine_wrap_anything_goes
BuiltinTests.debug_info_is_crazy
BuiltinTests.debug_traceback_is_crazy
BuiltinTests.dont_add_definitions_to_persistent_types
BuiltinTests.find_capture_types3
BuiltinTests.gmatch_definition
BuiltinTests.match_capture_types
BuiltinTests.match_capture_types2
@ -34,16 +27,14 @@ BuiltinTests.sort_with_bad_predicate
BuiltinTests.string_format_as_method
BuiltinTests.string_format_correctly_ordered_types
BuiltinTests.string_format_report_all_type_errors_at_correct_positions
BuiltinTests.string_format_tostring_specifier_type_constraint
BuiltinTests.string_format_use_correct_argument2
BuiltinTests.table_freeze_is_generic
BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload
BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload
BuiltinTests.table_pack
BuiltinTests.table_pack_reduce
BuiltinTests.table_pack_variadic
DefinitionTests.class_definition_overload_metamethods
DefinitionTests.class_definition_string_props
DefinitionTests.definition_file_classes
FrontendTest.environments
FrontendTest.nocheck_cycle_used_by_checked
GenericsTests.apply_type_function_nested_generics2
@ -52,6 +43,7 @@ GenericsTests.bound_tables_do_not_clone_original_fields
GenericsTests.check_mutual_generic_functions
GenericsTests.correctly_instantiate_polymorphic_member_functions
GenericsTests.do_not_infer_generic_functions
GenericsTests.dont_unify_bound_types
GenericsTests.generic_argument_count_too_few
GenericsTests.generic_argument_count_too_many
GenericsTests.generic_functions_should_be_memory_safe
@ -62,16 +54,13 @@ GenericsTests.infer_generic_function_function_argument_3
GenericsTests.infer_generic_function_function_argument_overloaded
GenericsTests.infer_generic_lib_function_function_argument
GenericsTests.instantiated_function_argument_names
GenericsTests.instantiation_sharing_types
GenericsTests.no_stack_overflow_from_quantifying
GenericsTests.self_recursive_instantiated_param
IntersectionTypes.select_correct_union_fn
IntersectionTypes.should_still_pick_an_overload_whose_arguments_are_unions
IntersectionTypes.overload_is_not_a_function
IntersectionTypes.table_intersection_write_sealed
IntersectionTypes.table_intersection_write_sealed_indirect
IntersectionTypes.table_write_sealed_indirect
ModuleTests.clone_self_property
ModuleTests.deepClone_cyclic_table
NonstrictModeTests.for_in_iterator_variables_are_any
NonstrictModeTests.function_parameters_are_any
NonstrictModeTests.inconsistent_module_return_types_are_ok
@ -85,7 +74,6 @@ NonstrictModeTests.offer_a_hint_if_you_use_a_dot_instead_of_a_colon
NonstrictModeTests.parameters_having_type_any_are_optional
NonstrictModeTests.table_dot_insert_and_recursive_calls
NonstrictModeTests.table_props_are_any
Normalize.cyclic_table_normalizes_sensibly
ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal
ProvisionalTests.bail_early_if_unification_is_too_complicated
ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack
@ -93,31 +81,28 @@ ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean
ProvisionalTests.free_options_cannot_be_unified_together
ProvisionalTests.generic_type_leak_to_module_interface_variadic
ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns
ProvisionalTests.pcall_returns_at_least_two_value_but_function_returns_nothing
ProvisionalTests.setmetatable_constrains_free_type_into_free_table
ProvisionalTests.specialization_binds_with_prototypes_too_early
ProvisionalTests.table_insert_with_a_singleton_argument
ProvisionalTests.typeguard_inference_incomplete
ProvisionalTests.weirditer_should_not_loop_forever
RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string
RefinementTest.discriminate_tag
RefinementTest.discriminate_from_isa_of_x
RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil
RefinementTest.narrow_property_of_a_bounded_variable
RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true
RefinementTest.refine_a_param_that_got_resolved_during_constraint_solving_stage
RefinementTest.refine_param_of_type_folder_or_part_without_using_typeof
RefinementTest.refine_unknowns
RefinementTest.type_guard_can_filter_for_intersection_of_tables
RefinementTest.type_narrow_for_all_the_userdata
RefinementTest.type_narrow_to_vector
RefinementTest.typeguard_cast_free_table_to_vector
RefinementTest.typeguard_in_assert_position
RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table
RefinementTest.x_is_not_instance_or_else_not_part
RuntimeLimits.typescript_port_of_Result_type
TableTests.a_free_shape_can_turn_into_a_scalar_directly
TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible
TableTests.accidentally_checked_prop_in_opposite_branch
TableTests.any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode
TableTests.call_method
TableTests.casting_tables_with_props_into_table_with_indexer3
TableTests.casting_tables_with_props_into_table_with_indexer4
TableTests.checked_prop_too_early
@ -135,7 +120,6 @@ TableTests.explicitly_typed_table_with_indexer
TableTests.found_like_key_in_table_function_call
TableTests.found_like_key_in_table_property_access
TableTests.found_multiple_like_keys
TableTests.function_calls_produces_sealed_table_given_unsealed_table
TableTests.fuzz_table_unify_instantiated_table
TableTests.generic_table_instantiation_potential_regression
TableTests.give_up_after_one_metatable_index_look_up
@ -144,21 +128,16 @@ TableTests.indexing_from_a_table_should_prefer_properties_when_possible
TableTests.inequality_operators_imply_exactly_matching_types
TableTests.infer_array_2
TableTests.inferred_return_type_of_free_table
TableTests.inferring_crazy_table_should_also_be_quick
TableTests.instantiate_table_cloning_3
TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound
TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound
TableTests.leaking_bad_metatable_errors
TableTests.less_exponential_blowup_please
TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred
TableTests.mixed_tables_with_implicit_numbered_keys
TableTests.nil_assign_doesnt_hit_indexer
TableTests.nil_assign_doesnt_hit_no_indexer
TableTests.okay_to_add_property_to_unsealed_tables_by_function_call
TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr
TableTests.only_ascribe_synthetic_names_at_module_scope
TableTests.oop_indexer_works
TableTests.oop_polymorphic
TableTests.open_table_unification_2
TableTests.quantify_even_that_table_was_never_exported_at_all
TableTests.quantify_metatables_of_metatables_of_table
TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table
@ -169,32 +148,21 @@ TableTests.shared_selfs
TableTests.shared_selfs_from_free_param
TableTests.shared_selfs_through_metatables
TableTests.table_call_metamethod_basic
TableTests.table_indexing_error_location
TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict
TableTests.table_insert_should_cope_with_optional_properties_in_strict
TableTests.table_param_row_polymorphism_3
TableTests.table_simple_call
TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors
TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors
TableTests.table_unification_4
TableTests.unifying_tables_shouldnt_uaf2
TableTests.used_colon_instead_of_dot
TableTests.used_dot_instead_of_colon
ToString.exhaustive_toString_of_cyclic_table
TableTests.when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type
ToString.named_metatable_toStringNamedFunction
ToString.toStringDetailed2
ToString.toStringErrorPack
ToString.toStringNamedFunction_generic_pack
ToString.toStringNamedFunction_hide_self_param
ToString.toStringNamedFunction_include_self_param
ToString.toStringNamedFunction_map
TryUnifyTests.cli_41095_concat_log_in_sealed_table_unification
TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType
TryUnifyTests.result_of_failed_typepack_unification_is_constrained
TryUnifyTests.typepack_unification_should_trim_free_tails
TryUnifyTests.variadics_should_use_reversed_properly
TypeAliases.cannot_create_cyclic_type_with_unknown_module
TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any
TypeAliases.generic_param_remap
TypeAliases.mismatched_generic_type_param
TypeAliases.mutually_recursive_types_restriction_not_ok_1
@ -218,11 +186,9 @@ TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict
TypeInfer.no_stack_overflow_from_isoptional
TypeInfer.no_stack_overflow_from_isoptional2
TypeInfer.tc_after_error_recovery_no_replacement_name_in_error
TypeInfer.tc_if_else_expressions_expected_type_3
TypeInfer.type_infer_recursion_limit_no_ice
TypeInfer.type_infer_recursion_limit_normalizer
TypeInferAnyError.for_in_loop_iterator_is_any2
TypeInferClasses.can_read_prop_of_base_class_using_string
TypeInferClasses.class_type_mismatch_with_name_conflict
TypeInferClasses.classes_without_overloaded_operators_cannot_be_added
TypeInferClasses.higher_order_function_arguments_are_contravariant
@ -232,6 +198,7 @@ TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_propert
TypeInferClasses.warn_when_prop_almost_matches
TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types
TypeInferFunctions.cannot_hoist_interior_defns_into_signature
TypeInferFunctions.check_function_before_lambda_that_uses_it
TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists
TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site
TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict
@ -243,10 +210,7 @@ TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer
TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict
TypeInferFunctions.improved_function_arg_mismatch_errors
TypeInferFunctions.infer_anonymous_function_arguments
TypeInferFunctions.infer_return_type_from_selected_overload
TypeInferFunctions.infer_that_function_does_not_return_a_table
TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count
TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count
TypeInferFunctions.luau_subtyping_is_np_hard
TypeInferFunctions.no_lossy_function_type
TypeInferFunctions.occurs_check_failure_in_function_return_type
@ -273,13 +237,11 @@ TypeInferModules.do_not_modify_imported_types_5
TypeInferModules.module_type_conflict
TypeInferModules.module_type_conflict_instantiated
TypeInferModules.type_error_of_unknown_qualified_type
TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works
TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory
TypeInferOOP.methods_are_topologically_sorted
TypeInferOOP.object_constructor_can_refer_to_method_of_self
TypeInferOperators.CallAndOrOfFunctions
TypeInferOperators.CallOrOfFunctions
TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable
TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable
TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators
TypeInferOperators.cli_38355_recursive_union
@ -303,25 +265,21 @@ TypeInferUnknownNever.math_operators_and_never
TypePackTests.detect_cyclic_typepacks2
TypePackTests.pack_tail_unification_check
TypePackTests.type_alias_backwards_compatible
TypePackTests.type_alias_default_export
TypePackTests.type_alias_default_mixed_self
TypePackTests.type_alias_default_type_chained
TypePackTests.type_alias_default_type_errors
TypePackTests.type_alias_default_type_pack_self_chained_tp
TypePackTests.type_alias_default_type_pack_self_tp
TypePackTests.type_alias_default_type_self
TypePackTests.type_alias_defaults_confusing_types
TypePackTests.type_alias_defaults_recursive_type
TypePackTests.type_alias_type_pack_multi
TypePackTests.type_alias_type_pack_variadic
TypePackTests.type_alias_type_packs_errors
TypePackTests.type_alias_type_packs_nested
TypePackTests.unify_variadic_tails_in_arguments
TypePackTests.unify_variadic_tails_in_arguments_free
TypePackTests.variadic_packs
TypeSingletons.function_call_with_singletons
TypeSingletons.function_call_with_singletons_mismatch
TypeSingletons.indexing_on_union_of_string_singletons
TypeSingletons.no_widening_from_callsites
TypeSingletons.overloaded_function_call_with_singletons
TypeSingletons.overloaded_function_call_with_singletons_mismatch
TypeSingletons.return_type_of_f_is_not_widened